diff --git a/.buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml b/.buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml index d70ecb2a7e7b0789bf433bc174a1b14237248e16..d392a5f64062940446f3620a8e03ec01ee8a9e4f 100644 --- a/.buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml +++ b/.buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash ./run-lm-eval-gsm-vllm-baseline.sh -m deepseek-ai/DeepSeek-V2-Lite-Chat -b "auto" -l 1000 -f 5 -t 2 model_name: "deepseek-ai/DeepSeek-V2-Lite-Chat" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml index 4397effa82cc8c8359996d57c0e717d8253af0bc..4b7776b20da26ac6d667f65066165bec72620342 100644 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml @@ -1,3 +1,4 @@ +# For hf script, without -t option (tensor parallel size). # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform -b auto -l 1000 -f 5 model_name: "nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct.yaml index fa6ea236ef04ff0533a9bf548bbe7e514b7680fc..05b66175199ef765d4a51731cad0b2ed5b1f4c5f 100644 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct.yaml +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct.yaml @@ -1,3 +1,4 @@ +# For hf script, without -t option (tensor parallel size). # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m meta-llama/Meta-Llama-3-70B-Instruct -b 32 -l 250 -f 5 model_name: "meta-llama/Meta-Llama-3-70B-Instruct" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml index c513159c6fa0d06b3d655837a0c16811a8937d41..12a87e5290146aaf050da68cc22aad50ddcb979d 100644 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8A8-FP8-Channelwise-compressed-tensors -b auto -l 1000 -f 5 -t 1 model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-FP8-Channelwise-compressed-tensors" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform.yaml index 5e57fcbcf7d9b86b402c5a1416a483576ec51157..7c7a1ca6edbf23bb21fc17f270f09345033d9def 100644 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform.yaml +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform -b auto -l 1000 -f 5 -t 1 model_name: "nm-testing/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml index 374171f1f915bc781c4e8fd67822649c494c4ba1..1d45c3770458d60725b461fa0b87dd912e8a4aa1 100644 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test -b 32 -l 1000 -f 5 -t 1 model_name: "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8.yaml index dc36b705634f994378108ed357c8976cc814f2b0..29a145252ef63a5b42c27cef9c451fd91c6130f2 100644 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8.yaml +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Meta-Llama-3-8B-Instruct-FP8 -b 32 -l 250 -f 5 -t 1 model_name: "neuralmagic/Meta-Llama-3-8B-Instruct-FP8" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml index 0ecfc01ef049f7afc58249f5adf4c5b4091a192c..3a5f120b3e71d3885a7e0ddc7bf337aa7995f22c 100644 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Asym-Per-Token-Test -b "auto" -l 250 -f 5 -t 1 model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Asym-Per-Token-Test" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml index bc29002985969fbcb1da1c23b05f2a1ae05e6be3..5ff57bae4921b7bf5bf5e18d9b7ccc1bae3afd5a 100644 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test -b "auto" -l 250 -f 5 -t 1 model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml index 3964f3be5e87479b80bc15ed33cd5d5f6cbdfdd3..07fb130464ab81fefa6a5926777b7beeea3ef4d3 100644 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test -b auto -l 1000 -f 5 -t 1 model_name: "nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml index fb4b4915ab955b01df5ba17de78feeeceb9aac21..c27886525bbb186e23d506ccd46d8c9fb7e8c3e3 100644 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml @@ -1,4 +1,5 @@ -# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m meta-llama/Meta-Llama-3-8B-Instruct -b 32 -l 250 -f 5 -t 1 +# For hf script, without -t option (tensor parallel size). +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m meta-llama/Meta-Llama-3-8B-Instruct -b 32 -l 250 -f 5 model_name: "meta-llama/Meta-Llama-3-8B-Instruct" tasks: - name: "gsm8k" diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml index 042458659839198ee8415e1602c9375401d083fe..56ec933c9cc0e5e1fc8041db7f485fa272575d20 100644 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m HandH1998/QQQ-Llama-3-8b-g128 -b 32 -l 1000 -f 5 -t 1 model_name: "HandH1998/QQQ-Llama-3-8b-g128" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml index 78347f63fa793744a4422ee304c223378064cf94..83e11f2be77e83638069dceea862c1e1e5f4f677 100644 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8 -b "auto" -l 1000 -f 5 -t 1 model_name: "neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Minitron-4B-Base-FP8.yaml b/.buildkite/lm-eval-harness/configs/Minitron-4B-Base-FP8.yaml index 4ef8b5c3709b3911e69808681d46f4b3dcbd795f..15a836dddbd838e41df3aadf2c7ad99b9ebfc058 100644 --- a/.buildkite/lm-eval-harness/configs/Minitron-4B-Base-FP8.yaml +++ b/.buildkite/lm-eval-harness/configs/Minitron-4B-Base-FP8.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m mgoin/Minitron-4B-Base-FP8 -b auto -l 1000 -f 5 -t 1 model_name: "mgoin/Minitron-4B-Base-FP8" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Mixtral-8x22B-Instruct-v0.1-FP8-Dynamic.yaml b/.buildkite/lm-eval-harness/configs/Mixtral-8x22B-Instruct-v0.1-FP8-Dynamic.yaml index 75a24e408e7ad085074356d661ee5aa37090339e..5633a2d9b821e918badf2963a0f11ceef5fb7ff2 100644 --- a/.buildkite/lm-eval-harness/configs/Mixtral-8x22B-Instruct-v0.1-FP8-Dynamic.yaml +++ b/.buildkite/lm-eval-harness/configs/Mixtral-8x22B-Instruct-v0.1-FP8-Dynamic.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash ./run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Mixtral-8x22B-Instruct-v0.1-FP8-dynamic -b "auto" -l 250 -f 5 -t 8 model_name: "neuralmagic/Mixtral-8x22B-Instruct-v0.1-FP8-dynamic" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1-FP8.yaml b/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1-FP8.yaml index 436ec21924ca1c4e176239887362e5db434a1f9d..b8024c80e8ebd2a98aaef36fa74db5b392d72768 100644 --- a/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1-FP8.yaml +++ b/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1-FP8.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash ./run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8 -b "auto" -l 250 -f 5 -t 4 model_name: "neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1.yaml b/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1.yaml index dec9164d1b84e41f80b334489a743c25e0455939..188a112ca3a4af8ccc8f81a6551a5f283543015a 100644 --- a/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1.yaml +++ b/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1.yaml @@ -1,4 +1,5 @@ -# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m neuralmagic/Mixtral-8x7B-Instruct-v0.1 -b 32 -l 250 -f 5 -t 4 +# For hf script, without -t option (tensor parallel size). +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m neuralmagic/Mixtral-8x7B-Instruct-v0.1 -b 32 -l 250 -f 5 model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" tasks: - name: "gsm8k" diff --git a/.buildkite/lm-eval-harness/configs/Qwen1.5-MoE-W4A16-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Qwen1.5-MoE-W4A16-compressed-tensors.yaml index 166af81a3f0ee06a70a76f2451ffd5a9294f4a95..099e0f465baceccf5462bba752f3dc1944e127ad 100644 --- a/.buildkite/lm-eval-harness/configs/Qwen1.5-MoE-W4A16-compressed-tensors.yaml +++ b/.buildkite/lm-eval-harness/configs/Qwen1.5-MoE-W4A16-compressed-tensors.yaml @@ -1,11 +1,12 @@ +# For vllm script, with -t option (tensor parallel size). # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16 -b auto -l 1319 -f 5 -t 1 model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16" tasks: - name: "gsm8k" metrics: - name: "exact_match,strict-match" - value: 0.31 + value: 0.30 - name: "exact_match,flexible-extract" - value: 0.47 + value: 0.465 limit: 1319 num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-FP8W8.yaml b/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-FP8W8.yaml index 42936fbfbe7d48d44c34842f0a4791156a70e5c3..426e8ff698733ce8e2c92c07946a10b9edb35188 100644 --- a/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-FP8W8.yaml +++ b/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-FP8W8.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen2-1.5B-Instruct-FP8W8 -b auto -l 1000 -f 5 -t 1 model_name: "nm-testing/Qwen2-1.5B-Instruct-FP8W8" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml index 43ff2bc5ce35ead1277aa4b273216891ec5ea485..8d57e9dabd56683b77077537fcd894b04b0a186f 100644 --- a/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml +++ b/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Qwen2-1.5B-Instruct-quantized.w8a8 -b "auto" -l 1000 -f 5 -t 1 model_name: "neuralmagic/Qwen2-1.5B-Instruct-quantized.w8a8" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-W8A16-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-W8A16-compressed-tensors.yaml index 259799ba8bfa9fd1a2676bbc57c0568e2f145ec4..1bce7e7fdf146fe6e19f4e5191a0b310364535fa 100644 --- a/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-W8A16-compressed-tensors.yaml +++ b/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-W8A16-compressed-tensors.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen2-1.5B-Instruct-W8A16-Channelwise -b "auto" -l 1000 -f 5 -t 1 model_name: "nm-testing/Qwen2-1.5B-Instruct-W8A16-Channelwise" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct.yaml b/.buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct.yaml index 45d5efc8860f537460ae293d11a0089864bfa4f6..fc9707d0d6f13ad49c3c12770207c60a98d50cb3 100644 --- a/.buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct.yaml +++ b/.buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash ./run-lm-eval-gsm-vllm-baseline.sh -m Qwen/Qwen2-57B-A14B-Instruct -b "auto" -l 250 -f 5 -t 4 model_name: "Qwen/Qwen2-57B-A14B-Instruct" tasks: diff --git a/.buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml b/.buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml index 2928d75ce4469a0b7fa68ecac828847470e621e4..9a9c749748ecb28071bc253d58b7878cbf942175 100644 --- a/.buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml +++ b/.buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash ./run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM -b "auto" -t 2 model_name: "nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM" tasks: diff --git a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py index 4ae23eff62f37eb1cf8c4260bfdf734cd0d707c6..6015a83e829504b0e9a9c4c41f96f7c48a747034 100644 --- a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py +++ b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py @@ -16,7 +16,7 @@ import numpy import pytest import yaml -RTOL = 0.05 +RTOL = 0.08 TEST_DATA_FILE = os.environ.get( "LM_EVAL_TEST_DATA_FILE", ".buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml") diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 3354ea37002b9865b47f9c551c1ef2375fa6f1d2..a21a657c4b05e742f1a60b9b00cbccf175154fca 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -86,3 +86,18 @@ steps: - "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)" env: DOCKER_BUILDKIT: "1" + + - block: "Build Neuron release image" + key: block-neuron-release-image-build + depends_on: ~ + + - label: "Build and publish Neuron release image" + depends_on: block-neuron-release-image-build + agents: + queue: neuron-postmerge + commands: + - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:latest --progress plain -f docker/Dockerfile.neuron ." + - "docker push public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version)" + env: + DOCKER_BUILDKIT: "1" diff --git a/.buildkite/scripts/hardware_ci/run-amd-test.sh b/.buildkite/scripts/hardware_ci/run-amd-test.sh index 469422ddec20a32efb959c762bcf8f8592f8a7fe..368f30434aa1d3c29029db2b444e8a27a2c4bfc4 100755 --- a/.buildkite/scripts/hardware_ci/run-amd-test.sh +++ b/.buildkite/scripts/hardware_ci/run-amd-test.sh @@ -98,6 +98,13 @@ if [[ $commands == *" kernels "* ]]; then --ignore=kernels/test_machete_mm.py \ --ignore=kernels/test_mha_attn.py \ --ignore=kernels/test_block_fp8.py \ + --ignore=kernels/test_cutlass_moe.py \ + --ignore=kernels/test_mamba_ssm_ssd.py \ + --ignore=kernels/test_attention.py \ + --ignore=kernels/test_block_int8.py \ + --ignore=kernels/test_fused_quant_layernorm.py \ + --ignore=kernels/test_int8_kernel.py \ + --ignore=kernels/test_triton_moe_ptpc_fp8.py \ --ignore=kernels/test_permute_cols.py" fi diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh b/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh index 9c5cf7cad9489b55ded5563371789e174035bf6d..5d863dd82e9b88276c341a763de359bdd90ec055 100755 --- a/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh @@ -5,10 +5,41 @@ set -ex # Setup cleanup -remove_docker_container() { docker rm -f cpu-test || true; docker system prune -f; } +remove_docker_container() { + if [[ -n "$container_id" ]]; then + podman rm -f "$container_id" || true + fi + podman system prune -f +} trap remove_docker_container EXIT remove_docker_container # Try building the docker image -docker build -t cpu-test -f docker/Dockerfile.ppc64le . +podman build -t cpu-test-ubi9-ppc -f docker/Dockerfile.ppc64le . + +# Run the image +container_id=$(podman run -itd --entrypoint /bin/bash -v /tmp/:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN cpu-test-ubi9-ppc) + +function cpu_tests() { + + # offline inference + podman exec -it "$container_id" bash -c " + set -e + python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m" + + # Run basic model test + podman exec -it "$container_id" bash -c " + set -e + pip install pytest pytest-asyncio einops peft Pillow soundfile transformers_stream_generator matplotlib + pip install sentence-transformers datamodel_code_generator + pytest -v -s tests/models/embedding/language/test_cls_models.py::test_classification_models[float-jason9693/Qwen2.5-1.5B-apeach] + pytest -v -s tests/models/embedding/language/test_embedding.py::test_models[half-BAAI/bge-base-en-v1.5] + pytest -v -s tests/models/encoder_decoder/language -m cpu_model" +} + +# All of CPU tests are expected to be finished less than 40 mins. + +export container_id +export -f cpu_tests +timeout 40m bash -c cpu_tests diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test-s390x.sh b/.buildkite/scripts/hardware_ci/run-cpu-test-s390x.sh new file mode 100755 index 0000000000000000000000000000000000000000..a97fa502e6cfcca084130166ece09d9f3aed0334 --- /dev/null +++ b/.buildkite/scripts/hardware_ci/run-cpu-test-s390x.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +# This script build the CPU docker image and run the offline inference inside the container. +# It serves a sanity check for compilation and basic model usage. +set -ex + +# Setup cleanup +remove_docker_container() { docker rm -f cpu-test || true; docker system prune -f; } +trap remove_docker_container EXIT +remove_docker_container + +# Try building the docker image +docker build -t cpu-test -f docker/Dockerfile.s390x . diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh index 87f74277cf90082e8dd891bcd41fc6c106490e4b..21982b01b9cc7783f9c40312e46e3c4162eea71d 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh @@ -17,10 +17,13 @@ source /etc/environment docker run --privileged --net host --shm-size=16G -it \ -e "HF_TOKEN=$HF_TOKEN" --name tpu-test \ vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \ - && python3 -m pip install pytest \ + && python3 -m pip install pytest pytest-asyncio tpu-info \ && python3 -m pip install lm_eval[api]==0.4.4 \ + && export VLLM_XLA_CACHE_PATH= \ && export VLLM_USE_V1=1 \ && export VLLM_XLA_CHECK_RECOMPILATION=1 \ + && echo HARDWARE \ + && tpu-info \ && echo TEST_0 \ && pytest -v -s /workspace/vllm/tests/v1/tpu/test_perf.py \ && echo TEST_1 \ @@ -40,7 +43,11 @@ docker run --privileged --net host --shm-size=16G -it \ && echo TEST_8 \ && pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py \ && echo TEST_9 \ - && pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py" \ + && pytest -s -v /workspace/vllm/tests/v1/tpu/test_multimodal.py \ + && echo TEST_10 \ + && pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py \ + && echo TEST_11 \ + && pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py" \ # TODO: This test fails because it uses RANDOM_SEED sampling diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 16acc2fd1127a1ad2ac7af242d98067e431785b0..20d858cb15a1169c05e1cf034f68960a40e5e491 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -8,6 +8,7 @@ # Documentation # label(str): the name of the test. emoji allowed. # fast_check(bool): whether to run this on each commit on fastcheck pipeline. +# torch_nightly(bool): whether to run this on vllm against torch nightly pipeline. # fast_check_only(bool): run this test on fastcheck pipeline only # optional(bool): never run this test by default (i.e. need to unblock manually) unless it's scheduled nightly run. # command(str): the single command to run for tests. incompatible with commands. @@ -70,6 +71,7 @@ steps: - label: Basic Correctness Test # 30min #mirror_hardwares: [amd] fast_check: true + torch_nightly: true source_file_dependencies: - vllm/ - tests/basic_correctness/test_basic_correctness @@ -104,6 +106,7 @@ steps: - label: Entrypoints Test # 40min working_dir: "/vllm-workspace/tests" fast_check: true + torch_nightly: true #mirror_hardwares: [amd] source_file_dependencies: - vllm/ @@ -118,7 +121,7 @@ steps: - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process - pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process - VLLM_USE_V1=0 pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process - - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/correctness/ + - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_openai_schema.py - pytest -v -s entrypoints/test_chat_utils.py - VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests @@ -205,6 +208,8 @@ steps: - pytest -v -s v1/sample - pytest -v -s v1/worker - pytest -v -s v1/structured_output + - pytest -v -s v1/spec_decode + - pytest -v -s v1/test_serial_utils.py - pytest -v -s v1/test_stats.py - pytest -v -s v1/test_utils.py - pytest -v -s v1/test_oracle.py @@ -294,6 +299,7 @@ steps: commands: - pytest -v -s compile/test_pass_manager.py - pytest -v -s compile/test_fusion.py + - pytest -v -s compile/test_sequence_parallelism.py - label: PyTorch Fullgraph Smoke Test # 9min source_file_dependencies: @@ -312,15 +318,46 @@ steps: commands: - pytest -v -s compile/test_full_graph.py -- label: Kernels Test %N # 1h each - # mirror_hardwares: [amd] +- label: Kernels Core Operation Test source_file_dependencies: - csrc/ + - tests/kernels/core + commands: + - pytest -v -s kernels/core + +- label: Kernels Attention Test %N + source_file_dependencies: + - csrc/attention/ - vllm/attention - - tests/kernels + - vllm/v1/attention + - tests/kernels/attention commands: - - pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT - parallelism: 4 + - pytest -v -s kernels/attention --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT + parallelism: 2 + +- label: Kernels Quantization Test %N + source_file_dependencies: + - csrc/quantization/ + - vllm/model_executor/layers/quantization + - tests/kernels/quantization + commands: + - pytest -v -s kernels/quantization --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT + parallelism: 2 + +- label: Kernels MoE Test + source_file_dependencies: + - csrc/moe/ + - tests/kernels/moe + - vllm/model_executor/layers/fused_moe/ + commands: + - pytest -v -s kernels/moe + +- label: Kernels Mamba Test + source_file_dependencies: + - csrc/mamba/ + - tests/kernels/mamba + commands: + - pytest -v -s kernels/mamba - label: Tensorizer Test # 11min # mirror_hardwares: [amd] @@ -341,6 +378,13 @@ steps: commands: - bash scripts/run-benchmarks.sh +- label: Benchmarks CLI Test # 10min + source_file_dependencies: + - vllm/ + - tests/benchmarks/ + commands: + - pytest -v -s benchmarks/ + - label: Quantization Test # 33min source_file_dependencies: - csrc/ @@ -393,8 +437,9 @@ steps: - pytest -v -s models/test_transformers.py - pytest -v -s models/test_registry.py # V1 Test: https://github.com/vllm-project/vllm/issues/14531 - - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4' + - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2' - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'llama4' + - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'plamo2' - label: Language Models Test (Standard) # 32min #mirror_hardwares: [amd] @@ -404,6 +449,8 @@ steps: - tests/models/embedding/language - tests/models/encoder_decoder/language commands: + # Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile. + - pip install causal-conv1d - pytest -v -s models/decoder_only/language -m 'core_model or quant_model' - pytest -v -s models/embedding/language -m core_model @@ -415,6 +462,8 @@ steps: - tests/models/embedding/language - tests/models/encoder_decoder/language commands: + # Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile. + - pip install causal-conv1d - pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model' - pytest -v -s models/embedding/language -m 'not core_model' @@ -535,11 +584,14 @@ steps: - pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)' - pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)' - pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)' + # test sequence parallel + - pytest -v -s distributed/test_sequence_parallel.py # this test fails consistently. # TODO: investigate and fix # - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py - VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py + - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown - label: Plugin Tests (2 GPUs) # 40min working_dir: "/vllm-workspace/tests" diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 860c5c6cd53744f1de3c0c73983b91cf94f30fa8..76aa5f7a35d5aeb435d1eab9d972ea51c1a6d80d 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -12,6 +12,7 @@ /vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth /vllm/model_executor/guided_decoding @mgoin @russellb /vllm/multimodal @DarkLight1337 @ywang96 +/vllm/vllm_flash_attn @LucasWilkinson CMakeLists.txt @tlrmchlsmth # vLLM V1 diff --git a/.github/ISSUE_TEMPLATE/200-installation.yml b/.github/ISSUE_TEMPLATE/200-installation.yml index 590e56c137813059a695e501833138e0a85a947e..34da4019687b227b819312f896c632b6c2d0d7a3 100644 --- a/.github/ISSUE_TEMPLATE/200-installation.yml +++ b/.github/ISSUE_TEMPLATE/200-installation.yml @@ -14,7 +14,7 @@ body: description: | Please run the following and paste the output below. ```sh - wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py + wget https://raw.githubusercontent.com/vllm-project/vllm/main/vllm/collect_env.py # For security purposes, please feel free to check the contents of collect_env.py before running it. python collect_env.py ``` diff --git a/.github/ISSUE_TEMPLATE/300-usage.yml b/.github/ISSUE_TEMPLATE/300-usage.yml index 004798a388a63d949c8ef0e73194e53c3c4c0e39..c9e4be0e7719febacbe0f4b328351b18b289c52a 100644 --- a/.github/ISSUE_TEMPLATE/300-usage.yml +++ b/.github/ISSUE_TEMPLATE/300-usage.yml @@ -14,7 +14,7 @@ body: description: | Please run the following and paste the output below. ```sh - wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py + wget https://raw.githubusercontent.com/vllm-project/vllm/main/vllm/collect_env.py # For security purposes, please feel free to check the contents of collect_env.py before running it. python collect_env.py ``` diff --git a/.github/ISSUE_TEMPLATE/400-bug-report.yml b/.github/ISSUE_TEMPLATE/400-bug-report.yml index d4113da8b5b8168f089dd8288a4c45cf299a6d01..b96ab40749003a9ea6fbcf112567028f9594bfdf 100644 --- a/.github/ISSUE_TEMPLATE/400-bug-report.yml +++ b/.github/ISSUE_TEMPLATE/400-bug-report.yml @@ -14,7 +14,7 @@ body: description: | Please run the following and paste the output below. ```sh - wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py + wget https://raw.githubusercontent.com/vllm-project/vllm/main/vllm/collect_env.py # For security purposes, please feel free to check the contents of collect_env.py before running it. python collect_env.py ``` diff --git a/.github/ISSUE_TEMPLATE/700-performance-discussion.yml b/.github/ISSUE_TEMPLATE/700-performance-discussion.yml index 273f50d59cf76a61fa96e02478885f27d7d0f40b..3d31c11550167211cbd6a21b82f262cdb5f8129f 100644 --- a/.github/ISSUE_TEMPLATE/700-performance-discussion.yml +++ b/.github/ISSUE_TEMPLATE/700-performance-discussion.yml @@ -35,7 +35,7 @@ body: description: | Please run the following and paste the output below. ```sh - wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py + wget https://raw.githubusercontent.com/vllm-project/vllm/main/vllm/collect_env.py # For security purposes, please feel free to check the contents of collect_env.py before running it. python collect_env.py ``` diff --git a/.github/mergify.yml b/.github/mergify.yml index 3097b994659ab2ebc8ea3db5f39dc1e10044befe..15fa3660a87df9154454e47e293a9b0b3b741f22 100644 --- a/.github/mergify.yml +++ b/.github/mergify.yml @@ -55,11 +55,19 @@ pull_request_rules: description: Automatically apply structured-output label conditions: - or: + - files~=^benchmarks/structured_schemas/ + - files=benchmarks/benchmark_serving_structured_output.py + - files=benchmarks/run_structured_output_benchmark.sh + - files=docs/source/features/structured_outputs.md + - files=examples/offline_inference/structured_outputs.py + - files=examples/online_serving/openai_chat_completion_structured_outputs.py + - files=examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py - files~=^vllm/model_executor/guided_decoding/ - files=tests/model_executor/test_guided_processors.py - files=tests/entrypoints/llm/test_guided_generate.py - - files=benchmarks/benchmark_serving_guided.py - - files=benchmarks/benchmark_guided.py + - files~=^tests/v1/structured_output/ + - files=tests/v1/entrypoints/llm/test_guided_generate.py + - files~=^vllm/v1/structured_output/ actions: label: add: @@ -118,6 +126,28 @@ pull_request_rules: remove: - tpu +- name: label-tool-calling + description: Automatically add tool-calling label + conditions: + - or: + - files~=^tests/tool_use/ + - files~=^tests/mistral_tool_use/ + - files~=^tests/entrypoints/openai/tool_parsers/ + - files=tests/entrypoints/openai/test_chat_with_tool_reasoning.py + - files~=^vllm/entrypoints/openai/tool_parsers/ + - files=docs/source/features/tool_calling.md + - files=docs/source/getting_started/examples/openai_chat_completion_client_with_tools.md + - files=docs/source/getting_started/examples/chat_with_tools.md + - files~=^examples/tool_chat_* + - files=examples/offline_inference/chat_with_tools.py + - files=examples/online_serving/openai_chat_completion_client_with_tools_required.py + - files=examples/online_serving/openai_chat_completion_tool_calls_with_reasoning.py + - files=examples/online_serving/openai_chat_completion_client_with_tools.py + actions: + label: + add: + - tool-calling + - name: ping author on conflicts and add 'needs-rebase' label conditions: - conflict diff --git a/.gitignore b/.gitignore index 6f5cbd0733da04ed1d6137892a0c632f8341194b..728213ceb74f050cf63f4a0437f0f50658568238 100644 --- a/.gitignore +++ b/.gitignore @@ -3,7 +3,6 @@ # vllm-flash-attn built from source vllm/vllm_flash_attn/* -!vllm/vllm_flash_attn/fa_utils.py # Byte-compiled / optimized / DLL files __pycache__/ @@ -203,3 +202,6 @@ benchmarks/**/*.json # Linting actionlint shellcheck*/ + +# Ingore moe/marlin_moe gen code +csrc/moe/marlin_moe_wna16/kernel_* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e921f69925b663d18906393bf1fb4fe92f8ab5a1..f76b24c025ffb9da08b181ecec0a7188129ba599 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,6 @@ repos: hooks: - id: yapf args: [--in-place, --verbose] - additional_dependencies: [toml] # TODO: Remove when yapf is upgraded - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.9.3 hooks: diff --git a/CMakeLists.txt b/CMakeLists.txt index b64659d604a30c0ad7448f85c4a57720cde52e27..9590f9d1fc23204b832c5b5999ff617d8616a829 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -264,7 +264,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case. # Please keep this in sync with FetchContent_Declare line below. - set(CUTLASS_REVISION "v3.8.0" CACHE STRING "CUTLASS revision to use") + set(CUTLASS_REVISION "v3.9.0" CACHE STRING "CUTLASS revision to use") # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) @@ -282,7 +282,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git # Please keep this in sync with CUTLASS_REVISION line above. - GIT_TAG v3.8.0 + GIT_TAG v3.9.0 GIT_PROGRESS TRUE # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. @@ -303,7 +303,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/fp4/nvfp4_quant_entry.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu" "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" - "csrc/cutlass_extensions/common.cpp") + "csrc/cutlass_extensions/common.cpp" + "csrc/attention/mla/cutlass_mla_entry.cu") set_gencode_flags_for_srcs( SRCS "${VLLM_EXT_SRC}" @@ -476,7 +477,26 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set(FP4_ARCHS) endif() - # + # CUTLASS MLA Archs and flags + cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND MLA_ARCHS) + set(SRCS + "csrc/attention/mla/cutlass_mla_kernels.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${MLA_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MLA=1") + # Add MLA-specific include directories only to MLA source files + set_source_files_properties(${SRCS} + PROPERTIES INCLUDE_DIRECTORIES "${CUTLASS_DIR}/examples/77_blackwell_fmha;${CUTLASS_DIR}/examples/common") + message(STATUS "Building CUTLASS MLA for archs: ${MLA_ARCHS}") + else() + message(STATUS "Not building CUTLASS MLA as no compatible archs were found.") + # clear MLA_ARCHS + set(MLA_ARCHS) + endif() + # CUTLASS MoE kernels # The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works @@ -622,21 +642,51 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}") cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}") if (MARLIN_MOE_ARCHS) - set(MARLIN_MOE_SRC - "csrc/moe/marlin_kernels/marlin_moe_kernel.h" - "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h" - "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu" - "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h" - "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu" - "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h" - "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu" - "csrc/moe/marlin_moe_ops.cu") + # + # For the Marlin MOE kernels we automatically generate sources for various + # preselected input type pairs and schedules. + # Generate sources: + set(MOE_MARLIN_GEN_SCRIPT + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/moe/marlin_moe_wna16/generate_kernels.py) + file(MD5 ${MOE_MARLIN_GEN_SCRIPT} MOE_MARLIN_GEN_SCRIPT_HASH) + + message(STATUS "Marlin MOE generation script hash: ${MOE_MARLIN_GEN_SCRIPT_HASH}") + message(STATUS "Last run Marlin MOE generate script hash: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH}") + + if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} + OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH}) + execute_process( + COMMAND ${CMAKE_COMMAND} -E env + PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH + ${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT} + RESULT_VARIABLE moe_marlin_generation_result + OUTPUT_VARIABLE moe_marlin_generation_output + OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log + ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log + ) + + if (NOT moe_marlin_generation_result EQUAL 0) + message(FATAL_ERROR "Marlin MOE generation failed." + " Result: \"${moe_marlin_generation_result}\"" + "\nCheck the log for details: " + "${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log") + else() + set(MOE_MARLIN_GEN_SCRIPT_HASH ${MOE_MARLIN_GEN_SCRIPT_HASH} + CACHE STRING "Last run Marlin MOE generate script hash" FORCE) + message(STATUS "Marlin MOE generation completed successfully.") + endif() + else() + message(STATUS "Marlin MOE generation script has not changed, skipping generation.") + endif() + + file(GLOB MOE_WNAA16_MARLIN_SRC "csrc/moe/marlin_moe_wna16/*.cu") set_gencode_flags_for_srcs( - SRCS "${MARLIN_MOE_SRC}" + SRCS "${MOE_WNAA16_MARLIN_SRC}" CUDA_ARCHS "${MARLIN_MOE_ARCHS}") - list(APPEND VLLM_MOE_EXT_SRC "${MARLIN_MOE_SRC}") + list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC}) + message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}") else() message(STATUS "Not building Marlin MOE kernels as no compatible archs found" @@ -662,6 +712,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP") # set(VLLM_ROCM_EXT_SRC "csrc/rocm/torch_bindings.cpp" + "csrc/rocm/skinny_gemms.cu" "csrc/rocm/attention.cu") define_gpu_extension_target( diff --git a/README.md b/README.md index 59e1ca2dff01ca88948ab5d82efcf2616d1b40f2..72b9d8e1411307fb7d2ee65917ba22cde8ddc97e 100644 --- a/README.md +++ b/README.md @@ -88,7 +88,7 @@ VLLM_INSTALL_PUNICA_KERNELS=1 python3 setup.py install (若调试,可使用V + 若使用 pip install 下载安装过慢,可添加源:-i https://pypi.tuna.tsinghua.edu.cn/simple/ ## 验证 -- python -c "import vllm; print(vllm.\_\_version__)",版本号与官方版本同步,查询该软件的版本号,例如0.8.4; +- python -c "import vllm; print(vllm.\_\_version__)",版本号与官方版本同步,查询该软件的版本号,例如0.8.5; ## Known Issue - 无 diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 287d500a81de2647b0059e75ea7f6f371b7162b7..efd51c79c37cfff04b91e37f5f123b2a8c489e84 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import io import json import os import sys @@ -32,6 +33,7 @@ class RequestFuncInput: extra_body: Optional[dict] = None multi_modal_content: Optional[dict] = None ignore_eos: bool = False + language: Optional[str] = None @dataclass @@ -436,6 +438,110 @@ async def async_request_openai_chat_completions( return output +async def async_request_openai_audio( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + # Lazy import without PlaceholderModule to avoid vllm dep. + import soundfile + api_url = request_func_input.api_url + assert api_url.endswith( + ("transcriptions", "translations" + )), "OpenAI Chat Completions API URL must end with 'transcriptions' " + "or `translations`." + + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: + content = [{"type": "text", "text": request_func_input.prompt}] + payload = { + "model": request_func_input.model_name \ + if request_func_input.model_name else request_func_input.model, + "temperature": 0.0, + "max_completion_tokens": request_func_input.output_len, + "stream": True, + "language": "en", + # Flattened due to multipart/form-data + "stream_include_usage": True, + "stream_continuous_usage_stats": True + } + if request_func_input.extra_body: + payload.update(request_func_input.extra_body) + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + + # Send audio file + def to_bytes(y, sr): + buffer = io.BytesIO() + soundfile.write(buffer, y, sr, format="WAV") + buffer.seek(0) + return buffer + + with to_bytes(*request_func_input.multi_modal_content['audio']) as f: + form = aiohttp.FormData() + form.add_field('file', f, content_type='audio/wav') + for key, value in payload.items(): + form.add_field(key, str(value)) + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, + data=form, + headers=headers) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = chunk_bytes.decode("utf-8").removeprefix( + "data: ") + if chunk != "[DONE]": + timestamp = time.perf_counter() + data = json.loads(chunk) + + if choices := data.get("choices"): + content = choices[0]["delta"].get( + "content") + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append( + timestamp - most_recent_timestamp) + + generated_text += content or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens") + + most_recent_timestamp = timestamp + + output.generated_text = generated_text + output.success = True + output.latency = most_recent_timestamp - st + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + def get_model(pretrained_model_name_or_path: str) -> str: if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true': from modelscope import snapshot_download @@ -493,6 +599,7 @@ ASYNC_REQUEST_FUNCS = { "deepspeed-mii": async_request_deepspeed_mii, "openai": async_request_openai_completions, "openai-chat": async_request_openai_chat_completions, + "openai-audio": async_request_openai_audio, "tensorrt-llm": async_request_trt_llm, "scalellm": async_request_openai_completions, "sglang": async_request_openai_completions, diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index 63f174275d47bbc99707c057510001905a5fc9c7..ccbc6c022f1f935576810450d914e7848966be41 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -64,6 +64,7 @@ class SampleRequest: class BenchmarkDataset(ABC): DEFAULT_SEED = 0 + IS_MULTIMODAL = False def __init__( self, @@ -621,6 +622,7 @@ class ConversationDataset(HuggingFaceDataset): SUPPORTED_DATASET_PATHS = { 'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered' } + IS_MULTIMODAL = True def sample(self, tokenizer: PreTrainedTokenizerBase, @@ -685,6 +687,7 @@ class VisionArenaDataset(HuggingFaceDataset): "lmarena-ai/vision-arena-bench-v0.1": lambda x: x["turns"][0][0]["content"] } + IS_MULTIMODAL = True def sample( self, @@ -815,3 +818,80 @@ class AIMODataset(HuggingFaceDataset): )) self.maybe_oversample_requests(sampled_requests, num_requests) return sampled_requests + + +# ----------------------------------------------------------------------------- +# ASR Dataset Implementation +# ----------------------------------------------------------------------------- + + +class ASRDataset(HuggingFaceDataset): + """ + Dataset class for processing a ASR dataset for transcription. + Tested on the following set: + + +----------------+----------------------------------------+--------------------------+-----------------------------+ + | Dataset | Domain | Speaking Style | hf-subset | + +----------------+----------------------------------------+--------------------------+-----------------------------+ + | TED-LIUM | TED talks | Oratory | release1, release2, release3| + | | | | release3-speaker-adaptation | + | VoxPopuli | European Parliament | Oratory | en, de, it, fr, ... | + | LibriSpeech | Audiobook | Narrated | "LIUM/tedlium" | + | GigaSpeech | Audiobook, podcast, YouTube | Narrated, spontaneous | xs, s, m, l, xl, dev, test | + | SPGISpeech | Financial meetings | Oratory, spontaneous | S, M, L, dev, test | + | AMI | Meetings | Spontaneous | ihm, sdm | + +----------------+----------------------------------------+--------------------------+-----------------------------+ + + """ # noqa: E501 + SUPPORTED_DATASET_PATHS = { + "openslr/librispeech_asr", "facebook/voxpopuli", "LIUM/tedlium", + "edinburghcstr/ami", "speechcolab/gigaspeech", "kensho/spgispeech" + } + + DEFAULT_OUTPUT_LEN = 128 + IS_MULTIMODAL = True + + # TODO Whisper-specific. Abstract interface when more models are supported. + TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|>"\ + "<|notimestamps|>" + skip_long_audios: bool = True + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + **kwargs, + ) -> list: + import librosa + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + prompt = ASRDataset.TRANSCRIPTION_PREAMBLE + prompt_len = len(tokenizer(prompt).input_ids) + sampled_requests = [] + skipped = 0 + for item in self.data: + if len(sampled_requests) >= num_requests: + break + audio = item["audio"] + y, sr = audio["array"], audio["sampling_rate"] + duration_s = librosa.get_duration(y=y, sr=sr) + # Whisper max supported duration + if self.skip_long_audios and duration_s > 30: + skipped += 1 + continue + + mm_content = {"audio": (y, sr)} + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=mm_content, + )) + if skipped: + logger.warning("%d samples discarded from dataset due to" \ + " their length being greater than" \ + " what Whisper supports.", skipped) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index 477bcaa43290280ffe14823c9ed7f00c66eb7cce..19f36941aa421e84f4e0572645eeaff5ad5aa20e 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -78,14 +78,16 @@ class Request: output_len: int -def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> str: +def sample_tokens(tokenizer: PreTrainedTokenizerBase, + length: int) -> list[int]: vocab = tokenizer.get_vocab() + all_special_ids = set(tokenizer.all_special_ids) + # Remove the special tokens. - vocab = { - k: v - for k, v in vocab.items() if k not in tokenizer.all_special_ids - } - return random.choices(list(vocab.values()), k=length) + return random.choices( + [v for k, v in vocab.items() if k not in all_special_ids], + k=length, + ) def sample_requests_from_dataset( diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index b5bd840d8410db765e7f4e9ae8a122dc5a660219..da124e1a81b487c79b26f4746aada51200c13a8e 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -50,7 +50,7 @@ try: except ImportError: from argparse import ArgumentParser as FlexibleArgumentParser -from benchmark_dataset import (AIMODataset, BurstGPTDataset, +from benchmark_dataset import (AIMODataset, ASRDataset, BurstGPTDataset, ConversationDataset, HuggingFaceDataset, InstructCoderDataset, RandomDataset, SampleRequest, ShareGPTDataset, SonnetDataset, @@ -274,10 +274,6 @@ async def benchmark( input_requests[0].expected_output_len, \ input_requests[0].multi_modal_data - if backend != "openai-chat" and test_mm_content is not None: - # multi-modal benchmark is only available on OpenAI Chat backend. - raise ValueError( - "Multi-modal content is only supported on 'openai-chat' backend.") assert test_mm_content is None or isinstance(test_mm_content, dict) test_input = RequestFuncInput( model=model_id, @@ -604,6 +600,9 @@ def main(args: argparse.Namespace): elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: dataset_class = AIMODataset args.hf_split = "train" + elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS: + dataset_class = ASRDataset + args.hf_split = "train" else: supported_datasets = set([ dataset_name for cls in HuggingFaceDataset.__subclasses__() @@ -615,6 +614,13 @@ def main(args: argparse.Namespace): f" from one of following: {supported_datasets}. " "Please consider contributing if you would " "like to add support for additional dataset formats.") + + if (dataset_class.IS_MULTIMODAL and backend not in \ + ["openai-chat", "openai-audio"]): + # multi-modal benchmark is only available on OpenAI Chat backend. + raise ValueError( + "Multi-modal content is only supported on 'openai-chat' and " \ + "'openai-audio' backend.") input_requests = dataset_class( dataset_path=args.dataset_path, dataset_subset=args.hf_subset, @@ -707,7 +713,7 @@ def main(args: argparse.Namespace): )) # Save config and results to json - if args.save_result: + if args.save_result or args.append_result: result_json: dict[str, Any] = {} # Setup @@ -728,6 +734,14 @@ def main(args: argparse.Namespace): raise ValueError( "Invalid metadata format. Please use KEY=VALUE format." ) + # Traffic + result_json["request_rate"] = (args.request_rate if args.request_rate + < float("inf") else "inf") + result_json["burstiness"] = args.burstiness + result_json["max_concurrency"] = args.max_concurrency + + # Merge with benchmark result + result_json = {**result_json, **benchmark_result} if not args.save_detailed: # Remove fields with too many data points @@ -738,15 +752,6 @@ def main(args: argparse.Namespace): if field in result_json: del result_json[field] - # Traffic - result_json["request_rate"] = (args.request_rate if args.request_rate - < float("inf") else "inf") - result_json["burstiness"] = args.burstiness - result_json["max_concurrency"] = args.max_concurrency - - # Merge with benchmark result - result_json = {**result_json, **benchmark_result} - # Save to file base_model_id = model_id.split("/")[-1] max_concurrency_str = (f"-concurrency{args.max_concurrency}" @@ -756,7 +761,12 @@ def main(args: argparse.Namespace): file_name = args.result_filename if args.result_dir: file_name = os.path.join(args.result_dir, file_name) - with open(file_name, "w", encoding='utf-8') as outfile: + with open(file_name, + mode="a+" if args.append_result else "w", + encoding='utf-8') as outfile: + # Append a newline. + if args.append_result and outfile.tell() != 0: + outfile.write("\n") json.dump(result_json, outfile) save_to_pytorch_benchmark_format(args, result_json, file_name) @@ -888,6 +898,11 @@ if __name__ == "__main__": help="When saving the results, whether to include per request " "information such as response, error, ttfs, tpots, etc.", ) + parser.add_argument( + "--append-result", + action="store_true", + help="Append the benchmark result to the existing json file.", + ) parser.add_argument( "--metadata", metavar="KEY=VALUE", diff --git a/benchmarks/benchmark_serving_structured_output.py b/benchmarks/benchmark_serving_structured_output.py index e52f16a8b12994f25f059cd802b9dbed99970949..74ee00ec893076f52f8428a10f7ced66189e712d 100644 --- a/benchmarks/benchmark_serving_structured_output.py +++ b/benchmarks/benchmark_serving_structured_output.py @@ -51,7 +51,7 @@ try: except ImportError: from argparse import ArgumentParser as FlexibleArgumentParser -from vllm.v1.structured_output.utils import ( +from vllm.v1.structured_output.backend_xgrammar import ( has_xgrammar_unsupported_json_features) MILLISECONDS_TO_SECONDS_CONVERSION = 1000 @@ -150,17 +150,17 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, elif args.dataset == "grammar": schema = """ - ?start: select_statement + root ::= select_statement - ?select_statement: "SELECT " column_list " FROM " table_name + select_statement ::= "SELECT " column " from " table " where " condition - ?column_list: column_name ("," column_name)* + column ::= "col_1 " | "col_2 " - ?table_name: identifier + table ::= "table_1 " | "table_2 " - ?column_name: identifier + condition ::= column "= " number - ?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/ + number ::= "1 " | "2 " """ prompt = "Generate an SQL query to show the 'username' \ and 'email' from the 'users' table." diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 4db6e18df0b25c35fc9b005658b34e96754aa8a9..4c3f8d940da0271ea9bd4f6baaa857883bcc5f91 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -571,6 +571,13 @@ def validate_args(args): raise ValueError( "Tokenizer must be the same as the model for MII backend.") + # --data-parallel is not supported currently. + # https://github.com/vllm-project/vllm/issues/16222 + if args.data_parallel_size > 1: + raise ValueError( + "Data parallel is not supported in offline benchmark, \ + please use benchmark serving instead") + if __name__ == "__main__": parser = FlexibleArgumentParser(description="Benchmark the throughput.") diff --git a/benchmarks/kernels/benchmark_bitblas.py b/benchmarks/kernels/benchmark_bitblas.py new file mode 100644 index 0000000000000000000000000000000000000000..b23b4f3ea685aa1c07b2d9f26c56102ba16785c1 --- /dev/null +++ b/benchmarks/kernels/benchmark_bitblas.py @@ -0,0 +1,236 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( + MINIMUM_BITBLAS_VERSION) + +try: + import bitblas + if bitblas.__version__ < MINIMUM_BITBLAS_VERSION: + raise ImportError("bitblas version is wrong. Please " + f"install bitblas>={MINIMUM_BITBLAS_VERSION}") +except ImportError as e: + bitblas_import_exception = e + raise ValueError("Trying to use the bitblas backend, but could not import" + f"with the following error: {bitblas_import_exception}. " + "Please install bitblas through the following command: " + f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`" + ) from bitblas_import_exception + +from bitblas import Matmul, MatmulConfig, auto_detect_nvidia_target + +from vllm.utils import FlexibleArgumentParser + +parser = FlexibleArgumentParser( + description="Benchmark BitBLAS int4 on a specific target.") + +# Add arguments to the parser +parser.add_argument( + "--target", + type=str, + default=auto_detect_nvidia_target(), + help="Specify the target device for benchmarking.", +) +parser.add_argument("--group_size", + type=int, + default=None, + help="Group size for grouped quantization.") +parser.add_argument( + "--A_dtype", + type=str, + default="float16", + choices=["float16", "float32", "float64", "int32", "int8"], + help="Data type of activation A.", +) +parser.add_argument( + "--W_dtype", + type=str, + default="int4", + choices=[ + "float16", + "float32", + "float64", + "int32", + "int8", + "int4", + "int2", + "int1", + "nf4", + "fp4_e2m1", + ], + help="Data type of weight W.", +) +parser.add_argument( + "--accum_dtype", + type=str, + default="float16", + choices=["float16", "int32"], + help="Data type for accumulation.", +) +parser.add_argument( + "--out_dtype", + type=str, + default="float16", + choices=["float16", "float32", "int32", "int8"], + help="Data type for output.", +) +parser.add_argument( + "--layout", + type=str, + default="nt", + choices=["nt", "nn"], + help="Matrix layout, 'nt' for non-transpose A and transpose W.", +) +parser.add_argument("--with_bias", + action="store_true", + help="Include bias in the benchmark.") +parser.add_argument( + "--with_scaling", + action="store_true", + help="Include scaling factor in the quantization.", +) +parser.add_argument("--with_zeros", + action="store_true", + help="Include zeros in the quantization.") +parser.add_argument( + "--zeros_mode", + type=str, + default=None, + choices=["original", "rescale", "quantized"], + help="Specify the mode for calculating zeros.", +) + +# Parse the arguments +args = parser.parse_args() + +# Assign arguments to variables +target = args.target +A_dtype = args.A_dtype +W_dtype = args.W_dtype +accum_dtype = args.accum_dtype +out_dtype = args.out_dtype +layout = args.layout +with_bias = args.with_bias +group_size = args.group_size +with_scaling = args.with_scaling +with_zeros = args.with_zeros +zeros_mode = args.zeros_mode + +# Define a list of shared arguments that repeat in every config +shared_args = [ + A_dtype, + W_dtype, + out_dtype, + accum_dtype, + layout, + with_bias, + group_size, + with_scaling, + with_zeros, + zeros_mode, +] + +# Define just the (M, K, N) shapes in a more compact list +shapes = [ + # square test + (1, 16384, 16384), + # BLOOM-176B + (1, 43008, 14336), + (1, 14336, 14336), + (1, 57344, 14336), + (1, 14336, 57344), + # OPT-65B + (1, 9216, 9216), + (1, 36864, 9216), + (1, 9216, 36864), + (1, 22016, 8192), + # LLAMA-70B/65B + (1, 8192, 22016), + (1, 8192, 8192), + (1, 28672, 8192), + (1, 8192, 28672), + # square test + (16384, 16384, 16384), + # BLOOM-176B + (8192, 43008, 14336), + (8192, 14336, 14336), + (8192, 57344, 14336), + (8192, 14336, 57344), + # OPT-65B + (8192, 9216, 9216), + (8192, 36864, 9216), + (8192, 9216, 36864), + (8192, 22016, 8192), + # LLAMA-70B/65B + (8192, 8192, 22016), + (8192, 8192, 8192), + (8192, 28672, 8192), + (8192, 8192, 28672), +] + +# Build test shapes with all the shared arguments +test_shapes = [(MatmulConfig, Matmul, (*shape, *shared_args)) + for shape in shapes] + +benchmark_sets = [] +benchmark_sets.extend(test_shapes) + +benchmark_results = {} +for config_class, operator, input_args in benchmark_sets: + config = config_class(*input_args) + matmul = operator(config, target=target, enable_tuning=True) + kernel_latency = matmul.profile_latency() + + print("Time cost is: {:.3f} ms".format(kernel_latency)) + + profile_config = { + f"{operator.__name__}-{'-'.join([str(i) for i in input_args])}": { + "BitBLAS_top20_latency": kernel_latency, + } + } + + benchmark_results.update(profile_config) + +# Define headers for the table +headers = [ + "PrimFunc", + "Input Arguments", + "BitBLAS Top20 Latency", +] + +# Calculate column widths for pretty printing +col_widths = [0, 0, 0] +for config_key, values in benchmark_results.items(): + args_split = config_key.split("-") + func_name = args_split[0] + input_args_str = "-".join(args_split[1:]) + col_widths[0] = max(col_widths[0], len(func_name) + 2, len(headers[0]) + 2) + col_widths[1] = max(col_widths[1], + len(input_args_str) + 2, + len(headers[1]) + 2) + col_widths[2] = max(col_widths[2], + len(f"{values['BitBLAS_top20_latency']:.3f} ms") + 2, + len(headers[2]) + 2) + # break only if you want to measure widths from a single example; + # otherwise, let it loop over all items. + +# Print header +for i, header in enumerate(headers): + headers[i] = header.ljust(col_widths[i]) +print("".join(headers)) +print("-" * sum(col_widths)) + +# Print rows +for config_key, values in benchmark_results.items(): + args_split = config_key.split("-") + func_name = args_split[0] + input_args_str = "-".join(args_split[1:]) + row = [ + func_name, + input_args_str, + f"{values['BitBLAS_top20_latency']:.3f} ms", + ] + row_str = "".join( + [str(cell).ljust(col_widths[idx]) for idx, cell in enumerate(row)]) + print(row_str) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index b4b91eda28440f0c788e065aeef20e2b07c4653a..d382ede10b41be1b2168214d88789100dc04e1d4 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -17,8 +17,14 @@ from torch.utils.benchmark import Measurement as TMeasurement from utils import ArgPool, Bench, CudaGraphBenchParams from weight_shapes import WEIGHT_SHAPES -from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink -from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + from vllm.lora.ops.triton_ops import (LoRAKernelMeta, lora_expand, + lora_shrink) + from vllm.lora.ops.triton_ops.utils import (_LORA_A_PTR_DICT, + _LORA_B_PTR_DICT) + from vllm.utils import FlexibleArgumentParser DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 674567e3c3c7c2cfe84747e7377fb4fcad55d7c9..03de69de885dd0db8666f79b768d7d569f955a4f 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -576,11 +576,10 @@ def get_weight_block_size_safety(config, default_value=None): def main(args: argparse.Namespace): print(args) - - block_quant_shape = None tp_size = args.tp_size + config = AutoConfig.from_pretrained( args.model, trust_remote_code=args.trust_remote_code) if config.architectures[0] == "DbrxForCausalLM": @@ -599,21 +598,16 @@ def main(args: argparse.Namespace): topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // tp_size - elif config.architectures[0] == "Qwen2MoeForCausalLM": + elif config.architectures[0] in [ + "Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM" + ]: E = config.num_experts topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // tp_size - block_quant_shape = get_weight_block_size_safety(config) - elif config.architectures[0] == "Qwen2MoeForCausalLM": - E = config.num_experts - topk = config.num_experts_per_tok - intermediate_size = config.moe_intermediate_size - shard_intermediate_size = 2 * intermediate_size // args.tp_size else: - if not hasattr(config, "hidden_size"): - # Support for llama4 - config = config.text_config + # Support for llama4 + config = config.get_text_config() # Default: Mixtral. E = config.num_local_experts topk = config.num_experts_per_tok @@ -624,6 +618,7 @@ def main(args: argparse.Namespace): dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_int8_w8a16 = args.dtype == "int8_w8a16" + block_quant_shape = get_weight_block_size_safety(config) if args.batch_size is None: batch_sizes = [ diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index afd7c47e8ac003806738c404bb8cbd4c3e93df0b..b04e4c2d06edc90b443d91e34491a0f67431bf67 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,7 +38,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG dc9d410b3e2d6534a4c70724c2515f4def670a22 + GIT_TAG 8798f27777fb57f447070301bf33a9f9c607f491 GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/csrc/attention/merge_attn_states.cu b/csrc/attention/merge_attn_states.cu index 7af0caceda2f0e9b807875e6eeaa48f632f4ef92..14e5edd7e283d4a26615cc5d97315d251c54da61 100644 --- a/csrc/attention/merge_attn_states.cu +++ b/csrc/attention/merge_attn_states.cu @@ -107,13 +107,14 @@ __global__ void merge_attn_states_kernel( #define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \ { \ - vllm::merge_attn_states_kernel<<>>( \ - reinterpret_cast(output.data_ptr()), output_lse_ptr, \ - reinterpret_cast(prefix_output.data_ptr()), \ - reinterpret_cast(prefix_lse.data_ptr()), \ - reinterpret_cast(suffix_output.data_ptr()), \ - reinterpret_cast(suffix_lse.data_ptr()), num_tokens, \ - num_heads, head_size); \ + vllm::merge_attn_states_kernel \ + <<>>( \ + reinterpret_cast(output.data_ptr()), output_lse_ptr, \ + reinterpret_cast(prefix_output.data_ptr()), \ + reinterpret_cast(prefix_lse.data_ptr()), \ + reinterpret_cast(suffix_output.data_ptr()), \ + reinterpret_cast(suffix_lse.data_ptr()), num_tokens, \ + num_heads, head_size); \ } /*@brief Merges the attention states from prefix and suffix @@ -122,10 +123,10 @@ __global__ void merge_attn_states_kernel( * @param output [n,h,d] The output tensor to store the merged attention states. * @param output_lse [h,d] Optional tensor to store the log-sum-exp values. * @param prefix_output [n,h,d] The prefix attention states. - * @param prefix_lse [h,d] The log-sum-exp values for the prefix attention + * @param prefix_lse [h,n] The log-sum-exp values for the prefix attention * states. * @param suffix_output [n,h,d] The suffix attention states. - * @param suffix_lse [h,d] The log-sum-exp values for the suffix attention + * @param suffix_lse [h,n] The log-sum-exp values for the suffix attention * states. */ template @@ -146,13 +147,17 @@ void merge_attn_states_launcher(torch::Tensor& output, if (output_lse.has_value()) { output_lse_ptr = output_lse.value().data_ptr(); } - // process one pack elements per thread. float -> 4, half/bf16 -> 8 + // Process one pack elements per thread. for float, the + // pack_size is 4 for half/bf16, the pack_size is 8. const uint threads_per_head = head_size / pack_size; const uint total_threads = num_tokens * num_heads * threads_per_head; dim3 block(NUM_THREADS); dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS); + const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device()); + auto stream = at::cuda::getCurrentCUDAStream(); + LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS); } diff --git a/csrc/attention/mla/cutlass_mla_entry.cu b/csrc/attention/mla/cutlass_mla_entry.cu new file mode 100644 index 0000000000000000000000000000000000000000..0319d1daf302f5f6b65f4a9c5aa28fff7753ceec --- /dev/null +++ b/csrc/attention/mla/cutlass_mla_entry.cu @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA +void cutlass_mla_decode_sm100a(torch::Tensor const& out, + torch::Tensor const& q_nope, + torch::Tensor const& q_pe, + torch::Tensor const& kv_c_and_k_pe_cache, + torch::Tensor const& seq_lens, + torch::Tensor const& page_table, double scale); +#endif + +void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope, + torch::Tensor const& q_pe, + torch::Tensor const& kv_c_and_k_pe_cache, + torch::Tensor const& seq_lens, + torch::Tensor const& page_table, double scale) { +#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA + return cutlass_mla_decode_sm100a(out, q_nope, q_pe, kv_c_and_k_pe_cache, + seq_lens, page_table, scale); +#endif + TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled cutlass MLA"); +} diff --git a/csrc/attention/mla/cutlass_mla_kernels.cu b/csrc/attention/mla/cutlass_mla_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..6743af0cf2dbab816b4c204320ebf484ea8516fb --- /dev/null +++ b/csrc/attention/mla/cutlass_mla_kernels.cu @@ -0,0 +1,225 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.h" + +#include "cutlass_extensions/common.hpp" + +#include "device/sm100_mla.hpp" +#include "kernel/sm100_mla_tile_scheduler.hpp" + +using namespace cute; +using namespace cutlass::fmha::kernel; + +template +struct MlaSm100 { + using Element = T; + using ElementAcc = float; + using ElementOut = T; + + using TileShape = Shape<_128, _128, Shape<_512, _64>>; + using TileShapeH = cute::tuple_element_t<0, TileShape>; + using TileShapeD = cute::tuple_element_t<2, TileShape>; + + // H K (D_latent D_rope) B + using ProblemShape = cute::tuple; + + using StrideQ = cute::tuple; // H D B + using StrideK = cute::tuple; // K D B + using StrideO = StrideK; // H D B + using StrideLSE = cute::tuple<_1, int>; // H B + + using TileScheduler = + std::conditional_t; + + using FmhaKernel = + cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized< + TileShape, Element, ElementAcc, ElementOut, ElementAcc, TileScheduler, + /*kIsCpAsync=*/true>; + using Fmha = cutlass::fmha::device::MLA; +}; + +template +typename T::Fmha::Arguments args_from_options( + at::Tensor const& out, at::Tensor const& q_nope, at::Tensor const& q_pe, + at::Tensor const& kv_c_and_k_pe_cache, at::Tensor const& seq_lens, + at::Tensor const& page_table, double scale) { + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = q_nope.device().index(); + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); + + int batches = q_nope.sizes()[0]; + int page_count_per_seq = page_table.sizes()[1]; + int page_count_total = kv_c_and_k_pe_cache.sizes()[0]; + int page_size = kv_c_and_k_pe_cache.sizes()[1]; + int max_seq_len = page_size * page_count_per_seq; + using TileShapeH = typename T::TileShapeH; + using TileShapeD = typename T::TileShapeD; + auto problem_shape = + cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches); + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + using StrideQ = typename T::StrideQ; + using StrideK = typename T::StrideK; + using StrideO = typename T::StrideO; + using StrideLSE = typename T::StrideLSE; + + StrideQ stride_Q_latent = cute::make_tuple( + static_cast(D_latent), _1{}, static_cast(H * D_latent)); + StrideQ stride_Q_rope = cute::make_tuple(static_cast(D_rope), _1{}, + static_cast(H * D_rope)); + StrideK stride_C = + cute::make_tuple(static_cast(D_latent + D_rope), _1{}, + static_cast(page_size * (D_latent + D_rope))); + StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq); + StrideLSE stride_LSE = cute::make_tuple(_1{}, static_cast(H)); + StrideO stride_O = cute::make_tuple(static_cast(D_latent), _1{}, + static_cast(H * D_latent)); + + using Element = typename T::Element; + using ElementOut = typename T::ElementOut; + using ElementAcc = typename T::ElementAcc; + auto Q_latent_ptr = static_cast(q_nope.data_ptr()); + auto Q_rope_ptr = static_cast(q_pe.data_ptr()); + auto C_ptr = static_cast(kv_c_and_k_pe_cache.data_ptr()); + auto scale_f = static_cast(scale); + typename T::Fmha::Arguments arguments{ + problem_shape, + {scale_f, Q_latent_ptr, stride_Q_latent, Q_rope_ptr, stride_Q_rope, C_ptr, + stride_C, C_ptr + D_latent, stride_C, + static_cast(seq_lens.data_ptr()), + static_cast(page_table.data_ptr()), stride_PT, page_count_total, + page_size}, + {static_cast(out.data_ptr()), stride_O, + static_cast(nullptr), stride_LSE}, + hw_info, + -1, // split_kv + nullptr, // is_var_split_kv + }; + // TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute + // split_kv automatically based on batch size and sequence length to balance + // workload across available SMs. Consider using var_split_kv for manual + // control if needed. + T::Fmha::set_split_kv(arguments); + return arguments; +} + +template +void runMla(at::Tensor const& out, at::Tensor const& q_nope, + at::Tensor const& q_pe, at::Tensor const& kv_c_and_k_pe_cache, + at::Tensor const& seq_lens, at::Tensor const& page_table, + float scale, cudaStream_t stream) { + using MlaSm100Type = MlaSm100; + typename MlaSm100Type::Fmha fmha; + auto arguments = args_from_options( + out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, scale); + size_t workspace_size = MlaSm100Type::Fmha::get_workspace_size(arguments); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(q_nope.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + CUTLASS_CHECK(fmha.can_implement(arguments)); + + CUTLASS_CHECK(fmha.initialize(arguments, workspace.data_ptr(), stream)); + + CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream)); +} + +void cutlass_mla_decode_sm100a(torch::Tensor const& out, + torch::Tensor const& q_nope, + torch::Tensor const& q_pe, + torch::Tensor const& kv_c_and_k_pe_cache, + torch::Tensor const& seq_lens, + torch::Tensor const& page_table, double scale) { + TORCH_CHECK(q_nope.device().is_cuda(), "q_nope must be on CUDA"); + TORCH_CHECK(q_nope.dim() == 3, "q_nope must be a 3D tensor"); + TORCH_CHECK(q_pe.dim() == 3, "q_pe must be a 3D tensor"); + TORCH_CHECK(kv_c_and_k_pe_cache.dim() == 3, + "kv_c_and_k_pe_cache must be a 3D tensor"); + TORCH_CHECK(seq_lens.dim() == 1, "seq_lens must be a 1D tensor"); + TORCH_CHECK(page_table.dim() == 2, "page_table must be a 2D tensor"); + TORCH_CHECK(out.dim() == 3, "out must be a 3D tensor"); + + auto B_q_nope = q_nope.size(0); + auto H_q_nope = q_nope.size(1); + auto D_q_nope = q_nope.size(2); + auto B_q_pe = q_pe.size(0); + auto H_q_pe = q_pe.size(1); + auto D_q_pe = q_pe.size(2); + auto B_pt = page_table.size(0); + auto PAGE_NUM = page_table.size(1); + auto PAGE_SIZE = kv_c_and_k_pe_cache.size(1); + auto D_ckv = kv_c_and_k_pe_cache.size(2); + auto B_o = out.size(0); + auto H_o = out.size(1); + auto D_o = out.size(2); + + TORCH_CHECK(D_q_nope == 512, "D_q_nope must be equal to 512"); + TORCH_CHECK(D_q_pe == 64, "D_q_pe must be equal to 64"); + TORCH_CHECK(D_ckv == 576, "D_ckv must be equal to 576"); + TORCH_CHECK(H_q_nope == H_q_pe && H_q_nope == H_o && H_o == 128, + "H_q_nope, H_q_pe, and H_o must be equal to 128"); + TORCH_CHECK(PAGE_SIZE > 0 && (PAGE_SIZE & (PAGE_SIZE - 1)) == 0, + "PAGE_SIZE must be a power of 2"); + TORCH_CHECK( + B_q_nope == B_q_pe && B_q_nope == B_pt && B_q_nope == B_o, + "Batch dims must be same for page_table, q_nope and q_pe, and out"); + TORCH_CHECK(PAGE_NUM % (128 / PAGE_SIZE) == 0, + "PAGE_NUM must be divisible by 128 / PAGE_SIZE"); + TORCH_CHECK(D_o == 512, "D_o must be equal to 512"); + + TORCH_CHECK(q_nope.dtype() == at::ScalarType::Half || + q_nope.dtype() == at::ScalarType::BFloat16 || + q_nope.dtype() == at::ScalarType::Float8_e4m3fn, + "q_nope must be a half, bfloat16, or float8_e4m3fn tensor"); + TORCH_CHECK(kv_c_and_k_pe_cache.dtype() == q_nope.dtype() && + q_nope.dtype() == q_pe.dtype(), + "kv_c_and_k_pe_cache, q_nope, and q_pe must be the same type"); + TORCH_CHECK(seq_lens.dtype() == torch::kInt32, + "seq_lens must be a 32-bit integer tensor"); + TORCH_CHECK(page_table.dtype() == torch::kInt32, + "page_table must be a 32-bit integer tensor"); + + auto in_dtype = q_nope.dtype(); + at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()}; + const cudaStream_t stream = + at::cuda::getCurrentCUDAStream(q_nope.get_device()); + if (in_dtype == at::ScalarType::Half) { + runMla(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, + page_table, scale, stream); + } else if (in_dtype == at::ScalarType::BFloat16) { + runMla(out, q_nope, q_pe, kv_c_and_k_pe_cache, + seq_lens, page_table, scale, stream); + } else if (in_dtype == at::ScalarType::Float8_e4m3fn) { + runMla(out, q_nope, q_pe, kv_c_and_k_pe_cache, + seq_lens, page_table, scale, stream); + } else { + TORCH_CHECK(false, "Unsupported input data type of MLA"); + } +} diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 40fb088ac2ba8fc517d431e8cfc2220efb0fc9dd..ea3de8006b6ac4cf17e55129ad871c6e423a57c9 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -270,9 +270,10 @@ __global__ void reshape_and_cache_flash_kernel( cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads, // head_size] const int64_t* __restrict__ slot_mapping, // [num_tokens] - const int block_stride, const int key_stride, const int value_stride, - const int num_heads, const int head_size, const int block_size, - const float* k_scale, const float* v_scale) { + const int64_t block_stride, const int64_t page_stride, + const int64_t head_stride, const int64_t key_stride, + const int64_t value_stride, const int num_heads, const int head_size, + const int block_size, const float* k_scale, const float* v_scale) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; // NOTE: slot_idx can be -1 if the token is padded @@ -288,8 +289,8 @@ __global__ void reshape_and_cache_flash_kernel( const int head_idx = i / head_size; const int head_offset = i % head_size; const int64_t tgt_key_value_idx = block_idx * block_stride + - block_offset * num_heads * head_size + - head_idx * head_size + head_offset; + block_offset * page_stride + + head_idx * head_stride + head_offset; scalar_t tgt_key = key[src_key_idx]; scalar_t tgt_value = value[src_value_idx]; if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { @@ -524,16 +525,16 @@ void reshape_and_cache( // KV_T is the data type of key and value tensors. // CACHE_T is the stored data type of kv-cache. // KV_DTYPE is the real data type of kv-cache. -#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \ - vllm::reshape_and_cache_flash_kernel \ - <<>>( \ - reinterpret_cast(key.data_ptr()), \ - reinterpret_cast(value.data_ptr()), \ - reinterpret_cast(key_cache.data_ptr()), \ - reinterpret_cast(value_cache.data_ptr()), \ - slot_mapping.data_ptr(), block_stride, key_stride, \ - value_stride, num_heads, head_size, block_size, \ - reinterpret_cast(k_scale.data_ptr()), \ +#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \ + vllm::reshape_and_cache_flash_kernel \ + <<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + slot_mapping.data_ptr(), block_stride, page_stride, \ + head_stride, key_stride, value_stride, num_heads, head_size, \ + block_size, reinterpret_cast(k_scale.data_ptr()), \ reinterpret_cast(v_scale.data_ptr())); void reshape_and_cache_flash( @@ -560,9 +561,11 @@ void reshape_and_cache_flash( int head_size = key.size(2); int block_size = key_cache.size(1); - int key_stride = key.stride(0); - int value_stride = value.stride(0); - int block_stride = key_cache.stride(0); + int64_t key_stride = key.stride(0); + int64_t value_stride = value.stride(0); + int64_t block_stride = key_cache.stride(0); + int64_t page_stride = key_cache.stride(1); + int64_t head_stride = key_cache.stride(2); TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0)); dim3 grid(num_tokens); diff --git a/csrc/moe/marlin_moe_wna16/generate_kernels.py b/csrc/moe/marlin_moe_wna16/generate_kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..d1c0d92f6814a080084fec7e77531f316a1e6373 --- /dev/null +++ b/csrc/moe/marlin_moe_wna16/generate_kernels.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 +import glob +import itertools +import os +import subprocess + +import jinja2 + +FILE_HEAD = """ +// auto generated by generate.py +// clang-format off + +#include "kernel.h" +#include "marlin_template.h" + +namespace MARLIN_NAMESPACE_NAME { +""".strip() + +TEMPLATE = ("template __global__ void Marlin<" + "{{scalar_t}}, " + "{{w_type_id}}, " + "{{threads}}, " + "{{thread_m_blocks}}, " + "{{thread_n_blocks}}, " + "{{thread_k_blocks}}, " + "{{'true' if m_block_size_8 else 'false'}}, " + "{{stages}}, " + "{{'true' if has_act_order else 'false'}}, " + "{{'true' if has_zp else 'false'}}, " + "{{group_blocks}}, " + "{{'true' if is_zp_float else 'false'}}>" + "( MARLIN_KERNEL_PARAMS );") + +# int8 with zero point case (vllm::kU8) is also supported, +# we don't add it to reduce wheel size. +SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128"] +THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)] + +THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] +# group_blocks: +# = 0 : act order case +# = -1 : channelwise quantization +# > 0 : group_size=16*group_blocks +GROUP_BLOCKS = [0, -1, 2, 4, 8] +DTYPES = ["fp16", "bf16"] + + +def remove_old_kernels(): + for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"): + subprocess.call(["rm", "-f", filename]) + + +def generate_new_kernels(): + for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): + has_zp = "B" not in scalar_type + all_template_str_list = [] + + for group_blocks, m_blocks, thread_configs in itertools.product( + GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS): + + has_act_order = group_blocks == 0 + if has_zp and has_act_order: + continue + if thread_configs[2] == 256: + if m_blocks <= 1 and thread_configs[0] != 128: + continue + if m_blocks > 1 and thread_configs[0] != 64: + continue + + k_blocks = thread_configs[0] // 16 + n_blocks = thread_configs[1] // 16 + threads = thread_configs[2] + + c_dtype = "half" if dtype == "fp16" else "nv_bfloat16" + + template_str = jinja2.Template(TEMPLATE).render( + scalar_t=c_dtype, + w_type_id=scalar_type + ".id()", + threads=threads, + thread_m_blocks=max(m_blocks, 1), + thread_n_blocks=n_blocks, + thread_k_blocks=k_blocks, + m_block_size_8=m_blocks == 0.5, + stages="pipe_stages", + has_act_order=has_act_order, + has_zp=has_zp, + group_blocks=group_blocks, + is_zp_float=False, + ) + + all_template_str_list.append(template_str) + + file_content = FILE_HEAD + "\n\n" + file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" + filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu" + + with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: + f.write(file_content) + + +if __name__ == "__main__": + remove_old_kernels() + generate_new_kernels() diff --git a/csrc/moe/marlin_moe_wna16/kernel.h b/csrc/moe/marlin_moe_wna16/kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..3d92660e8028e94885d0de42e1557054e619edab --- /dev/null +++ b/csrc/moe/marlin_moe_wna16/kernel.h @@ -0,0 +1,44 @@ + +#ifndef MARLIN_NAMESPACE_NAME + #define MARLIN_NAMESPACE_NAME marlin_moe_wna16 +#endif + +#include "quantization/gptq_marlin/marlin.cuh" +#include "quantization/gptq_marlin/marlin_dtypes.cuh" +#include "core/scalar_type.hpp" + +#define MARLIN_KERNEL_PARAMS \ + const int4 *__restrict__ A, const int4 *__restrict__ B, \ + int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ + const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \ + const int *__restrict__ g_idx, \ + const int32_t *__restrict__ sorted_token_ids_ptr, \ + const int32_t *__restrict__ expert_ids_ptr, \ + const int32_t *__restrict__ num_tokens_past_padded_ptr, \ + const float *__restrict__ topk_weights_ptr, int top_k, \ + bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ + int prob_n, int prob_k, int *locks, bool use_atomic_add, \ + bool use_fp32_reduce + +namespace MARLIN_NAMESPACE_NAME { +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin(MARLIN_KERNEL_PARAMS); + +} diff --git a/csrc/moe/marlin_moe_wna16/marlin_template.h b/csrc/moe/marlin_moe_wna16/marlin_template.h new file mode 100644 index 0000000000000000000000000000000000000000..205b308fe511bdf57863b48ecd308572b2d68a2c --- /dev/null +++ b/csrc/moe/marlin_moe_wna16/marlin_template.h @@ -0,0 +1,1917 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#ifndef MARLIN_NAMESPACE_NAME + #define MARLIN_NAMESPACE_NAME marlin_moe_wna16 +#endif + +#include "quantization/gptq_marlin/marlin.cuh" +#include "quantization/gptq_marlin/marlin_dtypes.cuh" +#include "core/scalar_type.hpp" + +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert(std::is_same::value || \ + std::is_same::value, \ + "only float16 and bfloat16 is supported"); + +namespace MARLIN_NAMESPACE_NAME { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k + const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids + const int32_t* __restrict__ expert_ids_ptr, // moe expert ids + const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens + const float* __restrict__ topk_weights_ptr, // moe top weights + int top_k, // num of experts per token + bool mul_topk_weights, // mul topk weights or not + bool is_ep, // expert parallelism + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks, // extra global storage for barrier synchronization + bool use_atomic_add, // whether to use atomic add to reduce + bool use_fp32_reduce // whether to use fp32 global reduce +) {} + +} // namespace MARLIN_NAMESPACE_NAME + +#else + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +template +__device__ inline void mma(const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } +} + +template +__device__ inline void mma_trans( + const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + const typename ScalarType::FragB& frag_b2, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + const uint32_t* b2 = reinterpret_cast(&frag_b2); + float* c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +template +__device__ inline void ldsm(typename ScalarType::FragA& frag_a, + const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + if constexpr (count == 4) { + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); + } else if constexpr (count == 2) { + asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" + : "=r"(a[0]), "=r"(a[1]) + : "r"(smem)); + } else if constexpr (count == 1) { + asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" + : "=r"(a[0]) + : "r"(smem)); + } else { + static_assert(count == 1 || count == 2 || count == 4, "invalid count"); + } +} + +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +template +__device__ inline typename ScalarType::FragB dequant( + int q, typename ScalarType::FragB& frag_b); + +// +// Efficiently dequantize 4bit values packed in an int32 value into a full +// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, +// with some small changes: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 +// +template <> +__device__ inline typename ScalarType::FragB dequant( + int q, typename ScalarType::FragB& frag_b) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); + int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +template <> +__device__ inline typename ScalarType::FragB +dequant(int q, + typename ScalarType::FragB& frag_b) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + + // Guarantee that the `(a & b) | c` operations are LOP3s. + + int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); + q >>= 4; + int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); + + static constexpr uint32_t MUL = 0x3F803F80; + static constexpr uint32_t ADD = 0xC308C308; + + frag_b[0] = __hfma2(*reinterpret_cast(&lo), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +// +// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or +// bf16 Reference: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 +// +template <> +__device__ inline typename ScalarType::FragB dequant( + int q, typename ScalarType::FragB& frag_b) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + +template <> +__device__ inline typename ScalarType::FragB +dequant(int q, + typename ScalarType::FragB& frag_b) { + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388736.f; + fp32_intermediates[1] -= 8388736.f; + fp32_intermediates[2] -= 8388736.f; + fp32_intermediates[3] -= 8388736.f; + + uint32_t* bf16_result_ptr = reinterpret_cast(&frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], + fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], + fp32_intermediates_casted[3], 0x7632); + + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +template +__device__ inline void scale(typename ScalarType::FragB& frag_b, + typename ScalarType::FragS& frag_s, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s = + ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +template +__device__ inline void scale_and_sub( + typename ScalarType::FragB& frag_b, scalar_t s, scalar_t zp) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s2 = ScalarType::num2num2(s); + scalar_t2 zp2 = ScalarType::num2num2(zp); + frag_b[0] = __hfma2(frag_b[0], s2, __hneg2(zp2)); + frag_b[1] = __hfma2(frag_b[1], s2, __hneg2(zp2)); +} + +template +__device__ inline void sub_zp(typename ScalarType::FragB& frag_b, + typename ScalarType::scalar_t2& frag_zp, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 zp = + ScalarType::num2num2(reinterpret_cast(&frag_zp)[i]); + frag_b[0] = __hsub2(frag_b[0], zp); + frag_b[1] = __hsub2(frag_b[1], zp); +} + +// Same as above, but for act_order (each K is multiplied individually) +template +__device__ inline void scale4(typename ScalarType::FragB& frag_b, + typename ScalarType::FragS& frag_s_1, + typename ScalarType::FragS& frag_s_2, + typename ScalarType::FragS& frag_s_3, + typename ScalarType::FragS& frag_s_4, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s_val_1_2; + s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; + + scalar_t2 s_val_3_4; + s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; + + frag_b[0] = __hmul2(frag_b[0], s_val_1_2); + frag_b[1] = __hmul2(frag_b[1], s_val_3_4); +} + +// Given 2 floats multiply by 2 scales (halves) +template +__device__ inline void scale_float(float* c, + typename ScalarType::FragS& s) { + scalar_t* s_ptr = reinterpret_cast(&s); + c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} + +// Wait until value of lock to be negative, and then add 1 +__device__ inline void wait_negative_and_add(int* lock) { + if (threadIdx.x == 0) { + int state = 0; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state >= 0); + atomicAdd(lock, 1); + } + __syncthreads(); +} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k + const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids + const int32_t* __restrict__ expert_ids_ptr, // moe expert ids + const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens + const float* __restrict__ topk_weights_ptr, // moe top weights + int top_k, // num of experts per token + bool mul_topk_weights, // mul topk weights or not + bool is_ep, // expert parallelism + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks, // extra global storage for barrier synchronization + bool use_atomic_add, // whether to use atomic add to reduce + bool use_fp32_reduce // whether to use fp32 global reduce +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + using Dtype = ScalarType; + using scalar_t2 = typename ScalarType::scalar_t2; + using FragA = typename ScalarType::FragA; + using FragB = typename ScalarType::FragB; + using FragC = typename ScalarType::FragC; + using FragS = typename ScalarType::FragS; + using FragZP = typename ScalarType::FragZP; + + extern __shared__ int4 sh[]; + static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); + + constexpr int pack_factor = 32 / w_type.size_bits(); + static_assert(thread_m_blocks == 1 || !m_block_size_8); + constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); + const int group_size = + (!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups; + const int scales_expert_stride = prob_n * prob_k / group_size / 8; + const int zp_expert_stride = + is_zp_float ? prob_n * prob_k / group_size / 8 + : prob_n * prob_k / group_size / (pack_factor * 4); + + // parallel: num valid moe blocks + int num_tokens_past_padded = num_tokens_past_padded_ptr[0]; + int parallel = num_tokens_past_padded / moe_block_size; + int num_valid_blocks = parallel; + if (is_ep) { + for (int i = 0; i < parallel; i++) { + if (expert_ids_ptr[i] == -1) num_valid_blocks--; + } + } + int num_invalid_blocks = parallel - num_valid_blocks; + parallel = num_valid_blocks; + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); + + if constexpr (!has_act_order && group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * + div_ceil(iters, (group_blocks / thread_k_blocks)); + } + } + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + int par_id = 0; + int block_id = -1; + int64_t expert_id = 0; // use int64 to avoid computation result overflow + int old_expert_id = 0; + int64_t B_expert_off = 0; + + int4* sh_block_sorted_ids_int4 = sh; + int32_t* sh_block_sorted_ids = + reinterpret_cast(sh_block_sorted_ids_int4); + int4* sh_block_topk_weights_int4 = + sh_block_sorted_ids_int4 + moe_block_size / 4; + scalar_t2* sh_block_topk_weights = + reinterpret_cast(sh_block_topk_weights_int4); + int4* sh_new = sh_block_topk_weights_int4 + moe_block_size / 4; + + int32_t block_num_valid_tokens = 0; + int32_t locks_off = 0; + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + slice_col = slice_col_par % n_tiles; + par_id = slice_col_par / n_tiles; + } + if (parallel * n_tiles >= gridDim.x) { + // when parallel * n_tiles >= sms + // then there are at most $sms$ conflict tile blocks + locks_off = blockIdx.x; + } else { + locks_off = (iters * blockIdx.x) / k_tiles - 1; + } + + // read moe block data given block_id + // block_sorted_ids / block_num_valid_tokens / block_topk_weights + auto read_moe_block_data = [&](int block_id) { + block_num_valid_tokens = moe_block_size; + #pragma unroll + for (int i = 0; i < moe_block_size / 4; i++) { + int4 sorted_token_ids_int4 = reinterpret_cast( + sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i]; + int* sorted_token_ids = reinterpret_cast(&sorted_token_ids_int4); + #pragma unroll + for (int j = 0; j < 4; j++) { + if (sorted_token_ids[j] >= prob_m * top_k) { + block_num_valid_tokens = i * 4 + j; + break; + } + } + if (block_num_valid_tokens != moe_block_size) break; + } + + __syncthreads(); + int tid4 = threadIdx.x / 4; + if (threadIdx.x % 4 == 0 && threadIdx.x < block_num_valid_tokens) { + sh_block_sorted_ids_int4[tid4] = reinterpret_cast( + sorted_token_ids_ptr)[block_id * moe_block_size / 4 + tid4]; + + if (mul_topk_weights) { + #pragma unroll + for (int i = 0; i < 4; i++) { + sh_block_topk_weights[tid4 * 4 + i] = + Dtype::num2num2(Dtype::float2num( + topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]])); + } + } + } + __syncthreads(); + }; + + // when move to next moe block, find the next block_id and expert_id + // and then read moe block data + auto update_next_moe_block_data = [&]() { + if (par_id >= parallel) return; + + old_expert_id = expert_id; + if (num_invalid_blocks > 0) { + int skip_count = block_id == -1 ? par_id : 0; + block_id++; + for (int i = block_id; i < num_tokens_past_padded / moe_block_size; i++) { + expert_id = expert_ids_ptr[i]; + if (expert_id != -1) { + if (skip_count == 0) { + block_id = i; + break; + }; + skip_count--; + }; + } + } else { + block_id = par_id; + expert_id = expert_ids_ptr[block_id]; + } + + B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4); + scales_ptr += (expert_id - old_expert_id) * scales_expert_stride; + if constexpr (has_zp) { + zp_ptr += (expert_id - old_expert_id) * zp_expert_stride; + } + if constexpr (has_act_order) { + g_idx += (expert_id - old_expert_id) * prob_k; + } + + read_moe_block_data(block_id); + }; + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&](bool first_init = false) { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = div_ceil(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (parallel * n_tiles >= gridDim.x) { + if (slice_count > 1 && slice_idx == slice_count - 1) { + locks_off++; + } + } else { + locks_off++; + } + + if (first_init && use_atomic_add && slice_count > 1 && slice_idx == 0) { + constexpr int threads_per_m = 16 * thread_n_blocks / 8; + int m_per_thread = + div_ceil(block_num_valid_tokens, threads / threads_per_m); + for (int i = 0; i < m_per_thread; i++) { + int row = threads / threads_per_m * i + threadIdx.x / threads_per_m; + if (row < block_num_valid_tokens) { + int64_t sorted_row = sh_block_sorted_ids[row]; + int col = slice_col * 16 * thread_n_blocks / 8 + + threadIdx.x % threads_per_m; + C[sorted_row * prob_n / 8 + col] = {0, 0, 0, 0}; + } + } + // After write zero to output, write a negative value to lock. + // Every SM that processes the same slice would wait for + // the negative value, and then atomicAdd 1 to it. + // After all SMs are processed, the lock value would back to 0 again. + __syncthreads(); + if (threadIdx.x == 0) locks[locks_off] = 1 - slice_count; + } + + if (slice_col == n_tiles) { + slice_col = 0; + par_id++; + update_next_moe_block_data(); + } + }; + + update_next_moe_block_data(); + init_slice(true); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_tb_groups = + !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks + ? thread_k_blocks / group_blocks + : 1; + constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + // Zero-points sizes/strides + int zp_gl_stride = is_zp_float ? prob_n / 8 : (prob_n / pack_factor) / 4; + constexpr int zp_sh_stride = is_zp_float + ? 16 * thread_n_blocks / 8 + : ((16 * thread_n_blocks) / pack_factor) / 4; + constexpr int zp_tb_groups = s_tb_groups; + constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; + int zp_gl_rd_delta = zp_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % (16 / (m_block_size_8 ? 2 : 1))) + + (threadIdx.x % 32) / (16 / (m_block_size_8 ? 2 : 1)); + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x * b_thread_vecs; + int b_sh_rd = threadIdx.x * b_thread_vecs; + + // For act_order + constexpr int k_iter_size = tb_k / b_sh_wr_iters; + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd; + if constexpr (!has_act_order) { + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + } + } + int s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // Zero-points + int zp_gl_rd; + if constexpr (has_zp) { + if constexpr (group_blocks == -1) { + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + zp_sh_stride * slice_col + threadIdx.x; + } + } + int zp_sh_wr = threadIdx.x; + bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; + + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + int s_sh_rd; + if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + else if constexpr (group_blocks == -1 && (m_block_size_8 || has_zp)) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 8; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + + // Zero-points have the same read layout as the scales + // (without column-wise case) + constexpr int num_col_threads = 8; + constexpr int num_row_threads = 4; + constexpr int num_ints_per_thread = 8 / pack_factor; + int zp_sh_rd; + if constexpr (has_zp) { + if constexpr (is_zp_float) { + if constexpr (group_blocks != -1) { + zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + } + } else { + zp_sh_rd = num_ints_per_thread * num_col_threads * + ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); + } + } + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh_new; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_g_idx = sh_b + (stages * b_sh_stage); + int4* sh_zp = sh_g_idx + (stages * g_idx_stage); + int4* sh_s = sh_zp + (stages * zp_sh_stage); + int4* sh_red = sh_b; + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2][b_thread_vecs]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + int frag_qzp[2][num_ints_per_thread]; // Zero-points + FragZP frag_zp; // Zero-points in fp16 + FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ + + // Zero accumulators. + auto zero_accums = [&]() { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + int sh_first_group_id = -1; + int sh_num_groups = -1; + constexpr int sh_max_num_groups = 32; + + auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id, + int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups < sh_max_num_groups) { + sh_num_groups = sh_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + int a_remaining_load_count_in_slice = stages; + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + if (prob_k > thread_k_blocks * 16 * stages || slice_col == 0 || + a_remaining_load_count_in_slice > 0) { + a_remaining_load_count_in_slice--; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; + int row = a_idx / a_gl_stride; + int64_t sorted_row = 0; + if (!m_block_size_8 || row < 8) + sorted_row = sh_block_sorted_ids[row] / top_k; + int64_t true_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; + cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[true_idx], + row < block_num_valid_tokens); + } + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], + B_ptr[i] + j + B_expert_off); + } + + B_ptr[i] += b_gl_rd_delta_o; + } + + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + + int4 const* cur_g_idx_stage_ptr = + reinterpret_cast(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], + &cur_g_idx_stage_ptr[threadIdx.x]); + } + } + } else { + if constexpr (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], + &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + } + + if constexpr (has_zp && group_blocks != -1) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch zero-points if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } else { + for (int i = 0; i < zp_tb_groups; i++) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], + &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + auto fetch_col_zp_to_shared = [&]() { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + }; + + auto fetch_col_scale_to_shared = [&]() { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm( + frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + + #pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + + auto fetch_scales_to_registers = [&](int k, int full_pipe) { + int pipe = full_pipe % stages; + + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks == -1) { + // load only when starting a new slice + if (k == 0 && full_pipe == 0) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } else if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / group_blocks; + + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } + } + + return; + } + + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { + return; + } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + int warp_id = threadIdx.x / 32; + int n_warps = + thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; + + cur_k += warp_row * 16; + + int th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + + s_col_shift]; + } else { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = {0, 1, 8, + 9}; // Tensor core offsets per thread + + #pragma unroll + for (int i = 0; i < 4; i++) { + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } + }; + + auto fetch_zp_to_registers = [&](int k, int full_pipe) { + // This code does not handle group_blocks == 0, + // which signifies act_order. + // has_zp implies AWQ, which doesn't have act_order, + static_assert(!has_zp || group_blocks != 0); + + if constexpr (has_zp && !is_zp_float) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks == -1) { + // load only when starting a new slice + if (k == 0 && full_pipe == 0) { + #pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; + } + } + + } else if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = 0; + + // Suppress bogus and persistent divide-by-zero warning + #pragma nv_diagnostic push + #pragma nv_diag_suppress divide_by_zero + cur_group_id = k_blocks / group_blocks; + #pragma nv_diagnostic pop + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + sh_zp_stage += cur_group_id * zp_sh_stride; + + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + } + + else if constexpr (has_zp && is_zp_float) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd]; + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + // Suppress bogus and persistent divide-by-zero warning + #pragma nv_diagnostic push + #pragma nv_diag_suppress divide_by_zero + int cur_group_id = k_blocks / group_blocks; + #pragma nv_diagnostic pop + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + reinterpret_cast(&frag_zpf[k % 2])[0] = + sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride]; + } + } + } + }; + + // Execute the actual tensor core matmul of a sub-tile. + bool is_first_matmul_in_slice = true; + auto matmul = [&](int k) { + int k2 = k % 2; + const bool is_new_zp = + ((group_blocks != -1) && (group_blocks < thread_k_blocks || k == 0)) || + (group_blocks == -1 && is_first_matmul_in_slice); + if constexpr (has_zp && !is_zp_float) { + if (is_new_zp) { + if constexpr (group_blocks == -1) is_first_matmul_in_slice = false; + FragB frag_zp_0; + FragB frag_zp_1; + int zp_quant_0, zp_quant_1; + + if constexpr (w_type.size_bits() == 4) { + zp_quant_0 = frag_qzp[k2][0]; + zp_quant_1 = zp_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + zp_quant_0 = frag_qzp[k2][0]; + zp_quant_1 = frag_qzp[k2][1]; + } + + dequant(zp_quant_0, frag_zp_0); + dequant(zp_quant_1, frag_zp_1); + + frag_zp[0] = frag_zp_0[0]; + frag_zp[1] = frag_zp_0[1]; + frag_zp[2] = frag_zp_1[0]; + frag_zp[3] = frag_zp_1[1]; + } + } + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + FragB frag_b0; + FragB frag_b1; + int b_quant_0, b_quant_1; + + if constexpr (w_type.size_bits() == 4) { + b_quant_0 = frag_b_quant[k2][0][j]; + b_quant_1 = b_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k2]); + b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + } + + dequant(b_quant_0, frag_b0); + dequant(b_quant_1, frag_b1); + + // Apply scale to frag_b0 + if constexpr (has_act_order) { + static_assert(group_blocks != -1); + scale4(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], + act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); + scale4(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], + act_frag_s[k][2][j], act_frag_s[k2][3][j], 1); + + } else if constexpr (has_zp && !is_zp_float && group_blocks == -1) { + int idx = (threadIdx.x / 4) % 2; + scalar_t2 s2 = Dtype::nums2num2( + reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx], + reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 1])[idx]); + if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2); + scale_and_sub(frag_b0, s2.x, frag_zp[j].x); + scale_and_sub(frag_b1, s2.y, frag_zp[j].y); + } else if constexpr (has_zp && !is_zp_float && group_blocks != -1) { + if (is_new_zp) + frag_zp[j] = __hmul2(frag_zp[j], + *reinterpret_cast(&frag_s[k2][j])); + scale_and_sub(frag_b0, frag_s[k % 2][j][0].x, frag_zp[j].x); + scale_and_sub(frag_b1, frag_s[k % 2][j][0].y, frag_zp[j].y); + } else if constexpr (has_zp && is_zp_float && group_blocks != -1) { + if (is_new_zp) + frag_zpf[k2][j] = __hmul2( + frag_zpf[k2][j], *reinterpret_cast(&frag_s[k2][j])); + scale_and_sub(frag_b0, frag_s[k2][j].x, frag_zpf[k2][j].x); + scale_and_sub(frag_b1, frag_s[k2][j].y, frag_zpf[k2][j].y); + } else if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k2][j], 0); + scale(frag_b1, frag_s[k2][j], 1); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + if constexpr (m_block_size_8) { + mma_trans(frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]); + } else { + mma(frag_a[k2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k2][i], frag_b1, frag_c[i][j][1]); + } + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j += (m_block_size_8 ? 2 : 1)) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = reinterpret_cast( + &sh_red[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh_red[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh_red[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i += (m_block_size_8 ? 2 : 1)) { + float* c_rd = + reinterpret_cast(&sh_red[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce_fp16 = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + bool is_th_active = threadIdx.x < active_threads; + if (!is_th_active) { + return; + } + + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr; + if constexpr (m_block_size_8) { + c_gl_wr = c_gl_stride * ((threadIdx.x % 4) * 2) + 4 * (threadIdx.x / 32) + + (threadIdx.x % 32) / 8; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + } else { + c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + } + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + if (!first) { + + #pragma unroll + for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) { + int c_idx; + if constexpr (m_block_size_8) + c_idx = c_gl_wr + i * c_gl_stride + + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i; + else + c_idx = + c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + if (c_idx / c_gl_stride < block_num_valid_tokens) { + int64_t sorted_row = sh_block_sorted_ids[c_idx / c_gl_stride]; + int64_t true_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; + sh_red[c_sh_wr + c_sh_wr_delta * i] = C[true_idx]; + } + } + } + + #pragma unroll + for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) { + if (!first) { + int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + int delta = 0; + if constexpr (m_block_size_8) { + delta = j % 2 == 1 ? -2 : 0; + } + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta] += + Dtype::num2float(reinterpret_cast(&c_red)[j]); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + int delta = 0; + if constexpr (m_block_size_8) { + delta = j % 2 == 1 ? -2 : 0; + } + reinterpret_cast(&c)[j] = + Dtype::float2num(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta]); + } + + int c_idx; + if constexpr (m_block_size_8) + c_idx = c_gl_wr + i * c_gl_stride + + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i; + else + c_idx = + c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + if (c_idx / c_gl_stride < block_num_valid_tokens) { + int64_t sorted_row = sh_block_sorted_ids[c_idx / c_gl_stride]; + int64_t true_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; + C[true_idx] = c; + } + } + } + }; + + // Globally reduce over threadblocks that compute the same column block. + // We use a tmp C buffer to reduce in full fp32 precision. + auto global_reduce_fp32 = [&](bool first = false, bool last = false) { + constexpr int tb_m = thread_m_blocks * 16; + constexpr int tb_n = thread_n_blocks * 16; + + constexpr int c_size = tb_m * tb_n * sizeof(float) / 16; + + constexpr int active_threads = 32 * thread_n_blocks / 4; + bool is_th_active = threadIdx.x < active_threads; + + constexpr int num_floats = thread_m_blocks * 4 * 2 * 4; + constexpr int th_size = num_floats * sizeof(float) / 16; + + int c_cur_offset = locks_off * c_size; + + if (!is_th_active) { + return; + } + + if (!first) { + float* frag_c_ptr = reinterpret_cast(&frag_c); + #pragma unroll + for (int k = 0; k < th_size; k++) { + if constexpr (m_block_size_8) { + if (k % 2) continue; + } else { + if (k / 8 * 16 + (threadIdx.x % 32) / 4 >= block_num_valid_tokens) + continue; + } + + sh_red[threadIdx.x] = + C_tmp[c_cur_offset + active_threads * k + threadIdx.x]; + + float* sh_c_ptr = reinterpret_cast(&sh_red[threadIdx.x]); + #pragma unroll + for (int f = 0; f < 4; f++) { + frag_c_ptr[k * 4 + f] += sh_c_ptr[f]; + } + } + } + + if (!last) { + int4* frag_c_ptr = reinterpret_cast(&frag_c); + #pragma unroll + for (int k = 0; k < th_size; k++) { + if constexpr (m_block_size_8) { + if (k % 2) continue; + } else { + if (k / 8 * 16 + (threadIdx.x % 32) / 4 >= block_num_valid_tokens) + continue; + } + + C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k]; + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = + c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr; + if constexpr (m_block_size_8) { + c_sh_wr = (8 * c_sh_stride) * ((threadIdx.x % 32) % 4 * 2) + + (threadIdx.x % 32) / 4; + c_sh_wr += 64 * (threadIdx.x / 32); + } else { + c_sh_wr = + (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + } + + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS& s) { + scalar_t2 res = + Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 4 && !has_zp) { + res = __hmul2(res, s[0]); + } + + if constexpr (m_block_size_8) { + ((scalar_t*)sh_red)[idx] = res.x; + ((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; + } else { + ((scalar_t2*)sh_red)[idx] = res; + } + }; + + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + if constexpr (m_block_size_8) { + int wr = c_sh_wr + 16 * j; + write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], + frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3], + frag_s[j / 2][2 * (j % 2) + 1]); + } else { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; + i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + int row = c_gl_wr / c_gl_stride; + if (row < block_num_valid_tokens) { + int64_t sorted_row = sh_block_sorted_ids[row]; + int64_t true_idx = sorted_row * c_gl_stride + c_gl_wr % c_gl_stride; + scalar_t2 topk_weight_score; + if (mul_topk_weights) topk_weight_score = sh_block_topk_weights[row]; + if (use_atomic_add && slice_count > 1 || mul_topk_weights) { + scalar_t2* C_half2 = reinterpret_cast(&C[true_idx]); + scalar_t2* sh_red_half2 = + reinterpret_cast(&sh_red[c_sh_rd]); + #pragma unroll + for (int a = 0; a < 4; a++) { + scalar_t2 res = sh_red_half2[a]; + if (mul_topk_weights) { + res = __hmul2(res, topk_weight_score); + } + + if (use_atomic_add && slice_count > 1) { + atomicAdd(&C_half2[a], res); + } else { + C_half2[a] = res; + }; + } + } else { + C[true_idx] = sh_red[c_sh_rd]; + } + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + __syncthreads(); + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + + #pragma unroll + for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + fetch_act_order_scales_to_shared(true, g_idx[slice_k_start], + g_idx[last_g_idx]); + } + + if constexpr (has_zp && !is_zp_float && group_blocks == -1) { + if (i == 0) { + fetch_col_zp_to_shared(); + fetch_col_scale_to_shared(); + } + } + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + fetch_zp_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + slice_k_start_shared_fetch += tb_k * (stages - 1); + }; + if (slice_iters) { + start_pipes(); + } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + fetch_zp_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } + a_remaining_load_count_in_slice = 0; + + a_gl_rd += a_gl_rd_delta_o * stages; + slice_k_start += tb_k * stages; + slice_k_start_shared_fetch += tb_k * stages; + + if constexpr (has_act_order) { + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_act_order_scales_to_shared(false, first_group_id, last_group_id); + __syncthreads(); + } + } + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if constexpr (!has_act_order && group_blocks == -1 && !has_zp) { + if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } + + thread_block_reduce(); + if constexpr (!has_act_order && group_blocks == -1 && !has_zp) { + if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + if constexpr (m_block_size_8) { + int idx = (threadIdx.x / 4) % 2; + scalar_t2* frag_s_half2 = reinterpret_cast(frag_s); + #pragma unroll + for (int i = 0; i < 8; i++) { + frag_s_half2[i] = Dtype::num2num2( + reinterpret_cast(&frag_s_half2[i])[idx]); + } + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 8 && !has_zp) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + scale_float( + reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float( + reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + (m_block_size_8 ? 1 : 0)]); + + if constexpr (!m_block_size_8) { + scale_float( + reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float( + reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + } + } + + if (slice_count > 1 && !use_atomic_add) { + // only globally reduce if there is more than one block in a slice + barrier_acquire(&locks[locks_off], slice_idx); + if (use_fp32_reduce) { + global_reduce_fp32(slice_idx == 0, last); + } else { + global_reduce_fp16(slice_idx == 0, last); + } + barrier_release(&locks[locks_off], last); + } + if (use_atomic_add && slice_count > 1 && slice_idx != 0) + wait_negative_and_add(&locks[locks_off]); + if (last || use_atomic_add) + // only the last block in a slice actually writes the result + write_result(); + if (slice_row) a_remaining_load_count_in_slice = stages; + slice_row = 0; + slice_col_par++; + slice_col++; + is_first_matmul_in_slice = true; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + } + + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + + } else { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } + + start_pipes(); + } + } + } +} + +} // namespace MARLIN_NAMESPACE_NAME + +#endif diff --git a/csrc/moe/marlin_moe_wna16/ops.cu b/csrc/moe/marlin_moe_wna16/ops.cu new file mode 100644 index 0000000000000000000000000000000000000000..a16e955a325e236a9131135d5935543603bf931c --- /dev/null +++ b/csrc/moe/marlin_moe_wna16/ops.cu @@ -0,0 +1,927 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#ifndef MARLIN_NAMESPACE_NAME + #define MARLIN_NAMESPACE_NAME marlin_moe_wna16 +#endif + +#include "kernel.h" +#include "core/registration.h" + +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert(std::is_same::value || \ + std::is_same::value, \ + "only float16 and bfloat16 is supported"); + +namespace MARLIN_NAMESPACE_NAME { + +__global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){}; + +using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS); + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +template +__global__ void permute_cols_kernel( + int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, + const int32_t* __restrict__ sorted_token_ids_ptr, + const int32_t* __restrict__ expert_ids_ptr, + const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m, + int size_k, int top_k) {}; + +} // namespace marlin + +torch::Tensor moe_wna16_marlin_gemm( + torch::Tensor& a, std::optional const& c_or_none, + torch::Tensor& b_q_weight, torch::Tensor& b_scales, + std::optional const& b_zeros_or_none, + std::optional const& g_idx_or_none, + std::optional const& perm_or_none, torch::Tensor& workspace, + torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids, + torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights, + int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep, + vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, + bool is_zp_float) { + TORCH_CHECK_NOT_IMPLEMENTED(false, + "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); + return torch::empty({1, 1}); +} + +#else + +// For a given "a" of size [M,K] performs a permutation of the K columns based +// on the given "perm" indices. +template +__global__ void permute_cols_kernel( + int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, + const int32_t* __restrict__ sorted_token_ids_ptr, + const int32_t* __restrict__ expert_ids_ptr, + const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m, + int size_k, int top_k) { + int num_tokens_past_padded = num_tokens_past_padded_ptr[0]; + int num_moe_blocks = div_ceil(num_tokens_past_padded, moe_block_size); + int32_t block_sorted_ids[moe_block_size]; + int block_num_valid_tokens = 0; + int64_t old_expert_id = 0; + int64_t expert_id = 0; + int row_stride = size_k * sizeof(half) / 16; + + auto read_moe_block_data = [&](int block_id) { + block_num_valid_tokens = moe_block_size; + int4* tmp_block_sorted_ids = reinterpret_cast(block_sorted_ids); + for (int i = 0; i < moe_block_size / 4; i++) { + tmp_block_sorted_ids[i] = + ((int4*)sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i]; + } + for (int i = 0; i < moe_block_size; i++) { + if (block_sorted_ids[i] >= size_m * top_k) { + block_num_valid_tokens = i; + break; + }; + } + }; + + auto permute_row = [&](int row) { + int iters = size_k / default_threads; + int rest = size_k % default_threads; + + int in_offset = (row / top_k) * row_stride; + int out_offset = row * row_stride; + + half const* a_row_half = + reinterpret_cast(a_int4_ptr + in_offset); + half* out_half = reinterpret_cast(out_int4_ptr + out_offset); + + int base_k = 0; + + for (int i = 0; i < iters; i++) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += default_threads; + } + + if (rest) { + if (threadIdx.x < rest) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int index = blockIdx.x; index < num_moe_blocks; index += gridDim.x) { + old_expert_id = expert_id; + int tmp_expert_id = expert_ids_ptr[index]; + if (tmp_expert_id == -1) continue; + expert_id = tmp_expert_id; + perm_int_ptr += (expert_id - old_expert_id) * size_k; + read_moe_block_data(index); + + for (int i = 0; i < block_num_valid_tokens; i++) + permute_row(block_sorted_ids[i]); + } +} + +typedef struct { + int thread_k; + int thread_n; + int num_threads; +} thread_config_t; + +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, + {64, 128, 128}}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, + {64, 128, 128}}; + +typedef struct { + int blocks_per_sm; + thread_config_t tb_cfg; +} exec_config_t; + +int get_scales_cache_size(thread_config_t const& th_config, int prob_m, + int prob_n, int prob_k, int num_bits, int group_size, + bool has_act_order, bool is_k_full) { + bool cache_scales_chunk = has_act_order && !is_k_full; + + int tb_n = th_config.thread_n; + int tb_k = th_config.thread_k; + + // Get max scale groups per thread-block + int tb_groups; + if (group_size == -1) { + tb_groups = 1; + } else if (group_size == 0) { + tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size + } else { + tb_groups = div_ceil(tb_k, group_size); + } + + if (cache_scales_chunk) { + int load_groups = + tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups + return load_groups * tb_n * 2; + + } else { + int tb_scales = tb_groups * tb_n * 2; + + return tb_scales * pipe_stages; + } +} + +int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int group_size, bool has_act_order, bool is_k_full, + int has_zp, int is_zp_float) { + int pack_factor = 32 / num_bits; + + // Get B size + int tb_k = th_config.thread_k; + int tb_n = th_config.thread_n; + int tb_m = thread_m_blocks * 16; + + // shm size for block_sorted_ids/block_topk_weights + // both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32) + int sh_block_meta_size = tb_m * 4 * 2; + int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; + int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; + int sh_s_size = + get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, + group_size, has_act_order, is_k_full); + int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0; + int sh_zp_size = 0; + if (has_zp) { + if (is_zp_float) + sh_zp_size = sh_s_size; + else if (num_bits == 4) + sh_zp_size = sh_s_size / 4; + else if (num_bits == 8) + sh_zp_size = sh_s_size / 2; + } + + int total_size = sh_a_size + sh_b_size + sh_s_size + sh_zp_size + + sh_g_idx_size + sh_block_meta_size; + + return total_size; +} + +bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int group_size, bool has_act_order, bool is_k_full, + int has_zp, int is_zp_float, int max_shared_mem) { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || + th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + // Check that pipeline fits into cache + int cache_size = get_kernel_cache_size( + th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, has_zp, is_zp_float); + return cache_size <= max_shared_mem; +} + + #define __GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + M_BLOCK_SIZE_8, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ + NUM_THREADS, IS_ZP_FLOAT) \ + else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + m_block_size_8 == M_BLOCK_SIZE_8 && \ + has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ + is_zp_float == IS_ZP_FLOAT) { \ + kernel = Marlin; \ + } + + #define GPTQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, true, false, 0, NUM_THREADS, \ + false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, false, 0, \ + NUM_THREADS, false) \ + \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, -1, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 2, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 4, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 8, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, -1, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 2, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 4, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 8, \ + NUM_THREADS, false) + + #define GPTQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, false, 0, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, false, 0, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, false, 0, \ + NUM_THREADS, false) \ + \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, -1, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 2, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 4, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 8, \ + NUM_THREADS, false) \ + \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, -1, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 2, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 4, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 8, \ + NUM_THREADS, false) \ + \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, -1, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 2, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 4, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 8, \ + NUM_THREADS, false) + + #define AWQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, -1, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 2, NUM_THREADS, \ + false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \ + false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 8, NUM_THREADS, \ + false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, -1, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 2, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 8, \ + NUM_THREADS, false) + + #define AWQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, -1, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 2, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 8, \ + NUM_THREADS, false) \ + \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, -1, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 2, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 8, \ + NUM_THREADS, false) \ + \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, -1, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 2, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 8, \ + NUM_THREADS, false) + + // We currently have 4-bit models only with group_blocks == 4 + #define HQQ_GET_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \ + true) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ + NUM_THREADS, true) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ + NUM_THREADS, true) \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ + NUM_THREADS, true) \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ + NUM_THREADS, true) + +template +MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, + int thread_m_blocks, int thread_n_blocks, + int thread_k_blocks, bool m_block_size_8, + bool has_act_order, bool has_zp, + int group_blocks, int num_threads, + bool is_zp_float) { + int num_bits = q_type.size_bits(); + auto kernel = MarlinDefault; + if (false) { + } + GPTQ_GET_IF_M1(vllm::kU4B8, 8, 8, 256) + GPTQ_GET_IF_M1(vllm::kU4B8, 8, 4, 128) + + GPTQ_GET_IF_M234(vllm::kU4B8, 16, 4, 256) + GPTQ_GET_IF_M234(vllm::kU4B8, 8, 4, 128) + + GPTQ_GET_IF_M1(vllm::kU8B128, 8, 8, 256) + GPTQ_GET_IF_M1(vllm::kU8B128, 8, 4, 128) + + GPTQ_GET_IF_M234(vllm::kU8B128, 16, 4, 256) + GPTQ_GET_IF_M234(vllm::kU8B128, 8, 4, 128) + + AWQ_GET_IF_M1(vllm::kU4, 8, 8, 256) + AWQ_GET_IF_M1(vllm::kU4, 8, 4, 128) + + AWQ_GET_IF_M234(vllm::kU4, 16, 4, 256) + AWQ_GET_IF_M234(vllm::kU4, 8, 4, 128) + + return kernel; +} + +template +exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, + int prob_n, int prob_k, int thread_m_blocks, + bool m_block_size_8, int num_bits, + int group_size, bool has_act_order, + bool is_k_full, bool has_zp, + bool is_zp_float, int max_shared_mem) { + exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; + thread_config_t* thread_configs = thread_m_blocks > 1 + ? large_batch_thread_configs + : small_batch_thread_configs; + int thread_configs_size = + thread_m_blocks > 1 + ? sizeof(large_batch_thread_configs) / sizeof(thread_config_t) + : sizeof(small_batch_thread_configs) / sizeof(thread_config_t); + + int count = 0; + constexpr int device_max_reg_size = 255 * 1024; + for (int i = 0; i < thread_configs_size; i++) { + thread_config_t th_config = thread_configs[i]; + + if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, has_zp, + is_zp_float, max_shared_mem)) { + continue; + } + + int cache_size = get_kernel_cache_size( + th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, + group_size, has_act_order, is_k_full, has_zp, is_zp_float); + + int group_blocks = 0; + if (!has_act_order) { + group_blocks = group_size == -1 ? -1 : group_size / 16; + } + + auto kernel = get_marlin_kernel( + q_type, thread_m_blocks, th_config.thread_n / 16, + th_config.thread_k / 16, m_block_size_8, has_act_order, has_zp, + group_blocks, th_config.num_threads, is_zp_float); + + if (kernel == MarlinDefault) continue; + + if (thread_m_blocks > 1) { + exec_cfg = {1, th_config}; + break; + } else { + cudaFuncAttributes attr; + cudaFuncGetAttributes(&attr, kernel); + int reg_size = max(attr.numRegs, 1) * th_config.num_threads * 4; + int allow_count = min(device_max_reg_size / reg_size, + max_shared_mem / (cache_size + 1024)); + allow_count = max(min(allow_count, 4), 1); + if (allow_count > count) { + count = allow_count; + exec_cfg = {count, th_config}; + }; + } + } + + return exec_cfg; +} + +template +void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, + void* zp, void* g_idx, void* perm, void* a_tmp, + void* sorted_token_ids, void* expert_ids, + void* num_tokens_past_padded, void* topk_weights, + int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep, + int prob_m, int prob_n, int prob_k, void* workspace, + vllm::ScalarType const& q_type, bool has_act_order, + bool is_k_full, bool has_zp, int num_groups, int group_size, + int dev, cudaStream_t stream, int thread_k, int thread_n, + int sms, bool use_atomic_add, bool use_fp32_reduce, + bool is_zp_float) { + int thread_m_blocks = div_ceil(moe_block_size, 16); + bool m_block_size_8 = moe_block_size == 8; + + if (has_zp) { + TORCH_CHECK( + q_type == vllm::kU4 || q_type == vllm::kU8, + "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); + } else { + TORCH_CHECK( + q_type == vllm::kU4B8 || q_type == vllm::kU8B128, + "q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ", + q_type.str()); + } + + TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, + ", ", prob_n, ", ", prob_k, "]"); + + int group_blocks = 0; + if (has_act_order) { + if (is_k_full) { + TORCH_CHECK(group_size != -1); + group_blocks = group_size / 16; + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } else { + TORCH_CHECK(group_size == 0); + group_blocks = 0; + } + } else { + if (group_size == -1) { + group_blocks = -1; + } else { + group_blocks = group_size / 16; + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } + } + + int num_bits = q_type.size_bits(); + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + int4* C_ptr = (int4*)C; + int4* C_tmp_ptr = (int4*)C_tmp; + const int4* s_ptr = (const int4*)s; + const int4* zp_ptr = (const int4*)zp; + const int* g_idx_ptr = (const int*)g_idx; + const int* perm_ptr = (const int*)perm; + int4* a_tmp_ptr = (int4*)a_tmp; + const int32_t* sorted_token_ids_ptr = (const int32_t*)sorted_token_ids; + const int32_t* expert_ids_ptr = (const int32_t*)expert_ids; + const int32_t* num_tokens_past_padded_ptr = + (const int32_t*)num_tokens_past_padded; + const float* topk_weights_ptr = (const float*)topk_weights; + int* locks = (int*)workspace; + + if (has_act_order) { + // Permute A columns + auto kernel = permute_cols_kernel<8>; + if (moe_block_size == 8) { + } else if (moe_block_size == 16) + kernel = permute_cols_kernel<16>; + else if (moe_block_size == 32) + kernel = permute_cols_kernel<32>; + else if (moe_block_size == 48) + kernel = permute_cols_kernel<48>; + else if (moe_block_size == 64) + kernel = permute_cols_kernel<64>; + else + TORCH_CHECK(false, "unsupported moe_block_size ", moe_block_size); + + // avoid ">>>" being formatted to "> > >" + // clang-format off + kernel<<>>( + A_ptr, perm_ptr, a_tmp_ptr, sorted_token_ids_ptr, expert_ids_ptr, + num_tokens_past_padded_ptr, prob_m, prob_k, top_k); + // clang-format on + A_ptr = a_tmp_ptr; + prob_m = prob_m * top_k; + top_k = 1; + + // If we have a full K, then we can run the non-act-order version of Marlin + // (since the weight rows are reordered by increasing group ids, and by + // having a full K, we have full original groups) + if (is_k_full) has_act_order = false; + } + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + // Set thread config + exec_config_t exec_cfg; + thread_config_t thread_tfg; + if (thread_k != -1 && thread_n != -1) { + thread_tfg = thread_config_t{thread_k, thread_n, default_threads}; + exec_cfg = exec_config_t{1, thread_tfg}; + TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, + " is not divisible by thread_n = ", thread_n); + TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, + " is not divisible by thread_k = ", thread_k); + } else { + // Auto config + exec_cfg = determine_exec_config( + q_type, prob_m, prob_n, prob_k, thread_m_blocks, m_block_size_8, + num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float, + max_shared_mem); + thread_tfg = exec_cfg.tb_cfg; + } + + int num_threads = thread_tfg.num_threads; + thread_k = thread_tfg.thread_k; + thread_n = thread_tfg.thread_n; + int blocks = sms * exec_cfg.blocks_per_sm; + if (exec_cfg.blocks_per_sm > 1) + max_shared_mem = max_shared_mem / exec_cfg.blocks_per_sm - 1024; + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + TORCH_CHECK(is_valid_config(thread_tfg, thread_m_blocks, prob_m, prob_n, + prob_k, num_bits, group_size, has_act_order, + is_k_full, has_zp, is_zp_float, max_shared_mem), + "Invalid thread config: thread_m_blocks = ", thread_m_blocks, + ", thread_k = ", thread_tfg.thread_k, + ", thread_n = ", thread_tfg.thread_n, + ", num_threads = ", thread_tfg.num_threads, " for MKN = [", + prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, + ", group_size = ", group_size, + ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, + ", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float, + ", max_shared_mem = ", max_shared_mem); + + auto kernel = get_marlin_kernel( + q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8, + has_act_order, has_zp, group_blocks, num_threads, is_zp_float); + + if (kernel == MarlinDefault) { + TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, + ", ", prob_k, "]", ", has_act_order = ", has_act_order, + ", num_groups = ", num_groups, ", group_size = ", group_size, + ", thread_m_blocks = ", thread_m_blocks, + ", thread_n_blocks = ", thread_n_blocks, + ", thread_k_blocks = ", thread_k_blocks, + ", num_bits = ", num_bits); + } + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + max_shared_mem); + // avoid ">>>" being formatted to "> > >" + // clang-format off + kernel<<>>( + A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, + sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr, + topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m, + prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce); + // clang-format on +} + +} // namespace MARLIN_NAMESPACE_NAME + +torch::Tensor moe_wna16_marlin_gemm( + torch::Tensor& a, std::optional const& c_or_none, + torch::Tensor& b_q_weight, torch::Tensor& b_scales, + std::optional const& b_zeros_or_none, + std::optional const& g_idx_or_none, + std::optional const& perm_or_none, torch::Tensor& workspace, + torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids, + torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights, + int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep, + vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, + bool is_zp_float) { + vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id); + int pack_factor = 32 / b_q_type.size_bits(); + + if (moe_block_size != 8) { + TORCH_CHECK(moe_block_size % 16 == 0, + "unsupported moe_block_size=", moe_block_size); + TORCH_CHECK(moe_block_size >= 16 && moe_block_size <= 64, + "unsupported moe_block_size=", moe_block_size); + } + + // Verify A + TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), + ", size_m = ", size_m); + TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1), + ", size_k = ", size_k); + + // Verify B + TORCH_CHECK( + size_k % MARLIN_NAMESPACE_NAME::tile_size == 0, "size_k = ", size_k, + " is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size); + TORCH_CHECK((size_k / MARLIN_NAMESPACE_NAME::tile_size) == b_q_weight.size(1), + "Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1), + ", size_k = ", size_k, + ", tile_size = ", MARLIN_NAMESPACE_NAME::tile_size); + TORCH_CHECK( + b_q_weight.size(2) % MARLIN_NAMESPACE_NAME::tile_size == 0, + "b_q_weight.size(2) = ", b_q_weight.size(2), + " is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size); + int actual_size_n = + (b_q_weight.size(2) / MARLIN_NAMESPACE_NAME::tile_size) * pack_factor; + TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, + ", actual_size_n = ", actual_size_n); + + // Verify device and strides + TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); + TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); + + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + + TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); + TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_k = -1; + // thread_n: `n` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_n = -1; + // sms: number of SMs to use for the kernel + int sms = -1; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device()); + + // Alloc buffers + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + torch::Tensor c; + if (c_or_none.has_value()) { + c = c_or_none.value(); + TORCH_CHECK(c.device().is_cuda(), "c is not on GPU"); + TORCH_CHECK(c.is_contiguous(), "c is not contiguous"); + TORCH_CHECK(c.size(0) == size_m * top_k, + "Shape mismatch: c.size(0) = ", c.size(0), + ", size_m * topk = ", size_m * top_k); + TORCH_CHECK(c.size(1) == size_n, "Shape mismatch: c.size(1) = ", c.size(1), + ", size_n = ", size_n); + } else { + c = torch::empty({size_m * top_k, size_n}, options); + } + + // Alloc C tmp buffer that is going to be used for the global reduce + torch::Tensor c_tmp; + auto options_fp32 = + torch::TensorOptions().dtype(at::kFloat).device(a.device()); + if (use_fp32_reduce && !use_atomic_add) { + // max num of threadblocks is sms * 4 + long max_c_tmp_size = min( + (long)size_n * sorted_token_ids.size(0), + (long)sms * 4 * moe_block_size * MARLIN_NAMESPACE_NAME::max_thread_n); + if (moe_block_size == 8) max_c_tmp_size *= 2; + c_tmp = torch::empty({max_c_tmp_size}, options_fp32); + } else { + c_tmp = torch::empty({0}, options_fp32); + } + + // Detect groupsize and act_order + int num_groups = -1; + int group_size = -1; + + int rank = b_scales.sizes().size(); + TORCH_CHECK(rank == 3, "b_scales rank = ", rank, " is not 3"); + TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2), + " is not size_n = ", size_n); + num_groups = b_scales.size(1); + + torch::Tensor g_idx, perm, a_tmp; + ; + if (g_idx_or_none.has_value() && perm_or_none.has_value()) { + g_idx = g_idx_or_none.value(); + perm = perm_or_none.value(); + + TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU"); + TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous"); + TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU"); + TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous"); + + // Verify g_idx and perm + TORCH_CHECK((g_idx.size(-1) == 0 && perm.size(-1) == 0) || + (g_idx.size(-1) == size_k && perm.size(-1) == size_k), + "Unexpected g_idx.size(-1) = ", g_idx.size(-1), + " and perm.size(-1) = ", perm.size(-1), + ", where size_k = ", size_k); + } else { + g_idx = torch::empty({0}, options); + perm = torch::empty({0}, options); + a_tmp = torch::empty({0}, options); + } + bool has_act_order = g_idx.size(-1) > 0 && perm.size(-1) > 0; + + if (has_act_order) { + a_tmp = torch::empty({size_m * top_k, size_k}, options); + if (is_k_full) { + TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); + TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by num_groups = ", num_groups); + group_size = size_k / num_groups; + } else { + group_size = 0; + } + + } else { + a_tmp = torch::empty({0}, options); + if (num_groups > 1) { + TORCH_CHECK( + size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by b_scales.size(1) = ", b_scales.size(1)); + group_size = size_k / num_groups; + } else { + group_size = -1; + } + } + + torch::Tensor b_zeros; + if (b_zeros_or_none.has_value()) { + b_zeros = b_zeros_or_none.value(); + TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU"); + TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous"); + } else { + b_zeros = torch::empty({0}, options); + } + bool has_zp = b_zeros.size(-1) > 0; + + if (has_zp) { + TORCH_CHECK( + b_q_type == vllm::kU4, + "b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str()); + } else { + TORCH_CHECK( + b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128, + "b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ", + b_q_type.str()); + } + + if (has_zp && is_zp_float) { + TORCH_CHECK(a.scalar_type() == at::ScalarType::Half, + "Computation type must be float16 (half) when using float zero " + "points."); + } + + // Verify b_zeros + if (has_zp) { + int rank = b_zeros.sizes().size(); + TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3"); + if (is_zp_float) { + TORCH_CHECK(b_zeros.size(2) == size_n, + "b_zeros dim 2 = ", b_zeros.size(2), + " is not size_n = ", size_n); + TORCH_CHECK(num_groups == b_zeros.size(1), + "b_zeros dim 1 = ", b_zeros.size(1), + " is not num_groups = ", num_groups); + TORCH_CHECK(num_groups != -1, "num_groups must be != -1"); + } else { + TORCH_CHECK(b_zeros.size(1) == num_groups, + "b_zeros dim 1 = ", b_zeros.size(1), + " is not num_groups = ", num_groups); + TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor, + "b_zeros dim 2 = ", b_zeros.size(2), + " is not size_n / pack_factor = ", size_n / pack_factor); + } + } + + // Verify workspace size + TORCH_CHECK(size_n % MARLIN_NAMESPACE_NAME::min_thread_n == 0, + "size_n = ", size_n, ", is not divisible by min_thread_n = ", + MARLIN_NAMESPACE_NAME::min_thread_n); + + int max_n_tiles = size_n / MARLIN_NAMESPACE_NAME::min_thread_n; + int min_workspace_size = min( + max_n_tiles * (int)(sorted_token_ids.size(0) / moe_block_size), sms * 4); + TORCH_CHECK(workspace.numel() >= min_workspace_size, + "workspace.numel = ", workspace.numel(), + " is below min_workspace_size = ", min_workspace_size); + + int dev = a.get_device(); + if (a.scalar_type() == at::ScalarType::Half) { + MARLIN_NAMESPACE_NAME::marlin_mm( + a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), + c_tmp.data_ptr(), b_scales.data_ptr(), + b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), + a_tmp.data_ptr(), sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(), + topk_weights.data_ptr(), moe_block_size, top_k, mul_topk_weights, is_ep, + size_m, size_n, size_k, workspace.data_ptr(), b_q_type, has_act_order, + is_k_full, has_zp, num_groups, group_size, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, + use_atomic_add, use_fp32_reduce, is_zp_float); + } else if (a.scalar_type() == at::ScalarType::BFloat16) { + MARLIN_NAMESPACE_NAME::marlin_mm( + a.data_ptr(), b_q_weight.data_ptr(), + c.data_ptr(), c_tmp.data_ptr(), + b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), + perm.data_ptr(), a_tmp.data_ptr(), + sorted_token_ids.data_ptr(), expert_ids.data_ptr(), + num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(), + moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k, + workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp, + num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), + thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float); + } else { + TORCH_CHECK(false, + "moe_wna16_marlin_gemm only supports bfloat16 and float16"); + } + + return c; +} + +#endif + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm); +} diff --git a/csrc/moe/moe_wna16.cu b/csrc/moe/moe_wna16.cu index 51ae76c1ec882aaac25b2fafabad5abb0a007460..7b6a111c00adcefd2fbd35be48c1023b4f15c193 100644 --- a/csrc/moe/moe_wna16.cu +++ b/csrc/moe/moe_wna16.cu @@ -13,7 +13,6 @@ template __global__ void moe_wna16_gemm_kernel( const scalar_t* __restrict__ input, scalar_t* __restrict__ output, - const uint32_t* __restrict__ qweight, const scalar_t* __restrict__ scales, const uint32_t* __restrict__ qzeros, @@ -54,8 +53,6 @@ __global__ void moe_wna16_gemm_kernel( if (token_index / top_k >= size_m) break; num_valid_tokens = m + 1; - if (blockIdx.z == 0 && offset_n < size_n) - output[token_index * size_n + offset_n] = Dtype::int2num(0); if (expert_id != -1) { int k_per_thread = DIVIDE(BLOCK_SIZE_K, BLOCK_SIZE_N); @@ -284,8 +281,7 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, int64_t BLOCK_SIZE_K, int64_t bit) { const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - auto options = - torch::TensorOptions().dtype(input.dtype()).device(input.device()); + output.zero_(); const int num_experts = b_qweight.size(0); const int size_m = input.size(0); @@ -302,9 +298,9 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, const uint32_t* b_qzeros_ptr; if (b_qzeros.has_value()) b_qzeros_ptr = (const uint32_t*)b_qzeros.value().data_ptr(); - const float* topk_weights_ptr; + const float* topk_weights_ptr = nullptr; if (topk_weights.has_value()) - topk_weights_ptr = (const float*)topk_weights.value().data_ptr(); + topk_weights_ptr = (const float*)topk_weights.value().data_ptr(); int groups_per_block_row = BLOCK_SIZE_K / group_size; TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8"); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 718418e6cd49750eed7f9bdc9cf26a571a8f3f9d..d0de42251f97adf6624c39d1feeea041a1e09e58 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -43,14 +43,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.impl("moe_wna16_gemm", torch::kCUDA, &moe_wna16_gemm); m.def( - "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " - "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " - "b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, " - "int b_q_type, SymInt size_m, " - "SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int " - "topk, " - "int moe_block_size, bool replicate_input, bool apply_weights)" - " -> Tensor"); + "moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none," + "Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none," + "Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace," + "Tensor sorted_token_ids," + "Tensor! expert_ids, Tensor! num_tokens_past_padded," + "Tensor! topk_weights, int moe_block_size, int top_k, " + "bool mul_topk_weights, bool is_ep, int b_q_type_id," + "int size_m, int size_n, int size_k," + "bool is_full_k, bool use_atomic_add," + "bool use_fp32_reduce, bool is_zp_float) -> Tensor"); + // conditionally compiled so impl registration is in source file #endif diff --git a/csrc/ops.h b/csrc/ops.h index 0e84926635a12c0d60fe8aa8ee8ff2af4f06bf6c..9d8a3195de42f08b9e4f28215a286f2d6d4c1a49 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -269,6 +269,12 @@ void advance_step_flashinfer( torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr, torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds); +// void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope, +// torch::Tensor const& q_pe, +// torch::Tensor const& kv_c_and_k_pe_cache, +// torch::Tensor const& seq_lens, +// torch::Tensor const& page_table, double scale); + torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor); #ifndef USE_ROCM diff --git a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu index 2fb0417ce6c4163bbb563a30c8213c221ce1f6a0..894727383a639b0d2c0858c4a884092aec067811 100644 --- a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu +++ b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu @@ -46,14 +46,26 @@ __global__ void compute_expert_offsets( } __global__ void compute_arg_sorts(const int* __restrict__ topk_ids, + const int32_t* __restrict__ expert_offsets, int32_t* input_permutation, int32_t* output_permutation, int32_t* atomic_buffer, const int topk_length, const int topk) { - int expert_id = blockIdx.x; + int const blk_expert_id = blockIdx.x; + int const num_experts = gridDim.x; + int32_t const num_tokens = expert_offsets[num_experts]; for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) { - if (topk_ids[i] == expert_id) { + int const expert_id = topk_ids[i]; + if (expert_id == -1 && blockIdx.x == 0) { + // output_permutation is used to re-order the moe outputs. It is + // used as c2 = c2[c_map], where c2 is a torch.tensor that is the + // output of the cutlass kernels and c_map is the output_permutation. + // c2 is initialized to zeros, therefore by setting the output_permutation + // to num_tokens, we are guaranteed to fill the moe outputs to zero + // for "invalid" topk_ids. + output_permutation[i] = num_tokens; + } else if (expert_id == blk_expert_id) { int start = atomicAdd(&atomic_buffer[expert_id], 1); input_permutation[start] = i / topk; output_permutation[i] = start; @@ -83,6 +95,7 @@ void get_cutlass_moe_mm_data_caller( static_cast(atomic_buffer.data_ptr()), num_experts); compute_arg_sorts<<>>( static_cast(topk_ids.data_ptr()), + static_cast(expert_offsets.data_ptr()), static_cast(input_permutation.data_ptr()), static_cast(output_permutation.data_ptr()), static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh index 4e82c99c3af31d6f1fabe8bd7df85b854fec3a78..6082937e7e1f9a0d2f0a25668f285d7386410174 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh @@ -336,7 +336,7 @@ inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out, uint32_t const m = a.size(0); uint32_t const mp2 = - std::max(static_cast(32), next_pow_2(m)); // next power of 2 + std::max(static_cast(16), next_pow_2(m)); // next power of 2 if (mp2 <= 16) { // M in [1, 16] diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh index 95723b31ca3ce7d28371363280f9408f527f7976..87be125b2eb3ce804ed978d13f88333c52ab0896 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh @@ -321,7 +321,7 @@ inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out, uint32_t const m = a.size(0); uint32_t const mp2 = - std::max(static_cast(32), next_pow_2(m)); // next power of 2 + std::max(static_cast(16), next_pow_2(m)); // next power of 2 if (mp2 <= 16) { // M in [1, 16] diff --git a/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu b/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu index 6e14de0c7805ceacb0615ae73a638814f4fb4fa9..97c0e0da7b1fbdf83ef3b208a4af3a3399f2c5f6 100644 --- a/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu +++ b/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu @@ -134,7 +134,7 @@ typename T::Gemm::Arguments args_from_options( using StrideB = typename T::StrideB; using StrideD = typename T::StrideD; using Sm100BlkScaledConfig = - typename T::Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig; + typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; int m = static_cast(M); int n = static_cast(N); diff --git a/csrc/quantization/gptq_marlin/marlin.cuh b/csrc/quantization/gptq_marlin/marlin.cuh index 74ccbac57bd3c84cc04fee02a62e67ab739caf28..f3b44641e77eea758397d4e3bf270b8a2eb2850a 100644 --- a/csrc/quantization/gptq_marlin/marlin.cuh +++ b/csrc/quantization/gptq_marlin/marlin.cuh @@ -9,7 +9,11 @@ #include #include -namespace marlin { +#ifndef MARLIN_NAMESPACE_NAME + #define MARLIN_NAMESPACE_NAME marlin +#endif + +namespace MARLIN_NAMESPACE_NAME { // Marlin params @@ -23,6 +27,7 @@ static constexpr int pipe_stages = static constexpr int min_thread_n = 64; static constexpr int min_thread_k = 64; +static constexpr int max_thread_n = 256; static constexpr int tile_size = 16; static constexpr int max_par = 16; @@ -84,4 +89,4 @@ __device__ inline void cp_async_wait() { #endif -} // namespace marlin +} // namespace MARLIN_NAMESPACE_NAME diff --git a/csrc/quantization/gptq_marlin/marlin_dtypes.cuh b/csrc/quantization/gptq_marlin/marlin_dtypes.cuh index be06c09bee3314bd87c0d58eeb3e205e2a2eac1e..cc16054814342c29c9d3863f1c413584aa0e67b6 100644 --- a/csrc/quantization/gptq_marlin/marlin_dtypes.cuh +++ b/csrc/quantization/gptq_marlin/marlin_dtypes.cuh @@ -5,7 +5,11 @@ #include #include -namespace marlin { +#ifndef MARLIN_NAMESPACE_NAME + #define MARLIN_NAMESPACE_NAME marlin +#endif + +namespace MARLIN_NAMESPACE_NAME { template class ScalarType {}; @@ -54,7 +58,7 @@ class ScalarType { using FragS = Vec; using FragZP = Vec; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 static __device__ float inline num2float(const nv_bfloat16 x) { return __bfloat162float(x); } @@ -74,6 +78,6 @@ class ScalarType { #endif }; -} // namespace marlin +} // namespace MARLIN_NAMESPACE_NAME #endif diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index afb735450e0cb2c25609a8420f1ea128ac84ae9a..b90cfdc617afdbef58879ae1ad5ad2ee008878ed 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -2,6 +2,15 @@ #include +torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b, + const int64_t rows_per_block); + +torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, + const int64_t CuCount); + +void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount); + void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu new file mode 100644 index 0000000000000000000000000000000000000000..72d2820f2aabfdbc8c9eb3c4aa9a014074164e33 --- /dev/null +++ b/csrc/rocm/skinny_gemms.cu @@ -0,0 +1,1600 @@ +#include +#include +#include + +#include +#include +#include + +#include +#include + +#include "cuda_compat.h" +#include "dispatch_utils.h" +#include "quantization/fp8/common.cuh" + +#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx942__)) + #define __HIP__MI300_MI250__ +#endif + +#if defined(__HIPCC__) && defined(__gfx942__) + #define __HIP__MI300__ +#endif + +#if defined(NDEBUG) + #undef NDEBUG + #include + #define UNREACHABLE_CODE assert(false); + #define NDEBUG +#else + #define UNREACHABLE_CODE assert(false); +#endif + +template +struct scalar {}; + +template +struct scalar2 {}; + +template +__device__ __forceinline__ float2 __s22float2(T v); + +template +__device__ __forceinline__ T __float2s(float v); + +template +__device__ __forceinline__ T __float22s2_rn(float2 v); + +// Definitions and cvt functions for fp16 +template <> +struct scalar { + using type = half; +}; + +template <> +struct scalar2 { + using type = __half2; +}; + +template <> +__device__ __forceinline__ half __float2s(float v) { + return __float2half(v); +} + +template <> +__device__ __forceinline__ float2 __s22float2(__half2 v) { + return __half22float2(v); +} + +template <> +__device__ __forceinline__ __half2 __float22s2_rn(float2 v) { + return __float22half2_rn(v); +} + +// Definitions and cvt functions for bf16 +template <> +struct scalar { + using type = __hip_bfloat16; +}; + +template <> +struct scalar2 { + using type = __hip_bfloat162; +}; + +template <> +__device__ __forceinline__ __hip_bfloat16 __float2s(float v) { + return __float2bfloat16(v); +} + +template <> +__device__ __forceinline__ float2 __s22float2(__hip_bfloat162 v) { + return __bfloat1622float2(v); +} + +template <> +__device__ __forceinline__ __hip_bfloat162 __float22s2_rn(float2 v) { + return __float22bfloat162_rn(v); +} + +template +__device__ __forceinline__ T loadnt(T* addr) { + return __builtin_nontemporal_load(addr); +} + +__device__ __forceinline__ float4 load_ntmprl(const float4* addr) { + auto addr_alias = reinterpret_cast(addr); + auto dat0 = loadnt(addr_alias); + auto dat1 = loadnt(addr_alias + 1); + auto dat2 = loadnt(addr_alias + 2); + auto dat3 = loadnt(addr_alias + 3); + return make_float4(dat0, dat1, dat2, dat3); +} + +// TBlock fetches entire rows of A, and entire col of B (K dimension); assume +// N=1 for time being grid is M/A_NUM_ROWS blocks +template +__global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b, + scalar_t* out_c, const int K) { + using scalar2_t = typename scalar2::type; + auto af4 = reinterpret_cast(in_a); + auto bf4 = reinterpret_cast(in_b); + auto c = reinterpret_cast(out_c); + __shared__ float red_smem[NUM_A_ROWS_PER_BLOCK][WARP_SIZE]; + const int row_addr = blockIdx.x * NUM_A_ROWS_PER_BLOCK * K / 8; + const int threadid = threadIdx.x; + const int warp = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + const int num_warps = blockDim.x / WARP_SIZE; + const int qwarpid = threadid / num_warps; + const int qthreadid = threadid % num_warps; + float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK]; + scalar2_t colB_elem4x, colB_elem4y, colB_elem4z, colB_elem4w; + float acc[NUM_A_ROWS_PER_BLOCK]; + scalar2_t acch2; + scalar2_t oval; + + // As we later use warp shuffle operations, we may have more threads in the + // block than the actual available data, hence the if guard here. + if (threadid * 8 < K) { +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + // rowA_elem4[i] holds 8 * half numbers seen as a single float4. + rowA_elem4[i] = load_ntmprl(&af4[row_addr + threadid + K / 8 * i]); + } + } + + colB_elem4x = bf4[threadid * 4 + 0]; + colB_elem4y = bf4[threadid * 4 + 1]; + colB_elem4z = bf4[threadid * 4 + 2]; + colB_elem4w = bf4[threadid * 4 + 3]; + + scalar2_t Af2; + [[maybe_unused]] scalar2_t Bf2; + float2 S; + + auto Ah2ptr = reinterpret_cast(&rowA_elem4); + scalar2_t* ah2lptr; + +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + // Multiply-add on 8 scalar_t. + ah2lptr = Ah2ptr + i * 4; + Af2 = *(ah2lptr); + acch2 = __hmul2(Af2, colB_elem4x); + Af2 = *(ah2lptr + 1); + acch2 = __hfma2(Af2, colB_elem4y, acch2); + Af2 = *(ah2lptr + 2); + acch2 = __hfma2(Af2, colB_elem4z, acch2); + Af2 = *(ah2lptr + 3); + acch2 = __hfma2(Af2, colB_elem4w, acch2); + S = __s22float2(acch2); + + // See comment above concerning the if guard. + acc[i] = (threadid * 8 < K ? S.x + S.y : 0.f); + } + +// all reduce across warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + acc[i] += __shfl_xor(acc[i], mask); + } + } + + // Warp leaders store the data to shared memory. + if (lane < NUM_A_ROWS_PER_BLOCK) { + red_smem[lane][warp] = acc[lane]; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + if (qwarpid < NUM_A_ROWS_PER_BLOCK) { + acc[qwarpid] = qthreadid < num_warps ? red_smem[qwarpid][qthreadid] : 0.f; + for (int mask = num_warps / 2; mask >= 1; mask /= 2) { + acc[qwarpid] += __shfl_xor(acc[qwarpid], mask); + } + float oval2 = __shfl_xor(acc[qwarpid], num_warps); + + if (lane % (num_warps * 2) == 0) { + oval = __float22s2_rn(make_float2(acc[qwarpid], oval2)); + c[blockIdx.x * NUM_A_ROWS_PER_BLOCK / 2 + qwarpid / 2] = oval; + } + } +} + +torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b, + const int64_t rows_per_block) { + auto M = in_a.size(0); + auto K = in_a.size(1); + auto N = in_b.size(0); + + TORCH_CHECK(N == 1, "Row number of activation tensor must be 1."); + TORCH_CHECK(in_a.dtype() == in_b.dtype()); + TORCH_CHECK(in_b.dtype() == torch::kFloat16 || + in_b.dtype() == torch::kBFloat16); + + auto out_c = torch::empty( + {N, M}, torch::TensorOptions().dtype(in_b.dtype()).device(in_b.device())); + + // NUM_TREADS need to be a multiple of WARP_SIZE, as we are using warp shuffle + // operations. + const int NUM_THREADS = + K * 2 / 16 % WARP_SIZE == 0 + ? K * 2 / 16 + : K * 2 / 16 + (WARP_SIZE - K * 2 / 16 % WARP_SIZE); + + int NUM_BLOCKS = M / rows_per_block; + + const at::cuda::OptionalCUDAGuard device_guard(device_of(in_b)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // call the kernel function... + AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "LLGemm1", [&] { + auto a_ptr = in_a.data_ptr(); + auto b_ptr = in_b.data_ptr(); + auto c_ptr = out_c.data_ptr(); + if (rows_per_block == 2) { + LLGemm1_kernel + <<>>(a_ptr, b_ptr, c_ptr, K); + } else if (rows_per_block == 4) { + LLGemm1_kernel + <<>>(a_ptr, b_ptr, c_ptr, K); + } else if (rows_per_block == 8) { + LLGemm1_kernel + <<>>(a_ptr, b_ptr, c_ptr, K); + } else if (rows_per_block == 16) { + LLGemm1_kernel + <<>>(a_ptr, b_ptr, c_ptr, K); + } else { + NUM_BLOCKS = M / 4; + LLGemm1_kernel + <<>>(a_ptr, b_ptr, c_ptr, K); + } + }); + + return out_c; +} + +#define DOT2C(V0, V2, V3) \ + if constexpr (std::is_same_v) { \ + asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(V0) : "0"(V0), "v"(V2), "v"(V3)); \ + } else if constexpr (std::is_same_v) { \ + float2 s = __bfloat1622float2(*((__hip_bfloat162*)(&(V2)))) * \ + __bfloat1622float2(*((__hip_bfloat162*)(&(V3)))); \ + V0 += (s.x + s.y); \ + } + +#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +// This version targets cases where A[] fits LDS capacity +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B, + const scalar_t* __restrict__ A, scalar_t* C, + const int _WvPrGrp, const int CuCount) { + using scalar8 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + union bigType { + scalar_t h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + scalar8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ scalar_t s[1024 * 32]; + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k = 0; k < min(K * N, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + if (k_in >= min(K * N, 32 * 1024)) break; + + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + } + __syncthreads(); + + if (threadIdx.y >= _WvPrGrp) return; + + uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE; + + float sum[N][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of available columns + //---------------------------------------------------- + while (m < M) { + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // split across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- + for (int i = 0; i < YTILE; i++) + for (int n = 0; n < N; n++) sum[n][i] = 0; + + bigType bigA[N][UNRL]; + bigType bigB[YTILE][UNRL]; + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + // for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + // Fetch the weight matrix from memory! + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + const scalar_t* B_ = &B[(m + 0) * K + k_]; + bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K]))); + //---------------------------------------------------- + // The following code with YTILE > 1 has to be deleted + //---------------------------------------------------- + if constexpr (YTILE >= 2) + bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K]))); + if constexpr (YTILE >= 3) + bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K]))); + if constexpr (YTILE >= 4) + bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K]))); + if constexpr (YTILE >= 5) + bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K]))); + if constexpr (YTILE >= 6) + bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K]))); + if constexpr (YTILE >= 7) + bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K]))); + if constexpr (YTILE >= 8) + bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K]))); + } + + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int n = 0; n < N; n++) { + bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n]))); + } + } + + // Do the matrix multiplication in interleaved manner + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! + #pragma unroll + for (uint32_t n = 0; n < N; n++) { + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b]) + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- + if constexpr (YTILE >= 2) { + DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]); + } + if constexpr (YTILE >= 3) { + DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]); + } + if constexpr (YTILE >= 4) { + DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]); + } + if constexpr (YTILE >= 5) { + DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]); + } + if constexpr (YTILE >= 6) { + DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]); + } + if constexpr (YTILE >= 7) { + DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]); + } + if constexpr (YTILE >= 8) { + DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]); + } + } + } + } + } + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + } + } + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[m + i + n * M] = __float2half(sum[n][i]); + C[m + i + n * M] = __float2s(sum[n][i]); + } + } + } + + m += CuCount * _WvPrGrp * YTILE; + } +} +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +template +__global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B, + const scalar_t* __restrict__ A, scalar_t* C, + const int _WvPrGrp, const int CuCount) { + UNREACHABLE_CODE +} +#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support + +#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +// This version targets cases where A[] marginally exceeds LDS capacity +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSplitK_hf_(const int K, const int M, const scalar_t* B, + const scalar_t* __restrict__ A, scalar_t* C, + const int _WvPrGrp, const int CuCount) { + using scalar8 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + union bigType { + scalar_t h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + scalar8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ scalar_t s[1024 * 32]; + + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) { + commitColumn[i] = 1; + } + + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + // int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); + uint32_t m = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (m < M && (m + YTILE) >= M) { + uint32_t startColumn = M - YTILE; + for (uint32_t i = 0; i < (m - startColumn); i++) { + commitColumn[i] = 0; + } + m = startColumn; + } + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k = 0; k < min(K * N, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + if (k_in >= min(K * N, 32 * 1024)) break; + + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + } + + __syncthreads(); + + if (threadIdx.y >= _WvPrGrp) return; + + float sum[N][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of available columns + //---------------------------------------------------- + while (m < M) { + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // split across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- + for (int i = 0; i < YTILE; i++) + for (int n = 0; n < N; n++) sum[n][i] = 0; + + bigType bigA[N][UNRL]; + bigType bigB[YTILE][UNRL]; + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + // Fetch the weight matrix from memory! + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + const scalar_t* B_ = &B[(m + 0) * K + k_]; + bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K]))); + //---------------------------------------------------- + // The following code with YTILE > 1 has to be deleted + //---------------------------------------------------- + if constexpr (YTILE >= 2) + bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K]))); + if constexpr (YTILE >= 3) + bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K]))); + if constexpr (YTILE >= 4) + bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K]))); + if constexpr (YTILE >= 5) + bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K]))); + if constexpr (YTILE >= 6) + bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K]))); + if constexpr (YTILE >= 7) + bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K]))); + if constexpr (YTILE >= 8) + bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K]))); + } + + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int n = 0; n < N; n++) { + if (k_ + K * n < 32 * 1024) + bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n]))); + else + bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n]))); + } + } + + // Do the matrix multiplication in interleaved manner + #pragma unroll + for (uint32_t n = 0; n < N; n++) { + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b]); + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- + if constexpr (YTILE >= 2) { + DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]); + } + if constexpr (YTILE >= 3) { + DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]); + } + if constexpr (YTILE >= 4) { + DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]); + } + if constexpr (YTILE >= 5) { + DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]); + } + if constexpr (YTILE >= 6) { + DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]); + } + if constexpr (YTILE >= 7) { + DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]); + } + if constexpr (YTILE >= 8) { + DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]); + } + } + } + } + } + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + } + } + + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + if (commitColumn[i]) + C[m + i + n * M] = __float2s(sum[n][i]); + } + } + } + + m += CuCount * _WvPrGrp * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (m < M && (m + YTILE) >= M) { + uint32_t startColumn = M - YTILE; + for (uint32_t i = 0; i < (m - startColumn); i++) { + commitColumn[i] = 0; + } + m = startColumn; + } + } +} + +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +template +__global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B, + const scalar_t* __restrict__ A, scalar_t* C, + const int _WvPrGrp, const int CuCount) { + UNREACHABLE_CODE +} +#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support + +#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +// This version targets big A[] cases, where it is much larger than LDS capacity +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSplitK_hf_big_(const int K, const int M, const scalar_t* B, + const scalar_t* __restrict__ A, scalar_t* C, + const int _WvPrGrp, const int CuCount) { + using scalar8 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + + union bigType { + scalar_t h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + scalar8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ scalar_t s[1024 * 32]; + + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) { + commitColumn[i] = 1; + } + + // int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); + if (threadIdx.y >= _WvPrGrp) return; + + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + uint32_t m = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (m < M && (m + YTILE) >= M) { + uint32_t startColumn = M - YTILE; + for (uint32_t i = 0; i < (m - startColumn); i++) { + commitColumn[i] = 0; + } + m = startColumn; + } + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + #define PCML + #ifndef PCML + for (uint32_t k = 0; k < min(K * N, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + if (k_in >= min(K * N, 32 * 1024)) break; + + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + } + __syncthreads(); + #endif + + #define TUC (THRDS * UNRL * A_CHUNK) + uint32_t kBase = 0; + // find biggest k size that fits in LDS + uint32_t kFit = (32 * 1024) / N; + // kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple + // of TUC + kFit = (kFit % TUC == 0) + ? kFit + : (kFit - kFit % TUC); // round up to multiple of TUC + // if (kFit == 0) kFit = TUC; + kFit = min(kFit, K); + + float sum[N][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of available columns + //---------------------------------------------------- + #ifdef PCML + int YW = (YTILE * _WvPrGrp); + uint32_t Mrndp = (M % YW == 0) ? M : (M - M % YW + YW); + while (m < Mrndp) { + #else + while (m < M) { + #endif + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // split across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- + for (int i = 0; i < YTILE; i++) + for (int n = 0; n < N; n++) sum[n][i] = 0; + + bigType bigA[N][UNRL]; + bigType bigB[YTILE][UNRL]; + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + #ifdef PCML + if ((k1 == 0) || (k1 == kBase + kFit)) { // load next chunk of A[] to LDS + if (k1 != 0) kBase += kFit; + __syncthreads(); + for (uint32_t k = 0; k < kFit; k += THRDS * _WvPrGrp * A_CHUNK) { + uint32_t kOff = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + if (kBase + kOff >= K) break; + if (kOff >= kFit) break; + for (uint32_t n = 0; n < N; n++) { + uint32_t k_in = kBase + n * K + kOff; + uint32_t k_ot = n * kFit + kOff; + *((bigType*)(&s[k_ot])) = *((bigType*)(&A[k_in])); + } + } + __syncthreads(); + } + if (m >= M) continue; + #endif + + // Fetch the weight matrix from memory! + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + const scalar_t* B_ = &B[(m + 0) * K + k_]; + bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K]))); + //---------------------------------------------------- + // The following code with YTILE > 1 has to be deleted + //---------------------------------------------------- + if constexpr (YTILE >= 2) + bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K]))); + if constexpr (YTILE >= 3) + bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K]))); + if constexpr (YTILE >= 4) + bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K]))); + if constexpr (YTILE >= 5) + bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K]))); + if constexpr (YTILE >= 6) + bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K]))); + if constexpr (YTILE >= 7) + bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K]))); + if constexpr (YTILE >= 8) + bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K]))); + } + + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int n = 0; n < N; n++) { + #ifdef PCML + bigA[n][k2] = *((const bigType*)(&(s[k_ - kBase + kFit * n]))); + #else + if (k_ + K * n < 32 * 1024) + bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n]))); + else + bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n]))); + #endif + } + } + + // Do the matrix multiplication in interleaved manner + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + #pragma unroll + for (uint32_t n = 0; n < N; n++) { + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b]); + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- + if constexpr (YTILE >= 2) { + DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]); + } + if constexpr (YTILE >= 3) { + DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]); + } + if constexpr (YTILE >= 4) { + DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]); + } + if constexpr (YTILE >= 5) { + DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]); + } + if constexpr (YTILE >= 6) { + DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]); + } + if constexpr (YTILE >= 7) { + DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]); + } + if constexpr (YTILE >= 8) { + DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]); + } + } + } + } + } + + #ifdef PCML + if (m >= M) { + m += CuCount * _WvPrGrp * YTILE; + kBase = 0; + continue; + } + #endif + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + } + } + + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + if (commitColumn[i]) + C[m + i + n * M] = __float2s(sum[n][i]); + } + } + } + + m += CuCount * _WvPrGrp * YTILE; + kBase = 0; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (m < M && (m + YTILE) >= M) { + uint32_t startColumn = M - YTILE; + for (uint32_t i = 0; i < (m - startColumn); i++) { + commitColumn[i] = 0; + } + m = startColumn; + } + } +} +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +template +__global__ void wvSplitK_hf_big_(const int K, const int M, const scalar_t* B, + const scalar_t* __restrict__ A, scalar_t* C, + const int _WvPrGrp, const int CuCount) { + UNREACHABLE_CODE +} +#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support + +int mindiv(int N, int div1, int div2) { + int nPrRnd = div1 * div2; + int rnds0 = N / nPrRnd; + nPrRnd -= div1 * 3; + int rnds3 = N / nPrRnd; + nPrRnd -= div1; + int rnds4 = N / nPrRnd; + nPrRnd -= div1; + int rnds5 = N / nPrRnd; + nPrRnd -= div1; + int rnds6 = N / nPrRnd; + nPrRnd -= div1; + int rnds7 = N / nPrRnd; + nPrRnd -= div1; + int rnds8 = N / nPrRnd; + nPrRnd -= div1; + int rnds9 = N / nPrRnd; + nPrRnd -= div1; + int rtn = div2; + if (rnds0 == rnds3) rtn = div2 - 3; + if (rnds0 == rnds4) rtn = div2 - 4; + if (rnds0 == rnds5) rtn = div2 - 5; + if (rnds0 == rnds6) rtn = div2 - 6; + if (rnds0 == rnds7) rtn = div2 - 7; + if (rnds0 == rnds8) rtn = div2 - 8; + if (rnds0 == rnds9) rtn = div2 - 9; + return rtn; +} + +torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, + const int64_t CuCount) { + auto M_in = in_a.size(0); + auto K_in = in_a.size(1); + auto N_in = in_b.size(0); + + TORCH_CHECK(in_a.dtype() == in_b.dtype()); + TORCH_CHECK(K_in % 8 == 0, "k % 8 == 0"); + TORCH_CHECK(in_a.dtype() == torch::kFloat16 || + in_a.dtype() == torch::kBFloat16); + + auto out_c = torch::empty( + {N_in, M_in}, + torch::TensorOptions().dtype(in_b.dtype()).device(in_b.device())); + + dim3 grid(CuCount); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + +#define WVSPLITK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ + _N) \ + { \ + dim3 block(64, _WvPrGrp); \ + if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ + wvSplitK_hf_sml_ \ + <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ + CuCount); \ + } else if (K_in * N_in <= 32 * 1024 * 1.2) { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ + wvSplitK_hf_ \ + <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ + CuCount); \ + } else { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \ + wvSplitK_hf_big_ \ + <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ + CuCount); \ + } \ + } + + AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitK", [&] { + using fptype = typename scalar::type; + fptype* af4 = reinterpret_cast(in_a.data_ptr()); + const fptype* bf4 = reinterpret_cast(in_b.data_ptr()); + fptype* c = reinterpret_cast(out_c.data_ptr()); + switch (N_in) { + case 1: + WVSPLITK(16, 2, 2, 2, 2, 2, 2, 1) + break; + case 2: + WVSPLITK(16, 2, 2, 2, 2, 2, 2, 2) + break; + case 3: + WVSPLITK(16, 4, 7, 7, 1, 1, 1, 3) + break; + case 4: + WVSPLITK(16, 4, 7, 7, 1, 1, 1, 4) + break; + default: + throw std::runtime_error( + "Unsupported N value: " + std::to_string(M_in) + "," + + std::to_string(K_in) + "," + std::to_string(N_in)); + } + }); + return out_c; +} + +#if defined(__HIP__MI300__) // TODO: Add NAVI support +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, const fp8_t* B, + const fp8_t* __restrict__ A, scalar_t* C, + const float* __restrict__ s_A, + const float* __restrict__ s_B, const int _WvPrGrp, + const int CuCount) { + using scalar8 = + __attribute__((__vector_size__((A_CHUNK / 4) * sizeof(float)))) float; + using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int; + using intx4 = __attribute__((__vector_size__(4 * sizeof(int)))) int; + union bigType { + char f8[A_CHUNK]; + char2 c2[A_CHUNK / 2]; + scalar_t h[A_CHUNK / 2]; + float f[A_CHUNK / 4]; + int i[A_CHUNK / 4]; + long l[A_CHUNK / 8]; + intx4 l2[A_CHUNK / 16]; + scalar8 h8; + }; + + __shared__ fp8_t s[1024 * 64]; + + for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK; + k < min(K * N, 64 * 1024); k += THRDS * WvPrGrp * A_CHUNK) { + *((bigType*)(&s[k])) = *((bigType*)(&A[k])); + } + __syncthreads(); + + if (threadIdx.y >= _WvPrGrp) return; + + uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE; + + using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float; + floatx16 sum[N][YTILE]; + float sA = *s_A; + float sB = *s_B; + + while (m < M) { + for (int i = 0; i < YTILE; i++) + for (int n = 0; n < N; n++) sum[n][i] = {0.f}; + + bigType bigA[N][UNRL]; + bigType bigB[YTILE][UNRL]; + + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + #pragma unroll + for (uint32_t n = 0; n < N; ++n) bigA[n][k2].h8 = {0.f}; + #pragma unroll + for (uint32_t y = 0; y < YTILE; ++y) bigB[y][k2].h8 = {0.f}; + } + + // Fetch the weight matrix from memory! + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + const fp8_t* B_ = &B[(m + 0) * Kp + k_]; + #pragma unroll + for (uint32_t y = 0; y < YTILE; ++y) { + bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * Kp]))); + } + } + + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + for (int n = 0; n < N; n++) { + bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n]))); + } + } + + // Do the matrix multiplication in interleaved manner + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + if (k >= K) break; + + for (uint32_t n = 0; n < N; n++) { + for (int i = 0; i < A_CHUNK; i += 8) { + for (int y = 0; y < YTILE; ++y) { + sum[n][y] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bigA[n][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[n][y], 0, 0, + 0); + } + } + } + } + } + + // Final reduction + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + float accm0 = sum[n][y][0]; + float accm16 = sum[n][y][8]; + asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][1]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][9]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][2]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][10]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][3]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][11]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][4]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][12]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][5]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][13]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][6]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][14]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][7]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][15]), "v"(accm16)); + accm0 += __shfl(accm0, 36); + accm16 += __shfl(accm16, 52); + sum[n][y][0] = accm0 + __shfl(accm16, 16); + } + } + + if (threadIdx.x == 0) { + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + C[m + y + n * M] = __float2s(sum[n][y][0] * sA * sB); + } + } + } + + m += CuCount * _WvPrGrp * YTILE; + } +} +#else // !defined(__HIP__MI300__) TODO: Add NAVI support +template +__global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, + const fp8_t* B, const fp8_t* __restrict__ A, + scalar_t* C, const float* __restrict__ s_A, + const float* __restrict__ s_B, + const int _WvPrGrp, const int CuCount) { + UNREACHABLE_CODE +} +#endif // defined(__HIP__MI300__) TODO: Add NAVI support + +#if defined(__HIP__MI300__) // TODO: Add NAVI support +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSplitKQ_hf_(const int K, const int Kp, const int M, const fp8_t* B, + const fp8_t* __restrict__ A, scalar_t* C, + const float* __restrict__ s_A, const float* __restrict__ s_B, + const int _WvPrGrp, const int CuCount) { + using scalar8 = + __attribute__((__vector_size__((A_CHUNK / 4) * sizeof(float)))) float; + using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int; + using intx4 = __attribute__((__vector_size__(4 * sizeof(int)))) int; + union bigType { + char f8[A_CHUNK]; + char2 c2[A_CHUNK / 2]; + scalar_t h[A_CHUNK / 2]; + float f[A_CHUNK / 4]; + int i[A_CHUNK / 4]; + long l[A_CHUNK / 8]; + intx4 l2[A_CHUNK / 16]; + scalar8 h8; + }; + + __shared__ fp8_t s[1024 * 64]; + + for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK; + k < min(K * N, 64 * 1024); k += THRDS * WvPrGrp * A_CHUNK) { + *((bigType*)(&s[k])) = *((bigType*)(&A[k])); + } + __syncthreads(); + + if (threadIdx.y >= _WvPrGrp) return; + + uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE; + + using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float; + floatx16 sum[N][YTILE]; + float sA = *s_A; + float sB = *s_B; + + while (m < M) { + for (int i = 0; i < YTILE; i++) + for (int n = 0; n < N; n++) sum[n][i] = {0}; + + bigType bigA[N][UNRL]; + bigType bigB[YTILE][UNRL]; + + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + // Fetch the weight matrix from memory! + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + const fp8_t* B_ = &B[(m + 0) * Kp + k_]; + for (int y = 0; y < YTILE; ++y) { + if (y + m >= M) break; // To avoid mem access fault. + bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * Kp]))); + } + } + + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + for (int n = 0; n < N; n++) { + if (k_ + K * n < 64 * 1024) + bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n]))); + else + bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n]))); + } + } + + // Do the matrix multiplication in interleaved manner + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + for (uint32_t n = 0; n < N; n++) { + for (int i = 0; i < A_CHUNK; i += 8) { + for (int y = 0; y < YTILE; ++y) { + sum[n][y] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bigA[n][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[n][y], 0, 0, + 0); + } + } + } + } + } + + // Final reduction + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + float accm0 = sum[n][y][0]; + float accm16 = sum[n][y][8]; + asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][1]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][9]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][2]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][10]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][3]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][11]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][4]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][12]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][5]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][13]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][6]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][14]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][7]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][15]), "v"(accm16)); + accm0 += __shfl(accm0, 36); + accm16 += __shfl(accm16, 52); + sum[n][y][0] = accm0 + __shfl(accm16, 16); + } + } + + if (threadIdx.x == 0) { + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + if (y + m >= M) break; // To avoid mem access fault. + C[m + y + n * M] = __float2s(sum[n][y][0] * sA * sB); + } + } + } + + m += CuCount * _WvPrGrp * YTILE; + } +} +#else // !defined(__HIP__MI300__) TODO: Add NAVI support +template +__global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M, + const fp8_t* B, const fp8_t* __restrict__ A, + scalar_t* C, const float* __restrict__ s_A, + const float* __restrict__ s_B, const int _WvPrGrp, + const int CuCount) { + UNREACHABLE_CODE +} +#endif // defined(__HIP__MI300__) TODO: Add NAVI support + +void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + at::Tensor& scale_a, at::Tensor& scale_b, + const int64_t CuCount) { + static c10::ScalarType kFp8Type = is_fp8_ocp() + ? c10::ScalarType::Float8_e4m3fn + : c10::ScalarType::Float8_e4m3fnuz; + auto M_in = in_a.size(0); + auto K_in = in_a.size(1); + auto N_in = in_b.size(0); + auto Kp_in = in_a.stride(0); + TORCH_CHECK(K_in % 16 == 0, "k % 16 == 0"); + TORCH_CHECK(in_a.dtype() == in_b.dtype() && in_a.dtype() == kFp8Type); + TORCH_CHECK(out_c.dtype() == torch::kFloat16 || + out_c.dtype() == torch::kBFloat16); + + dim3 grid(CuCount); + const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + +#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ + _N) \ + { \ + dim3 block(64, _WvPrGrp); \ + if ((K_in * N_in <= 64 * 1024) && (M_in % _YTILEs == 0)) { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ + wvSplitKQ_hf_sml_ \ + <<>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \ + s_a, s_b, __wvPrGrp, CuCount); \ + } else { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ + wvSplitKQ_hf_ \ + <<>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \ + s_a, s_b, __wvPrGrp, CuCount); \ + } \ + } + + AT_DISPATCH_REDUCED_FLOATING_TYPES(out_c.scalar_type(), "wvSplitKQ", [&] { + using fptype = typename scalar::type; + auto c_ptr = reinterpret_cast(out_c.data_ptr()); + auto s_a = scale_a.data_ptr(); + auto s_b = scale_b.data_ptr(); + VLLM_DISPATCH_FP8_TYPES(in_a.scalar_type(), "wvSplitKQ", [&] { + auto a_ptr = in_a.data_ptr(); + auto b_ptr = in_b.data_ptr(); + switch (N_in) { + case 1: + WVSPLITKQ(16, 2, 2, 2, 2, 2, 2, 1) + break; + case 2: + WVSPLITKQ(16, 2, 2, 2, 2, 2, 2, 2) + break; + case 3: + WVSPLITKQ(16, 4, 7, 7, 1, 1, 1, 3) + break; + case 4: + WVSPLITKQ(16, 4, 7, 7, 1, 1, 1, 4) + break; + default: + throw std::runtime_error( + "Unsupported N value: " + std::to_string(M_in) + "," + + std::to_string(K_in) + "," + std::to_string(N_in)); + } + }); + }); +} diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index 537e9357d52be7a5513b94e05815db8dbf6a81be..4ac6fd1e994081a5a88e2b6ac28bb1b45f07fdb6 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -14,6 +14,24 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { // vLLM custom ops for rocm + // Custom gemm op for matrix-vector multiplication + rocm_ops.def( + "LLMM1(Tensor in_a, Tensor in_b, int rows_per_block) -> " + "Tensor"); + rocm_ops.impl("LLMM1", torch::kCUDA, &LLMM1); + + // Custom gemm op for skinny matrix-matrix multiplication + rocm_ops.def( + "wvSplitK(Tensor in_a, Tensor in_b, int CuCount) -> " + "Tensor"); + rocm_ops.impl("wvSplitK", torch::kCUDA, &wvSplitK); + + // wvSplitK for fp8 + rocm_ops.def( + "wvSplitKQ(Tensor in_a, Tensor in_b, Tensor! out_c, Tensor scale_a, " + " Tensor scale_b, int CuCount) -> ()"); + rocm_ops.impl("wvSplitKQ", torch::kCUDA, &wvSplitKQ); + // Custom attention op // Compute the attention between an input query and the cached // keys/values using PagedAttention. diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 940b953590de4fb1236f5bcdaefcfbe6fa2a01bf..7cc25d6813c803b4b8638ddabdefd82ec9a6b38c 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -294,6 +294,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ") -> ()"); ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer); + // Compute MLA decode using cutlass. +// ops.def( +// "cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe," +// " Tensor kv_c_and_k_pe_cache, Tensor seq_lens," +// " Tensor page_table, float scale) -> ()"); +// ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode); + // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. ops.def( diff --git a/docker/Dockerfile b/docker/Dockerfile index d1ecef586d50bc70ccbaa34a1cd2344f07c587a9..1b28845d0ac04ac0b277360c58f82a0e6b1cf347 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -162,6 +162,9 @@ ENV UV_HTTP_TIMEOUT=500 COPY requirements/lint.txt requirements/lint.txt COPY requirements/test.txt requirements/test.txt COPY requirements/dev.txt requirements/dev.txt +# Workaround for #17068 +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system mamba-ssm==2.2.4 --no-build-isolation RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --system -r requirements/dev.txt #################### DEV IMAGE #################### @@ -240,6 +243,8 @@ if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \ uv pip install --system https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.1.post2/flashinfer_python-0.2.1.post2+cu124torch2.6-cp38-abi3-linux_x86_64.whl ; \ fi COPY examples examples +COPY benchmarks benchmarks +COPY ./vllm/collect_env.py . # Although we build Flashinfer with AOT mode, there's still # some issues w.r.t. JIT compilation. Therefore we need to @@ -263,6 +268,9 @@ ADD . /vllm-workspace/ ENV UV_HTTP_TIMEOUT=500 # install development dependencies (for testing) +# Workaround for #17068 +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system mamba-ssm==2.2.4 --no-build-isolation RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --system -r requirements/dev.txt @@ -289,6 +297,7 @@ RUN mv vllm test_docs/ #################### OPENAI API SERVER #################### # base openai image with additional requirements, for any subsequent openai-style images FROM vllm-base AS vllm-openai-base +ARG TARGETPLATFORM # This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out # Reference: https://github.com/astral-sh/uv/pull/1694 diff --git a/docker/Dockerfile.cpu b/docker/Dockerfile.cpu index 54d1ce86d0112d33c87495123804a321b1b295bc..c647d9036f40015a12d57d93e834dab356477e84 100644 --- a/docker/Dockerfile.cpu +++ b/docker/Dockerfile.cpu @@ -121,6 +121,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ ADD ./tests/ ./tests/ ADD ./examples/ ./examples/ ADD ./benchmarks/ ./benchmarks/ +ADD ./vllm/collect_env.py . # install development dependencies (for testing) RUN --mount=type=cache,target=/root/.cache/uv \ diff --git a/docker/Dockerfile.nightly_torch b/docker/Dockerfile.nightly_torch new file mode 100644 index 0000000000000000000000000000000000000000..0063712e47818fa08cca5ae350024140961d9fd5 --- /dev/null +++ b/docker/Dockerfile.nightly_torch @@ -0,0 +1,307 @@ +# The vLLM Dockerfile is used to construct vLLM image against torch nightly that can be directly used for testing + +# for torch nightly, cuda >=12.6 is required, +# use 12.8 due to FlashAttention issue with cuda 12.6 (https://github.com/vllm-project/vllm/issues/15435#issuecomment-2775924628) +ARG CUDA_VERSION=12.8.0 +# +#################### BASE BUILD IMAGE #################### +# prepare basic build environment +FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base +ARG CUDA_VERSION=12.8.0 +ARG PYTHON_VERSION=3.12 +ARG TARGETPLATFORM +ENV DEBIAN_FRONTEND=noninteractive +# Install Python and other dependencies +RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ + && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \ + && apt-get update -y \ + && apt-get install -y ccache software-properties-common git curl sudo \ + && add-apt-repository ppa:deadsnakes/ppa \ + && apt-get update -y \ + && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \ + && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \ + && update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \ + && ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \ + && curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \ + && python3 --version \ + && python3 -m pip --version +# Install uv for faster pip installs +RUN --mount=type=cache,target=/root/.cache/uv \ + python3 -m pip install uv + +# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out +# Reference: https://github.com/astral-sh/uv/pull/1694 +ENV UV_HTTP_TIMEOUT=500 + +# Upgrade to GCC 10 to avoid https://gcc.gnu.org/bugzilla/show_bug.cgi?id=92519 +# as it was causing spam when compiling the CUTLASS kernels +RUN apt-get install -y gcc-10 g++-10 +RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-10 110 --slave /usr/bin/g++ g++ /usr/bin/g++-10 +RUN < torch_build_versions.txt +RUN cat torch_build_versions.txt + +# cuda arch list used by torch +# can be useful for `test` +# explicitly set the list to avoid issues with torch 2.2 +# see https://github.com/pytorch/pytorch/pull/123243 + +# Override the arch list for flash-attn to reduce the binary size +ARG vllm_fa_cmake_gpu_arches='80-real;90-real' +ENV VLLM_FA_CMAKE_GPU_ARCHES=${vllm_fa_cmake_gpu_arches} +#################### BASE BUILD IMAGE #################### + +#################### WHEEL BUILD IMAGE #################### +FROM base AS build +ARG TARGETPLATFORM + +# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out +# Reference: https://github.com/astral-sh/uv/pull/1694 +ENV UV_HTTP_TIMEOUT=500 + +COPY . . + +RUN python3 use_existing_torch.py + +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system -r requirements/build.txt + +ARG GIT_REPO_CHECK=0 +RUN --mount=type=bind,source=.git,target=.git \ + if [ "$GIT_REPO_CHECK" != "0" ]; then bash tools/check_repo.sh ; fi + +# Max jobs used by Ninja to build extensions +ARG max_jobs=16 +ENV MAX_JOBS=${max_jobs} +ARG nvcc_threads=2 +ENV NVCC_THREADS=$nvcc_threads + +ARG USE_SCCACHE +ARG SCCACHE_BUCKET_NAME=vllm-build-sccache +ARG SCCACHE_REGION_NAME=us-west-2 +ARG SCCACHE_S3_NO_CREDENTIALS=0 + +# if USE_SCCACHE is set, use sccache to speed up compilation +RUN --mount=type=cache,target=/root/.cache/uv \ + --mount=type=bind,source=.git,target=.git \ + if [ "$USE_SCCACHE" = "1" ]; then \ + echo "Installing sccache..." \ + && curl -L -o sccache.tar.gz https://github.com/mozilla/sccache/releases/download/v0.8.1/sccache-v0.8.1-x86_64-unknown-linux-musl.tar.gz \ + && tar -xzf sccache.tar.gz \ + && sudo mv sccache-v0.8.1-x86_64-unknown-linux-musl/sccache /usr/bin/sccache \ + && rm -rf sccache.tar.gz sccache-v0.8.1-x86_64-unknown-linux-musl \ + && export SCCACHE_BUCKET=${SCCACHE_BUCKET_NAME} \ + && export SCCACHE_REGION=${SCCACHE_REGION_NAME} \ + && export SCCACHE_S3_NO_CREDENTIALS=${SCCACHE_S3_NO_CREDENTIALS} \ + && export SCCACHE_IDLE_TIMEOUT=0 \ + && export CMAKE_BUILD_TYPE=Release \ + && sccache --show-stats \ + && python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38 \ + && sccache --show-stats; \ + fi + +ENV CCACHE_DIR=/root/.cache/ccache +RUN --mount=type=cache,target=/root/.cache/ccache \ + --mount=type=cache,target=/root/.cache/uv \ + --mount=type=bind,source=.git,target=.git \ + if [ "$USE_SCCACHE" != "1" ]; then \ + # Clean any existing CMake artifacts + rm -rf .deps && \ + mkdir -p .deps && \ + python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38; \ + fi + +#################### WHEEL BUILD IMAGE #################### + +################### VLLM INSTALLED IMAGE #################### +# Setup clean environment for vLLM and its dependencies for test and api server using ubuntu22.04 with AOT flashinfer +FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS vllm-base +# prepare for environment starts +ARG CUDA_VERSION=12.8.0 +ARG PYTHON_VERSION=3.12 +WORKDIR /vllm-workspace +ENV DEBIAN_FRONTEND=noninteractive +ARG TARGETPLATFORM + +RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \ + echo "export PYTHON_VERSION_STR=${PYTHON_VERSION_STR}" >> /etc/environment + +# Install Python and other dependencies +RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ + && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \ + && apt-get update -y \ + && apt-get install -y ccache software-properties-common git curl wget sudo vim python3-pip \ + && apt-get install -y ffmpeg libsm6 libxext6 libgl1 \ + && add-apt-repository ppa:deadsnakes/ppa \ + && apt-get update -y \ + && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv libibverbs-dev \ + && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \ + && update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \ + && ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \ + && curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \ + && python3 --version && python3 -m pip --version + +RUN --mount=type=cache,target=/root/.cache/uv \ + python3 -m pip install uv + +# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out +# Reference: https://github.com/astral-sh/uv/pull/1694 +ENV UV_HTTP_TIMEOUT=500 + +# Workaround for https://github.com/openai/triton/issues/2507 and +# https://github.com/pytorch/pytorch/issues/107960 -- hopefully +# this won't be needed for future versions of this docker image +# or future versions of triton. +RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/ + +# get the nightly torch version used in the build to make sure the version is the same +COPY --from=base /workspace/torch_build_versions.txt ./torch_build_versions.txt + +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system $(cat torch_build_versions.txt | xargs) --index-url https://download.pytorch.org/whl/nightly/cu128 + +# install the vllm wheel +RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/vllm-dist \ + --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system vllm-dist/*.whl --verbose + +# install xformers again for the new environment +RUN --mount=type=bind,from=base,src=/workspace/xformers-dist,target=/vllm-workspace/xformers-dist \ + --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system /vllm-workspace/xformers-dist/*.whl --verbose + +ARG torch_cuda_arch_list='8.0;8.6;8.9;9.0' + +# install package for build flashinfer +# see issue: https://github.com/flashinfer-ai/flashinfer/issues/738 +RUN pip install setuptools==75.6.0 packaging==23.2 ninja==1.11.1.3 build==1.2.2.post1 + + +# build flashinfer for torch nightly from source around 10 mins +# release version: v0.2.2.post1 +# todo(elainewy): cache flashinfer build result for faster build +ENV CCACHE_DIR=/root/.cache/ccache +RUN --mount=type=cache,target=/root/.cache/ccache \ + --mount=type=cache,target=/root/.cache/uv \ + echo "git clone flashinfer..." \ + && git clone --recursive https://github.com/flashinfer-ai/flashinfer.git \ + && cd flashinfer \ + && git checkout v0.2.2.post1 \ + && git submodule update --init --recursive \ + && echo "finish git clone flashinfer..." \ + && rm -rf build \ + && export TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list} \ + && FLASHINFER_ENABLE_AOT=1 python3 setup.py bdist_wheel --dist-dir=../flashinfer-dist --verbose \ + && cd .. \ + && rm -rf flashinfer + +# install flashinfer +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system flashinfer-dist/*.whl --verbose + +# install common packages +COPY requirements/common.txt requirements/common.txt +COPY use_existing_torch.py use_existing_torch.py +COPY pyproject.toml pyproject.toml + +COPY examples examples +COPY benchmarks benchmarks +COPY ./vllm/collect_env.py . + +RUN python3 use_existing_torch.py +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system -r requirements/common.txt + +################### VLLM INSTALLED IMAGE #################### + + +#################### UNITTEST IMAGE ############################# +FROM vllm-base as test +COPY tests/ tests/ + +# install build and runtime dependencies without stable torch version +COPY requirements/nightly_torch_test.txt requirements/nightly_torch_test.txt + +# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out +# Reference: https://github.com/astral-sh/uv/pull/1694 +ENV UV_HTTP_TIMEOUT=500 + +# install development dependencies (for testing) +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system -e tests/vllm_test_utils + +# enable fast downloads from hf (for testing) +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system hf_transfer +ENV HF_HUB_ENABLE_HF_TRANSFER 1 + +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system -r requirements/nightly_torch_test.txt + +#################### UNITTEST IMAGE ############################# + diff --git a/docker/Dockerfile.ppc64le b/docker/Dockerfile.ppc64le index 4540af4e8cdc873b0e3da7e7cf455ca916132929..ec979227871c63a1c0bb7a345509e87caebddba9 100644 --- a/docker/Dockerfile.ppc64le +++ b/docker/Dockerfile.ppc64le @@ -126,13 +126,16 @@ RUN --mount=type=cache,target=/root/.cache/uv \ FROM base-builder AS cv-builder ARG MAX_JOBS -ARG OPENCV_VERSION=84 +ARG OPENCV_VERSION=86 +# patch for version 4.11.0.86 +ARG OPENCV_PATCH=97f3f39 ARG ENABLE_HEADLESS=1 RUN --mount=type=cache,target=/root/.cache/uv \ source /opt/rh/gcc-toolset-13/enable && \ git clone --recursive https://github.com/opencv/opencv-python.git -b ${OPENCV_VERSION} && \ cd opencv-python && \ - sed -i 's/"setuptools==59.2.0",/"setuptools<70.0",/g' pyproject.toml && \ + sed -i -E -e 's/"setuptools.+",/"setuptools",/g' pyproject.toml && \ + cd opencv && git cherry-pick --no-commit $OPENCV_PATCH && cd .. && \ python -m build --wheel --installer=uv --outdir /opencvwheels/ ############################################################### @@ -148,9 +151,15 @@ COPY --from=arrow-builder /tmp/control /dev/null COPY --from=cv-builder /tmp/control /dev/null ARG VLLM_TARGET_DEVICE=cpu +ARG GRPC_PYTHON_BUILD_SYSTEM_OPENSSL=1 # this step installs vllm and populates uv cache # with all the transitive dependencies +RUN --mount=type=cache,target=/root/.cache/uv \ + source /opt/rh/gcc-toolset-13/enable && \ + git clone https://github.com/huggingface/xet-core.git && cd xet-core/hf_xet/ && \ + uv pip install maturin && \ + uv build --wheel --out-dir /hf_wheels/ RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,from=torch-builder,source=/torchwheels/,target=/torchwheels/,ro \ --mount=type=bind,from=arrow-builder,source=/arrowwheels/,target=/arrowwheels/,ro \ @@ -159,7 +168,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ source /opt/rh/gcc-toolset-13/enable && \ uv pip install /opencvwheels/*.whl /arrowwheels/*.whl /torchwheels/*.whl && \ sed -i -e 's/.*torch.*//g' /src/pyproject.toml /src/requirements/*.txt && \ - uv pip install pandas pythran pybind11 && \ + uv pip install pandas pythran pybind11 /hf_wheels/*.whl && \ # sentencepiece.pc is in some pkgconfig inside uv cache export PKG_CONFIG_PATH=$(find / -type d -name "pkgconfig" 2>/dev/null | tr '\n' ':') && \ uv pip install -r /src/requirements/common.txt -r /src/requirements/cpu.txt -r /src/requirements/build.txt --no-build-isolation && \ @@ -247,8 +256,9 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,from=torch-builder,source=/torchwheels/,target=/torchwheels/,ro \ --mount=type=bind,from=arrow-builder,source=/arrowwheels/,target=/arrowwheels/,ro \ --mount=type=bind,from=cv-builder,source=/opencvwheels/,target=/opencvwheels/,ro \ + --mount=type=bind,from=vllmcache-builder,source=/hf_wheels/,target=/hf_wheels/,ro \ --mount=type=bind,from=vllmcache-builder,source=/vllmwheel/,target=/vllmwheel/,ro \ - HOME=/root uv pip install /opencvwheels/*.whl /arrowwheels/*.whl /torchwheels/*.whl /vllmwheel/*.whl + HOME=/root uv pip install /opencvwheels/*.whl /arrowwheels/*.whl /torchwheels/*.whl /hf_wheels/*.whl /vllmwheel/*.whl COPY ./ /workspace/vllm WORKDIR /workspace/vllm diff --git a/docker/Dockerfile.rocm_base b/docker/Dockerfile.rocm_base index b8523fbc2a01c31f5ffadf7c5fd0833eac661361..1776b26d445ce0b9b404f3860b93782686e352cd 100644 --- a/docker/Dockerfile.rocm_base +++ b/docker/Dockerfile.rocm_base @@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" ARG FA_BRANCH="1a7f4dfa" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" -ARG AITER_BRANCH="8970b25b" +ARG AITER_BRANCH="7e1ed08" ARG AITER_REPO="https://github.com/ROCm/aiter.git" FROM ${BASE_IMAGE} AS base diff --git a/docker/Dockerfile.s390x b/docker/Dockerfile.s390x index 5a84dc12d8f713c6b8fa6d1b955ad0b09bb7dbdd..128929ac333113fc94c67d160d7c96180ea9af4f 100644 --- a/docker/Dockerfile.s390x +++ b/docker/Dockerfile.s390x @@ -58,7 +58,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ cd ../../python && \ export PYARROW_PARALLEL=4 && \ export ARROW_BUILD_TYPE=release && \ - uv pip install -r requirements/build.txt && \ + uv pip install -r requirements-build.txt && \ python setup.py build_ext --build-type=$ARROW_BUILD_TYPE --bundle-arrow-cpp bdist_wheel FROM python-install AS numa-build @@ -96,6 +96,22 @@ RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install -v torch==${TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/nightly/cpu && \ python setup.py bdist_wheel +FROM python-install AS hf-xet-builder +# Install hf-xet +WORKDIR /tmp +ENV CARGO_HOME=/root/.cargo +ENV RUSTUP_HOME=/root/.rustup +ENV PATH="$CARGO_HOME/bin:$RUSTUP_HOME/bin:$PATH" +RUN --mount=type=cache,target=/root/.cache/uv \ + --mount=type=bind,from=rust,source=/root/.cargo,target=/root/.cargo,rw \ + --mount=type=bind,from=rust,source=/root/.rustup,target=/root/.rustup,rw \ + git clone https://github.com/huggingface/xet-core.git && \ + cd xet-core/hf_xet/ && \ + uv pip install maturin patchelf && \ + python -m maturin build --release --out dist && \ + mkdir -p /tmp/hf-xet/dist && \ + cp dist/*.whl /tmp/hf-xet/dist/ + # Final build stage FROM python-install AS vllm-cpu ARG PYTHON_VERSION @@ -120,12 +136,15 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,from=rust,source=/root/.rustup,target=/root/.rustup,rw \ --mount=type=bind,from=pyarrow,source=/tmp/arrow/python/dist,target=/tmp/arrow-wheels \ --mount=type=bind,from=torch-vision,source=/tmp/vision/dist,target=/tmp/vision-wheels/ \ + --mount=type=bind,from=hf-xet-builder,source=/tmp/hf-xet/dist,target=/tmp/hf-xet-wheels/ \ sed -i '/^torch/d' requirements/build.txt && \ ARROW_WHL_FILE=$(ls /tmp/arrow-wheels/pyarrow-*.whl | head -n 1) && \ VISION_WHL_FILE=$(ls /tmp/vision-wheels/*.whl | head -n 1) && \ + HF_XET_WHL_FILE=$(ls /tmp/hf-xet-wheels/*.whl | head -n 1) && \ uv pip install -v \ $ARROW_WHL_FILE \ $VISION_WHL_FILE \ + $HF_XET_WHL_FILE \ --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ --index-strategy unsafe-best-match \ -r requirements/build.txt \ @@ -149,4 +168,5 @@ USER 2000 WORKDIR /home/vllm # Set the default entrypoint -ENTRYPOINT ["python", "-m", "vllm.entrypoints.openai.api_server"] \ No newline at end of file +ENTRYPOINT ["python", "-m", "vllm.entrypoints.openai.api_server"] + diff --git a/docs/source/assets/deployment/anything-llm-chat-with-doc.png b/docs/source/assets/deployment/anything-llm-chat-with-doc.png new file mode 100644 index 0000000000000000000000000000000000000000..f9b57f5c3cecc92da660efaddb4e75d8f72160b3 Binary files /dev/null and b/docs/source/assets/deployment/anything-llm-chat-with-doc.png differ diff --git a/docs/source/assets/deployment/anything-llm-chat-without-doc.png b/docs/source/assets/deployment/anything-llm-chat-without-doc.png new file mode 100644 index 0000000000000000000000000000000000000000..952a43bcd677d23b8d78cdc23375c6a2c8621e8d Binary files /dev/null and b/docs/source/assets/deployment/anything-llm-chat-without-doc.png differ diff --git a/docs/source/assets/deployment/anything-llm-provider.png b/docs/source/assets/deployment/anything-llm-provider.png new file mode 100644 index 0000000000000000000000000000000000000000..bb699f7571f4034f4c26f42c96df213017d0eb9e Binary files /dev/null and b/docs/source/assets/deployment/anything-llm-provider.png differ diff --git a/docs/source/assets/deployment/anything-llm-upload-doc.png b/docs/source/assets/deployment/anything-llm-upload-doc.png new file mode 100644 index 0000000000000000000000000000000000000000..00c70e9c01f672cf4bc83fc4277b7f0c60ff3e55 Binary files /dev/null and b/docs/source/assets/deployment/anything-llm-upload-doc.png differ diff --git a/docs/source/assets/deployment/open_webui.png b/docs/source/assets/deployment/open_webui.png new file mode 100644 index 0000000000000000000000000000000000000000..fe9a7e15ea71d908c76eedc52d92e901bad9dae1 Binary files /dev/null and b/docs/source/assets/deployment/open_webui.png differ diff --git a/docs/source/conf.py b/docs/source/conf.py index a83ad764125c5a12027f31c67014333cbbda4837..c2ad6f9fa3a55d61fd499f3314911caeb1939351 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -177,6 +177,11 @@ def linkcode_resolve(domain, info): for part in info['fullname'].split('.'): obj = getattr(obj, part) + # Skip decorator wrappers by checking if the object is a function + # and has a __wrapped__ attribute (which decorators typically set) + while hasattr(obj, '__wrapped__'): + obj = obj.__wrapped__ + if not (inspect.isclass(obj) or inspect.isfunction(obj) or inspect.ismethod(obj)): obj = obj.__class__ # Get the class of the instance diff --git a/docs/source/contributing/model/multimodal.md b/docs/source/contributing/model/multimodal.md index 03d830fe90f11a882a73610129e7227a8fb0f41f..b42536f054d76a94bdf16ff77270748822c632f0 100644 --- a/docs/source/contributing/model/multimodal.md +++ b/docs/source/contributing/model/multimodal.md @@ -128,11 +128,9 @@ HF processing as well as memory profiling. ### For memory profiling -Override the abstract method {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_processor_inputs` -to construct dummy inputs for memory profiling. This dummy input should result in the worst-case memory usage of -the model so that vLLM can reserve the correct amount of memory for it. +Override the abstract methods {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_text` and {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_mm_data` to construct dummy inputs for memory profiling. These dummy inputs should result in the worst-case memory usage of the model so that vLLM can reserve the correct amount of memory for it. -Assuming that the memory usage increases with the number of tokens, the dummy input can be constructed to maximize the number of output embeddings, which is the same number as placeholder feature tokens. +Assuming that the memory usage increases with the number of tokens, the dummy inputs can be constructed to maximize the number of output embeddings, which is the same number as placeholder feature tokens. ::::{tab-set} :::{tab-item} Basic example: LLaVA @@ -244,38 +242,45 @@ def get_num_image_tokens( ``` Notice that the number of image tokens doesn't depend on the image width and height. -We can simply use a dummy `image_size`: +We can simply use a dummy `image_size` to calculate the multimodal profiling data: ```python +# NOTE: In actuality, this is usually implemented as part of the +# model's subclass of `BaseProcessingInfo`, but we show it as is +# here for simplicity. def get_image_size_with_most_features(self) -> ImageSize: hf_config = self.get_hf_config() width = height = hf_config.image_size return ImageSize(width=width, height=height) -def get_dummy_processor_inputs( +def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], -) -> ProcessorInputs: +) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) - processor = self.info.get_hf_processor() - image_token = processor.image_token - - hf_config = self.get_hf_config() - target_width, target_height = self.info.get_image_size_with_most_features() + target_width, target_height = \ + self.info.get_image_size_with_most_features() - mm_data = { + return { "image": self._get_dummy_images(width=target_width, height=target_height, num_images=num_images) } +``` - return ProcessorInputs( - prompt_text=image_token * num_images, - mm_data=mm_data, - ) +For the text, we simply expand the multimodal image token from the model config to match the desired number of images. + +```python +def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + image_token = processor.image_token + + return image_token * num_images ``` ::: @@ -412,29 +417,30 @@ def get_image_size_with_most_features(self) -> ImageSize: Fuyu does not expect image placeholders in the inputs to HF processor, so the dummy prompt text is empty regardless of the number of images. -Otherwise, the logic of this method is very similar to LLaVA: ```python -def get_dummy_processor_inputs( +def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + return "" +``` + +For the multimodal image profiling data, the logic is very similar to LLaVA: + +```python +def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], -) -> ProcessorInputs: +) -> MultiModalDataDict: target_width, target_height = \ self.info.get_image_size_with_most_features() num_images = mm_counts.get("image", 0) - mm_data = { + return { "image": self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images) + height=target_height, + num_images=num_images) } - - return ProcessorInputs( - prompt_text="", - mm_data=mm_data, - ) ``` ::: diff --git a/docs/source/deployment/docker.md b/docs/source/deployment/docker.md index 6b794db656c0530976daf2d078f823dfe285426b..ca56710bc2ef2d9822687262eeaff315b2dd8e49 100644 --- a/docs/source/deployment/docker.md +++ b/docs/source/deployment/docker.md @@ -19,6 +19,18 @@ $ docker run --runtime nvidia --gpus all \ --model mistralai/Mistral-7B-v0.1 ``` +This image can also be used with other container engines such as [Podman](https://podman.io/). + +```console +$ podman run --gpus all \ + -v ~/.cache/huggingface:/root/.cache/huggingface \ + --env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \ + -p 8000:8000 \ + --ipc=host \ + vllm/vllm-openai:latest \ + --model mistralai/Mistral-7B-v0.1 +``` + You can add any other you need after the image tag (`vllm/vllm-openai:latest`). :::{note} diff --git a/docs/source/deployment/frameworks/anything-llm.md b/docs/source/deployment/frameworks/anything-llm.md new file mode 100644 index 0000000000000000000000000000000000000000..d430c170ef5419984f4344118bfb86d7b2d25d55 --- /dev/null +++ b/docs/source/deployment/frameworks/anything-llm.md @@ -0,0 +1,47 @@ +(deployment-anything-llm)= + +# Anything LLM + +[Anything LLM](https://github.com/Mintplex-Labs/anything-llm) is a full-stack application that enables you to turn any document, resource, or piece of content into context that any LLM can use as references during chatting. + +It allows you to deploy a large language model (LLM) server with vLLM as the backend, which exposes OpenAI-compatible endpoints. + +## Prerequisites + +- Setup vLLM environment + +## Deploy + +- Start the vLLM server with the supported chat completion model, e.g. + +```console +vllm serve Qwen/Qwen1.5-32B-Chat-AWQ --max-model-len 4096 +``` + +- Download and install [Anything LLM desktop](https://anythingllm.com/desktop). + +- On the bottom left of open settings, AI Prooviders --> LLM: + - LLM Provider: Generic OpenAI + - Base URL: http://{vllm server host}:{vllm server port}/v1 + - Chat Model Name: `Qwen/Qwen1.5-32B-Chat-AWQ` + +:::{image} /assets/deployment/anything-llm-provider.png +::: + +- Back to home page, New Workspace --> create `vllm` workspace, and start to chat: + +:::{image} /assets/deployment/anything-llm-chat-without-doc.png +::: + +- Click the upload button: + - upload the doc + - select the doc and move to the workspace + - save and embed + +:::{image} /assets/deployment/anything-llm-upload-doc.png +::: + +- Chat again: + +:::{image} /assets/deployment/anything-llm-chat-with-doc.png +::: diff --git a/docs/source/deployment/frameworks/index.md b/docs/source/deployment/frameworks/index.md index cb758d3e6d2e412be93953b29d27cceb4448020e..a1b405386b77aa1b0f6a48685eb2c703a960f40b 100644 --- a/docs/source/deployment/frameworks/index.md +++ b/docs/source/deployment/frameworks/index.md @@ -3,12 +3,14 @@ :::{toctree} :maxdepth: 1 +anything-llm bentoml cerebrium dstack helm lws modal +open-webui skypilot triton ::: diff --git a/docs/source/deployment/frameworks/open-webui.md b/docs/source/deployment/frameworks/open-webui.md new file mode 100644 index 0000000000000000000000000000000000000000..83e5303a00ef2d15f58c998cbd4946759c407a72 --- /dev/null +++ b/docs/source/deployment/frameworks/open-webui.md @@ -0,0 +1,29 @@ +(deployment-open-webui)= + +# Open WebUI + +1. Install the [Docker](https://docs.docker.com/engine/install/) + +2. Start the vLLM server with the supported chat completion model, e.g. + +```console +vllm serve qwen/Qwen1.5-0.5B-Chat +``` + +1. Start the [Open WebUI](https://github.com/open-webui/open-webui) docker container (replace the vllm serve host and vllm serve port): + +```console +docker run -d -p 3000:8080 \ +--name open-webui \ +-v open-webui:/app/backend/data \ +-e OPENAI_API_BASE_URL=http://:/v1 \ +--restart always \ +ghcr.io/open-webui/open-webui:main +``` + +1. Open it in the browser: + +On the top of the web page, you can see the model `qwen/Qwen1.5-0.5B-Chat`. + +:::{image} /assets/deployment/open_webui.png +::: diff --git a/docs/source/deployment/integrations/production-stack.md b/docs/source/deployment/integrations/production-stack.md index e66e8e6a16b294b1d946561b646c7bdc2412b438..05f1568306cc927a18457d62e8a89735a17b80f7 100644 --- a/docs/source/deployment/integrations/production-stack.md +++ b/docs/source/deployment/integrations/production-stack.md @@ -16,7 +16,7 @@ Ensure that you have a running Kubernetes environment with GPU (you can follow [ ## Deployment using vLLM production stack -The standard vLLM production stack install uses a Helm chart. You can run this [bash script](https://github.com/vllm-project/production-stack/blob/main/tutorials/install-helm.sh) to install Helm on your GPU server. +The standard vLLM production stack is installed using a Helm chart. You can run this [bash script](https://github.com/vllm-project/production-stack/blob/main/utils/install-helm.sh) to install Helm on your GPU server. To install the vLLM production stack, run the following commands on your desktop: diff --git a/docs/source/deployment/security.md b/docs/source/deployment/security.md new file mode 100644 index 0000000000000000000000000000000000000000..e2ef8196c16711ca8837e5a9966ff75465d457fd --- /dev/null +++ b/docs/source/deployment/security.md @@ -0,0 +1,58 @@ +# Security Guide + +## Inter-Node Communication + +All communications between nodes in a multi-node vLLM deployment are **insecure by default** and must be protected by placing the nodes on an isolated network. This includes: + +1. PyTorch Distributed communications +2. KV cache transfer communications +3. Tensor, Pipeline, and Data parallel communications + +### Configuration Options for Inter-Node Communications + +The following options control inter-node communications in vLLM: + +1. **Environment Variables:** + - `VLLM_HOST_IP`: Sets the IP address for vLLM processes to communicate on + +2. **KV Cache Transfer Configuration:** + - `--kv-ip`: The IP address for KV cache transfer communications (default: 127.0.0.1) + - `--kv-port`: The port for KV cache transfer communications (default: 14579) + +3. **Data Parallel Configuration:** + - `data_parallel_master_ip`: IP of the data parallel master (default: 127.0.0.1) + - `data_parallel_master_port`: Port of the data parallel master (default: 29500) + +### Notes on PyTorch Distributed + +vLLM uses PyTorch's distributed features for some inter-node communication. For +detailed information about PyTorch Distributed security considerations, please +refer to the [PyTorch Security +Guide](https://github.com/pytorch/pytorch/security/policy#using-distributed-features). + +Key points from the PyTorch security guide: +- PyTorch Distributed features are intended for internal communication only +- They are not built for use in untrusted environments or networks +- No authorization protocol is included for performance reasons +- Messages are sent unencrypted +- Connections are accepted from anywhere without checks + +### Security Recommendations + +1. **Network Isolation:** + - Deploy vLLM nodes on a dedicated, isolated network + - Use network segmentation to prevent unauthorized access + - Implement appropriate firewall rules + +2. **Configuration Best Practices:** + - Always set `VLLM_HOST_IP` to a specific IP address rather than using defaults + - Configure firewalls to only allow necessary ports between nodes + +3. **Access Control:** + - Restrict physical and network access to the deployment environment + - Implement proper authentication and authorization for management interfaces + - Follow the principle of least privilege for all system components + +## Reporting Security Vulnerabilities + +If you believe you have found a security vulnerability in vLLM, please report it following the project's security policy. For more information on how to report security issues and the project's security policy, please see the [vLLM Security Policy](https://github.com/vllm-project/vllm/blob/main/SECURITY.md). diff --git a/docs/source/design/mm_processing.md b/docs/source/design/mm_processing.md index 0947c1da1e547e4952b52d10c5a8029c556e30e2..dc92a3c2c511e018fa19464ef357fb7d54543e45 100644 --- a/docs/source/design/mm_processing.md +++ b/docs/source/design/mm_processing.md @@ -47,7 +47,7 @@ Moreover, since the tokenized text has not passed through the HF processor, we h ### Dummy text -We work around the first issue by requiring each model to define how to generate dummy text based on the number of multi-modal inputs, via {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_processor_inputs`. This lets us generate dummy text corresponding to the multi-modal inputs and input them together to obtain the processed multi-modal data. +We work around the first issue by requiring each model to define how to generate dummy text based on the number of multi-modal inputs, via {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_text`. This lets us generate dummy text corresponding to the multi-modal inputs and input them together to obtain the processed multi-modal data. (mm-automatic-prompt-updating)= diff --git a/docs/source/design/v1/metrics.md b/docs/source/design/v1/metrics.md index b3981b2dc24a7adafcff2a623b5fb6c80f78a1fb..3f96290798a334c768f95996db5d68478a6dd6f1 100644 --- a/docs/source/design/v1/metrics.md +++ b/docs/source/design/v1/metrics.md @@ -66,8 +66,8 @@ vLLM also provides [a reference example](https://docs.vllm.ai/en/latest/getting_ The subset of metrics exposed in the Grafana dashboard gives us an indication of which metrics are especially important: - `vllm:e2e_request_latency_seconds_bucket` - End to end request latency measured in seconds -- `vllm:prompt_tokens_total` - Prompt Tokens/Sec -- `vllm:generation_tokens_total` - Generation Tokens/Sec +- `vllm:prompt_tokens_total` - Prompt Tokens +- `vllm:generation_tokens_total` - Generation Tokens - `vllm:time_per_output_token_seconds` - Inter token latency (Time Per Output Token, TPOT) in second. - `vllm:time_to_first_token_seconds` - Time to First Token (TTFT) latency in seconds. - `vllm:num_requests_running` (also, `_swapped` and `_waiting`) - Number of requests in RUNNING, WAITING, and SWAPPED state @@ -86,6 +86,17 @@ See [the PR which added this Dashboard](gh-pr:2316) for interesting and useful b Prometheus support was initially added [using the aioprometheus library](gh-pr:1890), but a switch was made quickly to [prometheus_client](gh-pr:2730). The rationale is discussed in both linked PRs. +With the switch to `aioprometheus`, we lost a `MetricsMiddleware` to track HTTP metrics, but this was reinstated [using prometheus_fastapi_instrumentator](gh-pr:15657): + +```bash +$ curl http://0.0.0.0:8000/metrics 2>/dev/null | grep -P '^http_(?!.*(_bucket|_created|_sum)).*' +http_requests_total{handler="/v1/completions",method="POST",status="2xx"} 201.0 +http_request_size_bytes_count{handler="/v1/completions"} 201.0 +http_response_size_bytes_count{handler="/v1/completions"} 201.0 +http_request_duration_highr_seconds_count 201.0 +http_request_duration_seconds_count{handler="/v1/completions",method="POST"} 201.0 +``` + ### Multi-process Mode In v0, metrics are collected in the engine core process and we use multi-process mode to make them available in the API server process. See . diff --git a/docs/source/design/v1/torch_compile.md b/docs/source/design/v1/torch_compile.md index 57dba680b97c6e0783032869fe55e60ded09943f..7920131643c26153d5ca4b7305ffb02c31835fab 100644 --- a/docs/source/design/v1/torch_compile.md +++ b/docs/source/design/v1/torch_compile.md @@ -99,7 +99,7 @@ This time, Inductor compilation is completely bypassed, and we will load from di The above example just uses Inductor to compile for a general shape (i.e. symbolic shape). We can also use Inductor to compile for some of the specific shapes, for example: -`VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.2-1B --compilation_config "{'compile_sizes': [1, 2, 4, 8]}"` +`vllm serve meta-llama/Llama-3.2-1B --compilation_config "{'compile_sizes': [1, 2, 4, 8]}"` Then it will also compile a specific kernel just for batch size `1, 2, 4, 8`. At this time, all of the shapes in the computation graph are static and known, and we will turn on auto-tuning to tune for max performance. This can be slow when you run it for the first time, but the next time you run it, we can directly bypass the tuning and run the tuned kernel. @@ -134,6 +134,6 @@ The cudagraphs are captured and managed by the compiler backend, and replayed wh By default, vLLM will try to determine a set of sizes to capture cudagraph. You can also override it using the config `cudagraph_capture_sizes`: -`VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.2-1B --compilation_config "{'cudagraph_capture_sizes': [1, 2, 4, 8]}"` +`vllm serve meta-llama/Llama-3.2-1B --compilation-config "{'cudagraph_capture_sizes': [1, 2, 4, 8]}"` Then it will only capture cudagraph for the specified sizes. It can be useful to have fine-grained control over the cudagraph capture. diff --git a/docs/source/features/disagg_prefill.md b/docs/source/features/disagg_prefill.md index 52d253b9c2b18831782443ef55a6332b4d3919d8..2fa20140c086d3109c681d4a0136edae54f6da56 100644 --- a/docs/source/features/disagg_prefill.md +++ b/docs/source/features/disagg_prefill.md @@ -21,11 +21,11 @@ Disaggregated prefill DOES NOT improve throughput. ## Usage example -Please refer to `examples/online_serving/disaggregated_prefill.sh` for the example usage of disaggregated prefilling. +Please refer to for the example usage of disaggregated prefilling. ## Benchmarks -Please refer to `benchmarks/disagg_benchmarks/` for disaggregated prefilling benchmarks. +Please refer to for disaggregated prefilling benchmarks. ## Development diff --git a/docs/source/features/lora.md b/docs/source/features/lora.md index a71da72e4360ae683803c47a206b57f45902f902..b5b51095b3a75656d69e3303200a80346052f19c 100644 --- a/docs/source/features/lora.md +++ b/docs/source/features/lora.md @@ -106,19 +106,18 @@ curl http://localhost:8000/v1/completions \ ## Dynamically serving LoRA Adapters -In addition to serving LoRA adapters at server startup, the vLLM server now supports dynamically loading and unloading -LoRA adapters at runtime through dedicated API endpoints. This feature can be particularly useful when the flexibility -to change models on-the-fly is needed. +In addition to serving LoRA adapters at server startup, the vLLM server supports dynamically configuring LoRA adapters at runtime through dedicated API endpoints and plugins. This feature can be particularly useful when the flexibility to change models on-the-fly is needed. Note: Enabling this feature in production environments is risky as users may participate in model adapter management. -To enable dynamic LoRA loading and unloading, ensure that the environment variable `VLLM_ALLOW_RUNTIME_LORA_UPDATING` -is set to `True`. When this option is enabled, the API server will log a warning to indicate that dynamic loading is active. +To enable dynamic LoRA configuration, ensure that the environment variable `VLLM_ALLOW_RUNTIME_LORA_UPDATING` +is set to `True`. ```bash export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True ``` +### Using API Endpoints Loading a LoRA Adapter: To dynamically load a LoRA adapter, send a POST request to the `/v1/load_lora_adapter` endpoint with the necessary @@ -153,6 +152,58 @@ curl -X POST http://localhost:8000/v1/unload_lora_adapter \ }' ``` +### Using Plugins +Alternatively, you can use the LoRAResolver plugin to dynamically load LoRA adapters. LoRAResolver plugins enable you to load LoRA adapters from both local and remote sources such as local file system and S3. On every request, when there's a new model name that hasn't been loaded yet, the LoRAResolver will try to resolve and load the corresponding LoRA adapter. + +You can set up multiple LoRAResolver plugins if you want to load LoRA adapters from different sources. For example, you might have one resolver for local files and another for S3 storage. vLLM will load the first LoRA adapter that it finds. + +You can either install existing plugins or implement your own. + +Steps to implement your own LoRAResolver plugin: +1. Implement the LoRAResolver interface. + + Example of a simple S3 LoRAResolver implementation: + + ```python + import os + import s3fs + from vllm.lora.request import LoRARequest + from vllm.lora.resolver import LoRAResolver + + class S3LoRAResolver(LoRAResolver): + def __init__(self): + self.s3 = s3fs.S3FileSystem() + self.s3_path_format = os.getenv("S3_PATH_TEMPLATE") + self.local_path_format = os.getenv("LOCAL_PATH_TEMPLATE") + + async def resolve_lora(self, base_model_name, lora_name): + s3_path = self.s3_path_format.format(base_model_name=base_model_name, lora_name=lora_name) + local_path = self.local_path_format.format(base_model_name=base_model_name, lora_name=lora_name) + + # Download the LoRA from S3 to the local path + await self.s3._get( + s3_path, local_path, recursive=True, maxdepth=1 + ) + + lora_request = LoRARequest( + lora_name=lora_name, + lora_path=local_path, + lora_int_id=abs(hash(lora_name)) + ) + return lora_request + ``` + +2. Register LoRAResolver plugin. + + ```python + from vllm.lora.resolver import LoRAResolverRegistry + + s3_resolver = S3LoRAResolver() + LoRAResolverRegistry.register_resolver("s3_resolver", s3_resolver) + ``` + + For more details, refer to the [vLLM's Plugins System](../design/plugin_system.md). + ## New format for `--lora-modules` In the previous version, users would provide LoRA modules via the following format, either as a key-value pair or in JSON format. For example: diff --git a/docs/source/features/quantization/auto_awq.md b/docs/source/features/quantization/auto_awq.md index b703d0195319305185c957f2274f5cddf68d8ba5..b4ac597f5a79c2c7f5f66f88470752739fb8115a 100644 --- a/docs/source/features/quantization/auto_awq.md +++ b/docs/source/features/quantization/auto_awq.md @@ -6,13 +6,13 @@ To create a new 4-bit quantized model, you can leverage [AutoAWQ](https://github Quantization reduces the model's precision from BF16/FP16 to INT4 which effectively reduces the total model memory footprint. The main benefits are lower latency and memory usage. -You can quantize your own models by installing AutoAWQ or picking one of the [6500+ models on Huggingface](https://huggingface.co/models?sort=trending&search=awq). +You can quantize your own models by installing AutoAWQ or picking one of the [6500+ models on Huggingface](https://huggingface.co/models?search=awq). ```console pip install autoawq ``` -After installing AutoAWQ, you are ready to quantize a model. Please refer to the `AutoAWQ documentation `_ for further details. Here is an example of how to quantize `mistralai/Mistral-7B-Instruct-v0.2`: +After installing AutoAWQ, you are ready to quantize a model. Please refer to the [AutoAWQ documentation](https://casper-hansen.github.io/AutoAWQ/examples/#basic-quantization) for further details. Here is an example of how to quantize `mistralai/Mistral-7B-Instruct-v0.2`: ```python from awq import AutoAWQForCausalLM diff --git a/docs/source/features/quantization/bitblas.md b/docs/source/features/quantization/bitblas.md new file mode 100644 index 0000000000000000000000000000000000000000..d0b2bf858c9b6b107f2399b1bde41f5a1aa862d1 --- /dev/null +++ b/docs/source/features/quantization/bitblas.md @@ -0,0 +1,48 @@ +(bitblas)= + +# BitBLAS + +vLLM now supports [BitBLAS](https://github.com/microsoft/BitBLAS) for more efficient and flexible model inference. Compared to other quantization frameworks, BitBLAS provides more precision combinations. + +:::{note} +Ensure your hardware supports the selected `dtype` (`torch.bfloat16` or `torch.float16`). +Most recent NVIDIA GPUs support `float16`, while `bfloat16` is more common on newer architectures like Ampere or Hopper. +For details see [supported hardware](https://docs.vllm.ai/en/latest/features/quantization/supported_hardware.html). +::: + +Below are the steps to utilize BitBLAS with vLLM. + +```console +pip install bitblas>=0.1.0 +``` + +vLLM reads the model's config file and supports pre-quantized checkpoints. + +You can find pre-quantized models on: + +- [Hugging Face (BitBLAS)](https://huggingface.co/models?search=bitblas) +- [Hugging Face (GPTQ)](https://huggingface.co/models?search=gptq) + +Usually, these repositories have a `quantize_config.json` file that includes a `quantization_config` section. + +## Read bitblas format checkpoint + +```python +from vllm import LLM +import torch + +# "hxbgsyxh/llama-13b-4bit-g-1-bitblas" is a pre-quantized checkpoint. +model_id = "hxbgsyxh/llama-13b-4bit-g-1-bitblas" +llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, quantization="bitblas") +``` + +## Read gptq format checkpoint + +```python +from vllm import LLM +import torch + +# "hxbgsyxh/llama-13b-4bit-g-1" is a pre-quantized checkpoint. +model_id = "hxbgsyxh/llama-13b-4bit-g-1" +llm = LLM(model=model_id, dtype=torch.float16, trust_remote_code=True, quantization="bitblas", max_model_len=1024) +``` diff --git a/docs/source/features/quantization/bnb.md b/docs/source/features/quantization/bnb.md index e356b99d85cdf7a0d4337789c32f6762c73e779b..1843a33a3dfdd6ae1607a649d7657893ec345580 100644 --- a/docs/source/features/quantization/bnb.md +++ b/docs/source/features/quantization/bnb.md @@ -14,7 +14,7 @@ pip install bitsandbytes>=0.45.3 vLLM reads the model's config file and supports both in-flight quantization and pre-quantized checkpoint. -You can find bitsandbytes quantized models on . +You can find bitsandbytes quantized models on . And usually, these repositories have a config.json file that includes a quantization_config section. ## Read quantized checkpoint diff --git a/docs/source/features/quantization/gptqmodel.md b/docs/source/features/quantization/gptqmodel.md index 34adf6512b7e215c326368bd44875990a155b1d2..9771d5a4fe9ee7882bdda0df907f4091ef6cf96b 100644 --- a/docs/source/features/quantization/gptqmodel.md +++ b/docs/source/features/quantization/gptqmodel.md @@ -16,12 +16,16 @@ GPTQModel is one of the few quantization toolkits in the world that allows `Dyna is fully integrated into vLLM and backed up by support from the ModelCloud.AI team. Please refer to [GPTQModel readme](https://github.com/ModelCloud/GPTQModel?tab=readme-ov-file#dynamic-quantization-per-module-quantizeconfig-override) for more details on this and other advanced features. -You can quantize your own models by installing [GPTQModel](https://github.com/ModelCloud/GPTQModel) or picking one of the [5000+ models on Huggingface](https://huggingface.co/models?sort=trending&search=gptq). +## Installation + +You can quantize your own models by installing [GPTQModel](https://github.com/ModelCloud/GPTQModel) or picking one of the [5000+ models on Huggingface](https://huggingface.co/models?search=gptq). ```console pip install -U gptqmodel --no-build-isolation -v ``` +## Quantizing a model + After installing GPTQModel, you are ready to quantize a model. Please refer to the [GPTQModel readme](https://github.com/ModelCloud/GPTQModel/?tab=readme-ov-file#quantization) for further details. Here is an example of how to quantize `meta-llama/Llama-3.2-1B-Instruct`: @@ -49,12 +53,16 @@ model.quantize(calibration_dataset, batch_size=2) model.save(quant_path) ``` +## Running a quantized model with vLLM + To run an GPTQModel quantized model with vLLM, you can use [DeepSeek-R1-Distill-Qwen-7B-gptqmodel-4bit-vortex-v2](https://huggingface.co/ModelCloud/DeepSeek-R1-Distill-Qwen-7B-gptqmodel-4bit-vortex-v2) with the following command: ```console -python examples/offline_inference/llm_engine_example.py --model DeepSeek-R1-Distill-Qwen-7B-gptqmodel-4bit-vortex-v2 +python examples/offline_inference/llm_engine_example.py --model ModelCloud/DeepSeek-R1-Distill-Qwen-7B-gptqmodel-4bit-vortex-v2 ``` +## Using GPTQModel with vLLM's Python API + GPTQModel quantized models are also supported directly through the LLM entrypoint: ```python @@ -67,17 +75,22 @@ prompts = [ "The capital of France is", "The future of AI is", ] + # Create a sampling params object. sampling_params = SamplingParams(temperature=0.6, top_p=0.9) # Create an LLM. -llm = LLM(model="DeepSeek-R1-Distill-Qwen-7B-gptqmodel-4bit-vortex-v2") +llm = LLM(model="ModelCloud/DeepSeek-R1-Distill-Qwen-7B-gptqmodel-4bit-vortex-v2") + # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) + # Print the outputs. +print("-"*50) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-"*50) ``` diff --git a/docs/source/features/quantization/index.md b/docs/source/features/quantization/index.md index 6f539f6e3f486b7363bdb74396428fa5dc83d54b..c7c8aeb662a56d189111fff7c0c1df334b385335 100644 --- a/docs/source/features/quantization/index.md +++ b/docs/source/features/quantization/index.md @@ -11,6 +11,7 @@ Quantization trades off model precision for smaller memory footprint, allowing l supported_hardware auto_awq bnb +bitblas gguf gptqmodel int4 diff --git a/docs/source/features/quantization/supported_hardware.md b/docs/source/features/quantization/supported_hardware.md index 2cbe8779dd8a18c5e45a1fea11472dd1e77014a0..984e6626e2417fc095850a14d1894a288d2a37ac 100644 --- a/docs/source/features/quantization/supported_hardware.md +++ b/docs/source/features/quantization/supported_hardware.md @@ -74,6 +74,17 @@ The table below shows the compatibility of various quantization implementations * ❌ * ❌ * ❌ +- * BitBLAS (GPTQ) + * ✅︎ + * ✅︎ + * ✅︎ + * ✅︎ + * ✅︎ + * ✅︎ + * ❌ + * ❌ + * ❌ + * ❌ - * AQLM * ✅︎ * ✅︎ diff --git a/docs/source/features/quantization/torchao.md b/docs/source/features/quantization/torchao.md index 9a85f0bab9ec7101bdca3e07ed33d37fa3b9cf75..82100c6ddcac0befe12ec4db9b7ead6dfaaa5ec5 100644 --- a/docs/source/features/quantization/torchao.md +++ b/docs/source/features/quantization/torchao.md @@ -30,5 +30,4 @@ tokenizer.push_to_hub(hub_repo) quantized_model.push_to_hub(hub_repo, safe_serialization=False) ``` -Alternatively, you can use the TorchAO Quantization space for quantizing models with a simple UI. -See: https://huggingface.co/spaces/medmekk/TorchAO_Quantization +Alternatively, you can use the [TorchAO Quantization space](https://huggingface.co/spaces/medmekk/TorchAO_Quantization) for quantizing models with a simple UI. diff --git a/docs/source/features/structured_outputs.md b/docs/source/features/structured_outputs.md index de3c5bf5e7ab96048c3321d1dc66e0f0a9677141..03119ec7441c90bcd2d227807bf447c4ade5a826 100644 --- a/docs/source/features/structured_outputs.md +++ b/docs/source/features/structured_outputs.md @@ -2,8 +2,11 @@ # Structured Outputs -vLLM supports the generation of structured outputs using [outlines](https://github.com/dottxt-ai/outlines), [lm-format-enforcer](https://github.com/noamgat/lm-format-enforcer), or [xgrammar](https://github.com/mlc-ai/xgrammar) as backends for the guided decoding. -This document shows you some examples of the different options that are available to generate structured outputs. +vLLM supports the generation of structured outputs using +[xgrammar](https://github.com/mlc-ai/xgrammar) or +[guidance](https://github.com/guidance-ai/llguidance) as backends. +This document shows you some examples of the different options that are +available to generate structured outputs. ## Online Serving (OpenAI API) @@ -15,10 +18,17 @@ The following parameters are supported, which must be added as extra parameters: - `guided_regex`: the output will follow the regex pattern. - `guided_json`: the output will follow the JSON schema. - `guided_grammar`: the output will follow the context free grammar. -- `guided_whitespace_pattern`: used to override the default whitespace pattern for guided json decoding. -- `guided_decoding_backend`: used to select the guided decoding backend to use. Additional backend-specific options can be supplied in a comma separated list following a colon after the backend name. For example `"xgrammar:no-fallback"` will not allow vLLM to fallback to a different backend on error. +- `structural_tag`: Follow a JSON schema within a set of specified tags within the generated text. -You can see the complete list of supported parameters on the [OpenAI-Compatible Server](#openai-compatible-server)page. +You can see the complete list of supported parameters on the [OpenAI-Compatible Server](#openai-compatible-server) page. + +Structured outputs are supported by default in the OpenAI-Compatible Server. You +may choose to specify the backend to use by setting the +`--guided-decoding-backend` flag to `vllm serve`. The default backend is `auto`, +which will try to choose an appropriate backend based on the details of the +request. You may also choose a specific backend, along with +some options. A full set of options is available in the `vllm serve --help` +text. Now let´s see an example for each of the cases, starting with the `guided_choice`, as it´s the easiest one: @@ -50,7 +60,7 @@ completion = client.chat.completions.create( "content": "Generate an example email address for Alan Turing, who works in Enigma. End in .com and new line. Example result: alan.turing@enigma.com\n", } ], - extra_body={"guided_regex": "\w+@\w+\.com\n", "stop": ["\n"]}, + extra_body={"guided_regex": r"\w+@\w+\.com\n", "stop": ["\n"]}, ) print(completion.choices[0].message.content) ``` @@ -96,26 +106,29 @@ print(completion.choices[0].message.content) ``` :::{tip} -While not strictly necessary, normally it´s better to indicate in the prompt that a JSON needs to be generated and which fields and how should the LLM fill them. -This can improve the results notably in most cases. +While not strictly necessary, normally it´s better to indicate in the prompt the +JSON schema and how the fields should be populated. This can improve the +results notably in most cases. ::: -Finally we have the `guided_grammar`, which probably is the most difficult one to use but it´s really powerful, as it allows us to define complete languages like SQL queries. -It works by using a context free EBNF grammar, which for example we can use to define a specific format of simplified SQL queries, like in the example below: +Finally we have the `guided_grammar` option, which is probably the most +difficult to use, but it´s really powerful. It allows us to define complete +languages like SQL queries. It works by using a context free EBNF grammar. +As an example, we can use to define a specific format of simplified SQL queries: ```python simplified_sql_grammar = """ - ?start: select_statement + root ::= select_statement - ?select_statement: "SELECT " column_list " FROM " table_name + select_statement ::= "SELECT " column " from " table " where " condition - ?column_list: column_name ("," column_name)* + column ::= "col_1 " | "col_2 " - ?table_name: identifier + table ::= "table_1 " | "table_2 " - ?column_name: identifier + condition ::= column "= " number - ?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/ + number ::= "1 " | "2 " """ completion = client.chat.completions.create( @@ -226,6 +239,8 @@ Step #2: explanation="Next, let's isolate 'x' by dividing both sides of the equa Answer: x = -29/8 ``` +An example of using `structural_tag` can be found here: + ## Offline Inference Offline inference allows for the same types of guided decoding. @@ -236,11 +251,11 @@ The main available options inside `GuidedDecodingParams` are: - `regex` - `choice` - `grammar` -- `backend` -- `whitespace_pattern` +- `structural_tag` -These parameters can be used in the same way as the parameters from the Online Serving examples above. -One example for the usage of the `choices` parameter is shown below: +These parameters can be used in the same way as the parameters from the Online +Serving examples above. One example for the usage of the `choice` parameter is +shown below: ```python from vllm import LLM, SamplingParams diff --git a/docs/source/features/tool_calling.md b/docs/source/features/tool_calling.md index 8b8bbd28d3483b21943bde7313cb7af2b5492a46..f98ec6108cea616068ab726c5f1bbd569467156e 100644 --- a/docs/source/features/tool_calling.md +++ b/docs/source/features/tool_calling.md @@ -152,12 +152,14 @@ Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_cha Supported models: -* `meta-llama/Meta-Llama-3.1-8B-Instruct` -* `meta-llama/Meta-Llama-3.1-70B-Instruct` -* `meta-llama/Meta-Llama-3.1-405B-Instruct` -* `meta-llama/Meta-Llama-3.1-405B-Instruct-FP8` +All Llama 3.1, 3.2 and 4 models should be supported. + +* `meta-llama/Llama-3.1-*` +* `meta-llama/Llama-3.2-*` +* `meta-llama/Llama-4-*` + +The tool calling that is supported is the [JSON based tool calling](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling). For [pythonic tool calling](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#zero-shot-function-calling) introduced by the Llama-3.2 models, see the `pythonic` tool parser below. -The tool calling that is supported is the [JSON based tool calling](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling). For [pythonic tool calling](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#zero-shot-function-calling) in Llama-3.2 models, see the `pythonic` tool parser below. Other tool calling formats like the built in python tool calling or custom tool calling are not supported. Known issues: @@ -166,10 +168,20 @@ Known issues: 2. The model can generate parameters with a wrong format, such as generating an array serialized as string instead of an array. -The `tool_chat_template_llama3_json.jinja` file contains the "official" Llama chat template, but tweaked so that -it works better with vLLM. +VLLM provides two JSON based chat templates for Llama 3.1 and 3.2: + +* `examples/tool_chat_template_llama3.1_json.jinja` - this is the "official" chat template for the Llama 3.1 +models, but tweaked so that it works better with vLLM. +* `examples/tool_chat_template_llama3.2_json.jinja` - this extends upon the Llama 3.1 chat template by adding support for +images. + +Recommended flags: `--tool-call-parser llama3_json --chat-template {see_above}` + +VLLM also provides a JSON based chat template for Llama 4: +* `examples/tool_chat_template_llama4_json.jinja` - this is based on the "official" chat template for the Llama 4 +models, but tweaked so that it works better with vLLM. -Recommended flags: `--tool-call-parser llama3_json --chat-template examples/tool_chat_template_llama3_json.jinja` +For Llama 4 use `--tool-call-parser llama4_json examples/tool_chat_template_llama4_json.jinja`. #### IBM Granite diff --git a/docs/source/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md b/docs/source/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md index e3046f35ee15fdc09dc4ed3fe1cea4580151c794..78938de317c48075f93d2474ac6e8c7da78fa4bd 100644 --- a/docs/source/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md +++ b/docs/source/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md @@ -13,11 +13,11 @@ There are no pre-built wheels or images for this device, so you must build vLLM - Intel Gaudi accelerator - Intel Gaudi software version 1.18.0 -Please follow the instructions provided in the [Gaudi Installation -Guide](https://docs.habana.ai/en/latest/Installation_Guide/index.html) +Please follow the instructions provided in the +[Gaudi Installation Guide](https://docs.habana.ai/en/latest/Installation_Guide/index.html) to set up the execution environment. To achieve the best performance, -please follow the methods outlined in the [Optimizing Training Platform -Guide](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_Training_Platform.html). +please follow the methods outlined in the +[Optimizing Training Platform Guide](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_Training_Platform.html). ## Configure a new environment @@ -32,15 +32,13 @@ pip list | grep habana # verify that habana-torch-plugin, habana-torch-dataloade pip list | grep neural # verify that neural_compressor is installed ``` -Refer to [Intel Gaudi Software Stack -Verification](https://docs.habana.ai/en/latest/Installation_Guide/SW_Verification.html#platform-upgrade) +Refer to [Intel Gaudi Software Stack Verification](https://docs.habana.ai/en/latest/Installation_Guide/SW_Verification.html#platform-upgrade) for more details. ### Run Docker Image It is highly recommended to use the latest Docker image from Intel Gaudi -vault. Refer to the [Intel Gaudi -documentation](https://docs.habana.ai/en/latest/Installation_Guide/Bare_Metal_Fresh_OS.html#pull-prebuilt-containers) +vault. Refer to the [Intel Gaudi documentation](https://docs.habana.ai/en/latest/Installation_Guide/Bare_Metal_Fresh_OS.html#pull-prebuilt-containers) for more details. Use the following commands to run a Docker image: @@ -278,8 +276,9 @@ Lower value corresponds to less usable graph memory reserved for prefill stage, ::: User can also configure the strategy for capturing HPU Graphs for prompt and decode stages separately. Strategy affects the order of capturing graphs. There are two strategies implemented: -\- `max_bs` - graph capture queue will sorted in descending order by their batch sizes. Buckets with equal batch sizes are sorted by sequence length in ascending order (e.g. `(64, 128)`, `(64, 256)`, `(32, 128)`, `(32, 256)`, `(1, 128)`, `(1,256)`), default strategy for decode -\- `min_tokens` - graph capture queue will be sorted in ascending order by the number of tokens each graph processes (`batch_size*sequence_length`), default strategy for prompt + +- `max_bs` - graph capture queue will sorted in descending order by their batch sizes. Buckets with equal batch sizes are sorted by sequence length in ascending order (e.g. `(64, 128)`, `(64, 256)`, `(32, 128)`, `(32, 256)`, `(1, 128)`, `(1,256)`), default strategy for decode +- `min_tokens` - graph capture queue will be sorted in ascending order by the number of tokens each graph processes (`batch_size*sequence_length`), default strategy for prompt When there's large amount of requests pending, vLLM scheduler will attempt to fill the maximum batch size for decode as soon as possible. When a request is finished, decode batch size decreases. When that happens, vLLM will attempt to schedule a prefill iteration for requests in the waiting queue, to fill the decode batch size to its previous state. This means that in a full load scenario, decode batch size is often at its maximum, which makes large batch size HPU Graphs crucial to capture, as reflected by `max_bs` strategy. On the other hand, prefills will be executed most frequently with very low batch sizes (1-4), which is reflected in `min_tokens` strategy. @@ -326,8 +325,7 @@ INFO 08-02 17:38:43 hpu_executor.py:91] init_cache_engine took 37.92 GiB of devi - We recommend running inference on Gaudi 2 with `block_size` of 128 for BF16 data type. Using default values (16, 32) might lead to sub-optimal performance due to Matrix Multiplication Engine - under-utilization (see [Gaudi - Architecture](https://docs.habana.ai/en/latest/Gaudi_Overview/Gaudi_Architecture.html)). + under-utilization (see [Gaudi Architecture](https://docs.habana.ai/en/latest/Gaudi_Overview/Gaudi_Architecture.html)). - For max throughput on Llama 7B, we recommend running with batch size of 128 or 256 and max context length of 2048 with HPU Graphs enabled. If you encounter out-of-memory issues, see troubleshooting section. @@ -336,11 +334,11 @@ INFO 08-02 17:38:43 hpu_executor.py:91] init_cache_engine took 37.92 GiB of devi **Diagnostic and profiling knobs:** -- `VLLM_PROFILER_ENABLED`: if `true`, high level profiler will be enabled. Resulting JSON traces can be viewed in [perfetto.habana.ai](https://perfetto.habana.ai/#!/viewer). Disabled by default. -- `VLLM_HPU_LOG_STEP_GRAPH_COMPILATION`: if `true`, will log graph compilations per each vLLM engine step, only when there was any - highly recommended to use alongside `PT_HPU_METRICS_GC_DETAILS=1`. Disabled by default. -- `VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL`: if `true`, will log graph compilations per each vLLM engine step, always, even if there were none. Disabled by default. -- `VLLM_HPU_LOG_STEP_CPU_FALLBACKS`: if `true`, will log cpu fallbacks per each vLLM engine step, only when there was any. Disabled by default. -- `VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL`: if `true`, will log cpu fallbacks per each vLLM engine step, always, even if there were none. Disabled by default. +- `VLLM_PROFILER_ENABLED`: If `true`, enable the high level profiler. Resulting JSON traces can be viewed in [perfetto.habana.ai](https://perfetto.habana.ai/#!/viewer). `false` by default. +- `VLLM_HPU_LOG_STEP_GRAPH_COMPILATION`: If `true`, log graph compilations for each vLLM engine step when any occurs. Highly recommended to use with `PT_HPU_METRICS_GC_DETAILS=1`. `false` by default. +- `VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL`: If `true`, always log graph compilations for each vLLM engine step even if none occurred. `false` by default. +- `VLLM_HPU_LOG_STEP_CPU_FALLBACKS`: If `true`, log CPU fallbacks for each vLLM engine step when any occurs. `false` by default. +- `VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL`: if `true`, always log CPU fallbacks for each vLLM engine step even if none occurred. `false` by default. **Performance tuning knobs:** @@ -381,7 +379,7 @@ INFO 08-02 17:38:43 hpu_executor.py:91] init_cache_engine took 37.92 GiB of devi Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM execution: -- `PT_HPU_LAZY_MODE`: if `0`, PyTorch Eager backend for Gaudi will be used, if `1` PyTorch Lazy backend for Gaudi will be used, `1` is default +- `PT_HPU_LAZY_MODE`: if `0`, PyTorch Eager backend for Gaudi will be used; if `1`, PyTorch Lazy backend for Gaudi will be used. `1` is default. - `PT_HPU_ENABLE_LAZY_COLLECTIVES`: required to be `true` for tensor parallel inference with HPU Graphs ## Troubleshooting: tweaking HPU graphs diff --git a/docs/source/getting_started/installation/ai_accelerator/tpu.inc.md b/docs/source/getting_started/installation/ai_accelerator/tpu.inc.md index beb803cf059782140f19d2ba4706ceadfb71ea80..8beb92ef7da0a3e50b7be726b78dd9631e9f95a6 100644 --- a/docs/source/getting_started/installation/ai_accelerator/tpu.inc.md +++ b/docs/source/getting_started/installation/ai_accelerator/tpu.inc.md @@ -44,7 +44,7 @@ There are no pre-built wheels for this device, so you must either use the pre-bu You can provision Cloud TPUs using the [Cloud TPU API](https://cloud.google.com/tpu/docs/reference/rest) or the [queued resources](https://cloud.google.com/tpu/docs/queued-resources) -API. This section shows how to create TPUs using the queued resource API. For +API (preferred). This section shows how to create TPUs using the queued resource API. For more information about using the Cloud TPU API, see [Create a Cloud TPU using the Create Node API](https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm#create-node-api). Queued resources enable you to request Cloud TPU resources in a queued manner. When you request queued resources, the request is added to a queue maintained by @@ -97,10 +97,10 @@ gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \ `TPU regions and zones `_ - * ACCELERATOR_TYPE * The TPU version you want to use. Specify the TPU version, for example - `v5litepod-4` specifies a v5e TPU with 4 cores. For more information, - see `TPU versions `_. + `v5litepod-4` specifies a v5e TPU with 4 cores, `v6e-1` specifies a v6e TPU with 1 core. For more information, + see [TPU versions](https://cloud.devsite.corp.google.com/tpu/docs/system-architecture-tpu-vm#versions). - * RUNTIME_VERSION - * The TPU VM runtime version to use. For more information see `TPU VM images `_. + * The TPU VM runtime version to use. For example, use `v2-alpha-tpuv6e` for a VM loaded with one or more v6e TPU(s). For more information see [TPU VM images](https://cloud.google.com/tpu/docs/runtimes). - * SERVICE_ACCOUNT * The email address for your service account. You can find it in the IAM Cloud Console under *Service Accounts*. For example: diff --git a/docs/source/getting_started/installation/cpu.md b/docs/source/getting_started/installation/cpu.md index db22ef79c926a2da20f1cf43e44873439f8b1ca4..2c0ec60d7100f637348a5972fb0ee9265743b976 100644 --- a/docs/source/getting_started/installation/cpu.md +++ b/docs/source/getting_started/installation/cpu.md @@ -272,7 +272,7 @@ $ python examples/offline_inference/basic/basic.py - Decouple the HTTP serving components from the inference components. In a GPU backend configuration, the HTTP serving and tokenization tasks operate on the CPU, while inference runs on the GPU, which typically does not pose a problem. However, in a CPU-based setup, the HTTP serving and tokenization can cause significant context switching and reduced cache efficiency. Therefore, it is strongly recommended to segregate these two components for improved performance. -- On CPU based setup with NUMA enabled, the memory access performance may be largely impacted by the [topology](https://github.com/intel/intel-extension-for-pytorch/blob/main/docs/tutorials/performance_tuning/tuning_guide.inc.md#non-uniform-memory-access-numa). For NUMA architecture, Tensor Parallel is a option for better performance. +- On CPU based setup with NUMA enabled, the memory access performance may be largely impacted by the [topology](https://github.com/intel/intel-extension-for-pytorch/blob/main/docs/tutorials/performance_tuning/tuning_guide.md#non-uniform-memory-access-numa). For NUMA architecture, Tensor Parallel is a option for better performance. - Tensor Parallel is supported for serving and offline inferencing. In general each NUMA node is treated as one GPU card. Below is the example script to enable Tensor Parallel = 2 for serving: diff --git a/docs/source/getting_started/installation/cpu/build.inc.md b/docs/source/getting_started/installation/cpu/build.inc.md index 39d9dfbd2b2e2464d37610f4eaf520f56cbf0b2e..f385f3d5b19842c126d733737a7f3de1ac19bd62 100644 --- a/docs/source/getting_started/installation/cpu/build.inc.md +++ b/docs/source/getting_started/installation/cpu/build.inc.md @@ -2,7 +2,7 @@ First, install recommended compiler. We recommend to use `gcc/g++ >= 12.3.0` as ```console sudo apt-get update -y -sudo apt-get install -y gcc-12 g++-12 libnuma-dev +sudo apt-get install -y gcc-12 g++-12 libnuma-dev python3-dev sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 ``` @@ -26,3 +26,9 @@ Finally, build and install vLLM CPU backend: ```console VLLM_TARGET_DEVICE=cpu python setup.py install ``` + +If you want to develop vllm, install it in editable mode instead. + +```console +VLLM_TARGET_DEVICE=cpu python setup.py develop +``` diff --git a/docs/source/getting_started/installation/gpu/cuda.inc.md b/docs/source/getting_started/installation/gpu/cuda.inc.md index d3e375aec10cb990ccad112102f34e9024550c7f..46bdb08ebb77c921bbce06193f269c6ff11c7591 100644 --- a/docs/source/getting_started/installation/gpu/cuda.inc.md +++ b/docs/source/getting_started/installation/gpu/cuda.inc.md @@ -46,7 +46,7 @@ LLM inference is a fast-evolving field, and the latest code may contain bug fixe ##### Install the latest code using `pip` ```console -pip install vllm --pre --extra-index-url https://wheels.vllm.ai/nightly +pip install -U vllm --pre --extra-index-url https://wheels.vllm.ai/nightly ``` `--pre` is required for `pip` to consider pre-released versions. @@ -65,9 +65,11 @@ Note that the wheels are built with Python 3.8 ABI (see [PEP 425](https://peps.p Another way to install the latest code is to use `uv`: ```console -uv pip install vllm --extra-index-url https://wheels.vllm.ai/nightly +uv pip install -U vllm --extra-index-url https://wheels.vllm.ai/nightly ``` +##### Install specific revisions using `uv` + If you want to access the wheels for previous commits (e.g. to bisect the behavior change, performance regression), you can specify the commit hash in the URL: ```console @@ -151,7 +153,7 @@ git clone https://github.com/vllm-project/vllm.git cd vllm python use_existing_torch.py pip install -r requirements/build.txt -pip install -e . --no-build-isolation +pip install --no-build-isolation -e . ``` ##### Use the local cutlass for compilation diff --git a/docs/source/getting_started/installation/gpu/xpu.inc.md b/docs/source/getting_started/installation/gpu/xpu.inc.md index c41905f250f8391caa0d4bf012f342dacdddb7b2..fbf5421eeec5b358732a2c73246708121dcbb2d6 100644 --- a/docs/source/getting_started/installation/gpu/xpu.inc.md +++ b/docs/source/getting_started/installation/gpu/xpu.inc.md @@ -23,6 +23,8 @@ Currently, there are no pre-built XPU wheels. - Second, install Python packages for vLLM XPU backend building: ```console +git clone https://github.com/vllm-project/vllm.git +cd vllm pip install --upgrade pip pip install -v -r requirements/xpu.txt ``` diff --git a/docs/source/getting_started/troubleshooting.md b/docs/source/getting_started/troubleshooting.md index 87fa442e9a4893844d6eb6185d1e322c1e18a3ac..a4744827f2268865ecc9923746a56693dec79922 100644 --- a/docs/source/getting_started/troubleshooting.md +++ b/docs/source/getting_started/troubleshooting.md @@ -24,7 +24,7 @@ To isolate the model downloading and loading issue, you can use the `--load-form ## Out of memory -If the model is too large to fit in a single GPU, you will get an out-of-memory (OOM) error. Consider [using tensor parallelism](#distributed-serving) to split the model across multiple GPUs. In that case, every process will read the whole model and split it into chunks, which makes the disk reading time even longer (proportional to the size of tensor parallelism). You can convert the model checkpoint to a sharded checkpoint using . The conversion process might take some time, but later you can load the sharded checkpoint much faster. The model loading time should remain constant regardless of the size of tensor parallelism. +If the model is too large to fit in a single GPU, you will get an out-of-memory (OOM) error. Consider adopting [these options](#reducing-memory-usage) to reduce the memory consumption. ## Generation quality changed diff --git a/docs/source/getting_started/v1_user_guide.md b/docs/source/getting_started/v1_user_guide.md index a87484c3bb042590672d5d18b6959b14f6abe2fa..de90b8a7851e69f6da7735d4a4e3cbe70cb0f9bf 100644 --- a/docs/source/getting_started/v1_user_guide.md +++ b/docs/source/getting_started/v1_user_guide.md @@ -44,8 +44,8 @@ This living user guide outlines a few known **important changes and limitations* |-----------------|-----------------------------------------------------------------------------------| | **Prefix Caching** | 🚀 Optimized | | **Chunked Prefill** | 🚀 Optimized | +| **LoRA** | 🚀 Optimized | | **Logprobs Calculation** | 🟢 Functional | -| **LoRA** | 🟢 Functional ([PR #13096](https://github.com/vllm-project/vllm/pull/13096))| | **Multimodal Models** | 🟢 Functional | | **FP8 KV Cache** | 🟢 Functional on Hopper devices ([PR #15191](https://github.com/vllm-project/vllm/pull/15191))| | **Spec Decode** | 🚧 WIP ([PR #13933](https://github.com/vllm-project/vllm/pull/13933))| @@ -121,11 +121,6 @@ Although we have re-implemented and partially optimized many features and models These features are already supported in vLLM V1, but their optimization is still in progress. -- **LoRA**: LoRA is functionally working on vLLM V1 but its performance is - inferior to that of V0. The team is actively working on improving its - performance -(e.g., see [PR #13096](https://github.com/vllm-project/vllm/pull/13096)). - - **Spec Decode**: Currently, only ngram-based spec decode is supported in V1. There will be follow-up work to support other types of spec decode (e.g., see [PR #13933](https://github.com/vllm-project/vllm/pull/13933)). We will prioritize the support for Eagle, MTP compared to draft model based spec decode. diff --git a/docs/source/index.md b/docs/source/index.md index 28dc0f67d7746a89cab6b5a6563437484700a059..43b330e4b432e5c83a2282809db1b498621625bb 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -132,6 +132,7 @@ serving/integrations/index :caption: Deployment :maxdepth: 1 +deployment/security deployment/docker deployment/k8s deployment/nginx diff --git a/docs/source/models/extensions/fastsafetensor.md b/docs/source/models/extensions/fastsafetensor.md index 66cd710c97e9fa5494732d0dd3f665f9f8d1e600..531d58690014ee0ec79c3dac7c33e7650dbf4b5a 100644 --- a/docs/source/models/extensions/fastsafetensor.md +++ b/docs/source/models/extensions/fastsafetensor.md @@ -1,5 +1,5 @@ Loading Model weights with fastsafetensors =================================================================== -Using fastsafetensor library enables loading model weights to GPU memory by leveraging GPU direct storage. See https://github.com/foundation-model-stack/fastsafetensors for more details. +Using fastsafetensors library enables loading model weights to GPU memory by leveraging GPU direct storage. See [their GitHub repository](https://github.com/foundation-model-stack/fastsafetensors) for more details. For enabling this feature, set the environment variable ``USE_FASTSAFETENSOR`` to ``true`` diff --git a/docs/source/models/extensions/runai_model_streamer.md b/docs/source/models/extensions/runai_model_streamer.md index 99c37876a01b3bddbbfeae8632d2db7c43d8df68..e0daa6f86dde4e45d507dfa19bd24583e60d9854 100644 --- a/docs/source/models/extensions/runai_model_streamer.md +++ b/docs/source/models/extensions/runai_model_streamer.md @@ -51,3 +51,29 @@ vllm serve /home/meta-llama/Llama-3.2-3B-Instruct --load-format runai_streamer - :::{note} For further instructions about tunable parameters and additional parameters configurable through environment variables, read the [Environment Variables Documentation](https://github.com/run-ai/runai-model-streamer/blob/master/docs/src/env-vars.md). ::: + +## Sharded Model Loading + +vLLM also supports loading sharded models using Run:ai Model Streamer. This is particularly useful for large models that are split across multiple files. To use this feature, use the `--load-format runai_streamer_sharded` flag: + +```console +vllm serve /path/to/sharded/model --load-format runai_streamer_sharded +``` + +The sharded loader expects model files to follow the same naming pattern as the regular sharded state loader: `model-rank-{rank}-part-{part}.safetensors`. You can customize this pattern using the `pattern` parameter in `--model-loader-extra-config`: + +```console +vllm serve /path/to/sharded/model --load-format runai_streamer_sharded --model-loader-extra-config '{"pattern":"custom-model-rank-{rank}-part-{part}.safetensors"}' +``` + +To create sharded model files, you can use the script provided in . This script demonstrates how to save a model in the sharded format that is compatible with the Run:ai Model Streamer sharded loader. + +The sharded loader supports all the same tunable parameters as the regular Run:ai Model Streamer, including `concurrency` and `memory_limit`. These can be configured in the same way: + +```console +vllm serve /path/to/sharded/model --load-format runai_streamer_sharded --model-loader-extra-config '{"concurrency":16, "memory_limit":5368709120}' +``` + +:::{note} +The sharded loader is particularly efficient for tensor or pipeline parallel models where each worker only needs to read its own shard rather than the entire checkpoint. +::: diff --git a/docs/source/models/generative_models.md b/docs/source/models/generative_models.md index 63fc53b0e7c55c416ffe65f86fa1a255673b3b3b..3291006ed668ca9bd0c95999740bbd5ad1be54bf 100644 --- a/docs/source/models/generative_models.md +++ b/docs/source/models/generative_models.md @@ -59,7 +59,7 @@ A code example can be found here: ]}` (offline) or `--hf_overrides '{"is_matryoshka": true}'`, `--hf_overrides '{"matryoshka_dimensions": []}'`(online). + +Here is an example to serve a model with Matryoshka Embeddings enabled. + +```text +vllm serve Snowflake/snowflake-arctic-embed-m-v1.5 --hf_overrides '{"matryoshka_dimensions":[256]}' +``` + +### Offline Inference + +You can change the output dimensions of embedding models that support Matryoshka Embeddings by using the dimensions parameter in {class}`~vllm.PoolingParams`. + +```python +from vllm import LLM, PoolingParams + +model = LLM(model="jinaai/jina-embeddings-v3", + task="embed", + trust_remote_code=True) +outputs = model.embed(["Follow the white rabbit."], + pooling_params=PoolingParams(dimensions=32)) +print(outputs[0].outputs) +``` + +A code example can be found here: + +### Online Inference + +Use the following command to start vllm server. + +```text +vllm serve jinaai/jina-embeddings-v3 --trust-remote-code +``` + +You can change the output dimensions of embedding models that support Matryoshka Embeddings by using the dimensions parameter. + +```text +curl http://127.0.0.1:8000/v1/embeddings \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "input": "Follow the white rabbit.", + "model": "jinaai/jina-embeddings-v3", + "encoding_format": "float", + "dimensions": 32 + }' +``` + +Expected output: + +```json +{"id":"embd-5c21fc9a5c9d4384a1b021daccaf9f64","object":"list","created":1745476417,"model":"jinaai/jina-embeddings-v3","data":[{"index":0,"object":"embedding","embedding":[-0.3828125,-0.1357421875,0.03759765625,0.125,0.21875,0.09521484375,-0.003662109375,0.1591796875,-0.130859375,-0.0869140625,-0.1982421875,0.1689453125,-0.220703125,0.1728515625,-0.2275390625,-0.0712890625,-0.162109375,-0.283203125,-0.055419921875,-0.0693359375,0.031982421875,-0.04052734375,-0.2734375,0.1826171875,-0.091796875,0.220703125,0.37890625,-0.0888671875,-0.12890625,-0.021484375,-0.0091552734375,0.23046875]}],"usage":{"prompt_tokens":8,"total_tokens":8,"completion_tokens":0,"prompt_tokens_details":null}} +``` + +A openai client example can be found here: diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 2e02bcfc64e02bf0ba6b720168a944c2b35f745c..bc68e34832ccb116826f4c1d5a69b26a3bf58952 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -40,29 +40,37 @@ You can force the use of `TransformersForCausalLM` by setting `model_impl="trans vLLM may not fully optimise the Transformers implementation so you may see degraded performance if comparing a native model to a Transformers model in vLLM. ::: -#### Supported features +#### Custom models -The Transformers modeling backend explicitly supports the following features: +If a model is neither supported natively by vLLM or Transformers, it can still be used in vLLM! -- (except GGUF) -- -- +For a model to be compatible with the Transformers backend for vLLM it must: -#### Remote Code +- be a Transformers compatible custom model (see [Transformers - Customizing models](https://huggingface.co/docs/transformers/en/custom_models)): + * The model directory must have the correct structure (e.g. `config.json` is present). + * `config.json` must contain `auto_map.AutoModel`. +- be a Transformers backend for vLLM compatible model (see ): + * Customisation should be done in the base model (e.g. in `MyModel`, not `MyModelForCausalLM`). -If your model is neither supported natively by vLLM or Transformers, you can still run it in vLLM! +If the compatible model is: -Simply set `trust_remote_code=True` and vLLM will run any model on the Model Hub that is compatible with Transformers. -Provided that the model writer implements their model in a compatible way, this means that you can run new models before they are officially supported in Transformers or vLLM! +- on the Hugging Face Model Hub, simply set `trust_remote_code=True` for or `--trust-remode-code` for the . +- in a local directory, simply pass directory path to `model=` for or `vllm serve ` for the . -```python -from vllm import LLM -llm = LLM(model=..., task="generate", trust_remote_code=True) # Name or path of your model -llm.apply_model(lambda model: print(model.__class__)) -``` +This means that, with the Transformers backend for vLLM, new models can be used before they are officially supported in Transformers or vLLM! + +(writing-custom-models)= + +#### Writing custom models + +This section details the necessary modifications to make to a Transformers compatible custom model that make it compatible with the Transformers backend for vLLM. (We assume that a Transformers compatible custom model has already been created, see [Transformers - Customizing models](https://huggingface.co/docs/transformers/en/custom_models)). To make your model compatible with the Transformers backend, it needs: +1. `kwargs` passed down through all modules from `MyModel` to `MyAttention`. +2. `MyAttention` must use `ALL_ATTENTION_FUNCTIONS` to call attention. +3. `MyModel` must contain `_supports_attention_backend = True`. + ```{code-block} python :caption: modeling_my_model.py @@ -71,7 +79,7 @@ from torch import nn class MyAttention(nn.Module): - def forward(self, hidden_states, **kwargs): # <- kwargs are required + def forward(self, hidden_states, **kwargs): ... attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -87,11 +95,11 @@ class MyModel(PreTrainedModel): _supports_attention_backend = True ``` -Here is what happens in the background: +Here is what happens in the background when this model is loaded: -1. The config is loaded -2. `MyModel` Python class is loaded from the `auto_map`, and we check that the model `_supports_attention_backend`. -3. The `TransformersForCausalLM` backend is used. See , which leverage `self.config._attn_implementation = "vllm"`, thus the need to use `ALL_ATTENTION_FUNCTION`. +1. The config is loaded. +2. `MyModel` Python class is loaded from the `auto_map` in config, and we check that the model `is_backend_compatible()`. +3. `MyModel` is loaded into `TransformersForCausalLM` (see ) which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used. That's it! @@ -129,7 +137,7 @@ class MyConfig(PretrainedConfig): ### Hugging Face Hub -By default, vLLM loads models from [Hugging Face (HF) Hub](https://huggingface.co/models). +By default, vLLM loads models from [Hugging Face (HF) Hub](https://huggingface.co/models). To change the download path for models, you can set the `HF_HOME` environment variable; for more details, refer to [their official documentation](https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hfhome). To determine whether a given model is natively supported, you can check the `config.json` file inside the HF repository. If the `"architectures"` field contains a model architecture listed below, then it should be natively supported. @@ -213,6 +221,16 @@ output = llm.encode("Hello, my name is") print(output) ``` +(feature-status-legend)= + +## Feature Status Legend + +- ✅︎ indicates that the feature is supported for the model. + +- 🚧 indicates that the feature is planned but not yet supported for the model. + +- ⚠️ indicates that the feature is available but may have known issues or limitations. + (supported-text-models)= ## List of Text-only Language Models @@ -314,7 +332,7 @@ See [this page](#generative-models) for more information on how to use generativ * ✅︎ - * `GemmaForCausalLM` * Gemma - * `google/gemma-2b`, `google/gemma-7b`, etc. + * `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. * ✅︎ * ✅︎ - * `Gemma2ForCausalLM` @@ -497,6 +515,11 @@ See [this page](#generative-models) for more information on how to use generativ * `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. * * ✅︎ +- * `Plamo2ForCausalLM` + * PLaMo2 + * `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. + * + * - * `QWenLMHeadModel` * Qwen * `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. @@ -735,6 +758,11 @@ If your model is not in the above list, we will try to automatically convert the * `BAAI/bge-reranker-v2-m3`, etc. * * +- * `ModernBertForSequenceClassification` + * ModernBert-based + * `Alibaba-NLP/gte-reranker-modernbert-base`, etc. + * + * ::: (supported-mm-models)= @@ -765,6 +793,8 @@ or `--limit-mm-per-prompt` (online serving). For example, to enable passing up t Offline inference: ```python +from vllm import LLM + llm = LLM( model="Qwen/Qwen2-VL-7B-Instruct", limit_mm_per_prompt={"image": 4}, @@ -774,7 +804,7 @@ llm = LLM( Online serving: ```bash -vllm serve Qwen/Qwen2-VL-7B-Instruct --limit-mm-per-prompt image=4 +vllm serve Qwen/Qwen2-VL-7B-Instruct --limit-mm-per-prompt '{"image":4}' ``` **This is no longer required if you are using vLLM V1.** @@ -865,6 +895,13 @@ See [this page](#generative-models) for more information on how to use generativ * ✅︎ * ✅︎ * ✅︎ +- * `GraniteSpeechForConditionalGeneration` + * Granite Speech + * T + A + * `ibm-granite/granite-speech-3.3-8b` + * ✅︎ + * ✅︎ + * ✅︎ - * `H2OVLChatModel` * H2OVL * T + IE+ @@ -886,6 +923,13 @@ See [this page](#generative-models) for more information on how to use generativ * * ✅︎ * ✅︎ +- * `KimiVLForConditionalGeneration` + * Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking + * T + I+ + * `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` + * + * + * ✅︎ - * `Llama4ForConditionalGeneration` * Llama 4 * T + I+ @@ -990,7 +1034,7 @@ See [this page](#generative-models) for more information on how to use generativ * `microsoft/Phi-4-multimodal-instruct`, etc. * ✅︎ * - * + * ✅︎ - * `PixtralForConditionalGeneration` * Pixtral * T + I+ @@ -1026,6 +1070,13 @@ See [this page](#generative-models) for more information on how to use generativ * ✅︎ * ✅︎ * ✅︎ +- * `Qwen2_5OmniThinkerForConditionalGeneration` + * Qwen2.5-Omni + * T + IE+ + VE+ + A+ + * `Qwen/Qwen2.5-Omni-7B` + * + * ✅︎ + * ✅︎\* - * `SkyworkR1VChatModel` * Skywork-R1V-38B * T + I @@ -1057,7 +1108,7 @@ See [this page](#generative-models) for more information on how to use generativ :::{important} Pan-and-scan image pre-processing is currently supported on V0 (but not V1). -You can enable it by passing `--mm-processor-kwargs '{"do_pan_and_scan": True}'`. +You can enable it by passing `--mm-processor-kwargs '{"do_pan_and_scan": true}'`. ::: :::{warning} @@ -1072,7 +1123,7 @@ V0 correctly implements the model's attention pattern: V1 currently uses a simplified attention pattern: - Uses causal attention for all tokens, including image tokens -- Generates reasonable outputs but does not match the original model's attention for text + image inputs, especially when `{"do_pan_and_scan": True}` +- Generates reasonable outputs but does not match the original model's attention for text + image inputs, especially when `{"do_pan_and_scan": true}` - Will be updated in the future to support the correct behavior This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends. @@ -1086,6 +1137,36 @@ This limitation exists because the model's mixed attention pattern (bidirectiona To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM. ::: +:::{warning} +The output quality of `AllenAI/Molmo-7B-D-0924` (especially in object localization tasks) has deteriorated in recent updates. + +For the best results, we recommend using the following dependency versions (tested on A10 and L40): + +```text +# Core vLLM-compatible dependencies with Molmo accuracy setup (tested on L40) +torch==2.5.1 +torchvision==0.20.1 +transformers==4.48.1 +tokenizers==0.21.0 +tiktoken==0.7.0 +vllm==0.7.0 + +# Optional but recommended for improved performance and stability +triton==3.1.0 +xformers==0.0.28.post3 +uvloop==0.21.0 +protobuf==5.29.3 +openai==1.60.2 +opencv-python-headless==4.11.0.86 +pillow==10.4.0 + +# Installed FlashAttention (for float16 only) +flash-attn>=2.5.6 # Not used in float32, but should be documented +``` + +**Note:** Make sure you understand the security implications of using outdated packages. +::: + :::{note} The official `openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (`HwwwH/MiniCPM-V-2`) for now. For more details, please see: @@ -1095,6 +1176,14 @@ For more details, please see: Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1. ::: +:::{note} +To use Qwen2.5-Omni, you have to install Hugging Face Transformers library from source via +`pip install git+https://github.com/huggingface/transformers.git`. + +Read audio from video pre-processing is currently supported on V0 (but not V1), because overlapping modalities is not yet supported in V1. +`--mm-processor-kwargs '{"use_audio_in_video": true}'`. +::: + ### Pooling Models See [this page](pooling-models) for more information on how to use pooling models. diff --git a/docs/source/serving/distributed_serving.md b/docs/source/serving/distributed_serving.md index 591acc2c9b753ca4498a450bbf9698f5eb3cc9b4..c285ef3e8e1c13e91f7e3bdabad5c45131462a4f 100644 --- a/docs/source/serving/distributed_serving.md +++ b/docs/source/serving/distributed_serving.md @@ -77,6 +77,10 @@ bash run_cluster.sh \ Then you get a ray cluster of **containers**. Note that you need to keep the shells running these commands alive to hold the cluster. Any shell disconnect will terminate the cluster. In addition, please note that the argument `ip_of_head_node` should be the IP address of the head node, which is accessible by all the worker nodes. The IP addresses of each worker node should be specified in the `VLLM_HOST_IP` environment variable, and should be different for each worker node. Please check the network configuration of your cluster to make sure the nodes can communicate with each other through the specified IP addresses. +:::{warning} +It is considered best practice to set `VLLM_HOST_IP` to an address on a private network segment for the vLLM cluster. The traffic sent here is not encrypted. The endpoints are also exchanging data in a format that could be exploited to execute arbitrary code should a malicious party gain access to the network. Please ensure that this network is not reachable by any untrusted parties. +::: + :::{warning} Since this is a ray cluster of **containers**, all the following commands should be executed in the **containers**, otherwise you are executing the commands on the host machine, which is not connected to the ray cluster. To enter the container, you can use `docker exec -it node /bin/bash`. ::: diff --git a/docs/source/serving/engine_args.md b/docs/source/serving/engine_args.md index e9943571a40a123c747404a778de31822b322593..97ea01cd3b2e66c35cc675a50478740935021b37 100644 --- a/docs/source/serving/engine_args.md +++ b/docs/source/serving/engine_args.md @@ -16,6 +16,7 @@ Below, you can find an explanation of every engine argument: :func: _engine_args_parser :prog: vllm serve :nodefaultconst: + :markdownhelp: ``` ## Async Engine Arguments @@ -29,4 +30,5 @@ Additional arguments are available to the asynchronous engine which is used for :func: _async_engine_args_parser :prog: vllm serve :nodefaultconst: + :markdownhelp: ``` diff --git a/docs/source/serving/multimodal_inputs.md b/docs/source/serving/multimodal_inputs.md index f45d36c3ccaca830eeab2dc0e1fe35c2d8732f95..d9a093e8d145d2facb9438212efc9a8dfbb806ae 100644 --- a/docs/source/serving/multimodal_inputs.md +++ b/docs/source/serving/multimodal_inputs.md @@ -228,7 +228,7 @@ First, launch the OpenAI-compatible server: ```bash vllm serve microsoft/Phi-3.5-vision-instruct --task generate \ - --trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2 + --trust-remote-code --max-model-len 4096 --limit-mm-per-prompt '{"image":2}' ``` Then, you can use the OpenAI client as follows: diff --git a/docs/source/serving/offline_inference.md b/docs/source/serving/offline_inference.md index 85f2cafacdd3870e6ed2fd6b0b4f5dc7d7339e9d..894878ed14e764c814af965d6bdcacf06b1d8155 100644 --- a/docs/source/serving/offline_inference.md +++ b/docs/source/serving/offline_inference.md @@ -28,6 +28,8 @@ Please refer to the above pages for more details about each API. [API Reference](/api/offline_inference/index) ::: +(configuration-options)= + ## Configuration Options This section lists the most common options for running the vLLM engine. @@ -59,6 +61,8 @@ model = LLM( Our [list of supported models](#supported-models) shows the model architectures that are recognized by vLLM. +(reducing-memory-usage)= + ### Reducing memory usage Large models might cause your machine to run out of memory (OOM). Here are some options that help alleviate this problem. @@ -81,6 +85,12 @@ before initializing vLLM. Otherwise, you may run into an error like `RuntimeErro To control which devices are used, please instead set the `CUDA_VISIBLE_DEVICES` environment variable. ::: +:::{note} +With tensor parallelism enabled, each process will read the whole model and split it into chunks, which makes the disk reading time even longer (proportional to the size of tensor parallelism). + +You can convert the model checkpoint to a sharded checkpoint using . The conversion process might take some time, but later you can load the sharded checkpoint much faster. The model loading time should remain constant regardless of the size of tensor parallelism. +::: + #### Quantization Quantized models take less memory at the cost of lower precision. @@ -103,6 +113,39 @@ llm = LLM(model="adept/fuyu-8b", max_num_seqs=2) ``` +#### Reduce CUDA Graphs + +By default, we optimize model inference using CUDA graphs which take up extra memory in the GPU. + +:::{important} +CUDA graph capture takes up more memory in V1 than in V0. +::: + +You can adjust `compilation_config` to achieve a better balance between inference speed and memory usage: + +```python +from vllm import LLM +from vllm.config import CompilationConfig, CompilationLevel + +llm = LLM( + model="meta-llama/Llama-3.1-8B-Instruct", + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + # By default, it goes up to max_num_seqs + cudagraph_capture_sizes=[1, 2, 4, 8, 16], + ), +) +``` + +You can disable graph capturing completely via the `enforce_eager` flag: + +```python +from vllm import LLM + +llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", + enforce_eager=True) +``` + #### Adjust cache size If you run out of CPU RAM, try the following options: @@ -110,16 +153,25 @@ If you run out of CPU RAM, try the following options: - (Multi-modal models only) you can set the size of multi-modal input cache using `VLLM_MM_INPUT_CACHE_GIB` environment variable (default 4 GiB). - (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB). -#### Disable unused modalities +#### Multi-modal input limits -You can disable unused modalities (except for text) by setting its limit to zero. +You can allow a smaller number of multi-modal items per prompt to reduce the memory footprint of the model: + +```python +from vllm import LLM + +# Accept up to 3 images and 1 video per prompt +llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", + limit_mm_per_prompt={"image": 3, "video": 1}) +``` +You can go a step further and disable unused modalities completely by setting its limit to zero. For example, if your application only accepts image input, there is no need to allocate any memory for videos. ```python from vllm import LLM -# Accept images but not videos +# Accept any number of images but no videos llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", limit_mm_per_prompt={"video": 0}) ``` @@ -134,6 +186,29 @@ llm = LLM(model="google/gemma-3-27b-it", limit_mm_per_prompt={"image": 0}) ``` +#### Multi-modal processor arguments + +For certain models, you can adjust the multi-modal processor arguments to +reduce the size of the processed multi-modal inputs, which in turn saves memory. + +Here are some examples: + +```python +from vllm import LLM + +# Available for Qwen2-VL series models +llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", + mm_processor_kwargs={ + "max_pixels": 768 * 768, # Default is 1280 * 28 * 28 + }) + +# Available for InternVL series models +llm = LLM(model="OpenGVLab/InternVL2-2B", + mm_processor_kwargs={ + "max_dynamic_patch": 4, # Default is 12 + }) +``` + ### Performance optimization and tuning You can potentially improve the performance of vLLM by finetuning various options. diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 11ca571c684a109796fa1c68bb4625a63bc7e3a3..34382c87a484b8721a4de1a8d8fbc65f76e0e13b 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -33,11 +33,13 @@ print(completion.choices[0].message) vLLM supports some parameters that are not supported by OpenAI, `top_k` for example. You can pass these parameters to vLLM using the OpenAI client in the `extra_body` parameter of your requests, i.e. `extra_body={"top_k": 50}` for `top_k`. ::: + :::{important} By default, the server applies `generation_config.json` from the Hugging Face model repository if it exists. This means the default values of certain sampling parameters can be overridden by those recommended by the model creator. To disable this behavior, please pass `--generation-config vllm` when launching the server. ::: + ## Supported APIs We currently support the following OpenAI APIs: @@ -172,6 +174,12 @@ print(completion._request_id) The `vllm serve` command is used to launch the OpenAI-compatible server. +:::{tip} +The vast majority of command-line arguments are based on those for offline inference. + +See [here](configuration-options) for some common options. +::: + :::{argparse} :module: vllm.entrypoints.openai.cli_args :func: create_parser_for_docs @@ -394,9 +402,26 @@ you can use the [official OpenAI Python client](https://github.com/openai/openai To use the Transcriptions API, please install with extra audio dependencies using `pip install vllm[audio]`. ::: +Code example: -Code example: +#### Extra Parameters + +The following [sampling parameters](#sampling-params) are supported. + +:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py +:language: python +:start-after: begin-transcription-sampling-params +:end-before: end-transcription-sampling-params +::: + +The following extra parameters are supported: + +:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py +:language: python +:start-after: begin-transcription-extra-params +:end-before: end-transcription-extra-params +::: (tokenizer-api)= diff --git a/examples/lmcache/README.md b/examples/lmcache/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7d0c23f529bb2173ba926706a68874499419b9b7 --- /dev/null +++ b/examples/lmcache/README.md @@ -0,0 +1,56 @@ +# LMCache Examples + +This folder demonstrates how to use LMCache for disaggregated prefilling, CPU offloading and KV cache sharing. + +## 1. Disaggregated Prefill in vLLM v1 + +This example demonstrates how to run LMCache with disaggregated prefill using NIXL on a single node. + +### Prerequisites + +- Install [LMCache](https://github.com/LMCache/LMCache). You can simply run `pip install lmcache`. +- Install [NIXL](https://github.com/ai-dynamo/nixl). +- At least 2 GPUs +- Valid Hugging Face token (HF_TOKEN) for Llama 3.1 8B Instruct. + +### Usage + +Run +`cd disagg_prefill_lmcache_v1` +to get into `disagg_prefill_lmcache_v1` folder, and then run + +```bash +bash disagg_example_nixl.sh +``` + +to run disaggregated prefill and benchmark the performance. + +### Components + +#### Server Scripts +- `disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh` - Launches individual vLLM servers for prefill/decode, and also launches the proxy server. +- `disagg_prefill_lmcache_v1/disagg_proxy_server.py` - FastAPI proxy server that coordinates between prefiller and decoder +- `disagg_prefill_lmcache_v1/disagg_example_nixl.sh` - Main script to run the example + +#### Configuration +- `disagg_prefill_lmcache_v1/configs/lmcache-prefiller-config.yaml` - Configuration for prefiller server +- `disagg_prefill_lmcache_v1/configs/lmcache-decoder-config.yaml` - Configuration for decoder server + +#### Log Files +The main script generates several log files: +- `prefiller.log` - Logs from the prefill server +- `decoder.log` - Logs from the decode server +- `proxy.log` - Logs from the proxy server + +## 2. CPU Offload Examples + +- `cpu_offload_lmcache_v0.py` - CPU offloading implementation for vLLM v0 +- `cpu_offload_lmcache_v1.py` - CPU offloading implementation for vLLM v1 + +## 3. KV Cache Sharing + +The `kv_cache_sharing_lmcache_v1.py` example demonstrates how to share KV caches between vLLM v1 instances. + +## 4. Disaggregated Prefill in vLLM v0 + +The `disaggregated_prefill_lmcache_v0.py` provides an example of how to run disaggregated prefill in vLLM v0. diff --git a/examples/lmcache/cpu_offload_lmcache_v0.py b/examples/lmcache/cpu_offload_lmcache_v0.py new file mode 100644 index 0000000000000000000000000000000000000000..37aea281032fd8b826c79182f3d28d72f49f87e1 --- /dev/null +++ b/examples/lmcache/cpu_offload_lmcache_v0.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This file demonstrates the example usage of cpu offloading +with LMCache. + +Note that `lmcache` is needed to run this example. +Requirements: Linux, Python: 3.10 or higher, CUDA: 12.1 +Learn more about LMCache environment setup, please refer to: +https://docs.lmcache.ai/getting_started/installation.html +""" +import contextlib +import os +import time + +from lmcache.experimental.cache_engine import LMCacheEngineBuilder +from lmcache.integration.vllm.utils import ENGINE_NAME + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + + +def setup_environment_variables(): + # LMCache-related environment variables + # Use experimental features in LMCache + os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True" + # LMCache is set to use 256 tokens per chunk + os.environ["LMCACHE_CHUNK_SIZE"] = "256" + # Enable local CPU backend in LMCache + os.environ["LMCACHE_LOCAL_CPU"] = "True" + # Set local CPU memory limit to 5.0 GB + os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0" + + +@contextlib.contextmanager +def build_llm_with_lmcache(): + ktc = KVTransferConfig.from_cli( + '{"kv_connector":"LMCacheConnector", "kv_role":"kv_both"}') + # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB + # memory. Reduce the value if your GPU has less memory. + # Note: LMCache supports chunked prefill (see vLLM#14505, LMCache#392). + llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", + kv_transfer_config=ktc, + max_model_len=8000, + enable_chunked_prefill=True, + gpu_memory_utilization=0.8) + + try: + yield llm + finally: + # Clean up lmcache backend + LMCacheEngineBuilder.destroy(ENGINE_NAME) + + +def print_output( + llm: LLM, + prompt: list[str], + sampling_params: SamplingParams, + req_str: str, +): + start = time.time() + outputs = llm.generate(prompt, sampling_params) + print("-" * 50) + for output in outputs: + generated_text = output.outputs[0].text + print(f"Generated text: {generated_text!r}") + print(f"Generation took {time.time() - start:.2f} seconds, " + f"{req_str} request done.") + print("-" * 50) + + +def main(): + setup_environment_variables() + + with build_llm_with_lmcache() as llm: + + # This example script runs two requests with a shared prefix. + # Define the shared prompt and specific prompts + shared_prompt = "Hello, how are you?" * 1000 + first_prompt = [ + shared_prompt + "Hello, my name is", + ] + second_prompt = [ + shared_prompt + "Tell me a very long story", + ] + + sampling_params = SamplingParams(temperature=0, + top_p=0.95, + max_tokens=10) + + # Print the first output + print_output(llm, first_prompt, sampling_params, "first") + + time.sleep(1) + + # print the second output + print_output(llm, second_prompt, sampling_params, "second") + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/cpu_offload_lmcache.py b/examples/lmcache/cpu_offload_lmcache_v1.py similarity index 76% rename from examples/offline_inference/cpu_offload_lmcache.py rename to examples/lmcache/cpu_offload_lmcache_v1.py index 8211629b24ecce19517649b68537a008293d3813..f44075a36965fc12e677bad2d69a4d0797378b05 100644 --- a/examples/offline_inference/cpu_offload_lmcache.py +++ b/examples/lmcache/cpu_offload_lmcache_v1.py @@ -1,13 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 """ This file demonstrates the example usage of cpu offloading -with LMCache. +with LMCache in vLLM v1. -Note that `pip install lmcache` is needed to run this example. +Note that lmcache needs to be installed to run this example. Learn more about LMCache in https://github.com/LMCache/LMCache. """ import os -import time from lmcache.experimental.cache_engine import LMCacheEngineBuilder from lmcache.integration.vllm.utils import ENGINE_NAME @@ -37,29 +36,22 @@ second_prompt = [ sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) ktc = KVTransferConfig.from_cli( - '{"kv_connector":"LMCacheConnector", "kv_role":"kv_both"}') + '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}') # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB # memory. Reduce the value if your GPU has less memory. # Note that LMCache is not compatible with chunked prefill for now. -llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", +llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct", kv_transfer_config=ktc, max_model_len=8000, - enable_chunked_prefill=False, gpu_memory_utilization=0.8) +# Should be able to see logs like the following: +# `LMCache INFO: Storing KV cache for 6006 out of 6006 tokens for request 0` +# This indicates that the KV cache has been stored in LMCache. outputs = llm.generate(first_prompt, sampling_params) for output in outputs: generated_text = output.outputs[0].text print(f"Generated text: {generated_text!r}") -print("First request done.") - -time.sleep(1) - -outputs = llm.generate(second_prompt, sampling_params) -for output in outputs: - generated_text = output.outputs[0].text - print(f"Generated text: {generated_text!r}") -print("Second request done.") # Clean up lmcache backend LMCacheEngineBuilder.destroy(ENGINE_NAME) diff --git a/examples/offline_inference/disaggregated_prefill_lmcache.py b/examples/lmcache/disagg_prefill_lmcache_v0.py similarity index 98% rename from examples/offline_inference/disaggregated_prefill_lmcache.py rename to examples/lmcache/disagg_prefill_lmcache_v0.py index 5c84bbfc92c53c68b87fbbcda6990e8e0c6101b9..7da6fb7aaa230fbcc56983a9da4dfad818737f92 100644 --- a/examples/offline_inference/disaggregated_prefill_lmcache.py +++ b/examples/lmcache/disagg_prefill_lmcache_v0.py @@ -38,6 +38,10 @@ os.environ["LMCACHE_REMOTE_URL"] = f"lm://localhost:{port}" # `naive` indicates using raw bytes of the tensor without any compression os.environ["LMCACHE_REMOTE_SERDE"] = "naive" +prompts = [ + "Hello, how are you?" * 1000, +] + def run_prefill(prefill_done, prompts): # We use GPU 0 for prefill node. @@ -106,12 +110,7 @@ def run_lmcache_server(port): return server_proc -if __name__ == "__main__": - - prompts = [ - "Hello, how are you?" * 1000, - ] - +def main(): prefill_done = Event() prefill_process = Process(target=run_prefill, args=(prefill_done, prompts)) decode_process = Process(target=run_decode, args=(prefill_done, prompts)) @@ -128,3 +127,7 @@ if __name__ == "__main__": prefill_process.terminate() lmcache_server_process.terminate() lmcache_server_process.wait() + + +if __name__ == "__main__": + main() diff --git a/examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-decoder-config.yaml b/examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-decoder-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c3f5a0ae69c061c132ce27c48446bc9314a30473 --- /dev/null +++ b/examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-decoder-config.yaml @@ -0,0 +1,13 @@ +local_cpu: False +max_local_cpu_size: 0 +#local_disk: +max_local_disk_size: 0 +remote_serde: NULL + +enable_nixl: True +nixl_role: "receiver" +nixl_peer_host: "localhost" +nixl_peer_port: 55555 +nixl_buffer_size: 1073741824 # 1GB +nixl_buffer_device: "cuda" +nixl_enable_gc: True diff --git a/examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-prefiller-config.yaml b/examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-prefiller-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8b0e82958a64c2aeb111178331b485e866b775e2 --- /dev/null +++ b/examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-prefiller-config.yaml @@ -0,0 +1,13 @@ +local_cpu: False +max_local_cpu_size: 0 +#local_disk: +max_local_disk_size: 0 +remote_serde: NULL + +enable_nixl: True +nixl_role: "sender" +nixl_peer_host: "localhost" +nixl_peer_port: 55555 +nixl_buffer_size: 1073741824 # 1GB +nixl_buffer_device: "cuda" +nixl_enable_gc: True diff --git a/examples/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh new file mode 100644 index 0000000000000000000000000000000000000000..df8a41293504908016d3710ff5fa713805c55fee --- /dev/null +++ b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh @@ -0,0 +1,136 @@ +#!/bin/bash + +echo "Warning: LMCache disaggregated prefill support for vLLM v1 is experimental and subject to change." + + +PIDS=() + +# Switch to the directory of the current script +cd "$(dirname "${BASH_SOURCE[0]}")" + +check_hf_token() { + if [ -z "$HF_TOKEN" ]; then + echo "HF_TOKEN is not set. Please set it to your Hugging Face token." + exit 1 + fi + if [[ "$HF_TOKEN" != hf_* ]]; then + echo "HF_TOKEN is not a valid Hugging Face token. Please set it to your Hugging Face token." + exit 1 + fi + echo "HF_TOKEN is set and valid." +} + +check_num_gpus() { + # can you check if the number of GPUs are >=2 via nvidia-smi? + num_gpus=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) + if [ "$num_gpus" -lt 2 ]; then + echo "You need at least 2 GPUs to run disaggregated prefill." + exit 1 + else + echo "Found $num_gpus GPUs." + fi +} + +ensure_python_library_installed() { + echo "Checking if $1 is installed..." + python -c "import $1" > /dev/null 2>&1 + if [ $? -ne 0 ]; then + if [ "$1" == "nixl" ]; then + echo "$1 is not installed. Please refer to https://github.com/ai-dynamo/nixl for installation." + else + echo "$1 is not installed. Please install it via pip install $1." + fi + exit 1 + else + echo "$1 is installed." + fi +} + +cleanup() { + echo "Stopping everything…" + trap - INT TERM # prevent re-entrancy + kill -- -$$ # negative PID == “this whole process-group” + wait # reap children so we don't leave zombies + exit 0 +} + +wait_for_server() { + local port=$1 + local timeout_seconds=1200 + local start_time=$(date +%s) + + echo "Waiting for server on port $port..." + + while true; do + if curl -s "localhost:${port}/v1/completions" > /dev/null; then + return 0 + fi + + local now=$(date +%s) + if (( now - start_time >= timeout_seconds )); then + echo "Timeout waiting for server" + return 1 + fi + + sleep 1 + done +} + + +main() { + check_hf_token + check_num_gpus + ensure_python_library_installed lmcache + ensure_python_library_installed nixl + ensure_python_library_installed pandas + ensure_python_library_installed datasets + ensure_python_library_installed vllm + + trap cleanup INT + trap cleanup USR1 + trap cleanup TERM + + echo "Launching prefiller, decoder and proxy..." + echo "Please check prefiller.log, decoder.log and proxy.log for logs." + + bash disagg_vllm_launcher.sh prefiller \ + > >(tee prefiller.log) 2>&1 & + prefiller_pid=$! + PIDS+=($prefiller_pid) + + bash disagg_vllm_launcher.sh decoder \ + > >(tee decoder.log) 2>&1 & + decoder_pid=$! + PIDS+=($decoder_pid) + + python3 disagg_proxy_server.py \ + --host localhost \ + --port 9000 \ + --prefiller-host localhost \ + --prefiller-port 8100 \ + --decoder-host localhost \ + --decoder-port 8200 \ + > >(tee proxy.log) 2>&1 & + proxy_pid=$! + PIDS+=($proxy_pid) + + wait_for_server 8100 + wait_for_server 8200 + wait_for_server 9000 + + echo "All servers are up. Starting benchmark..." + + # begin benchmark + cd ../../../benchmarks/ + python benchmark_serving.py --port 9000 --seed $(date +%s) \ + --model meta-llama/Llama-3.1-8B-Instruct \ + --dataset-name random --random-input-len 7500 --random-output-len 200 \ + --num-prompts 200 --burstiness 100 --request-rate 3.6 | tee benchmark.log + + echo "Benchmarking done. Cleaning up..." + + cleanup + +} + +main \ No newline at end of file diff --git a/examples/lmcache/disagg_prefill_lmcache_v1/disagg_proxy_server.py b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_proxy_server.py new file mode 100644 index 0000000000000000000000000000000000000000..8db93bc8931b2ef7cb87788c6de2748b4b5bc0bf --- /dev/null +++ b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_proxy_server.py @@ -0,0 +1,193 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import os +import time +from contextlib import asynccontextmanager + +import httpx +import numpy as np +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + Lifespan context manager to handle startup and shutdown events. + """ + # Startup: Initialize clients + prefiller_base_url = f'http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1' + decoder_base_url = f'http://{global_args.decoder_host}:{global_args.decoder_port}/v1' + + app.state.prefill_client = httpx.AsyncClient(timeout=None, + base_url=prefiller_base_url) + app.state.decode_client = httpx.AsyncClient(timeout=None, + base_url=decoder_base_url) + + yield + + # Shutdown: Close clients + await app.state.prefill_client.aclose() + await app.state.decode_client.aclose() + + +# Update FastAPI app initialization to use lifespan +app = FastAPI(lifespan=lifespan) + + +class StatsCalculator: + + def __init__(self): + self._stats = [] + self._last_log_time = time.time() + + def add(self, value): + self._stats.append(value) + if time.time() - self._last_log_time > 5: + self._log_stats() + self._last_log_time = time.time() + + def _log_stats(self): + # Print average, median, and 99th percentile + np_arr = np.array(self._stats) + output_str = f"\nNum requests: {len(self._stats)}" + \ + "\nPrefill node TTFT stats:" + \ + f"\n - Average (ms): {np.mean(np_arr)}" + \ + f"\n - Median (ms): {np.median(np_arr)}" + \ + f"\n - 99th Percentile (ms): {np.percentile(np_arr, 99)}\n" + print("===============================", output_str, + "===============================") + + +stats_calculator = StatsCalculator() +counter = 0 + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--prefiller-host", type=str, default="localhost") + parser.add_argument("--prefiller-port", type=int, default=8100) + parser.add_argument("--decoder-host", type=str, default="localhost") + parser.add_argument("--decoder-port", type=int, default=8200) + args = parser.parse_args() + return args + + +# Initialize variables to hold the persistent clients +app.state.prefill_client = None +app.state.decode_client = None + + +async def send_request_to_service(client: httpx.AsyncClient, endpoint: str, + req_data: dict): + """ + Send a request to a service using a persistent client. + """ + req_data = req_data.copy() + req_data['max_tokens'] = 1 + if 'max_completion_tokens' in req_data: + req_data['max_completion_tokens'] = 1 + + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + response = await client.post(endpoint, json=req_data, headers=headers) + response.raise_for_status() + return response + + +async def stream_service_response(client: httpx.AsyncClient, endpoint: str, + req_data: dict): + """ + Asynchronously stream the response from a service using a persistent client. + """ + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + async with client.stream("POST", endpoint, json=req_data, + headers=headers) as response: + response.raise_for_status() + async for chunk in response.aiter_bytes(): + yield chunk + + +@app.post("/v1/completions") +async def handle_completions(request: Request): + global counter, stats_calculator + counter += 1 + + st = time.time() + try: + req_data = await request.json() + + # Send request to prefill service, ignore the response + await send_request_to_service(app.state.prefill_client, "/completions", + req_data) + + et = time.time() + stats_calculator.add(et - st) + + # Stream response from decode service + async def generate_stream(): + async for chunk in stream_service_response(app.state.decode_client, + "/completions", + req_data): + yield chunk + + return StreamingResponse(generate_stream(), + media_type="application/json") + + except Exception as e: + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server" + " - completions endpoint") + print(e) + print("".join(traceback.format_exception(*exc_info))) + raise + + +@app.post("/v1/chat/completions") +async def handle_chat_completions(request: Request): + global counter, stats_calculator + counter += 1 + + st = time.time() + try: + req_data = await request.json() + + # Send request to prefill service, ignore the response + await send_request_to_service(app.state.prefill_client, + "/chat/completions", req_data) + + et = time.time() + stats_calculator.add(et - st) + + # Stream response from decode service + async def generate_stream(): + async for chunk in stream_service_response(app.state.decode_client, + "/chat/completions", + req_data): + yield chunk + + return StreamingResponse(generate_stream(), + media_type="application/json") + + except Exception as e: + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server " + " - chat completions endpoint") + print(e) + print("".join(traceback.format_exception(*exc_info))) + raise + + +if __name__ == '__main__': + global global_args + global_args = parse_args() + + import uvicorn + uvicorn.run(app, host=global_args.host, port=global_args.port) diff --git a/examples/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh new file mode 100644 index 0000000000000000000000000000000000000000..831ef0bb574bf1ffdce6803db2336ac27ccbd051 --- /dev/null +++ b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh @@ -0,0 +1,59 @@ +#!/bin/bash + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +if [[ $# -lt 1 ]]; then + echo "Usage: $0 [model]" + exit 1 +fi + +if [[ $# -eq 1 ]]; then + echo "Using default model: meta-llama/Llama-3.1-8B-Instruct" + MODEL="meta-llama/Llama-3.1-8B-Instruct" +else + echo "Using model: $2" + MODEL=$2 +fi + + +if [[ $1 == "prefiller" ]]; then + # Prefiller listens on port 8100 + prefill_config_file=$SCRIPT_DIR/configs/lmcache-prefiller-config.yaml + + UCX_TLS=cuda_ipc,cuda_copy,tcp \ + LMCACHE_CONFIG_FILE=$prefill_config_file \ + LMCACHE_USE_EXPERIMENTAL=True \ + VLLM_ENABLE_V1_MULTIPROCESSING=1 \ + VLLM_WORKER_MULTIPROC_METHOD=spawn \ + CUDA_VISIBLE_DEVICES=0 \ + vllm serve $MODEL \ + --port 8100 \ + --disable-log-requests \ + --enforce-eager \ + --kv-transfer-config \ + '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_producer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "producer1"}}' + + +elif [[ $1 == "decoder" ]]; then + # Decoder listens on port 8200 + decode_config_file=$SCRIPT_DIR/configs/lmcache-decoder-config.yaml + + UCX_TLS=cuda_ipc,cuda_copy,tcp \ + LMCACHE_CONFIG_FILE=$decode_config_file \ + LMCACHE_USE_EXPERIMENTAL=True \ + VLLM_ENABLE_V1_MULTIPROCESSING=1 \ + VLLM_WORKER_MULTIPROC_METHOD=spawn \ + CUDA_VISIBLE_DEVICES=1 \ + vllm serve $MODEL \ + --port 8200 \ + --disable-log-requests \ + --enforce-eager \ + --kv-transfer-config \ + '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_consumer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "consumer1"}}' + + +else + echo "Invalid role: $1" + echo "Should be either prefill, decode" + exit 1 +fi diff --git a/examples/lmcache/kv_cache_sharing_lmcache_v1.py b/examples/lmcache/kv_cache_sharing_lmcache_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..af1b4351dd54c8cd91f2cda5d0724d199ae63b2e --- /dev/null +++ b/examples/lmcache/kv_cache_sharing_lmcache_v1.py @@ -0,0 +1,130 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This file demonstrates the example usage of remote KV cache sharing +with LMCache. +We will launch 2 vllm instances, and launch an additional LMCache server. +KV cache is transferred in the following manner: +(1) vLLM instance 1 -> LMCache server (KV cache store). +(2) LMCache server -> vLLM instance 2 (KV cache reuse/retrieve). + +Note that lmcache needs to be installed to run this example. +Learn more about LMCache in https://github.com/LMCache/LMCache. +""" +import os +import subprocess +import time +from multiprocessing import Event, Process + +from lmcache.experimental.cache_engine import LMCacheEngineBuilder +from lmcache.integration.vllm.utils import ENGINE_NAME + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + +# LMCache-related environment variables +# The port to start LMCache server +port = 8100 +# Use experimental features in LMCache +os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True" +# LMCache is set to use 256 tokens per chunk +os.environ["LMCACHE_CHUNK_SIZE"] = "256" +# Disable local CPU backend in LMCache +os.environ["LMCACHE_LOCAL_CPU"] = "False" +# Set local CPU memory buffer limit to 5.0 GB +os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0" +# Set the remote URL for LMCache server +os.environ["LMCACHE_REMOTE_URL"] = f"lm://localhost:{port}" +# Set the serializer/deserializer between vllm and LMCache server +# `naive` indicates using raw bytes of the tensor without any compression +os.environ["LMCACHE_REMOTE_SERDE"] = "naive" + +prompts = [ + "Hello, how are you?" * 1000, +] + + +def run_store(store_done, prompts): + # We use GPU 0 for KV cache store process. + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) + + ktc = KVTransferConfig.from_cli( + '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}') + # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB + # memory. Reduce the value if your GPU has less memory. + llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", + kv_transfer_config=ktc, + max_model_len=8000, + gpu_memory_utilization=0.8, + enforce_eager=True) + + outputs = llm.generate(prompts, sampling_params) + for output in outputs: + generated_text = output.outputs[0].text + print(f"Generated text: {generated_text!r}") + print("KV cache store is finished.") + store_done.set() + + # Clean up lmcache backend + LMCacheEngineBuilder.destroy(ENGINE_NAME) + + +def run_retrieve(store_done, prompts, timeout=1): + # We use GPU 1 for KV cache retrieve process. + os.environ["CUDA_VISIBLE_DEVICES"] = "1" + + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) + + ktc = KVTransferConfig.from_cli( + '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}') + # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB + # of memory. Reduce the value if your GPU has less memory. + llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", + kv_transfer_config=ktc, + max_model_len=8000, + gpu_memory_utilization=0.8, + enforce_eager=True) + + print("Waiting for KV cache store to finish...") + store_done.wait() + time.sleep(timeout) + + outputs = llm.generate(prompts, sampling_params) + for output in outputs: + generated_text = output.outputs[0].text + print(f"Generated text: {generated_text!r}") + + # Clean up lmcache backend + LMCacheEngineBuilder.destroy(ENGINE_NAME) + + +def run_lmcache_server(port): + server_proc = subprocess.Popen([ + "python", "-m", "lmcache.experimental.server", "localhost", + str(port) + ]) + return server_proc + + +def main(): + store_done = Event() + store_process = Process(target=run_store, args=(store_done, prompts)) + retrieve_process = Process(target=run_retrieve, args=(store_done, prompts)) + lmcache_server_process = run_lmcache_server(port) + + # Start KV cache store process + store_process.start() + + # Start KV cache retrieve process + retrieve_process.start() + + # Clean up the processes + store_process.join() + retrieve_process.terminate() + lmcache_server_process.terminate() + lmcache_server_process.wait() + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py index 248090474de66bfebfa1c383fa0d81d3eb67cf5c..bab41c915c32d9ff9528accde96a635348462ebf 100644 --- a/examples/offline_inference/audio_language.py +++ b/examples/offline_inference/audio_language.py @@ -38,6 +38,37 @@ class ModelRequestData(NamedTuple): # Unless specified, these settings have been tested to work on a single L4. +# Granite Speech +def run_granite_speech(question: str, audio_count: int) -> ModelRequestData: + # NOTE - the setting in this example are somehat different than what is + # optimal for granite speech, and it is generally recommended to use beam + # search. Check the model README for suggested settings. + # https://huggingface.co/ibm-granite/granite-speech-3.3-8b + model_name = "ibm-granite/granite-speech-3.3-8b" + + engine_args = EngineArgs( + model=model_name, + trust_remote_code=True, + max_model_len=2048, + max_num_seqs=2, + enable_lora=True, + max_lora_rank=64, + limit_mm_per_prompt={"audio": audio_count}, + ) + + # The model has an audio-specific lora directly in its model dir; + # it should be enabled whenever you pass audio inputs to the model. + speech_lora_path = model_name + audio_placeholder = "<|audio|>" * audio_count + prompts = f"<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>{audio_placeholder}{question}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>" # noqa: E501 + + return ModelRequestData( + engine_args=engine_args, + prompt=prompts, + lora_requests=[LoRARequest("speech", 1, speech_lora_path)], + ) + + # MiniCPM-O def run_minicpmo(question: str, audio_count: int) -> ModelRequestData: model_name = "openbmb/MiniCPM-o-2_6" @@ -89,7 +120,7 @@ def run_phi4mm(question: str, audio_count: int) -> ModelRequestData: engine_args = EngineArgs( model=model_path, trust_remote_code=True, - max_model_len=4096, + max_model_len=12800, max_num_seqs=2, enable_lora=True, max_lora_rank=320, @@ -130,6 +161,36 @@ def run_qwen2_audio(question: str, audio_count: int) -> ModelRequestData: ) +# Qwen2.5-Omni +def run_qwen2_5_omni(question: str, audio_count: int): + model_name = "Qwen/Qwen2.5-Omni-7B" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=5, + limit_mm_per_prompt={"audio": audio_count}, + ) + + audio_in_prompt = "".join([ + "<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count) + ]) + + default_system = ( + "You are Qwen, a virtual human developed by the Qwen Team, Alibaba " + "Group, capable of perceiving auditory and visual inputs, as well as " + "generating text and speech.") + + prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n" + f"{audio_in_prompt}{question}<|im_end|>\n" + "<|im_start|>assistant\n") + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + ) + + # Ultravox 0.5-1B def run_ultravox(question: str, audio_count: int) -> ModelRequestData: model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b" @@ -179,14 +240,43 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData: model_example_map = { + "granite_speech": run_granite_speech, "minicpmo": run_minicpmo, "phi4_mm": run_phi4mm, "qwen2_audio": run_qwen2_audio, + "qwen2_5_omni": run_qwen2_5_omni, "ultravox": run_ultravox, "whisper": run_whisper, } +def parse_args(): + parser = FlexibleArgumentParser( + description='Demo on using vLLM for offline inference with ' + 'audio language models') + parser.add_argument('--model-type', + '-m', + type=str, + default="ultravox", + choices=model_example_map.keys(), + help='Huggingface "model_type".') + parser.add_argument('--num-prompts', + type=int, + default=1, + help='Number of prompts to run.') + parser.add_argument("--num-audios", + type=int, + default=1, + choices=[0, 1, 2], + help="Number of audio items per prompt.") + parser.add_argument("--seed", + type=int, + default=None, + help="Set the seed when initializing `vllm.LLM`.") + + return parser.parse_args() + + def main(args): model = args.model_type if model not in model_example_map: @@ -240,28 +330,5 @@ def main(args): if __name__ == "__main__": - parser = FlexibleArgumentParser( - description='Demo on using vLLM for offline inference with ' - 'audio language models') - parser.add_argument('--model-type', - '-m', - type=str, - default="ultravox", - choices=model_example_map.keys(), - help='Huggingface "model_type".') - parser.add_argument('--num-prompts', - type=int, - default=1, - help='Number of prompts to run.') - parser.add_argument("--num-audios", - type=int, - default=1, - choices=[0, 1, 2], - help="Number of audio items per prompt.") - parser.add_argument("--seed", - type=int, - default=None, - help="Set the seed when initializing `vllm.LLM`.") - - args = parser.parse_args() + args = parse_args() main(args) diff --git a/examples/offline_inference/basic/basic.py b/examples/offline_inference/basic/basic.py index 28ef8252552c0ccbade68a314cc5d7b151ff8d12..0a67374271fb1deca7667b4a989adf4cd1b6ac21 100644 --- a/examples/offline_inference/basic/basic.py +++ b/examples/offline_inference/basic/basic.py @@ -2,20 +2,22 @@ from vllm import LLM, SamplingParams -if __name__ == '__main__': - # Sample prompts. - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - # Create a sampling params object. - sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=16) +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=16) + +def main(): # Create an LLM. llm = LLM(model="facebook/opt-125m",tensor_parallel_size=1, dtype="float16",trust_remote_code=True, enforce_eager=True) - # Generate texts from the prompts. The output is a list of RequestOutput objects + # Generate texts from the prompts. + # The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) # Print the outputs. @@ -26,3 +28,8 @@ if __name__ == '__main__': print(f"Prompt: {prompt!r}") print(f"Output: {generated_text!r}") print("-" * 60) + + +if __name__ == "__main__": + main() + diff --git a/examples/offline_inference/basic/chat.py b/examples/offline_inference/basic/chat.py index 2dea45f843cf3daa9f24c1b8e319f8bbf3ba6e71..6857c6e9e31dfaac9f0fe472fc78d34f689d7be1 100644 --- a/examples/offline_inference/basic/chat.py +++ b/examples/offline_inference/basic/chat.py @@ -4,6 +4,24 @@ from vllm import LLM, EngineArgs from vllm.utils import FlexibleArgumentParser +def create_parser(): + parser = FlexibleArgumentParser() + # Add engine args + engine_group = parser.add_argument_group("Engine arguments") + EngineArgs.add_cli_args(engine_group) + engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct") + # Add sampling params + sampling_group = parser.add_argument_group("Sampling parameters") + sampling_group.add_argument("--max-tokens", type=int) + sampling_group.add_argument("--temperature", type=float) + sampling_group.add_argument("--top-p", type=float) + sampling_group.add_argument("--top-k", type=int) + # Add example params + parser.add_argument("--chat-template-path", type=str) + + return parser + + def main(args: dict): # Pop arguments not used by LLM max_tokens = args.pop("max_tokens") @@ -82,18 +100,6 @@ def main(args: dict): if __name__ == "__main__": - parser = FlexibleArgumentParser() - # Add engine args - engine_group = parser.add_argument_group("Engine arguments") - EngineArgs.add_cli_args(engine_group) - engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct") - # Add sampling params - sampling_group = parser.add_argument_group("Sampling parameters") - sampling_group.add_argument("--max-tokens", type=int) - sampling_group.add_argument("--temperature", type=float) - sampling_group.add_argument("--top-p", type=float) - sampling_group.add_argument("--top-k", type=int) - # Add example params - parser.add_argument("--chat-template-path", type=str) + parser = create_parser() args: dict = vars(parser.parse_args()) main(args) diff --git a/examples/offline_inference/basic/classify.py b/examples/offline_inference/basic/classify.py index 72c29e4c77c3093bcb45a8f74f8b3baadd4f54e8..5b6dcb41eee1c2804fc274de58d466e1ae07c8f8 100644 --- a/examples/offline_inference/basic/classify.py +++ b/examples/offline_inference/basic/classify.py @@ -6,6 +6,16 @@ from vllm import LLM, EngineArgs from vllm.utils import FlexibleArgumentParser +def parse_args(): + parser = FlexibleArgumentParser() + parser = EngineArgs.add_cli_args(parser) + # Set example specific arguments + parser.set_defaults(model="jason9693/Qwen2.5-1.5B-apeach", + task="classify", + enforce_eager=True) + return parser.parse_args() + + def main(args: Namespace): # Sample prompts. prompts = [ @@ -34,11 +44,5 @@ def main(args: Namespace): if __name__ == "__main__": - parser = FlexibleArgumentParser() - parser = EngineArgs.add_cli_args(parser) - # Set example specific arguments - parser.set_defaults(model="jason9693/Qwen2.5-1.5B-apeach", - task="classify", - enforce_eager=True) - args = parser.parse_args() + args = parse_args() main(args) diff --git a/examples/offline_inference/basic/embed.py b/examples/offline_inference/basic/embed.py index 0283909a2a84a14fe8057015afc26384e717cf93..cb5f923ffb697e3c92589384ef39d9086d432c40 100644 --- a/examples/offline_inference/basic/embed.py +++ b/examples/offline_inference/basic/embed.py @@ -6,6 +6,16 @@ from vllm import LLM, EngineArgs from vllm.utils import FlexibleArgumentParser +def parse_args(): + parser = FlexibleArgumentParser() + parser = EngineArgs.add_cli_args(parser) + # Set example specific arguments + parser.set_defaults(model="intfloat/e5-mistral-7b-instruct", + task="embed", + enforce_eager=True) + return parser.parse_args() + + def main(args: Namespace): # Sample prompts. prompts = [ @@ -34,11 +44,5 @@ def main(args: Namespace): if __name__ == "__main__": - parser = FlexibleArgumentParser() - parser = EngineArgs.add_cli_args(parser) - # Set example specific arguments - parser.set_defaults(model="intfloat/e5-mistral-7b-instruct", - task="embed", - enforce_eager=True) - args = parser.parse_args() + args = parse_args() main(args) diff --git a/examples/offline_inference/basic/generate.py b/examples/offline_inference/basic/generate.py index 93f4f2a36fac6ffcf48c51d87e91be88e25014fc..54b52b22a45a977cccfe19fc8435c963977e481b 100644 --- a/examples/offline_inference/basic/generate.py +++ b/examples/offline_inference/basic/generate.py @@ -4,6 +4,22 @@ from vllm import LLM, EngineArgs from vllm.utils import FlexibleArgumentParser +def create_parser(): + parser = FlexibleArgumentParser() + # Add engine args + engine_group = parser.add_argument_group("Engine arguments") + EngineArgs.add_cli_args(engine_group) + engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct") + # Add sampling params + sampling_group = parser.add_argument_group("Sampling parameters") + sampling_group.add_argument("--max-tokens", type=int) + sampling_group.add_argument("--temperature", type=float) + sampling_group.add_argument("--top-p", type=float) + sampling_group.add_argument("--top-k", type=int) + + return parser + + def main(args: dict): # Pop arguments not used by LLM max_tokens = args.pop("max_tokens") @@ -35,23 +51,15 @@ def main(args: dict): ] outputs = llm.generate(prompts, sampling_params) # Print the outputs. + print("-" * 50) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 50) if __name__ == "__main__": - parser = FlexibleArgumentParser() - # Add engine args - engine_group = parser.add_argument_group("Engine arguments") - EngineArgs.add_cli_args(engine_group) - engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct") - # Add sampling params - sampling_group = parser.add_argument_group("Sampling parameters") - sampling_group.add_argument("--max-tokens", type=int) - sampling_group.add_argument("--temperature", type=float) - sampling_group.add_argument("--top-p", type=float) - sampling_group.add_argument("--top-k", type=int) + parser = create_parser() args: dict = vars(parser.parse_args()) main(args) diff --git a/examples/offline_inference/basic/score.py b/examples/offline_inference/basic/score.py index 83b8253f4e257cc76e58860bbad3662d01cfad8c..d2bda8b3180c35ad979f65e60e03219e276e1956 100644 --- a/examples/offline_inference/basic/score.py +++ b/examples/offline_inference/basic/score.py @@ -6,6 +6,16 @@ from vllm import LLM, EngineArgs from vllm.utils import FlexibleArgumentParser +def parse_args(): + parser = FlexibleArgumentParser() + parser = EngineArgs.add_cli_args(parser) + # Set example specific arguments + parser.set_defaults(model="BAAI/bge-reranker-v2-m3", + task="score", + enforce_eager=True) + return parser.parse_args() + + def main(args: Namespace): # Sample prompts. text_1 = "What is the capital of France?" @@ -30,11 +40,5 @@ def main(args: Namespace): if __name__ == "__main__": - parser = FlexibleArgumentParser() - parser = EngineArgs.add_cli_args(parser) - # Set example specific arguments - parser.set_defaults(model="BAAI/bge-reranker-v2-m3", - task="score", - enforce_eager=True) - args = parser.parse_args() + args = parse_args() main(args) diff --git a/examples/offline_inference/batch_llm_inference.py b/examples/offline_inference/batch_llm_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..6548857b6d11122729fd39c9662b5f7f16e8505a --- /dev/null +++ b/examples/offline_inference/batch_llm_inference.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This example shows how to use Ray Data for data parallel batch inference. + +Ray Data is a data processing framework that can handle large datasets +and integrates tightly with vLLM for data-parallel inference. + +As of Ray 2.44, Ray Data has a native integration with +vLLM (under ray.data.llm). + +Ray Data provides functionality for: +* Reading and writing to cloud storage (S3, GCS, etc.) +* Automatic sharding and load-balancing across a cluster +* Optimized configuration of vLLM using continuous batching +* Compatible with tensor/pipeline parallel inference as well. + +Learn more about Ray Data's LLM integration: +https://docs.ray.io/en/latest/data/working-with-llms.html +""" +import ray +from packaging.version import Version +from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig + +assert Version(ray.__version__) >= Version( + "2.44.1"), "Ray version must be at least 2.44.1" + +# Uncomment to reduce clutter in stdout +# ray.init(log_to_driver=False) +# ray.data.DataContext.get_current().enable_progress_bars = False + +# Read one text file from S3. Ray Data supports reading multiple files +# from cloud storage (such as JSONL, Parquet, CSV, binary format). +ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt") +print(ds.schema()) + +size = ds.count() +print(f"Size of dataset: {size} prompts") + +# Configure vLLM engine. +config = vLLMEngineProcessorConfig( + model_source="unsloth/Llama-3.1-8B-Instruct", + engine_kwargs={ + "enable_chunked_prefill": True, + "max_num_batched_tokens": 4096, + "max_model_len": 16384, + }, + concurrency=1, # set the number of parallel vLLM replicas + batch_size=64, +) + +# Create a Processor object, which will be used to +# do batch inference on the dataset +vllm_processor = build_llm_processor( + config, + preprocess=lambda row: dict( + messages=[{ + "role": "system", + "content": "You are a bot that responds with haikus." + }, { + "role": "user", + "content": row["text"] + }], + sampling_params=dict( + temperature=0.3, + max_tokens=250, + )), + postprocess=lambda row: dict( + answer=row["generated_text"], + **row # This will return all the original columns in the dataset. + ), +) + +ds = vllm_processor(ds) + +# Peek first 10 results. +# NOTE: This is for local testing and debugging. For production use case, +# one should write full result out as shown below. +outputs = ds.take(limit=10) + +for output in outputs: + prompt = output["prompt"] + generated_text = output["generated_text"] + print(f"Prompt: {prompt!r}") + print(f"Generated text: {generated_text!r}") + +# Write inference output data out as Parquet files to S3. +# Multiple files would be written to the output destination, +# and each task would write one or more files separately. +# +# ds.write_parquet("s3://") diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index 04a79e2f8ae66690a64bf684b3a757abe5f78a82..965915beaf58f073698eab081d4734a1501cf4cd 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -34,6 +34,40 @@ from vllm import LLM, SamplingParams from vllm.utils import get_open_port +def parse_args(): + import argparse + parser = argparse.ArgumentParser(description="Data Parallel Inference") + parser.add_argument("--model", + type=str, + default="ibm-research/PowerMoE-3b", + help="Model name or path") + parser.add_argument("--dp-size", + type=int, + default=2, + help="Data parallel size") + parser.add_argument("--tp-size", + type=int, + default=2, + help="Tensor parallel size") + parser.add_argument("--node-size", + type=int, + default=1, + help="Total number of nodes") + parser.add_argument("--node-rank", + type=int, + default=0, + help="Rank of the current node") + parser.add_argument("--master-addr", + type=str, + default="", + help="Master node IP address") + parser.add_argument("--master-port", + type=int, + default=0, + help="Master node port") + return parser.parse_args() + + def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank): os.environ["VLLM_DP_RANK"] = str(global_dp_rank) @@ -95,37 +129,8 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, if __name__ == "__main__": - import argparse - parser = argparse.ArgumentParser(description="Data Parallel Inference") - parser.add_argument("--model", - type=str, - default="ibm-research/PowerMoE-3b", - help="Model name or path") - parser.add_argument("--dp-size", - type=int, - default=2, - help="Data parallel size") - parser.add_argument("--tp-size", - type=int, - default=2, - help="Tensor parallel size") - parser.add_argument("--node-size", - type=int, - default=1, - help="Total number of nodes") - parser.add_argument("--node-rank", - type=int, - default=0, - help="Rank of the current node") - parser.add_argument("--master-addr", - type=str, - default="", - help="Master node IP address") - parser.add_argument("--master-port", - type=int, - default=0, - help="Master node port") - args = parser.parse_args() + + args = parse_args() dp_size = args.dp_size tp_size = args.tp_size diff --git a/examples/offline_inference/disaggregated-prefill-v1/decode_example.py b/examples/offline_inference/disaggregated-prefill-v1/decode_example.py new file mode 100644 index 0000000000000000000000000000000000000000..66efbc0c9deecf083047c27aa0fd6db16f95ad7f --- /dev/null +++ b/examples/offline_inference/disaggregated-prefill-v1/decode_example.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + +# Read prompts from output.txt +prompts = [] +try: + with open("output.txt") as f: + for line in f: + prompts.append(line.strip()) + print(f"Loaded {len(prompts)} prompts from output.txt") +except FileNotFoundError: + print("Error: output.txt file not found") + exit(-1) + +sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) + +llm = LLM( + model="meta-llama/Llama-3.2-1B-Instruct", + enforce_eager=True, + gpu_memory_utilization=0.8, + max_num_batched_tokens=64, + max_num_seqs=16, + kv_transfer_config=KVTransferConfig.from_cli( + '{"kv_connector":"SharedStorageConnector","kv_role":"kv_both",' + '"kv_connector_extra_config": {"shared_storage_path": "local_storage"}}' + )) #, max_model_len=2048, max_num_batched_tokens=2048) + +# 1ST generation (prefill instance) +outputs = llm.generate(prompts, sampling_params) + +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py b/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py new file mode 100644 index 0000000000000000000000000000000000000000..f7cbf6557d54f8ea0c25394fc86bb649a0ff4b33 --- /dev/null +++ b/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py @@ -0,0 +1,43 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + +context = "Hi " * 1000 +context2 = "Hey " * 500 +prompts = [ + context + "Hello, my name is", + context + "The capital of France is", + context2 + "Your name is", + context2 + "The capital of China is", +] + +sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) + +llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", + enforce_eager=True, + gpu_memory_utilization=0.8, + kv_transfer_config=KVTransferConfig.from_cli( + '{"kv_connector":"SharedStorageConnector","kv_role":"kv_both", ' + '"kv_connector_extra_config": ' + '{"shared_storage_path": "local_storage"}}') + ) #, max_model_len=2048, max_num_batched_tokens=2048) + +# 1ST generation (prefill instance) +outputs = llm.generate( + prompts, + sampling_params, +) + +new_prompts = [] +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + new_prompts.append(prompt + generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +# Write new_prompts to output.txt +with open("output.txt", "w") as f: + for prompt in new_prompts: + f.write(prompt + "\n") +print(f"Saved {len(new_prompts)} prompts to output.txt") diff --git a/examples/offline_inference/disaggregated-prefill-v1/run.sh b/examples/offline_inference/disaggregated-prefill-v1/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..0ebf45a1586a012d9075219cb37318d68fce00bc --- /dev/null +++ b/examples/offline_inference/disaggregated-prefill-v1/run.sh @@ -0,0 +1,5 @@ +rm -rf local_storage/ +rm output.txt + +VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 prefill_example.py +VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py diff --git a/examples/offline_inference/disaggregated_prefill.py b/examples/offline_inference/disaggregated_prefill.py index 36ee24bf7f18b324973de6dfcfffa35b96b6ea1d..d60985146c5c9172483b422597dc21183c5a6488 100644 --- a/examples/offline_inference/disaggregated_prefill.py +++ b/examples/offline_inference/disaggregated_prefill.py @@ -95,7 +95,7 @@ def run_decode(prefill_done): print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") -if __name__ == "__main__": +def main(): prefill_done = Event() prefill_process = Process(target=run_prefill, args=(prefill_done, )) decode_process = Process(target=run_decode, args=(prefill_done, )) @@ -109,3 +109,7 @@ if __name__ == "__main__": # Terminate the prefill node when decode is finished decode_process.join() prefill_process.terminate() + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/distributed.py b/examples/offline_inference/distributed.py deleted file mode 100644 index e890c6dad8bd1c5f92a4eb7dc41bf897dcc12edb..0000000000000000000000000000000000000000 --- a/examples/offline_inference/distributed.py +++ /dev/null @@ -1,109 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -""" -This example shows how to use Ray Data for running offline batch inference -distributively on a multi-nodes cluster. - -Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html -""" - -from typing import Any - -import numpy as np -import ray -from packaging.version import Version -from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy - -from vllm import LLM, SamplingParams - -assert Version(ray.__version__) >= Version( - "2.22.0"), "Ray version must be at least 2.22.0" - -# Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) - -# Set tensor parallelism per instance. -tensor_parallel_size = 1 - -# Set number of instances. Each instance will use tensor_parallel_size GPUs. -num_instances = 1 - - -# Create a class to do batch inference. -class LLMPredictor: - - def __init__(self): - # Create an LLM. - self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf", - tensor_parallel_size=tensor_parallel_size) - - def __call__(self, batch: dict[str, np.ndarray]) -> dict[str, list]: - # Generate texts from the prompts. - # The output is a list of RequestOutput objects that contain the prompt, - # generated text, and other information. - outputs = self.llm.generate(batch["text"], sampling_params) - prompt: list[str] = [] - generated_text: list[str] = [] - for output in outputs: - prompt.append(output.prompt) - generated_text.append(' '.join([o.text for o in output.outputs])) - return { - "prompt": prompt, - "generated_text": generated_text, - } - - -# Read one text file from S3. Ray Data supports reading multiple files -# from cloud storage (such as JSONL, Parquet, CSV, binary format). -ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt") - - -# For tensor_parallel_size > 1, we need to create placement groups for vLLM -# to use. Every actor has to have its own placement group. -def scheduling_strategy_fn(): - # One bundle per tensor parallel worker - pg = ray.util.placement_group( - [{ - "GPU": 1, - "CPU": 1 - }] * tensor_parallel_size, - strategy="STRICT_PACK", - ) - return dict(scheduling_strategy=PlacementGroupSchedulingStrategy( - pg, placement_group_capture_child_tasks=True)) - - -resources_kwarg: dict[str, Any] = {} -if tensor_parallel_size == 1: - # For tensor_parallel_size == 1, we simply set num_gpus=1. - resources_kwarg["num_gpus"] = 1 -else: - # Otherwise, we have to set num_gpus=0 and provide - # a function that will create a placement group for - # each instance. - resources_kwarg["num_gpus"] = 0 - resources_kwarg["ray_remote_args_fn"] = scheduling_strategy_fn - -# Apply batch inference for all input data. -ds = ds.map_batches( - LLMPredictor, - # Set the concurrency to the number of LLM instances. - concurrency=num_instances, - # Specify the batch size for inference. - batch_size=32, - **resources_kwarg, -) - -# Peek first 10 results. -# NOTE: This is for local testing and debugging. For production use case, -# one should write full result out as shown below. -outputs = ds.take(limit=10) -for output in outputs: - prompt = output["prompt"] - generated_text = output["generated_text"] - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - -# Write inference output data out as Parquet files to S3. -# Multiple files would be written to the output destination, -# and each task would write one or more files separately. -# -# ds.write_parquet("s3://") diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py index 453ae7b6f56faefdd869ee9de973a59a6e2b5e64..474b745a610607da3c0626c3c931ffe1a666babe 100644 --- a/examples/offline_inference/eagle.py +++ b/examples/offline_inference/eagle.py @@ -27,7 +27,7 @@ def load_prompts(dataset_path, num_prompts): return prompts[:num_prompts] -def main(): +def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( "--dataset", @@ -45,10 +45,15 @@ def main(): parser.add_argument("--enable_chunked_prefill", action='store_true') parser.add_argument("--max_num_batched_tokens", type=int, default=2048) parser.add_argument("--temp", type=float, default=0) - args = parser.parse_args() + return parser.parse_args() + + +def main(): - model_dir = "meta-llama/Meta-Llama-3-8B-Instruct" - eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm" + args = parse_args() + + model_dir = "meta-llama/Llama-3.1-8B-Instruct" + eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" max_model_len = 2048 @@ -76,7 +81,7 @@ def main(): max_num_seqs=args.max_num_seqs, gpu_memory_utilization=0.8, speculative_config={ - "method": "eagle", + "method": "eagle3" if "eagle3" in eagle_dir.lower() else "eagle", "model": eagle_dir, "num_speculative_tokens": args.num_spec_tokens, "draft_tensor_parallel_size": args.draft_tp, @@ -90,6 +95,9 @@ def main(): outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params) + if not hasattr(outputs, "metrics") or outputs.metrics is None: + return + # calculate the average number of accepted tokens per forward pass, +1 is # to account for the token from the target model that's always going to be # accepted @@ -104,6 +112,11 @@ def main(): {sum(acceptance_counts) / acceptance_counts[0]:.2f}") print("-" * 50) + # print acceptance at each token position + for i in range(len(acceptance_counts)): + print(f"acceptance at token {i}:" + f"{acceptance_counts[i] / (acceptance_counts[0]):.2f}") + if __name__ == "__main__": main() diff --git a/examples/offline_inference/embed_jina_embeddings_v3.py b/examples/offline_inference/embed_jina_embeddings_v3.py index f7d9e47e7953e3e2244bbd27fcc3d102d37b4e25..b347ddbf3197a37f018554a8301e6a3900f94524 100644 --- a/examples/offline_inference/embed_jina_embeddings_v3.py +++ b/examples/offline_inference/embed_jina_embeddings_v3.py @@ -6,6 +6,16 @@ from vllm import LLM, EngineArgs from vllm.utils import FlexibleArgumentParser +def parse_args(): + parser = FlexibleArgumentParser() + parser = EngineArgs.add_cli_args(parser) + # Set example specific arguments + parser.set_defaults(model="jinaai/jina-embeddings-v3", + task="embed", + trust_remote_code=True) + return parser.parse_args() + + def main(args: Namespace): # Sample prompts. prompts = [ @@ -40,11 +50,5 @@ def main(args: Namespace): if __name__ == "__main__": - parser = FlexibleArgumentParser() - parser = EngineArgs.add_cli_args(parser) - # Set example specific arguments - parser.set_defaults(model="jinaai/jina-embeddings-v3", - task="embed", - trust_remote_code=True) - args = parser.parse_args() + args = parse_args() main(args) diff --git a/examples/offline_inference/embed_matryoshka_fy.py b/examples/offline_inference/embed_matryoshka_fy.py index ab71fbe73e6aac9da344fff1da455e6a3c75e7ba..7a6cb02556d9aff2486c406985b5b61e7efb9064 100644 --- a/examples/offline_inference/embed_matryoshka_fy.py +++ b/examples/offline_inference/embed_matryoshka_fy.py @@ -6,6 +6,16 @@ from vllm import LLM, EngineArgs, PoolingParams from vllm.utils import FlexibleArgumentParser +def parse_args(): + parser = FlexibleArgumentParser() + parser = EngineArgs.add_cli_args(parser) + # Set example specific arguments + parser.set_defaults(model="jinaai/jina-embeddings-v3", + task="embed", + trust_remote_code=True) + return parser.parse_args() + + def main(args: Namespace): # Sample prompts. prompts = [ @@ -38,11 +48,5 @@ def main(args: Namespace): if __name__ == "__main__": - parser = FlexibleArgumentParser() - parser = EngineArgs.add_cli_args(parser) - # Set example specific arguments - parser.set_defaults(model="jinaai/jina-embeddings-v3", - task="embed", - trust_remote_code=True) - args = parser.parse_args() + args = parse_args() main(args) diff --git a/examples/offline_inference/encoder_decoder.py b/examples/offline_inference/encoder_decoder.py index c6ccfd42ec85b9d2cf712b1663de4bf7797003b1..c4916e00f473c7a778d95a09e326605cfa7a6b4c 100644 --- a/examples/offline_inference/encoder_decoder.py +++ b/examples/offline_inference/encoder_decoder.py @@ -8,94 +8,112 @@ from vllm import LLM, SamplingParams from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt, zip_enc_dec_prompts) -dtype = "float" - -# Create a BART encoder/decoder model instance -llm = LLM( - model="facebook/bart-large-cnn", - dtype=dtype, -) - -# Get BART tokenizer -tokenizer = llm.llm_engine.get_tokenizer_group() - -# Test prompts -# -# This section shows all of the valid ways to prompt an -# encoder/decoder model. -# -# - Helpers for building prompts -text_prompt_raw = "Hello, my name is" -text_prompt = TextPrompt(prompt="The president of the United States is") -tokens_prompt = TokensPrompt(prompt_token_ids=tokenizer.encode( - prompt="The capital of France is")) -# - Pass a single prompt to encoder/decoder model -# (implicitly encoder input prompt); -# decoder input prompt is assumed to be None - -single_text_prompt_raw = text_prompt_raw # Pass a string directly -single_text_prompt = text_prompt # Pass a TextPrompt -single_tokens_prompt = tokens_prompt # Pass a TokensPrompt - -# - Pass explicit encoder and decoder input prompts within one data structure. -# Encoder and decoder prompts can both independently be text or tokens, with -# no requirement that they be the same prompt type. Some example prompt-type -# combinations are shown below, note that these are not exhaustive. - -enc_dec_prompt1 = ExplicitEncoderDecoderPrompt( - # Pass encoder prompt string directly, & - # pass decoder prompt tokens - encoder_prompt=single_text_prompt_raw, - decoder_prompt=single_tokens_prompt, -) -enc_dec_prompt2 = ExplicitEncoderDecoderPrompt( - # Pass TextPrompt to encoder, and - # pass decoder prompt string directly - encoder_prompt=single_text_prompt, - decoder_prompt=single_text_prompt_raw, -) -enc_dec_prompt3 = ExplicitEncoderDecoderPrompt( - # Pass encoder prompt tokens directly, and - # pass TextPrompt to decoder - encoder_prompt=single_tokens_prompt, - decoder_prompt=single_text_prompt, -) - -# - Finally, here's a useful helper function for zipping encoder and -# decoder prompts together into a list of ExplicitEncoderDecoderPrompt -# instances -zipped_prompt_list = zip_enc_dec_prompts( - ['An encoder prompt', 'Another encoder prompt'], - ['A decoder prompt', 'Another decoder prompt']) - -# - Let's put all of the above example prompts together into one list -# which we will pass to the encoder/decoder LLM. -prompts = [ - single_text_prompt_raw, single_text_prompt, single_tokens_prompt, - enc_dec_prompt1, enc_dec_prompt2, enc_dec_prompt3 -] + zipped_prompt_list + +def create_prompts(tokenizer): + # Test prompts + # + # This section shows all of the valid ways to prompt an + # encoder/decoder model. + # + # - Helpers for building prompts + text_prompt_raw = "Hello, my name is" + text_prompt = TextPrompt(prompt="The president of the United States is") + tokens_prompt = TokensPrompt(prompt_token_ids=tokenizer.encode( + prompt="The capital of France is")) + # - Pass a single prompt to encoder/decoder model + # (implicitly encoder input prompt); + # decoder input prompt is assumed to be None + + single_text_prompt_raw = text_prompt_raw # Pass a string directly + single_text_prompt = text_prompt # Pass a TextPrompt + single_tokens_prompt = tokens_prompt # Pass a TokensPrompt + + # ruff: noqa: E501 + # - Pass explicit encoder and decoder input prompts within one data structure. + # Encoder and decoder prompts can both independently be text or tokens, with + # no requirement that they be the same prompt type. Some example prompt-type + # combinations are shown below, note that these are not exhaustive. + + enc_dec_prompt1 = ExplicitEncoderDecoderPrompt( + # Pass encoder prompt string directly, & + # pass decoder prompt tokens + encoder_prompt=single_text_prompt_raw, + decoder_prompt=single_tokens_prompt, + ) + enc_dec_prompt2 = ExplicitEncoderDecoderPrompt( + # Pass TextPrompt to encoder, and + # pass decoder prompt string directly + encoder_prompt=single_text_prompt, + decoder_prompt=single_text_prompt_raw, + ) + enc_dec_prompt3 = ExplicitEncoderDecoderPrompt( + # Pass encoder prompt tokens directly, and + # pass TextPrompt to decoder + encoder_prompt=single_tokens_prompt, + decoder_prompt=single_text_prompt, + ) + + # - Finally, here's a useful helper function for zipping encoder and + # decoder prompts together into a list of ExplicitEncoderDecoderPrompt + # instances + zipped_prompt_list = zip_enc_dec_prompts( + ['An encoder prompt', 'Another encoder prompt'], + ['A decoder prompt', 'Another decoder prompt']) + + # - Let's put all of the above example prompts together into one list + # which we will pass to the encoder/decoder LLM. + return [ + single_text_prompt_raw, single_text_prompt, single_tokens_prompt, + enc_dec_prompt1, enc_dec_prompt2, enc_dec_prompt3 + ] + zipped_prompt_list + # Create a sampling params object. -sampling_params = SamplingParams( - temperature=0, - top_p=1.0, - min_tokens=0, - max_tokens=20, -) - -# Generate output tokens from the prompts. The output is a list of -# RequestOutput objects that contain the prompt, generated -# text, and other information. -outputs = llm.generate(prompts, sampling_params) +def create_sampling_params(): + return SamplingParams( + temperature=0, + top_p=1.0, + min_tokens=0, + max_tokens=20, + ) + # Print the outputs. -print("-" * 50) -for i, output in enumerate(outputs): - prompt = output.prompt - encoder_prompt = output.encoder_prompt - generated_text = output.outputs[0].text - print(f"Output {i+1}:") - print(f"Encoder prompt: {encoder_prompt!r}\n" - f"Decoder prompt: {prompt!r}\n" - f"Generated text: {generated_text!r}") +def print_outputs(outputs): print("-" * 50) + for i, output in enumerate(outputs): + prompt = output.prompt + encoder_prompt = output.encoder_prompt + generated_text = output.outputs[0].text + print(f"Output {i+1}:") + print(f"Encoder prompt: {encoder_prompt!r}\n" + f"Decoder prompt: {prompt!r}\n" + f"Generated text: {generated_text!r}") + print("-" * 50) + + +def main(): + dtype = "float" + + # Create a BART encoder/decoder model instance + llm = LLM( + model="facebook/bart-large-cnn", + dtype=dtype, + ) + + # Get BART tokenizer + tokenizer = llm.llm_engine.get_tokenizer_group() + + prompts = create_prompts(tokenizer) + sampling_params = create_sampling_params() + + # Generate output tokens from the prompts. The output is a list of + # RequestOutput objects that contain the prompt, generated + # text, and other information. + outputs = llm.generate(prompts, sampling_params) + + print_outputs(outputs) + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/encoder_decoder_multimodal.py b/examples/offline_inference/encoder_decoder_multimodal.py index 456ee60eaabf369240074d4b82aa131671b8473e..2883c37ca23607ac5f9cfd2dcab2844133e00ef8 100644 --- a/examples/offline_inference/encoder_decoder_multimodal.py +++ b/examples/offline_inference/encoder_decoder_multimodal.py @@ -22,7 +22,7 @@ class ModelRequestData(NamedTuple): def run_florence2(): engine_args = EngineArgs( model="microsoft/Florence-2-large", - tokenizer="facebook/bart-large", + tokenizer="Isotr0py/Florence-2-tokenizer", max_num_seqs=8, trust_remote_code=True, limit_mm_per_prompt={"image": 1}, @@ -126,6 +126,23 @@ model_example_map = { } +def parse_args(): + parser = FlexibleArgumentParser( + description='Demo on using vLLM for offline inference with ' + 'vision language models for text generation') + parser.add_argument('--model-type', + '-m', + type=str, + default="mllama", + choices=model_example_map.keys(), + help='Huggingface "model_type".') + parser.add_argument("--seed", + type=int, + default=None, + help="Set the seed when initializing `vllm.LLM`.") + return parser.parse_args() + + def main(args): model = args.model_type if model not in model_example_map: @@ -148,6 +165,7 @@ def main(args): temperature=0, top_p=1.0, max_tokens=64, + skip_special_tokens=False, ) start = time.time() @@ -171,19 +189,5 @@ def main(args): if __name__ == "__main__": - parser = FlexibleArgumentParser( - description='Demo on using vLLM for offline inference with ' - 'vision language models for text generation') - parser.add_argument('--model-type', - '-m', - type=str, - default="mllama", - choices=model_example_map.keys(), - help='Huggingface "model_type".') - parser.add_argument("--seed", - type=int, - default=None, - help="Set the seed when initializing `vllm.LLM`.") - - args = parser.parse_args() + args = parse_args() main(args) diff --git a/examples/offline_inference/llm_engine_example.py b/examples/offline_inference/llm_engine_example.py index abff90d1c0cb6bdfe62be33809dd4e31f9d9155a..d84cd9ee9f52bc2e8080364d1ca06bc38c85509f 100644 --- a/examples/offline_inference/llm_engine_example.py +++ b/examples/offline_inference/llm_engine_example.py @@ -50,6 +50,13 @@ def initialize_engine(args: argparse.Namespace) -> LLMEngine: return LLMEngine.from_engine_args(engine_args) +def parse_args(): + parser = FlexibleArgumentParser( + description='Demo on using the LLMEngine class directly') + parser = EngineArgs.add_cli_args(parser) + return parser.parse_args() + + def main(args: argparse.Namespace): """Main function that sets up and runs the prompt processing.""" engine = initialize_engine(args) @@ -58,8 +65,5 @@ def main(args: argparse.Namespace): if __name__ == '__main__': - parser = FlexibleArgumentParser( - description='Demo on using the LLMEngine class directly') - parser = EngineArgs.add_cli_args(parser) - args = parser.parse_args() + args = parse_args() main(args) diff --git a/examples/offline_inference/mistral-small.py b/examples/offline_inference/mistral-small.py index efa1aa5b03692e525d6981f48df326f7cb2c4b28..37c3181dc5fafd7d7c26a1bc227c1271b1a0815b 100644 --- a/examples/offline_inference/mistral-small.py +++ b/examples/offline_inference/mistral-small.py @@ -16,11 +16,11 @@ from vllm.sampling_params import SamplingParams # # Mistral format # vllm serve mistralai/Mistral-Small-3.1-24B-Instruct-2503 \ # --tokenizer-mode mistral --config-format mistral --load-format mistral \ -# --limit-mm-per-prompt 'image=4' --max-model-len 16384 +# --limit-mm-per-prompt '{"image":4}' --max-model-len 16384 # # # HF format # vllm serve mistralai/Mistral-Small-3.1-24B-Instruct-2503 \ -# --limit-mm-per-prompt 'image=4' --max-model-len 16384 +# --limit-mm-per-prompt '{"image":4}' --max-model-len 16384 # ``` # # - Client: @@ -62,6 +62,7 @@ def run_simple_demo(args: argparse.Namespace): tokenizer_mode="mistral" if args.format == "mistral" else "auto", config_format="mistral" if args.format == "mistral" else "auto", load_format="mistral" if args.format == "mistral" else "auto", + limit_mm_per_prompt={"image": 1}, max_model_len=4096, max_num_seqs=2, tensor_parallel_size=2, @@ -168,7 +169,7 @@ def run_advanced_demo(args: argparse.Namespace): print("-" * 50) -def main(): +def parse_args(): parser = argparse.ArgumentParser( description="Run a demo in simple or advanced mode.") @@ -187,8 +188,11 @@ def main(): '--disable-mm-preprocessor-cache', action='store_true', help='If True, disables caching of multi-modal preprocessor/mapper.') + return parser.parse_args() + - args = parser.parse_args() +def main(): + args = parse_args() if args.mode == "simple": print("Running simple demo...") diff --git a/examples/offline_inference/mlpspeculator.py b/examples/offline_inference/mlpspeculator.py index a2a984b04e005c9ad1bbce67c860c61f407b8398..53c58a76d9dc10c27b36977b1e7621493afa1978 100644 --- a/examples/offline_inference/mlpspeculator.py +++ b/examples/offline_inference/mlpspeculator.py @@ -34,8 +34,7 @@ def time_generation(llm: LLM, prompts: list[str], print("-" * 50) -if __name__ == "__main__": - +def main(): template = ( "Below is an instruction that describes a task. Write a response " "that appropriately completes the request.\n\n### Instruction:\n{}" @@ -66,3 +65,7 @@ if __name__ == "__main__": ) time_generation(llm, prompts, sampling_params, "With speculation") + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/prithvi_geospatial_mae.py b/examples/offline_inference/prithvi_geospatial_mae.py index 3ae507cac5ce1cd9dfb8afefea74e120245bf8b5..f97a1f32e6210d6678828078dc02e208af8091ac 100644 --- a/examples/offline_inference/prithvi_geospatial_mae.py +++ b/examples/offline_inference/prithvi_geospatial_mae.py @@ -417,6 +417,38 @@ def run_model(input_data, return pred_imgs +def parse_args(): + parser = argparse.ArgumentParser("MAE run inference", add_help=False) + + parser.add_argument( + "--data_file", + type=str, + default="./India_900498_S2Hand.tif", + help="Path to the file.", + ) + parser.add_argument( + "--output_dir", + type=str, + default="output", + help="Path to the directory where to save outputs.", + ) + parser.add_argument( + "--input_indices", + default=[1, 2, 3, 8, 11, 12], + type=int, + nargs="+", + help= + "0-based indices of the six Prithvi channels to be selected from the " + "input. By default selects [1,2,3,8,11,12] for S2L1C data.", + ) + parser.add_argument( + "--rgb_outputs", + action="store_true", + help="If present, output files will only contain RGB channels. " + "Otherwise, all bands will be saved.", + ) + + def main( data_file: str, output_dir: str, @@ -496,35 +528,7 @@ def main( if __name__ == "__main__": - parser = argparse.ArgumentParser("MAE run inference", add_help=False) - parser.add_argument( - "--data_file", - type=str, - default="./India_900498_S2Hand.tif", - help="Path to the file.", - ) - parser.add_argument( - "--output_dir", - type=str, - default="output", - help="Path to the directory where to save outputs.", - ) - parser.add_argument( - "--input_indices", - default=[1, 2, 3, 8, 11, 12], - type=int, - nargs="+", - help= - "0-based indices of the six Prithvi channels to be selected from the " - "input. By default selects [1,2,3,8,11,12] for S2L1C data.", - ) - parser.add_argument( - "--rgb_outputs", - action="store_true", - help="If present, output files will only contain RGB channels. " - "Otherwise, all bands will be saved.", - ) - args = parser.parse_args() + args = parse_args() main(**vars(args)) diff --git a/examples/offline_inference/profiling.py b/examples/offline_inference/profiling.py index 6e1d4722440a537a8e8b1e19b11bceece361fa85..9c818d0757345e7024adfa34c0ba2d3cfce13315 100644 --- a/examples/offline_inference/profiling.py +++ b/examples/offline_inference/profiling.py @@ -359,7 +359,7 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], f" in folder {context.save_chrome_traces_folder}") -if __name__ == "__main__": +def parse_args(): parser = FlexibleArgumentParser(description=""" Profile a model @@ -449,7 +449,10 @@ Profile a model EngineArgs.add_cli_args(parser) - args = parser.parse_args() + return parser.parse_args() + + +def main(args): context = ProfileContext( engine_args=EngineArgs.from_cli_args(args), **{ @@ -458,3 +461,8 @@ Profile a model if k in inspect.signature(ProfileContext).parameters }) run_profile(context, csv_output=args.csv, json_output=args.json) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/offline_inference/qwen2_5_omni/README.md b/examples/offline_inference/qwen2_5_omni/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c30541a598cee2d698a684ff69c55695f7783e89 --- /dev/null +++ b/examples/offline_inference/qwen2_5_omni/README.md @@ -0,0 +1,32 @@ +# Qwen2.5-Omni Offline Inference Examples + +This folder provides several example scripts on how to inference Qwen2.5-Omni offline. + +## Thinker Only + +```bash +# Audio + image + video +python examples/offline_inference/qwen2_5_omni/only_thinker.py -q mixed_modalities + +# Read vision and audio inputs from a single video file +# NOTE: V1 engine does not support interleaved modalities yet. +VLLM_USE_V1=0 python examples/offline_inference/qwen2_5_omni/only_thinker.py -q use_audio_in_video + +# Multiple audios +VLLM_USE_V1=0 python examples/offline_inference/qwen2_5_omni/only_thinker.py -q multi_audios +``` + +This script will run the thinker part of Qwen2.5-Omni, and generate text response. + +You can also test Qwen2.5-Omni on a single modality: + +```bash +# Process audio inputs +python examples/offline_inference/audio_language.py --model-type qwen2_5_omni + +# Process image inputs +python examples/offline_inference/vision_language.py --modality image --model-type qwen2_5_omni + +# Process video inputs +python examples/offline_inference/vision_language.py --modality video --model-type qwen2_5_omni +``` diff --git a/examples/offline_inference/qwen2_5_omni/only_thinker.py b/examples/offline_inference/qwen2_5_omni/only_thinker.py new file mode 100644 index 0000000000000000000000000000000000000000..c75a990120e0741aad71853c0a85b42a84cfd70c --- /dev/null +++ b/examples/offline_inference/qwen2_5_omni/only_thinker.py @@ -0,0 +1,160 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This example shows how to use vLLM for running offline inference +with the correct prompt format on Qwen2.5-Omni (thinker only). +""" + +from typing import NamedTuple + +import vllm.envs as envs +from vllm import LLM, SamplingParams +from vllm.assets.audio import AudioAsset +from vllm.assets.image import ImageAsset +from vllm.assets.video import VideoAsset +from vllm.utils import FlexibleArgumentParser + + +class QueryResult(NamedTuple): + inputs: dict + limit_mm_per_prompt: dict[str, int] + + +# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on +# lower-end GPUs. +# Unless specified, these settings have been tested to work on a single L4. + +default_system = ( + "You are Qwen, a virtual human developed by the Qwen Team, Alibaba " + "Group, capable of perceiving auditory and visual inputs, as well as " + "generating text and speech.") + + +def get_mixed_modalities_query() -> QueryResult: + question = ("What is recited in the audio? " + "What is the content of this image? Why is this video funny?") + prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>" + "<|vision_bos|><|IMAGE|><|vision_eos|>" + "<|vision_bos|><|VIDEO|><|vision_eos|>" + f"{question}<|im_end|>\n" + f"<|im_start|>assistant\n") + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": { + "audio": + AudioAsset("mary_had_lamb").audio_and_sample_rate, + "image": + ImageAsset("cherry_blossom").pil_image.convert("RGB"), + "video": + VideoAsset(name="sample_demo_1.mp4", + num_frames=16).np_ndarrays, + }, + }, + limit_mm_per_prompt={ + "audio": 1, + "image": 1, + "video": 1 + }, + ) + + +def get_use_audio_in_video_query() -> QueryResult: + question = ("Describe the content of the video, " + "then convert what the baby say into text.") + prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n<|vision_bos|><|VIDEO|><|vision_eos|>" + f"{question}<|im_end|>\n" + f"<|im_start|>assistant\n") + asset = VideoAsset(name="sample_demo_1.mp4", num_frames=16) + audio = asset.get_audio(sampling_rate=16000) + assert not envs.VLLM_USE_V1, ("V1 does not support use_audio_in_video. " + "Please launch this example with " + "`VLLM_USE_V1=0`.") + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": { + "video": asset.np_ndarrays, + "audio": audio, + }, + "mm_processor_kwargs": { + "use_audio_in_video": True, + }, + }, + limit_mm_per_prompt={ + "audio": 1, + "video": 1 + }, + ) + + +def get_multi_audios_query() -> QueryResult: + question = "Are these two audio clips the same?" + prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>" + "<|audio_bos|><|AUDIO|><|audio_eos|>" + f"{question}<|im_end|>\n" + f"<|im_start|>assistant\n") + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": { + "audio": [ + AudioAsset("winning_call").audio_and_sample_rate, + AudioAsset("mary_had_lamb").audio_and_sample_rate, + ], + }, + }, + limit_mm_per_prompt={ + "audio": 2, + }, + ) + + +query_map = { + "mixed_modalities": get_mixed_modalities_query, + "use_audio_in_video": get_use_audio_in_video_query, + "multi_audios": get_multi_audios_query, +} + + +def main(args): + model_name = "Qwen/Qwen2.5-Omni-7B" + query_result = query_map[args.query_type]() + + llm = LLM(model=model_name, + max_model_len=5632, + max_num_seqs=5, + limit_mm_per_prompt=query_result.limit_mm_per_prompt, + seed=args.seed) + + # We set temperature to 0.2 so that outputs can be different + # even when all prompts are identical when running batch inference. + sampling_params = SamplingParams(temperature=0.2, max_tokens=64) + + outputs = llm.generate(query_result.inputs, + sampling_params=sampling_params) + + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description='Demo on using vLLM for offline inference with ' + 'audio language models') + parser.add_argument('--query-type', + '-q', + type=str, + default="mixed_modalities", + choices=query_map.keys(), + help='Query type.') + parser.add_argument("--seed", + type=int, + default=None, + help="Set the seed when initializing `vllm.LLM`.") + + args = parser.parse_args() + main(args) diff --git a/examples/offline_inference/save_sharded_state.py b/examples/offline_inference/save_sharded_state.py index 6aac9b75c59cf5343ac51ae19c34d03670cc1533..338380cc9684175c3dd147786ee4a33cb124d2b4 100644 --- a/examples/offline_inference/save_sharded_state.py +++ b/examples/offline_inference/save_sharded_state.py @@ -29,20 +29,23 @@ from pathlib import Path from vllm import LLM, EngineArgs from vllm.utils import FlexibleArgumentParser -parser = FlexibleArgumentParser() -EngineArgs.add_cli_args(parser) -parser.add_argument("--output", - "-o", - required=True, - type=str, - help="path to output checkpoint") -parser.add_argument("--file-pattern", - type=str, - help="string pattern of saved filenames") -parser.add_argument("--max-file-size", - type=str, - default=5 * 1024**3, - help="max size (in bytes) of each safetensors file") + +def parse_args(): + parser = FlexibleArgumentParser() + EngineArgs.add_cli_args(parser) + parser.add_argument("--output", + "-o", + required=True, + type=str, + help="path to output checkpoint") + parser.add_argument("--file-pattern", + type=str, + help="string pattern of saved filenames") + parser.add_argument("--max-file-size", + type=str, + default=5 * 1024**3, + help="max size (in bytes) of each safetensors file") + return parser.parse_args() def main(args): @@ -87,5 +90,5 @@ def main(args): if __name__ == "__main__": - args = parser.parse_args() + args = parse_args() main(args) diff --git a/examples/offline_inference/simple_profiling.py b/examples/offline_inference/simple_profiling.py index 6a8e3a5a3e75700d39be4bd77811e5eb97ebc3d0..d583110c8e69bc5615be361c277aba30cd44de7f 100644 --- a/examples/offline_inference/simple_profiling.py +++ b/examples/offline_inference/simple_profiling.py @@ -18,8 +18,8 @@ prompts = [ # Create a sampling params object. sampling_params = SamplingParams(temperature=0.8, top_p=0.95) -if __name__ == "__main__": +def main(): # Create an LLM. llm = LLM(model="facebook/opt-125m", tensor_parallel_size=1) @@ -42,3 +42,7 @@ if __name__ == "__main__": # Add a buffer to wait for profiler in the background process # (in case MP is on) to finish writing profiling output. time.sleep(10) + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index f51cef95e85968a3120fb0aef5974e121a9d2311..d02ac17cfdd68ba6db4ec70df87f5d2aa05f1db0 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -150,7 +150,7 @@ def run_florence2(questions: list[str], modality: str) -> ModelRequestData: engine_args = EngineArgs( model="microsoft/Florence-2-large", - tokenizer="facebook/bart-large", + tokenizer="Isotr0py/Florence-2-tokenizer", max_model_len=4096, max_num_seqs=2, trust_remote_code=True, @@ -364,6 +364,29 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData: ) +# Kimi-VL +def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" + + prompts = [ + "<|im_user|>user<|im_middle|><|media_start|>image<|media_content|>" + f"<|media_pad|><|media_end|>{question}<|im_end|>" + "<|im_assistant|>assistant<|im_middle|>" for question in questions + ] + + engine_args = EngineArgs( + model="moonshotai/Kimi-VL-A3B-Instruct", + trust_remote_code=True, + max_model_len=4096, + limit_mm_per_prompt={"image": 1}, + ) + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # LLaVA-1.5 def run_llava(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -791,10 +814,13 @@ def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData: engine_args = EngineArgs( model=model_path, trust_remote_code=True, - max_model_len=4096, + max_model_len=5120, max_num_seqs=2, + max_num_batched_tokens=12800, enable_lora=True, max_lora_rank=320, + # Note - mm_processor_kwargs can also be passed to generate/chat calls + mm_processor_kwargs={"dynamic_hd": 16}, limit_mm_per_prompt={"image": 1}, ) @@ -918,6 +944,42 @@ def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData: ) +# Qwen2.5-Omni +def run_qwen2_5_omni(questions: list[str], modality: str): + model_name = "Qwen/Qwen2.5-Omni-7B" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=5, + mm_processor_kwargs={ + "min_pixels": 28 * 28, + "max_pixels": 1280 * 28 * 28, + "fps": [1], + }, + limit_mm_per_prompt={"image": 1}, + ) + + if modality == "image": + placeholder = "<|IMAGE|>" + elif modality == "video": + placeholder = "<|VIDEO|>" + + default_system = ( + "You are Qwen, a virtual human developed by the Qwen Team, Alibaba " + "Group, capable of perceiving auditory and visual inputs, as well as " + "generating text and speech.") + + prompts = [(f"<|im_start|>system\n{default_system}<|im_end|>\n" + f"<|im_start|>user\n<|vision_bos|>{placeholder}<|vision_eos|>" + f"{question}<|im_end|>\n" + "<|im_start|>assistant\n") for question in questions] + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # SkyworkR1V def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -966,6 +1028,7 @@ model_example_map = { "h2ovl_chat": run_h2ovl, "idefics3": run_idefics3, "internvl_chat": run_internvl, + "kimi_vl": run_kimi_vl, "llava": run_llava, "llava-next": run_llava_next, "llava-next-video": run_llava_next_video, @@ -986,6 +1049,7 @@ model_example_map = { "qwen_vl": run_qwen_vl, "qwen2_vl": run_qwen2_vl, "qwen2_5_vl": run_qwen2_5_vl, + "qwen2_5_omni": run_qwen2_5_omni, "skywork_chat": run_skyworkr1v, "smolvlm": run_smolvlm, } @@ -1073,6 +1137,59 @@ def time_counter(enable: bool): yield +def parse_args(): + parser = FlexibleArgumentParser( + description='Demo on using vLLM for offline inference with ' + 'vision language models for text generation') + parser.add_argument('--model-type', + '-m', + type=str, + default="llava", + choices=model_example_map.keys(), + help='Huggingface "model_type".') + parser.add_argument('--num-prompts', + type=int, + default=4, + help='Number of prompts to run.') + parser.add_argument('--modality', + type=str, + default="image", + choices=['image', 'video'], + help='Modality of the input.') + parser.add_argument('--num-frames', + type=int, + default=16, + help='Number of frames to extract from the video.') + parser.add_argument("--seed", + type=int, + default=None, + help="Set the seed when initializing `vllm.LLM`.") + + parser.add_argument( + '--image-repeat-prob', + type=float, + default=None, + help='Simulates the hit-ratio for multi-modal preprocessor cache' + ' (if enabled)') + + parser.add_argument( + '--disable-mm-preprocessor-cache', + action='store_true', + help='If True, disables caching of multi-modal preprocessor/mapper.') + + parser.add_argument( + '--time-generate', + action='store_true', + help='If True, then print the total generate() call time') + + parser.add_argument( + '--use-different-prompt-per-request', + action='store_true', + help='If True, then use different prompt (with the same multi-modal ' + 'data) for each request.') + return parser.parse_args() + + def main(args): model = args.model_type if model not in model_example_map: @@ -1151,55 +1268,5 @@ def main(args): if __name__ == "__main__": - parser = FlexibleArgumentParser( - description='Demo on using vLLM for offline inference with ' - 'vision language models for text generation') - parser.add_argument('--model-type', - '-m', - type=str, - default="llava", - choices=model_example_map.keys(), - help='Huggingface "model_type".') - parser.add_argument('--num-prompts', - type=int, - default=4, - help='Number of prompts to run.') - parser.add_argument('--modality', - type=str, - default="image", - choices=['image', 'video'], - help='Modality of the input.') - parser.add_argument('--num-frames', - type=int, - default=16, - help='Number of frames to extract from the video.') - parser.add_argument("--seed", - type=int, - default=None, - help="Set the seed when initializing `vllm.LLM`.") - - parser.add_argument( - '--image-repeat-prob', - type=float, - default=None, - help='Simulates the hit-ratio for multi-modal preprocessor cache' - ' (if enabled)') - - parser.add_argument( - '--disable-mm-preprocessor-cache', - action='store_true', - help='If True, disables caching of multi-modal preprocessor/mapper.') - - parser.add_argument( - '--time-generate', - action='store_true', - help='If True, then print the total generate() call time') - - parser.add_argument( - '--use-different-prompt-per-request', - action='store_true', - help='If True, then use different prompt (with the same multi-modal ' - 'data) for each request.') - - args = parser.parse_args() + args = parse_args() main(args) diff --git a/examples/offline_inference/vision_language_embedding.py b/examples/offline_inference/vision_language_embedding.py index ad3c5ae0627b37ad98277e21c0de522a1ea433bf..2637949551a1add8b7a1b6abe1c954c0b6a32091 100644 --- a/examples/offline_inference/vision_language_embedding.py +++ b/examples/offline_inference/vision_language_embedding.py @@ -156,16 +156,13 @@ def run_encode(model: str, modality: QueryModality, seed: Optional[int]): print("-" * 50) -def main(args: Namespace): - run_encode(args.model_name, args.modality, args.seed) - - model_example_map = { "e5_v": run_e5_v, "vlm2vec": run_vlm2vec, } -if __name__ == "__main__": + +def parse_args(): parser = FlexibleArgumentParser( description='Demo on using vLLM for offline inference with ' 'vision language models for multimodal embedding') @@ -184,6 +181,13 @@ if __name__ == "__main__": type=int, default=None, help="Set the seed when initializing `vllm.LLM`.") + return parser.parse_args() - args = parser.parse_args() + +def main(args: Namespace): + run_encode(args.model_name, args.modality, args.seed) + + +if __name__ == "__main__": + args = parse_args() main(args) diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index 89818f8b33ee6e68672adc58acb28d2cf869cd5b..7f6608559f9c4a31d27818a272d54e520605447a 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -326,6 +326,44 @@ def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData: ) +def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "moonshotai/Kimi-VL-A3B-Instruct" + + engine_args = EngineArgs( + model=model_name, + trust_remote_code=True, + max_model_len=4096, + max_num_seqs=4, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [{ + "role": + "user", + "content": [ + *placeholders, + { + "type": "text", + "text": question + }, + ], + }] + + processor = AutoProcessor.from_pretrained(model_name, + trust_remote_code=True) + + prompt = processor.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=[fetch_image(url) for url in image_urls], + ) + + def load_mistral3(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" @@ -465,11 +503,13 @@ def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData: engine_args = EngineArgs( model=model_path, trust_remote_code=True, - max_model_len=10000, + max_model_len=4096, max_num_seqs=2, limit_mm_per_prompt={"image": len(image_urls)}, enable_lora=True, max_lora_rank=320, + # Note - mm_processor_kwargs can also be passed to generate/chat calls + mm_processor_kwargs={"dynamic_hd": 4}, ) placeholders = "".join(f"<|image_{i}|>" @@ -640,6 +680,7 @@ model_example_map = { "h2ovl_chat": load_h2ovl, "idefics3": load_idefics3, "internvl_chat": load_internvl, + "kimi_vl": load_kimi_vl, "llama4": load_llama4, "mistral3": load_mistral3, "mllama": load_mllama, @@ -727,22 +768,7 @@ def run_chat(model: str, question: str, image_urls: list[str], print("-" * 50) -def main(args: Namespace): - model = args.model_type - method = args.method - seed = args.seed - - image_urls = IMAGE_URLS[:args.num_images] - - if method == "generate": - run_generate(model, QUESTION, image_urls, seed) - elif method == "chat": - run_chat(model, QUESTION, image_urls, seed) - else: - raise ValueError(f"Invalid method: {method}") - - -if __name__ == "__main__": +def parse_args(): parser = FlexibleArgumentParser( description='Demo on using vLLM for offline inference with ' 'vision language models that support multi-image input for text ' @@ -765,9 +791,29 @@ if __name__ == "__main__": parser.add_argument( "--num-images", "-n", - choices=list(range(1, 13)), # 12 is the max number of images + type=int, + choices=list(range(1, + len(IMAGE_URLS) + 1)), # the max number of images default=2, help="Number of images to use for the demo.") + return parser.parse_args() - args = parser.parse_args() + +def main(args: Namespace): + model = args.model_type + method = args.method + seed = args.seed + + image_urls = IMAGE_URLS[:args.num_images] + + if method == "generate": + run_generate(model, QUESTION, image_urls, seed) + elif method == "chat": + run_chat(model, QUESTION, image_urls, seed) + else: + raise ValueError(f"Invalid method: {method}") + + +if __name__ == "__main__": + args = parse_args() main(args) diff --git a/examples/online_serving/api_client.py b/examples/online_serving/api_client.py index 60e4bccb7517c45cd9d960800900a1d391fc4f76..36079ff11d07e1c7a8e4093fb0d2da45f37c9be7 100644 --- a/examples/online_serving/api_client.py +++ b/examples/online_serving/api_client.py @@ -58,6 +58,16 @@ def get_response(response: requests.Response) -> list[str]: return output +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--n", type=int, default=1) + parser.add_argument("--prompt", type=str, default="San Francisco is a") + parser.add_argument("--stream", action="store_true") + return parser.parse_args() + + def main(args: Namespace): prompt = args.prompt api_url = f"http://{args.host}:{args.port}/generate" @@ -82,11 +92,5 @@ def main(args: Namespace): if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--host", type=str, default="localhost") - parser.add_argument("--port", type=int, default=8000) - parser.add_argument("--n", type=int, default=1) - parser.add_argument("--prompt", type=str, default="San Francisco is a") - parser.add_argument("--stream", action="store_true") - args = parser.parse_args() + args = parse_args() main(args) diff --git a/examples/online_serving/cohere_rerank_client.py b/examples/online_serving/cohere_rerank_client.py index fc434ada1d15625ddb8706f37d78ed188b0e08ff..c2d4ef08ddbbe695986b442c8b19338cc195df58 100644 --- a/examples/online_serving/cohere_rerank_client.py +++ b/examples/online_serving/cohere_rerank_client.py @@ -2,32 +2,46 @@ """ Example of using the OpenAI entrypoint's rerank API which is compatible with the Cohere SDK: https://github.com/cohere-ai/cohere-python +Note that `pip install cohere` is needed to run this example. run: vllm serve BAAI/bge-reranker-base """ +from typing import Union + import cohere +from cohere import Client, ClientV2 + +model = "BAAI/bge-reranker-base" + +query = "What is the capital of France?" + +documents = [ + "The capital of France is Paris", "Reranking is fun!", + "vLLM is an open-source framework for fast AI serving" +] + + +def cohere_rerank(client: Union[Client, ClientV2], model: str, query: str, + documents: list[str]) -> dict: + return client.rerank(model=model, query=query, documents=documents) + + +def main(): + # cohere v1 client + cohere_v1 = cohere.Client(base_url="http://localhost:8000", + api_key="sk-fake-key") + rerank_v1_result = cohere_rerank(cohere_v1, model, query, documents) + print("-" * 50) + print("rerank_v1_result:\n", rerank_v1_result) + print("-" * 50) + + # or the v2 + cohere_v2 = cohere.ClientV2("sk-fake-key", + base_url="http://localhost:8000") + rerank_v2_result = cohere_rerank(cohere_v2, model, query, documents) + print("rerank_v2_result:\n", rerank_v2_result) + print("-" * 50) + -# cohere v1 client -co = cohere.Client(base_url="http://localhost:8000", api_key="sk-fake-key") -rerank_v1_result = co.rerank( - model="BAAI/bge-reranker-base", - query="What is the capital of France?", - documents=[ - "The capital of France is Paris", "Reranking is fun!", - "vLLM is an open-source framework for fast AI serving" - ]) - -print(rerank_v1_result) - -# or the v2 -co2 = cohere.ClientV2("sk-fake-key", base_url="http://localhost:8000") - -v2_rerank_result = co2.rerank( - model="BAAI/bge-reranker-base", - query="What is the capital of France?", - documents=[ - "The capital of France is Paris", "Reranking is fun!", - "vLLM is an open-source framework for fast AI serving" - ]) - -print(v2_rerank_result) +if __name__ == "__main__": + main() diff --git a/examples/online_serving/gradio_openai_chatbot_webserver.py b/examples/online_serving/gradio_openai_chatbot_webserver.py index ee01e1eae6281e6064a5ef19d380080e7cc20d23..314f1c5b7395161dd177c51d0c0b7b58fec37a47 100644 --- a/examples/online_serving/gradio_openai_chatbot_webserver.py +++ b/examples/online_serving/gradio_openai_chatbot_webserver.py @@ -1,52 +1,32 @@ # SPDX-License-Identifier: Apache-2.0 +"""Example for starting a Gradio OpenAI Chatbot Webserver +Start vLLM API server: + vllm serve meta-llama/Llama-2-7b-chat-hf +Start Gradio OpenAI Chatbot Webserver: + python examples/online_serving/gradio_openai_chatbot_webserver.py \ + -m meta-llama/Llama-2-7b-chat-hf + +Note that `pip install --upgrade gradio` is needed to run this example. +More details: https://github.com/gradio-app/gradio + +If your antivirus software blocks the download of frpc for gradio, +you can install it manually by following these steps: + +1. Download this file: https://cdn-media.huggingface.co/frpc-gradio-0.3/frpc_linux_amd64 +2. Rename the downloaded file to: frpc_linux_amd64_v0.3 +3. Move the file to this location: /home/user/.cache/huggingface/gradio/frpc +""" import argparse import gradio as gr from openai import OpenAI -# Argument parser setup -parser = argparse.ArgumentParser( - description='Chatbot Interface with Customizable Parameters') -parser.add_argument('--model-url', - type=str, - default='http://localhost:8000/v1', - help='Model URL') -parser.add_argument('-m', - '--model', - type=str, - required=True, - help='Model name for the chatbot') -parser.add_argument('--temp', - type=float, - default=0.8, - help='Temperature for text generation') -parser.add_argument('--stop-token-ids', - type=str, - default='', - help='Comma-separated stop token IDs') -parser.add_argument("--host", type=str, default=None) -parser.add_argument("--port", type=int, default=8001) - -# Parse the arguments -args = parser.parse_args() - -# Set OpenAI's API key and API base to use vLLM's API server. -openai_api_key = "EMPTY" -openai_api_base = args.model_url - -# Create an OpenAI client to interact with the API server -client = OpenAI( - api_key=openai_api_key, - base_url=openai_api_base, -) - - -def predict(message, history): - # Convert chat history to OpenAI format + +def format_history_to_openai(history): history_openai_format = [{ "role": "system", - "content": "You are a great ai assistant." + "content": "You are a great AI assistant." }] for human, assistant in history: history_openai_format.append({"role": "user", "content": human}) @@ -54,31 +34,92 @@ def predict(message, history): "role": "assistant", "content": assistant }) + return history_openai_format + + +def predict(message, history, client, model_name, temp, stop_token_ids): + # Format history to OpenAI chat format + history_openai_format = format_history_to_openai(history) history_openai_format.append({"role": "user", "content": message}) - # Create a chat completion request and send it to the API server + # Send request to OpenAI API (vLLM server) stream = client.chat.completions.create( - model=args.model, # Model name to use - messages=history_openai_format, # Chat history - temperature=args.temp, # Temperature for text generation - stream=True, # Stream response + model=model_name, + messages=history_openai_format, + temperature=temp, + stream=True, extra_body={ 'repetition_penalty': 1, - 'stop_token_ids': [ - int(id.strip()) for id in args.stop_token_ids.split(',') - if id.strip() - ] if args.stop_token_ids else [] + 'stop_token_ids': + [int(id.strip()) + for id in stop_token_ids.split(',')] if stop_token_ids else [] }) - # Read and return generated text from response stream - partial_message = "" + # Collect all chunks and concatenate them into a full message + full_message = "" for chunk in stream: - partial_message += (chunk.choices[0].delta.content or "") - yield partial_message + full_message += (chunk.choices[0].delta.content or "") + + # Return the full message as a single response + return full_message + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Chatbot Interface with Customizable Parameters') + parser.add_argument('--model-url', + type=str, + default='http://localhost:8000/v1', + help='Model URL') + parser.add_argument('-m', + '--model', + type=str, + required=True, + help='Model name for the chatbot') + parser.add_argument('--temp', + type=float, + default=0.8, + help='Temperature for text generation') + parser.add_argument('--stop-token-ids', + type=str, + default='', + help='Comma-separated stop token IDs') + parser.add_argument("--host", type=str, default=None) + parser.add_argument("--port", type=int, default=8001) + return parser.parse_args() + + +def build_gradio_interface(client, model_name, temp, stop_token_ids): + + def chat_predict(message, history): + return predict(message, history, client, model_name, temp, + stop_token_ids) + + return gr.ChatInterface(fn=chat_predict, + title="Chatbot Interface", + description="A simple chatbot powered by vLLM") + + +def main(): + # Parse the arguments + args = parse_args() + + # Set OpenAI's API key and API base to use vLLM's API server + openai_api_key = "EMPTY" + openai_api_base = args.model_url + + # Create an OpenAI client + client = OpenAI(api_key=openai_api_key, base_url=openai_api_base) + + # Define the Gradio chatbot interface using the predict function + gradio_interface = build_gradio_interface(client, args.model, args.temp, + args.stop_token_ids) + + gradio_interface.queue().launch(server_name=args.host, + server_port=args.port, + share=True) -# Create and launch a chat interface with Gradio -gr.ChatInterface(predict).queue().launch(server_name=args.host, - server_port=args.port, - share=True) +if __name__ == "__main__": + main() diff --git a/examples/online_serving/gradio_webserver.py b/examples/online_serving/gradio_webserver.py index 85a9119c6aa2f3510217d5bb3265db55ab093a16..2e7c2a0c5838c6d2c34694e35521dca2f7b5fa50 100644 --- a/examples/online_serving/gradio_webserver.py +++ b/examples/online_serving/gradio_webserver.py @@ -1,5 +1,22 @@ # SPDX-License-Identifier: Apache-2.0 +"""Example for starting a Gradio Webserver +Start vLLM API server: + python -m vllm.entrypoints.api_server \ + --model meta-llama/Llama-2-7b-chat-hf +Start Webserver: + python examples/online_serving/gradio_webserver.py + +Note that `pip install --upgrade gradio` is needed to run this example. +More details: https://github.com/gradio-app/gradio + +If your antivirus software blocks the download of frpc for gradio, +you can install it manually by following these steps: + +1. Download this file: https://cdn-media.huggingface.co/frpc-gradio-0.3/frpc_linux_amd64 +2. Rename the downloaded file to: frpc_linux_amd64_v0.3 +3. Move the file to this location: /home/user/.cache/huggingface/gradio/frpc +""" import argparse import json @@ -39,16 +56,23 @@ def build_demo(): return demo -if __name__ == "__main__": +def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default=None) parser.add_argument("--port", type=int, default=8001) parser.add_argument("--model-url", type=str, default="http://localhost:8000/generate") - args = parser.parse_args() + return parser.parse_args() + +def main(args): demo = build_demo() demo.queue().launch(server_name=args.host, server_port=args.port, share=True) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/online_serving/jinaai_rerank_client.py b/examples/online_serving/jinaai_rerank_client.py index 3e760e1717883d250a63d669307bb33748778c79..3076bba765ce53ece15775f97cae489b823d7648 100644 --- a/examples/online_serving/jinaai_rerank_client.py +++ b/examples/online_serving/jinaai_rerank_client.py @@ -23,12 +23,19 @@ data = { "The capital of France is Paris.", "Horses and cows are both animals" ] } -response = requests.post(url, headers=headers, json=data) - -# Check the response -if response.status_code == 200: - print("Request successful!") - print(json.dumps(response.json(), indent=2)) -else: - print(f"Request failed with status code: {response.status_code}") - print(response.text) + + +def main(): + response = requests.post(url, headers=headers, json=data) + + # Check the response + if response.status_code == 200: + print("Request successful!") + print(json.dumps(response.json(), indent=2)) + else: + print(f"Request failed with status code: {response.status_code}") + print(response.text) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/openai_chat_completion_client.py b/examples/online_serving/openai_chat_completion_client.py index a81562041130962c2d103a547c6c395d9ce054fb..74e0c045d6214b87670659efaae7582b65681adc 100644 --- a/examples/online_serving/openai_chat_completion_client.py +++ b/examples/online_serving/openai_chat_completion_client.py @@ -1,38 +1,49 @@ # SPDX-License-Identifier: Apache-2.0 - +"""Example Python client for OpenAI Chat Completion using vLLM API server +NOTE: start a supported chat completion model server with `vllm serve`, e.g. + vllm serve meta-llama/Llama-2-7b-chat-hf +""" from openai import OpenAI # Modify OpenAI's API key and API base to use vLLM's API server. openai_api_key = "EMPTY" openai_api_base = "http://localhost:8000/v1" -client = OpenAI( - # defaults to os.environ.get("OPENAI_API_KEY") - api_key=openai_api_key, - base_url=openai_api_base, -) - -models = client.models.list() -model = models.data[0].id - -chat_completion = client.chat.completions.create( - messages=[{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "Who won the world series in 2020?" - }, { - "role": - "assistant", - "content": - "The Los Angeles Dodgers won the World Series in 2020." - }, { - "role": "user", - "content": "Where was it played?" - }], - model=model, -) - -print("Chat completion results:") -print(chat_completion) +messages = [{ + "role": "system", + "content": "You are a helpful assistant." +}, { + "role": "user", + "content": "Who won the world series in 2020?" +}, { + "role": "assistant", + "content": "The Los Angeles Dodgers won the World Series in 2020." +}, { + "role": "user", + "content": "Where was it played?" +}] + + +def main(): + client = OpenAI( + # defaults to os.environ.get("OPENAI_API_KEY") + api_key=openai_api_key, + base_url=openai_api_base, + ) + + models = client.models.list() + model = models.data[0].id + + chat_completion = client.chat.completions.create( + messages=messages, + model=model, + ) + + print("-" * 50) + print("Chat completion results:") + print(chat_completion) + print("-" * 50) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/openai_chat_completion_client_for_multimodal.py b/examples/online_serving/openai_chat_completion_client_for_multimodal.py index ecfcf05a90d1699eb9f000a69d8612785f94c626..70db4d95e64941a163440b32545d55704e495966 100644 --- a/examples/online_serving/openai_chat_completion_client_for_multimodal.py +++ b/examples/online_serving/openai_chat_completion_client_for_multimodal.py @@ -9,7 +9,7 @@ vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja (multi-image inference with Phi-3.5-vision-instruct) vllm serve microsoft/Phi-3.5-vision-instruct --task generate \ - --trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2 + --trust-remote-code --max-model-len 4096 --limit-mm-per-prompt '{"image":2}' (audio inference with Ultravox) vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b --max-model-len 4096 @@ -303,12 +303,7 @@ example_function_map = { } -def main(args) -> None: - chat_type = args.chat_type - example_function_map[chat_type]() - - -if __name__ == "__main__": +def parse_args(): parser = FlexibleArgumentParser( description='Demo on using OpenAI client for online serving with ' 'multimodal language models served with vLLM.') @@ -318,5 +313,14 @@ if __name__ == "__main__": default="single-image", choices=list(example_function_map.keys()), help='Conversation type with multimodal data.') - args = parser.parse_args() + return parser.parse_args() + + +def main(args) -> None: + chat_type = args.chat_type + example_function_map[chat_type]() + + +if __name__ == "__main__": + args = parse_args() main(args) diff --git a/examples/online_serving/openai_chat_completion_client_with_tools.py b/examples/online_serving/openai_chat_completion_client_with_tools.py index 416fb61ca8bb58329c72862553c685e9f41929f7..c25203860ff398176eb02c4c6adbe025b86cf987 100644 --- a/examples/online_serving/openai_chat_completion_client_with_tools.py +++ b/examples/online_serving/openai_chat_completion_client_with_tools.py @@ -17,6 +17,7 @@ vllm serve --model NousResearch/Hermes-2-Pro-Llama-3-8B \ --enable-auto-tool-choice --tool-call-parser hermes """ import json +from typing import Any from openai import OpenAI @@ -24,15 +25,6 @@ from openai import OpenAI openai_api_key = "EMPTY" openai_api_base = "http://localhost:8000/v1" -client = OpenAI( - # defaults to os.environ.get("OPENAI_API_KEY") - api_key=openai_api_key, - base_url=openai_api_base, -) - -models = client.models.list() -model = models.data[0].id - tools = [{ "type": "function", "function": { @@ -78,86 +70,123 @@ messages = [{ "Can you tell me what the temperate will be in Dallas, in fahrenheit?" }] -chat_completion = client.chat.completions.create(messages=messages, - model=model, - tools=tools) - -print("Chat completion results:") -print(chat_completion) -print("\n\n") - -tool_calls_stream = client.chat.completions.create(messages=messages, - model=model, - tools=tools, - stream=True) - -chunks = [] -for chunk in tool_calls_stream: - chunks.append(chunk) - if chunk.choices[0].delta.tool_calls: - print(chunk.choices[0].delta.tool_calls[0]) - else: - print(chunk.choices[0].delta) - -arguments = [] -tool_call_idx = -1 -for chunk in chunks: - - if chunk.choices[0].delta.tool_calls: - tool_call = chunk.choices[0].delta.tool_calls[0] - - if tool_call.index != tool_call_idx: - if tool_call_idx >= 0: - print( - f"streamed tool call arguments: {arguments[tool_call_idx]}" - ) - tool_call_idx = chunk.choices[0].delta.tool_calls[0].index - arguments.append("") - if tool_call.id: - print(f"streamed tool call id: {tool_call.id} ") - - if tool_call.function: - if tool_call.function.name: - print(f"streamed tool call name: {tool_call.function.name}") - - if tool_call.function.arguments: - arguments[tool_call_idx] += tool_call.function.arguments - -if len(arguments): - print(f"streamed tool call arguments: {arguments[-1]}") - -print("\n\n") - -messages.append({ - "role": "assistant", - "tool_calls": chat_completion.choices[0].message.tool_calls -}) - -# Now, simulate a tool call def get_current_weather(city: str, state: str, unit: 'str'): return ("The weather in Dallas, Texas is 85 degrees fahrenheit. It is " "partly cloudly, with highs in the 90's.") -available_tools = {"get_current_weather": get_current_weather} - -completion_tool_calls = chat_completion.choices[0].message.tool_calls -for call in completion_tool_calls: - tool_to_call = available_tools[call.function.name] - args = json.loads(call.function.arguments) - result = tool_to_call(**args) - print(result) +def handle_tool_calls_stream( + client: OpenAI, + messages: list[dict[str, str]], + model: str, + tools: list[dict[str, Any]], +) -> list[Any]: + tool_calls_stream = client.chat.completions.create(messages=messages, + model=model, + tools=tools, + stream=True) + chunks = [] + print("chunks: ") + for chunk in tool_calls_stream: + chunks.append(chunk) + if chunk.choices[0].delta.tool_calls: + print(chunk.choices[0].delta.tool_calls[0]) + else: + print(chunk.choices[0].delta) + return chunks + + +def handle_tool_calls_arguments(chunks: list[Any]) -> list[str]: + arguments = [] + tool_call_idx = -1 + print("arguments: ") + for chunk in chunks: + if chunk.choices[0].delta.tool_calls: + tool_call = chunk.choices[0].delta.tool_calls[0] + if tool_call.index != tool_call_idx: + if tool_call_idx >= 0: + print(f"streamed tool call arguments: " + f"{arguments[tool_call_idx]}") + tool_call_idx = chunk.choices[0].delta.tool_calls[0].index + arguments.append("") + if tool_call.id: + print(f"streamed tool call id: {tool_call.id} ") + + if tool_call.function: + if tool_call.function.name: + print( + f"streamed tool call name: {tool_call.function.name}") + + if tool_call.function.arguments: + arguments[tool_call_idx] += tool_call.function.arguments + + return arguments + + +def main(): + # Initialize OpenAI client + client = OpenAI( + # defaults to os.environ.get("OPENAI_API_KEY") + api_key=openai_api_key, + base_url=openai_api_base, + ) + + # Get available models and select one + models = client.models.list() + model = models.data[0].id + + chat_completion = client.chat.completions.create(messages=messages, + model=model, + tools=tools) + + print("-" * 70) + print("Chat completion results:") + print(chat_completion) + print("-" * 70) + + # Stream tool calls + chunks = handle_tool_calls_stream(client, messages, model, tools) + print("-" * 70) + + # Handle arguments from streamed tool calls + arguments = handle_tool_calls_arguments(chunks) + + if len(arguments): + print(f"streamed tool call arguments: {arguments[-1]}\n") + + print("-" * 70) + + # Add tool call results to the conversation messages.append({ - "role": "tool", - "content": result, - "tool_call_id": call.id, - "name": call.function.name + "role": "assistant", + "tool_calls": chat_completion.choices[0].message.tool_calls }) -chat_completion_2 = client.chat.completions.create(messages=messages, - model=model, - tools=tools, - stream=False) -print("\n\n") -print(chat_completion_2) + # Now, simulate a tool call + available_tools = {"get_current_weather": get_current_weather} + + completion_tool_calls = chat_completion.choices[0].message.tool_calls + for call in completion_tool_calls: + tool_to_call = available_tools[call.function.name] + args = json.loads(call.function.arguments) + result = tool_to_call(**args) + print("tool_to_call result: ", result) + messages.append({ + "role": "tool", + "content": result, + "tool_call_id": call.id, + "name": call.function.name + }) + + chat_completion_2 = client.chat.completions.create(messages=messages, + model=model, + tools=tools, + stream=False) + print("Chat completion2 results:") + print(chat_completion_2) + print("-" * 70) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/openai_chat_completion_client_with_tools_required.py b/examples/online_serving/openai_chat_completion_client_with_tools_required.py index 779369d1634425c8ceb2d2192823e0f5521bff46..97d900bb75f1aaf76676d9ee543ca9af323b268b 100644 --- a/examples/online_serving/openai_chat_completion_client_with_tools_required.py +++ b/examples/online_serving/openai_chat_completion_client_with_tools_required.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """ -To run this example, you can start the vLLM server +To run this example, you can start the vLLM server without any specific flags: ```bash @@ -8,7 +8,7 @@ VLLM_USE_V1=0 vllm serve unsloth/Llama-3.2-1B-Instruct \ --guided-decoding-backend outlines ``` -This example demonstrates how to generate chat completions +This example demonstrates how to generate chat completions using the OpenAI Python client library. """ @@ -18,15 +18,6 @@ from openai import OpenAI openai_api_key = "EMPTY" openai_api_base = "http://localhost:8000/v1" -client = OpenAI( - # defaults to os.environ.get("OPENAI_API_KEY") - api_key=openai_api_key, - base_url=openai_api_base, -) - -models = client.models.list() -model = models.data[0].id - tools = [ { "type": "function", @@ -116,21 +107,36 @@ messages = [ }, ] -chat_completion = client.chat.completions.create( - messages=messages, - model=model, - tools=tools, - tool_choice="required", - stream=True # Enable streaming response -) -for chunk in chat_completion: - if chunk.choices and chunk.choices[0].delta.tool_calls: - print(chunk.choices[0].delta.tool_calls) +def main(): + client = OpenAI( + # defaults to os.environ.get("OPENAI_API_KEY") + api_key=openai_api_key, + base_url=openai_api_base, + ) + + models = client.models.list() + model = models.data[0].id + + chat_completion = client.chat.completions.create( + messages=messages, + model=model, + tools=tools, + tool_choice="required", + stream=True # Enable streaming response + ) + + for chunk in chat_completion: + if chunk.choices and chunk.choices[0].delta.tool_calls: + print(chunk.choices[0].delta.tool_calls) + + chat_completion = client.chat.completions.create(messages=messages, + model=model, + tools=tools, + tool_choice="required") + + print(chat_completion.choices[0].message.tool_calls) -chat_completion = client.chat.completions.create(messages=messages, - model=model, - tools=tools, - tool_choice="required") -print(chat_completion.choices[0].message.tool_calls) +if __name__ == "__main__": + main() diff --git a/examples/online_serving/openai_chat_completion_structured_outputs.py b/examples/online_serving/openai_chat_completion_structured_outputs.py index 986ff500e586e8918bd22ae366ebbb8cae22ff4c..f71162e36efd20415589f139903665b670fa92b4 100644 --- a/examples/online_serving/openai_chat_completion_structured_outputs.py +++ b/examples/online_serving/openai_chat_completion_structured_outputs.py @@ -1,43 +1,49 @@ # SPDX-License-Identifier: Apache-2.0 +""" +To run this example, you need to start the vLLM server: + +```bash +vllm serve Qwen/Qwen2.5-3B-Instruct +``` +""" from enum import Enum from openai import BadRequestError, OpenAI from pydantic import BaseModel -client = OpenAI( - base_url="http://localhost:8000/v1", - api_key="-", -) # Guided decoding by Choice (list of possible options) -completion = client.chat.completions.create( - model="Qwen/Qwen2.5-3B-Instruct", - messages=[{ - "role": "user", - "content": "Classify this sentiment: vLLM is wonderful!" - }], - extra_body={"guided_choice": ["positive", "negative"]}, -) -print(completion.choices[0].message.content) +def guided_choice_completion(client: OpenAI, model: str): + completion = client.chat.completions.create( + model=model, + messages=[{ + "role": "user", + "content": "Classify this sentiment: vLLM is wonderful!" + }], + extra_body={"guided_choice": ["positive", "negative"]}, + ) + return completion.choices[0].message.content + # Guided decoding by Regex -prompt = ("Generate an email address for Alan Turing, who works in Enigma." - "End in .com and new line. Example result:" - "alan.turing@enigma.com\n") - -completion = client.chat.completions.create( - model="Qwen/Qwen2.5-3B-Instruct", - messages=[{ - "role": "user", - "content": prompt, - }], - extra_body={ - "guided_regex": "\w+@\w+\.com\n", - "stop": ["\n"] - }, -) -print(completion.choices[0].message.content) +def guided_regex_completion(client: OpenAI, model: str): + prompt = ("Generate an email address for Alan Turing, who works in Enigma." + "End in .com and new line. Example result:" + "alan.turing@enigma.com\n") + + completion = client.chat.completions.create( + model=model, + messages=[{ + "role": "user", + "content": prompt, + }], + extra_body={ + "guided_regex": r"\w+@\w+\.com\n", + "stop": ["\n"] + }, + ) + return completion.choices[0].message.content # Guided decoding by JSON using Pydantic schema @@ -54,66 +60,100 @@ class CarDescription(BaseModel): car_type: CarType -json_schema = CarDescription.model_json_schema() - -prompt = ("Generate a JSON with the brand, model and car_type of" - "the most iconic car from the 90's") -completion = client.chat.completions.create( - model="Qwen/Qwen2.5-3B-Instruct", - messages=[{ - "role": "user", - "content": prompt, - }], - extra_body={"guided_json": json_schema}, -) -print(completion.choices[0].message.content) +def guided_json_completion(client: OpenAI, model: str): + json_schema = CarDescription.model_json_schema() -# Guided decoding by Grammar -simplified_sql_grammar = """ - ?start: select_statement + prompt = ("Generate a JSON with the brand, model and car_type of" + "the most iconic car from the 90's") + completion = client.chat.completions.create( + model=model, + messages=[{ + "role": "user", + "content": prompt, + }], + extra_body={"guided_json": json_schema}, + ) + return completion.choices[0].message.content - ?select_statement: "SELECT " column_list " FROM " table_name - ?column_list: column_name ("," column_name)* +# Guided decoding by Grammar +def guided_grammar_completion(client: OpenAI, model: str): + simplified_sql_grammar = """ + root ::= select_statement - ?table_name: identifier + select_statement ::= "SELECT " column " from " table " where " condition - ?column_name: identifier + column ::= "col_1 " | "col_2 " - ?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/ -""" + table ::= "table_1 " | "table_2 " -prompt = ("Generate an SQL query to show the 'username' and 'email'" - "from the 'users' table.") -completion = client.chat.completions.create( - model="Qwen/Qwen2.5-3B-Instruct", - messages=[{ - "role": "user", - "content": prompt, - }], - extra_body={"guided_grammar": simplified_sql_grammar}, -) -print(completion.choices[0].message.content) + condition ::= column "= " number -# Extra backend options -prompt = ("Generate an email address for Alan Turing, who works in Enigma." - "End in .com and new line. Example result:" - "alan.turing@enigma.com\n") + number ::= "1 " | "2 " + """ -try: - # The no-fallback option forces vLLM to use xgrammar, so when it fails - # you get a 400 with the reason why + prompt = ("Generate an SQL query to show the 'username' and 'email'" + "from the 'users' table.") completion = client.chat.completions.create( - model="Qwen/Qwen2.5-3B-Instruct", + model=model, messages=[{ "role": "user", "content": prompt, }], - extra_body={ - "guided_regex": "\w+@\w+\.com\n", - "stop": ["\n"], - "guided_decoding_backend": "xgrammar:no-fallback" - }, + extra_body={"guided_grammar": simplified_sql_grammar}, ) -except BadRequestError as e: - print("This error is expected:", e) + return completion.choices[0].message.content + + +# Extra backend options +def extra_backend_options_completion(client: OpenAI, model: str): + prompt = ("Generate an email address for Alan Turing, who works in Enigma." + "End in .com and new line. Example result:" + "alan.turing@enigma.com\n") + + try: + # The no-fallback option forces vLLM to use xgrammar, so when it fails + # you get a 400 with the reason why + completion = client.chat.completions.create( + model=model, + messages=[{ + "role": "user", + "content": prompt, + }], + extra_body={ + "guided_regex": r"\w+@\w+\.com\n", + "stop": ["\n"], + "guided_decoding_backend": "xgrammar:no-fallback" + }, + ) + return completion.choices[0].message.content + except BadRequestError as e: + print("This error is expected:", e) + + +def main(): + client: OpenAI = OpenAI( + base_url="http://localhost:8000/v1", + api_key="-", + ) + + model = "Qwen/Qwen2.5-3B-Instruct" + + print("Guided Choice Completion:") + print(guided_choice_completion(client, model)) + + print("\nGuided Regex Completion:") + print(guided_regex_completion(client, model)) + + print("\nGuided JSON Completion:") + print(guided_json_completion(client, model)) + + print("\nGuided Grammar Completion:") + print(guided_grammar_completion(client, model)) + + print("\nExtra Backend Options Completion:") + print(extra_backend_options_completion(client, model)) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/openai_chat_completion_structured_outputs_structural_tag.py b/examples/online_serving/openai_chat_completion_structured_outputs_structural_tag.py new file mode 100644 index 0000000000000000000000000000000000000000..b807bc5405262790f35bea7f7c52acfa9b280bd2 --- /dev/null +++ b/examples/online_serving/openai_chat_completion_structured_outputs_structural_tag.py @@ -0,0 +1,85 @@ +# SPDX-License-Identifier: Apache-2.0 +from openai import OpenAI + +# This example demonstrates the `structural_tag` response format. +# It can be used to specify a structured output format that occurs between +# specific tags in the response. This example shows how it could be used +# to enforce the format of a tool call response, but it could be used for +# any structured output within a subset of the response. + + +def main(): + client = OpenAI( + base_url="http://localhost:8000/v1", + api_key="-", + ) + + messages = [{ + "role": + "user", + "content": + """ +You have access to the following function to retrieve the weather in a city: + + { + "name": "get_weather", + "parameters": { + "city": { + "param_type": "string", + "description": "The city to get the weather for", + "required": True + } + } + } + +If a you choose to call a function ONLY reply in the following format: +<{start_tag}={function_name}>{parameters}{end_tag} +where + +start_tag => ` a JSON dict with the function argument name as key and function + argument value as value. +end_tag => `` + +Here is an example, +{"example_name": "example_value"} + +Reminder: +- Function calls MUST follow the specified format +- Required parameters MUST be specified +- Only call one function at a time +- Put the entire function call reply on one line +- Always add your sources when using search results to answer the user query + +You are a helpful assistant. + +Given the previous instructions, what is the weather in New York City, Boston, +and San Francisco? +""" + }] + + response = client.chat.completions.create( + model="meta-llama/Llama-3.1-8B-Instruct", + messages=messages, + response_format={ + "type": + "structural_tag", + "structures": [{ + "begin": "", + "schema": { + "type": "object", + "properties": { + "city": { + "type": "string" + } + } + }, + "end": "" + }], + "triggers": [" requests.Response: return response -if __name__ == "__main__": +def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=8000) parser.add_argument("--model", type=str, default="BAAI/bge-reranker-v2-m3") + return parser.parse_args() + - args = parser.parse_args() +def main(args): api_url = f"http://{args.host}:{args.port}/score" model_name = args.model @@ -30,9 +32,9 @@ if __name__ == "__main__": text_2 = "The capital of Brazil is Brasilia." prompt = {"model": model_name, "text_1": text_1, "text_2": text_2} score_response = post_http_request(prompt=prompt, api_url=api_url) - print("Prompt when text_1 and text_2 are both strings:") + print("\nPrompt when text_1 and text_2 are both strings:") pprint.pprint(prompt) - print("Score Response:") + print("\nScore Response:") pprint.pprint(score_response.json()) text_1 = "What is the capital of France?" @@ -41,9 +43,9 @@ if __name__ == "__main__": ] prompt = {"model": model_name, "text_1": text_1, "text_2": text_2} score_response = post_http_request(prompt=prompt, api_url=api_url) - print("Prompt when text_1 is string and text_2 is a list:") + print("\nPrompt when text_1 is string and text_2 is a list:") pprint.pprint(prompt) - print("Score Response:") + print("\nScore Response:") pprint.pprint(score_response.json()) text_1 = [ @@ -54,7 +56,12 @@ if __name__ == "__main__": ] prompt = {"model": model_name, "text_1": text_1, "text_2": text_2} score_response = post_http_request(prompt=prompt, api_url=api_url) - print("Prompt when text_1 and text_2 are both lists:") + print("\nPrompt when text_1 and text_2 are both lists:") pprint.pprint(prompt) - print("Score Response:") + print("\nScore Response:") pprint.pprint(score_response.json()) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/online_serving/openai_embedding_client.py b/examples/online_serving/openai_embedding_client.py index b7c5651e3bab28bef3864550836dc165481c4dca..bc217f7ca7a0ba4364a04fe6b8e5fbefb69f9df4 100644 --- a/examples/online_serving/openai_embedding_client.py +++ b/examples/online_serving/openai_embedding_client.py @@ -6,22 +6,29 @@ from openai import OpenAI openai_api_key = "EMPTY" openai_api_base = "http://localhost:8000/v1" -client = OpenAI( - # defaults to os.environ.get("OPENAI_API_KEY") - api_key=openai_api_key, - base_url=openai_api_base, -) - -models = client.models.list() -model = models.data[0].id - -responses = client.embeddings.create( - input=[ - "Hello my name is", - "The best thing about vLLM is that it supports many different models" - ], - model=model, -) - -for data in responses.data: - print(data.embedding) # List of float of len 4096 + +def main(): + client = OpenAI( + # defaults to os.environ.get("OPENAI_API_KEY") + api_key=openai_api_key, + base_url=openai_api_base, + ) + + models = client.models.list() + model = models.data[0].id + + responses = client.embeddings.create( + # ruff: noqa: E501 + input=[ + "Hello my name is", + "The best thing about vLLM is that it supports many different models" + ], + model=model, + ) + + for data in responses.data: + print(data.embedding) # List of float of len 4096 + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/openai_embedding_matryoshka_fy.py b/examples/online_serving/openai_embedding_matryoshka_fy.py new file mode 100644 index 0000000000000000000000000000000000000000..4544dcfb5ab09edbcf0a6ba416572f4c9c2142f3 --- /dev/null +++ b/examples/online_serving/openai_embedding_matryoshka_fy.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Example Python client for embedding API dimensions using vLLM API server +NOTE: + start a supported Matryoshka Embeddings model server with `vllm serve`, e.g. + vllm serve jinaai/jina-embeddings-v3 --trust-remote-code +""" + +from openai import OpenAI + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" + + +def main(): + client = OpenAI( + # defaults to os.environ.get("OPENAI_API_KEY") + api_key=openai_api_key, + base_url=openai_api_base, + ) + + models = client.models.list() + model = models.data[0].id + + responses = client.embeddings.create( + input=["Follow the white rabbit."], + model=model, + dimensions=32, + ) + + for data in responses.data: + print(data.embedding) # List of float of len 32 + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/openai_pooling_client.py b/examples/online_serving/openai_pooling_client.py index e17f9c5efd65907ba0cf3c070daa0875d54f3059..abcfe27c276990949b1c3e559f8a8f16bd3213ab 100644 --- a/examples/online_serving/openai_pooling_client.py +++ b/examples/online_serving/openai_pooling_client.py @@ -17,7 +17,7 @@ def post_http_request(prompt: dict, api_url: str) -> requests.Response: return response -if __name__ == "__main__": +def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=8000) @@ -25,15 +25,20 @@ if __name__ == "__main__": type=str, default="jason9693/Qwen2.5-1.5B-apeach") - args = parser.parse_args() + return parser.parse_args() + + +def main(args): api_url = f"http://{args.host}:{args.port}/pooling" model_name = args.model # Input like Completions API prompt = {"model": model_name, "input": "vLLM is great!"} pooling_response = post_http_request(prompt=prompt, api_url=api_url) + print("-" * 50) print("Pooling Response:") pprint.pprint(pooling_response.json()) + print("-" * 50) # Input like Chat API prompt = { @@ -50,3 +55,9 @@ if __name__ == "__main__": pooling_response = post_http_request(prompt=prompt, api_url=api_url) print("Pooling Response:") pprint.pprint(pooling_response.json()) + print("-" * 50) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/online_serving/openai_transcription_client.py b/examples/online_serving/openai_transcription_client.py index 062868dd8adf07040e792e6fb0aa5ebeca3ff5a7..5fcb7c5264162e45705582c41b4e2b2f6042d771 100644 --- a/examples/online_serving/openai_transcription_client.py +++ b/examples/online_serving/openai_transcription_client.py @@ -26,7 +26,12 @@ def sync_openai(): model="openai/whisper-large-v3", language="en", response_format="json", - temperature=0.0) + temperature=0.0, + # Additional sampling params not provided by OpenAI API. + extra_body=dict( + seed=4419, + repetition_penalty=1.3, + )) print("transcription result:", transcription.text) diff --git a/examples/online_serving/ray_serve_deepseek.py b/examples/online_serving/ray_serve_deepseek.py new file mode 100644 index 0000000000000000000000000000000000000000..f9ef3e2da1a19fe3ffcf640d715f6b82d3e47e93 --- /dev/null +++ b/examples/online_serving/ray_serve_deepseek.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Example to deploy DeepSeek R1 or V3 with Ray Serve LLM. +See Ray Serve LLM documentation at: +https://docs.ray.io/en/latest/serve/llm/serving-llms.html + +Run `python3 ray_serve_deepseek.py` to deploy the model. +""" + +from ray import serve +from ray.serve.llm import LLMConfig, build_openai_app + +llm_config = LLMConfig( + model_loading_config={ + "model_id": "deepseek", + # Since DeepSeek model is huge, it is recommended to pre-download + # the model to local disk, say /path/to/the/model and specify: + # model_source="/path/to/the/model" + "model_source": "deepseek-ai/DeepSeek-R1", + }, + deployment_config={ + "autoscaling_config": { + "min_replicas": 1, + "max_replicas": 1, + } + }, + # Change to the accelerator type of the node + accelerator_type="H100", + runtime_env={"env_vars": { + "VLLM_USE_V1": "1" + }}, + # Customize engine arguments as needed (e.g. vLLM engine kwargs) + engine_kwargs={ + "tensor_parallel_size": 8, + "pipeline_parallel_size": 2, + "gpu_memory_utilization": 0.92, + "dtype": "auto", + "max_num_seqs": 40, + "max_model_len": 16384, + "enable_chunked_prefill": True, + "enable_prefix_caching": True, + "trust_remote_code": True, + }, +) + +# Deploy the application +llm_app = build_openai_app({"llm_configs": [llm_config]}) +serve.run(llm_app) diff --git a/examples/tool_chat_template_llama4_json.jinja b/examples/tool_chat_template_llama4_json.jinja new file mode 100644 index 0000000000000000000000000000000000000000..759f16554436eefea10d9679443fbc30195702de --- /dev/null +++ b/examples/tool_chat_template_llama4_json.jinja @@ -0,0 +1,116 @@ +{%- macro is_array_of_type_objects(var) -%} + {%- if var is iterable and var is not string -%} + {%- set valid = true -%} + {%- for item in var -%} + {%- if 'type' not in item -%} + {%- set valid = false -%} + {%- break -%} + {%- endif -%} + {%- endfor -%} + {{ valid }} + {%- else -%} + {{ false }} + {%- endif -%} +{%- endmacro %} + +{%- macro render_message(message) %} + {%- if message['content'] is string %} + {{- message['content']|trim }} + {%- elif is_array_of_type_objects(data) == 'True' %} + {%- for content in message['content'] %} + {%- if content['type'] == 'image' %} + {{- '<|image|>' }} + {%- elif content['type'] == 'text' %} + {{- content['text']|trim }} + {%- endif %} + {%- endfor %} + {%- else %} + {{- message['content']|tojson }} + {%- endif %} +{%- endmacro %} + +{{- bos_token }} +{%- if custom_tools is defined %} + {%- set tools = custom_tools %} +{%- endif %} +{%- if not tools_in_user_message is defined %} + {%- set tools_in_user_message = true %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} + +{#- This block extracts the system message, so we can slot it into the right place. #} +{%- if messages[0]['role'] == 'system' %} + {%- set system_message = messages[0] %} + {%- set messages = messages[1:] %} +{%- else %} + {%- set system_message = ({ "content": "You are a helpful assistant with tool calling " + "capabilities. Only reply with a tool call if the function exists in the " + "library provided by the user. If it doesn't exist, just reply directly in " + "natural language. When you receive a tool call response, use the output to " + "format an answer to the original user question."}) %} +{%- endif %} + +{%- set tool_lib_preamble = 'Tools: You have access to the following tools. You might need to use one ' + 'or more function/tool calls to fulfill the task. \n' + 'If none are needed, then proceed to the response.\n\n' + 'Tool Call Syntax: You can call tools using the following syntax:\n' + '{"name": function name, "parameters": dictionary of argument name and its value}.\n' + 'Separate multiple function calls by "; ". Do not use variables.\n' + 'Do not include anything else when calling the tools with the syntax above.\n\n' + 'Here is a list of functions in JSON format that you can invoke.\n' %} + +{{- "<|header_start|>system<|header_end|>\n\n" }} +{%- if tools is not none and not tools_in_user_message %} + {{- tool_lib_preamble }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} +{%- endif %} +{{- render_message(system_message) }} +{{ "<|eot|>\n" }} + +{#- Custom tools are passed in a user message with some extra guidance #} +{%- if tools_in_user_message and not tools is none %} + {#- Extract the first user message so we can plug it in here #} + {%- if messages | length != 0 %} + {%- set first_user_message = messages[0] %} + {%- set messages = messages[1:] %} + {%- else %} + {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} + {%- endif %} + {{- '<|header_start|>user<|header_end|>\n\n' }} + {{- tool_lib_preamble }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {{- render_message(first_user_message) + "\n<|eot|>"}} +{%- endif %} + +{%- for message in messages %} + {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} + {{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }} + {{- render_message(message) }} + {{- "\n<|eot|>" }} + {%- elif 'tool_calls' in message and message.tool_calls|length > 0 %} + {{- '\n<|header_start|>assistant<|header_end|>\n\n' -}} + {{- render_message(message) }} + {%- for tool_call in message.tool_calls %} + {{- '{"name": "' + tool_call.function.name + '", ' }} + {{- '"parameters": ' }} + {{- tool_call.function.arguments | tojson }} + {{- "}" }} + {%- endfor %} + {{- "\n<|eot|>" }} + {%- elif message.role == "tool" or message.role == "ipython" %} + {{- "\n<|header_start|>ipython<|header_end|>\n\n" }} + {{- render_message(message) }} + {{- "\n<|eom|>" }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '\n<|header_start|>assistant<|header_end|>\n\n' }} +{%- endif %} diff --git a/pyproject.toml b/pyproject.toml index 167e975c70fdb6ddbc269f6d17305178f49585e1..b5f1039b44daccaf84389cb1d0ba766ac0eb9a71 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,8 @@ build-backend = "setuptools.build_meta" [project] name = "vllm" authors = [{name = "vLLM Team"}] -license = { "file"= "LICENSE" } +license = "Apache-2.0" +license-files = ["LICENSE"] readme = "README.md" description = "A high-throughput and memory-efficient inference and serving engine for LLMs" classifiers = [ @@ -23,7 +24,6 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", - "License :: OSI Approved :: Apache Software License", "Intended Audience :: Developers", "Intended Audience :: Information Technology", "Intended Audience :: Science/Research", @@ -46,8 +46,7 @@ vllm = "vllm.entrypoints.cli.main:main" [tool.setuptools.packages.find] where = ["."] -exclude = ["benchmarks", "csrc", "docs", "examples", "tests*"] -namespaces = false +include = ["vllm*"] [tool.yapfignore] ignore_patterns = [ @@ -59,7 +58,8 @@ ignore_patterns = [ line-length = 80 exclude = [ # External file, leaving license intact - "examples/other/fp8/quantizer/quantize.py" + "examples/other/fp8/quantizer/quantize.py", + "vllm/vllm_flash_attn/flash_attn_interface.pyi" ] [tool.ruff.lint.per-file-ignores] diff --git a/requirements/common.txt b/requirements/common.txt index ef24ab956bb4e7bfdcf87c9d5880c4c6c443b563..83e6a41b754bf02b43944f38c0d05987f1286c57 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -8,7 +8,7 @@ blake3 py-cpuinfo transformers >= 4.51.1 huggingface-hub[hf_xet] >= 0.30.0 # Required for Xet downloads. -tokenizers >= 0.19.1 # Required for Llama 3. +tokenizers >= 0.21.1 # Required for fast incremental detokenization. protobuf # Required by LlamaTokenizer. fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint. aiohttp @@ -26,7 +26,7 @@ xgrammar == 0.1.18; platform_machine == "x86_64" or platform_machine == "aarch64 typing_extensions >= 4.10 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 partial-json-parser # used for parsing partial JSON outputs -pyzmq +pyzmq >= 25.0.0 msgspec gguf >= 0.13.0 importlib_metadata diff --git a/requirements/cpu.txt b/requirements/cpu.txt index d845fb201ceff7dced7ca0a896f7c2ede041b7ac..69f732c2417a1b4270296652b391e3a3c7c2f4ff 100644 --- a/requirements/cpu.txt +++ b/requirements/cpu.txt @@ -12,9 +12,9 @@ torchaudio; platform_machine != "ppc64le" and platform_machine != "s390x" torchaudio==2.6.0; platform_machine == "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch -torchvision; platform_machine != "ppc64le" and platform_machine != "s390x" +torchvision; platform_machine != "ppc64le" and platform_machine != "s390x" torchvision==0.21.0; platform_machine == "ppc64le" datasets # for benchmark scripts # cpu cannot use triton 3.3.0 -triton==3.2.0; platform_machine != "ppc64le" +triton==3.2.0; platform_machine == "x86_64" diff --git a/requirements/docs.txt b/requirements/docs.txt index 416ca503b36c0c7113e313d551869eceef3a1b8c..d84fd633ce108946a263d028a3bb7df76213b2d4 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -7,6 +7,7 @@ sphinx-togglebutton==0.3.2 myst-parser==3.0.1 msgspec cloudpickle +commonmark # Required by sphinx-argparse when using :markdownhelp: # packages to install to build the documentation cachetools @@ -18,6 +19,7 @@ transformers mistral_common >= 1.5.4 aiohttp starlette +scipy openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args fastapi # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args diff --git a/requirements/hpu.txt b/requirements/hpu.txt index 830f6ef3f50cb9d8ba2b0b1a14f91fd72582f1f4..5ac58bc02892e7aa3c9ac006b7afb1b51440f52d 100644 --- a/requirements/hpu.txt +++ b/requirements/hpu.txt @@ -9,4 +9,4 @@ numpy==1.26.4 tabulate setuptools>=61 setuptools-scm>=8 -vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@4312768 +vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@f1f6624 diff --git a/requirements/nightly_torch_test.txt b/requirements/nightly_torch_test.txt new file mode 100644 index 0000000000000000000000000000000000000000..20372a9b2ef168128702ec70e6be05c52f18875f --- /dev/null +++ b/requirements/nightly_torch_test.txt @@ -0,0 +1,28 @@ +# Dependency that able to run entrypoints test +# pytest and its extensions +pytest +pytest-asyncio +pytest-forked +pytest-mock +pytest-rerunfailures +pytest-shard +pytest-timeout + + +librosa # required by audio tests in entrypoints/openai +sentence-transformers +numba == 0.61.2; python_version > '3.9' +# testing utils +awscli +boto3 +botocore +datasets +ray >= 2.10.0 +peft +runai-model-streamer==0.11.0 +runai-model-streamer-s3==0.11.0 +tensorizer>=2.9.0 +lm-eval==0.4.8 +buildkite-test-collector==0.1.9 + +lm-eval[api]==0.4.8 # required for model evaluation test diff --git a/requirements/rocm-build.txt b/requirements/rocm-build.txt index 29d5647807bb984fca42f6c3e3302a43a784ace1..05de4ff168453d055180605327561e266e685f18 100644 --- a/requirements/rocm-build.txt +++ b/requirements/rocm-build.txt @@ -6,6 +6,7 @@ torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 +triton==3.2 cmake>=3.26,<4 packaging setuptools>=61 diff --git a/requirements/test.in b/requirements/test.in index 95c94dcdbe999f7b496a734f581f3646d3904038..c5d2c4cd4c30f87243fb4cacebdde11529a96179 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -10,6 +10,7 @@ pytest-timeout # testing utils awscli backoff # required for phi4mm test +blobfile # required for kimi-vl test einops # required for MPT, qwen-vl and Mamba httpx librosa # required for audio tests @@ -26,14 +27,17 @@ torch==2.6.0 torchaudio==2.6.0 torchvision==0.21.0 transformers_stream_generator # required for qwen-vl test +mamba_ssm # required for plamo2 test matplotlib # required for qwen-vl test mistral_common[opencv] >= 1.5.4 # required for pixtral test num2words # required for smolvlm test opencv-python-headless >= 4.11.0 # required for video test datamodel_code_generator # required for minicpm3 test lm-eval[api]==0.4.8 # required for model evaluation test -transformers==4.51.1 +transformers==4.51.3 +tokenizers==0.21.1 huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads. +schemathesis>=3.39.15 # Required for openai schema test. # quantization bitsandbytes>=0.45.3 buildkite-test-collector==0.1.9 diff --git a/requirements/test.txt b/requirements/test.txt index 476b4a2cc0ec23aa4fa42f399715bb62828a92af..9642a5bfe68d421fdcd4932ebb0d790f15e9f56a 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -20,25 +20,35 @@ aiosignal==1.3.1 annotated-types==0.7.0 # via pydantic anyio==4.6.2.post1 - # via httpx + # via + # httpx + # starlette argcomplete==3.5.1 # via datamodel-code-generator +arrow==1.3.0 + # via isoduration attrs==24.2.0 # via # aiohttp + # hypothesis # jsonlines # jsonschema + # pytest-subtests # referencing audioread==3.0.1 # via librosa awscli==1.35.23 # via -r requirements/test.in backoff==2.2.1 - # via -r requirements/test.in + # via + # -r requirements/test.in + # schemathesis bitsandbytes==0.45.3 # via -r requirements/test.in black==24.10.0 # via datamodel-code-generator +blobfile==3.0.0 + # via -r requirements/test.in boto3==1.35.57 # via tensorizer botocore==1.35.57 @@ -67,11 +77,13 @@ click==8.1.7 # jiwer # nltk # ray + # schemathesis # typer colorama==0.4.6 # via # awscli # sacrebleu + # schemathesis # tqdm-multiprocess contourpy==1.3.0 # via matplotlib @@ -109,6 +121,7 @@ einops==0.8.0 # via # -r requirements/test.in # encodec + # mamba-ssm # vector-quantize-pytorch # vocos einx==0.3.0 @@ -127,6 +140,7 @@ fastsafetensors==0.1.10 # via -r requirements/test.in filelock==3.16.1 # via + # blobfile # datasets # huggingface-hub # ray @@ -134,6 +148,8 @@ filelock==3.16.1 # transformers fonttools==4.54.1 # via matplotlib +fqdn==1.5.1 + # via jsonschema frozendict==2.4.6 # via einx frozenlist==1.5.0 @@ -152,8 +168,12 @@ genai-perf==0.0.8 # via -r requirements/test.in genson==1.3.0 # via datamodel-code-generator +graphql-core==3.2.6 + # via hypothesis-graphql h11==0.14.0 # via httpcore +harfile==0.3.0 + # via schemathesis hf-xet==0.1.4 # via huggingface-hub hiredis==3.0.0 @@ -161,7 +181,9 @@ hiredis==3.0.0 httpcore==1.0.6 # via httpx httpx==0.27.2 - # via -r requirements/test.in + # via + # -r requirements/test.in + # schemathesis huggingface-hub==0.30.1 # via # -r requirements/test.in @@ -176,17 +198,29 @@ huggingface-hub==0.30.1 # vocos humanize==4.11.0 # via runai-model-streamer +hypothesis==6.131.0 + # via + # hypothesis-graphql + # hypothesis-jsonschema + # schemathesis +hypothesis-graphql==0.11.1 + # via schemathesis +hypothesis-jsonschema==0.23.1 + # via schemathesis idna==3.10 # via # anyio # email-validator # httpx + # jsonschema # requests # yarl inflect==5.6.2 # via datamodel-code-generator iniconfig==2.0.0 # via pytest +isoduration==20.11.0 + # via jsonschema isort==5.13.2 # via datamodel-code-generator jinja2==3.1.6 @@ -206,12 +240,18 @@ joblib==1.4.2 # scikit-learn jsonlines==4.0.0 # via lm-eval +jsonpointer==3.0.0 + # via jsonschema jsonschema==4.23.0 # via + # hypothesis-jsonschema # mistral-common # ray + # schemathesis jsonschema-specifications==2024.10.1 # via jsonschema +junit-xml==1.9 + # via schemathesis kaleido==0.2.1 # via genai-perf kiwisolver==1.4.7 @@ -227,11 +267,17 @@ llvmlite==0.44.0 lm-eval==0.4.8 # via -r requirements/test.in lxml==5.3.0 - # via sacrebleu + # via + # blobfile + # sacrebleu +mamba-ssm==2.2.4 + # via -r requirements/test.in markdown-it-py==3.0.0 # via rich markupsafe==3.0.2 - # via jinja2 + # via + # jinja2 + # werkzeug matplotlib==3.9.2 # via -r requirements/test.in mbstrdecoder==1.1.3 @@ -263,6 +309,8 @@ mypy-extensions==1.0.0 # via black networkx==3.2.1 # via torch +ninja==1.11.1.3 + # via mamba-ssm nltk==3.9.1 # via rouge-score num2words==0.5.14 @@ -355,6 +403,7 @@ packaging==24.1 # fastparquet # huggingface-hub # lazy-loader + # mamba-ssm # matplotlib # peft # plotly @@ -426,6 +475,8 @@ pybind11==2.13.6 # via lm-eval pycparser==2.22 # via cffi +pycryptodomex==3.22.0 + # via blobfile pydantic==2.9.2 # via # datamodel-code-generator @@ -436,6 +487,8 @@ pygments==2.18.0 # via rich pyparsing==3.2.0 # via matplotlib +pyrate-limiter==3.7.0 + # via schemathesis pytablewriter==1.2.0 # via lm-eval pytest==8.3.3 @@ -448,7 +501,9 @@ pytest==8.3.3 # pytest-mock # pytest-rerunfailures # pytest-shard + # pytest-subtests # pytest-timeout + # schemathesis pytest-asyncio==0.24.0 # via -r requirements/test.in pytest-forked==1.6.0 @@ -459,10 +514,13 @@ pytest-rerunfailures==14.0 # via -r requirements/test.in pytest-shard==0.1.2 # via -r requirements/test.in +pytest-subtests==0.14.1 + # via schemathesis pytest-timeout==2.3.1 # via -r requirements/test.in python-dateutil==2.9.0.post0 # via + # arrow # botocore # matplotlib # pandas @@ -484,6 +542,7 @@ pyyaml==6.0.2 # peft # ray # responses + # schemathesis # timm # transformers # vocos @@ -514,10 +573,16 @@ requests==2.32.3 # pooch # ray # responses + # schemathesis + # starlette-testclient # tiktoken # transformers responses==0.25.3 # via genai-perf +rfc3339-validator==0.1.4 + # via jsonschema +rfc3987==1.3.8 + # via jsonschema rich==13.9.4 # via # genai-perf @@ -546,6 +611,8 @@ safetensors==0.4.5 # peft # timm # transformers +schemathesis==3.39.15 + # via -r requirements/test.in scikit-learn==1.5.2 # via # librosa @@ -564,18 +631,23 @@ sentencepiece==0.2.0 # via mistral-common setuptools==75.8.0 # via + # mamba-ssm # pytablewriter # torch shellingham==1.5.4 # via typer six==1.16.0 # via + # junit-xml # python-dateutil + # rfc3339-validator # rouge-score sniffio==1.3.1 # via # anyio # httpx +sortedcontainers==2.4.0 + # via hypothesis soundfile==0.12.1 # via # -r requirements/test.in @@ -584,6 +656,12 @@ soxr==0.5.0.post1 # via librosa sqlitedict==2.1.0 # via lm-eval +starlette==0.46.2 + # via + # schemathesis + # starlette-testclient +starlette-testclient==0.4.1 + # via schemathesis statsmodels==0.14.4 # via genai-perf sympy==1.13.1 @@ -610,8 +688,14 @@ tiktoken==0.7.0 # mistral-common timm==1.0.11 # via -r requirements/test.in -tokenizers==0.21.0 - # via transformers +tokenizers==0.21.1 + # via + # -r requirements/test.in + # transformers +tomli==2.2.1 + # via schemathesis +tomli-w==1.2.0 + # via schemathesis torch==2.6.0 # via # -r requirements/test.in @@ -620,6 +704,7 @@ torch==2.6.0 # encodec # fastsafetensors # lm-eval + # mamba-ssm # peft # runai-model-streamer # sentence-transformers @@ -652,11 +737,12 @@ tqdm==4.66.6 # transformers tqdm-multiprocess==0.0.11 # via lm-eval -transformers==4.51.1 +transformers==4.51.3 # via # -r requirements/test.in # genai-perf # lm-eval + # mamba-ssm # peft # sentence-transformers # transformers-stream-generator @@ -675,6 +761,8 @@ typepy==1.3.2 # tabledata typer==0.15.2 # via fastsafetensors +types-python-dateutil==2.9.0.20241206 + # via arrow typing-extensions==4.12.2 # via # huggingface-hub @@ -687,8 +775,11 @@ typing-extensions==4.12.2 # typer tzdata==2024.2 # via pandas +uri-template==1.3.0 + # via jsonschema urllib3==2.2.3 # via + # blobfile # botocore # requests # responses @@ -697,6 +788,10 @@ vector-quantize-pytorch==1.21.2 # via -r requirements/test.in vocos==0.1.0 # via -r requirements/test.in +webcolors==24.11.1 + # via jsonschema +werkzeug==3.1.3 + # via schemathesis word2number==1.1 # via lm-eval xxhash==3.5.0 @@ -704,6 +799,8 @@ xxhash==3.5.0 # datasets # evaluate yarl==1.17.1 - # via aiohttp + # via + # aiohttp + # schemathesis zstandard==0.23.0 # via lm-eval diff --git a/requirements/tpu.txt b/requirements/tpu.txt index 75ebbc4ed94036f8a31c1a20ef0c5d5c91c4fa9d..b63993ba1ee453776333bfde58328aba53722293 100644 --- a/requirements/tpu.txt +++ b/requirements/tpu.txt @@ -17,9 +17,8 @@ ray[data] --find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250408-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" -torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250408-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" -torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250408-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" +torch==2.8.0.dev20250408 +torchvision==0.22.0.dev20250408 torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" diff --git a/setup.py b/setup.py index c5f2e55ab250097150d2ff0a923174599b4578ff..d39068590b439db6d37141197b50b118fda1dc21 100644 --- a/setup.py +++ b/setup.py @@ -276,15 +276,17 @@ class cmake_build_ext(build_ext): # First, run the standard build_ext command to compile the extensions super().run() - # copy vllm/vllm_flash_attn/*.py from self.build_lib to current + # copy vllm/vllm_flash_attn/**/*.py from self.build_lib to current # directory so that they can be included in the editable build import glob - files = glob.glob( - os.path.join(self.build_lib, "vllm", "vllm_flash_attn", "*.py")) + files = glob.glob(os.path.join(self.build_lib, "vllm", + "vllm_flash_attn", "**", "*.py"), + recursive=True) for file in files: dst_file = os.path.join("vllm/vllm_flash_attn", - os.path.basename(file)) + file.split("vllm/vllm_flash_attn/")[-1]) print(f"Copying {file} to {dst_file}") + os.makedirs(os.path.dirname(dst_file), exist_ok=True) self.copy_file(file, dst_file) @@ -384,13 +386,22 @@ class repackage_wheel(build_ext): "vllm/_flashmla_C.abi3.so", "vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so", "vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so", - "vllm/vllm_flash_attn/flash_attn_interface.py", - "vllm/vllm_flash_attn/__init__.py", "vllm/cumem_allocator.abi3.so", # "vllm/_version.py", # not available in nightly wheels yet ] - file_members = filter(lambda x: x.filename in files_to_copy, - wheel.filelist) + + file_members = list( + filter(lambda x: x.filename in files_to_copy, wheel.filelist)) + + # vllm_flash_attn python code: + # Regex from + # `glob.translate('vllm/vllm_flash_attn/**/*.py', recursive=True)` + import re + compiled_regex = re.compile( + r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py") + file_members += list( + filter(lambda x: compiled_regex.match(x.filename), + wheel.filelist)) for file in file_members: print(f"Extracting and including {file.filename} " @@ -563,9 +574,9 @@ def get_version_add(sha: Optional[str] = None) -> str: new_version_content = f""" try: - __version__ = "0.8.4" - __version_tuple__ = (0, 8, 4) - __hcu_version__ = f'0.8.4+{version}' + __version__ = "0.8.5" + __version_tuple__ = (0, 8, 5) + __hcu_version__ = f'0.8.5+{version}' from vllm.version import __version__, __version_tuple__, __hcu_version__ except Exception as e: diff --git a/tests/benchmarks/__init__.py b/tests/benchmarks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/benchmarks/test_latency_cli.py b/tests/benchmarks/test_latency_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..8537459b9f94dea5417ae2e84e14b5b967c70765 --- /dev/null +++ b/tests/benchmarks/test_latency_cli.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 +import subprocess + +import pytest + +MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" + + +@pytest.mark.benchmark +def test_bench_latency(): + command = [ + "vllm", "bench", "latency", "--model", MODEL_NAME, "--input-len", "32", + "--output-len", "1", "--enforce-eager", "--load-format", "dummy" + ] + result = subprocess.run(command, capture_output=True, text=True) + print(result.stdout) + print(result.stderr) + + assert result.returncode == 0, f"Benchmark failed: {result.stderr}" diff --git a/tests/benchmarks/test_serve_cli.py b/tests/benchmarks/test_serve_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..b746d6b7853c9dbcff8cbb29886b291bcbeeb0c4 --- /dev/null +++ b/tests/benchmarks/test_serve_cli.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 +import subprocess + +import pytest + +from ..utils import RemoteOpenAIServer + +MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--max-model-len", "1024", "--enforce-eager", "--load-format", "dummy" + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.mark.benchmark +def test_bench_serve(server): + command = [ + "vllm", + "bench", + "serve", + "--model", + MODEL_NAME, + "--host", + server.host, + "--port", + str(server.port), + "--random-input-len", + "32", + "--random-output-len", + "4", + "--num-prompts", + "5", + ] + result = subprocess.run(command, capture_output=True, text=True) + print(result.stdout) + print(result.stderr) + + assert result.returncode == 0, f"Benchmark failed: {result.stderr}" diff --git a/tests/benchmarks/test_throughput_cli.py b/tests/benchmarks/test_throughput_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..2045b362935658d381c5f65ab1bfc27b738108af --- /dev/null +++ b/tests/benchmarks/test_throughput_cli.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 +import subprocess + +import pytest + +MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" + + +@pytest.mark.benchmark +def test_bench_throughput(): + command = [ + "vllm", "bench", "throughput", "--model", MODEL_NAME, "--input-len", + "32", "--output-len", "1", "--enforce-eager", "--load-format", "dummy" + ] + result = subprocess.run(command, capture_output=True, text=True) + print(result.stdout) + print(result.stderr) + + assert result.returncode == 0, f"Benchmark failed: {result.stderr}" diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 579133ec0c3f68e9b67c490363c7411d9c5ed67e..c094063859876ff715247de8b344021f8db94142 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -20,15 +20,11 @@ def models_list(*, all: bool = True, keywords: Optional[list[str]] = None): ("facebook/opt-125m", {}), ("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", { "dtype": torch.float16, - "quantization": "compressed-tensors" }), ("neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic", { "dtype": torch.float16, - "quantization": "compressed-tensors" - }), - ("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", { - "quantization": "compressed-tensors" }), + ("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}), ("meta-llama/Llama-3.2-1B-Instruct", {}), ] diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index cb41987903173a544803151dfadae316f0d658ef..30eb0288ab1db6ce11755fb273c2663bda133801 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -11,7 +11,7 @@ from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym) from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.noop_elimination import NoOpEliminationPass -from vllm.config import CompilationConfig +from vllm.config import CompilationConfig, VllmConfig from .backend import TestBackend from ..utils import models_path_prefix @@ -51,13 +51,15 @@ def test_fix_functionalization(model: str, quant_key: QuantKey, do_fusion: bool): torch.set_default_device("cuda") - config = CompilationConfig.PassConfig(enable_fusion=do_fusion, - enable_noop=True) - noop_pass = NoOpEliminationPass(config) - fusion_pass = FusionPass.instance(config) + vllm_config = VllmConfig() + vllm_config.compilation_config = CompilationConfig(pass_config= \ + CompilationConfig.PassConfig(enable_fusion=do_fusion, + enable_noop=True)) + noop_pass = NoOpEliminationPass(vllm_config) + fusion_pass = FusionPass.instance(vllm_config) passes = [noop_pass, fusion_pass] if do_fusion else [noop_pass] - func_pass = FixFunctionalizationPass(config) + func_pass = FixFunctionalizationPass(vllm_config) backend_func = TestBackend(*passes, func_pass) backend_no_func = TestBackend(*passes) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index efebf05b6b04779ef7689f90660e4c47621cba63..6a696fe0226b1a4a8dd61797aa5842154226ecda 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -77,12 +77,13 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, vllm_config = VllmConfig(compilation_config=CompilationConfig( level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"])) + vllm_config.compilation_config.pass_config = \ + CompilationConfig.PassConfig(enable_fusion=True, + enable_noop=True) with vllm.config.set_current_vllm_config(vllm_config): # Reshape pass is needed for the fusion pass to work - config = CompilationConfig.PassConfig(enable_fusion=True, - enable_noop=True) - noop_pass = NoOpEliminationPass(config) - fusion_pass = FusionPass.instance(config) + noop_pass = NoOpEliminationPass(vllm_config) + fusion_pass = FusionPass.instance(vllm_config) backend = TestBackend(noop_pass, fusion_pass) model = TestModel(hidden_size, eps, static, cutlass_fp8_enabled) diff --git a/tests/compile/test_pass_manager.py b/tests/compile/test_pass_manager.py index 2c1ee4dc74806c8fe910ba8189f34f73a62109f3..673ebe8b6fdc0c800f4b2d832bcd07db94f97061 100644 --- a/tests/compile/test_pass_manager.py +++ b/tests/compile/test_pass_manager.py @@ -6,7 +6,7 @@ import torch from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.compilation.pass_manager import PostGradPassManager -from vllm.config import CompilationConfig +from vllm.config import VllmConfig # dummy custom pass that doesn't inherit @@ -16,7 +16,7 @@ def simple_callable(graph: torch.fx.Graph): # Should fail to add directly to the pass manager def test_bad_callable(): - config = CompilationConfig().pass_config + config = VllmConfig() pass_manager = PostGradPassManager() pass_manager.configure(config) @@ -43,7 +43,7 @@ class ProperPass(InductorPass): ], ) def test_pass_manager_uuid(callable): - config = CompilationConfig().pass_config + config = VllmConfig() pass_manager = PostGradPassManager() pass_manager.configure(config) @@ -64,7 +64,8 @@ def test_pass_manager_uuid(callable): # UUID should be different due to config change config2 = copy.deepcopy(config) - config2.enable_fusion = not config2.enable_fusion + config2.compilation_config.pass_config.enable_fusion = not \ + config2.compilation_config.pass_config.enable_fusion pass_manager3 = PostGradPassManager() pass_manager3.configure(config2) pass_manager3.add(callable) diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py new file mode 100644 index 0000000000000000000000000000000000000000..79f5486dadcdd6d06c683b6dd2913541a7dad227 --- /dev/null +++ b/tests/compile/test_sequence_parallelism.py @@ -0,0 +1,190 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +import vllm.envs as envs +from vllm.compilation.fix_functionalization import FixFunctionalizationPass +from vllm.compilation.fx_utils import (find_auto_fn, find_auto_fn_maybe, + find_specified_fn, + find_specified_fn_maybe, is_func) +from vllm.compilation.sequence_parallelism import SequenceParallelismPass +from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig, + VllmConfig) +from vllm.distributed import tensor_model_parallel_all_reduce +from vllm.distributed.parallel_state import (init_distributed_environment, + initialize_model_parallel) +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.platforms import current_platform +from vllm.utils import update_environment_variables + +from ..utils import multi_gpu_test +from .backend import TestBackend + +OPS_IN_MODEL_BEFORE = [ + torch.ops.vllm.all_reduce.default, +] + +OPS_IN_MODEL_AFTER = [ + torch.ops.vllm.reduce_scatter.default, + torch.ops.vllm.all_gather.default, +] + +OPS_IN_MODEL = [torch.ops._C.fused_add_rms_norm.default] + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + + +class TestModel(torch.nn.Module): + + def __init__(self, hidden_size=16, intermediate_size=32): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = torch.nn.Parameter( + torch.empty((intermediate_size, hidden_size))) + self.norm = RMSNorm(hidden_size, 1e-05) + # Initialize weights + torch.nn.init.normal_(self.gate_proj, std=0.02) + + def forward(self, hidden_states, residual): + """ + Forward pass implementing the operations in the FX graph + + Args: + hidden_states: Input tensor + residual: Residual tensor from previous layer + + Returns: + Tuple containing the output tensor + """ + # Reshape input + view = hidden_states.reshape(-1, self.hidden_size) + + #matrix multiplication + permute = self.gate_proj.permute(1, 0) + mm = torch.mm(view, permute) + + # Tensor parallel all-reduce + all_reduce = tensor_model_parallel_all_reduce(mm) + + # layer normalization + norm_output, residual_output = self.norm(all_reduce, residual) + + return norm_output, residual_output + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seq_len", [16]) +@pytest.mark.parametrize("hidden_size", [16]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], + reason="Only test on CUDA") +def test_sequence_parallelism_pass(batch_size: int, seq_len: int, + hidden_size: int, dtype: torch.dtype): + num_processes = 2 + + def run_torch_spawn(fn, nprocs): + # need to use torch.mp.spawn otherwise will have problems with + # torch.distributed and cuda + torch.multiprocessing.spawn(fn, + args=(num_processes, batch_size, seq_len, + hidden_size, dtype), + nprocs=nprocs) + + run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes) + + +def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int, + batch_size: int, seq_len: int, + hidden_size: int, + dtype: torch.dtype): + current_platform.seed_everything(0) + + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + torch.set_default_device(device) + torch.set_default_dtype(dtype) + + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': '12345', + }) + + # initialize distributed + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # configure vllm config for SequenceParallelismPass + vllm_config = VllmConfig() + vllm_config.compilation_config = CompilationConfig( + pass_config=CompilationConfig.PassConfig( + enable_sequence_parallelism=True, ), ) + vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) + + # this is a fake model name to construct the model config + # in the vllm_config, it's not really used. + model = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" + vllm_config.model_config = ModelConfig(model=model, + task="auto", + tokenizer=model, + tokenizer_mode="auto", + trust_remote_code=True, + dtype=dtype, + seed=42) + + sequence_parallelism_pass = SequenceParallelismPass(vllm_config) + backend_no_func = TestBackend(sequence_parallelism_pass) + func_pass = FixFunctionalizationPass(vllm_config) + backend_func = TestBackend(sequence_parallelism_pass, func_pass) + + model = TestModel(hidden_size, hidden_size * 2) + hidden_states = torch.randn((batch_size * seq_len, hidden_size), + dtype=dtype) + residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) + + compiled_model_no_func = torch.compile(model, backend=backend_no_func) + compiled_model_no_func(hidden_states, residual) + compiled_model_func = torch.compile(model, backend=backend_func) + compiled_model_func(hidden_states, residual) + + # Check substitution worked + pre_nodes = backend_no_func.graph_pre_pass.nodes + post_nodes = backend_no_func.graph_post_pass.nodes + + # In pre-nodes, all reduce should be there, + # reduce scatter and all gather should not + for op in OPS_IN_MODEL_BEFORE: + find_specified_fn(pre_nodes, op) + for op in OPS_IN_MODEL_AFTER: + assert find_specified_fn_maybe(pre_nodes, op) is None + + # In post-nodes, reduce scatter and all gather should be there, + # all reduce should not + for op in OPS_IN_MODEL_AFTER: + find_specified_fn(post_nodes, op) + for op in OPS_IN_MODEL_BEFORE: + assert find_specified_fn_maybe(post_nodes, op) is None + + # check if the functionalization pass is applied + for op in OPS_IN_MODEL: + find_auto_fn(backend_no_func.graph_post_pass.nodes, op) + assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, + op) is None # noqa: E501 + + # make sure the ops were all de-functionalized + found = dict() + for node in backend_func.graph_post_pass.nodes: + for op in OPS_IN_MODEL: + if is_func(node, op): + found[op] = True + assert all(found[op] for op in OPS_IN_MODEL) diff --git a/tests/conftest.py b/tests/conftest.py index 26ed163a6c595b32172d6e9f83b37921115ff6d9..2509b51261ee9580f3198642ef215dc8a239ab25 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,23 +24,24 @@ from transformers.models.auto.auto_factory import _BaseAutoModelClass from tests.models.utils import (TokensTextLogprobs, TokensTextLogprobsPromptLogprobs) from vllm import LLM, SamplingParams +from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset -from vllm.config import TaskOption, TokenizerPoolConfig, _get_and_verify_dtype +from vllm.config import TaskOption, _get_and_verify_dtype from vllm.connections import global_http_connection from vllm.distributed import (cleanup_dist_env_and_memory, init_distributed_environment, initialize_model_parallel) from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, - TokensPrompt, to_enc_dec_tuple_list, - zip_enc_dec_prompts) + to_enc_dec_tuple_list, zip_enc_dec_prompts) from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams -from vllm.utils import cuda_device_count_stateless, is_list_of +from vllm.utils import cuda_device_count_stateless from .utils import models_path_prefix + logger = init_logger(__name__) _TEST_DIR = os.path.dirname(__file__) @@ -109,10 +110,25 @@ class _VideoAssets(_VideoAssetsBase): return [prompts["sample_demo_1"]] +class _AudioAssetsBase(UserList[AudioAsset]): + pass + + +class _AudioAssets(_AudioAssetsBase): + + def __init__(self) -> None: + super().__init__([ + AudioAsset("mary_had_lamb"), + AudioAsset("winning_call"), + ]) + + IMAGE_ASSETS = _ImageAssets() """Singleton instance of :class:`_ImageAssets`.""" VIDEO_ASSETS = _VideoAssets() """Singleton instance of :class:`_VideoAssets`.""" +AUDIO_ASSETS = _AudioAssets() +"""Singleton instance of :class:`_AudioAssets`.""" @pytest.fixture(scope="function", autouse=True) @@ -269,6 +285,11 @@ def video_assets() -> _VideoAssets: return VIDEO_ASSETS +@pytest.fixture(scope="session") +def audio_assets() -> _AudioAssets: + return AUDIO_ASSETS + + _T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict) _R = TypeVar("_R") @@ -396,10 +417,15 @@ class HfRunner: processor_kwargs["images"] = image if videos is not None and (video := videos[i]) is not None: processor_kwargs["videos"] = video - if audios is not None and (audio_tuple := audios[i]) is not None: - audio, sr = audio_tuple - processor_kwargs["audio"] = audio - processor_kwargs["sampling_rate"] = sr + if audios is not None and (audio_inputs := audios[i]) is not None: + # HACK - not all processors take sampling_rate; we should + # clean this up in the future. + if len(audio_inputs) == 2: + audio, sr = audio_inputs + processor_kwargs["audio"] = audio + processor_kwargs["sampling_rate"] = sr + else: + processor_kwargs["audio"] = audio_inputs inputs = self.processor(**processor_kwargs) if isinstance(inputs, BatchFeature): @@ -474,12 +500,19 @@ class HfRunner: prompts: list[str], beam_width: int, max_tokens: int, + images: Optional[PromptImageInput] = None, + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, ) -> list[tuple[list[list[int]], list[str]]]: outputs = self.generate(prompts, do_sample=False, max_new_tokens=max_tokens, num_beams=beam_width, - num_return_sequences=beam_width) + num_return_sequences=beam_width, + images=images, + videos=videos, + audios=audios) + for i in range(len(outputs)): output_ids, output_str = outputs[i] for j in range(len(output_ids)): @@ -530,7 +563,10 @@ class HfRunner: for _, hidden_state in enumerate(hidden_states): last_hidden_states = hidden_state[-1][0] logits = torch.matmul( - last_hidden_states.to(output_embeddings.weight.device), + last_hidden_states.to( + device=output_embeddings.weight.device, + dtype=output_embeddings.weight.dtype, + ), output_embeddings.weight.t(), ) if getattr(output_embeddings, "bias", None) is not None: @@ -924,6 +960,7 @@ class VllmRunner: max_tokens: int, num_logprobs: int, num_prompt_logprobs: Optional[int] = None, + skip_special_tokens: bool = True, ) -> Union[list[TokensTextLogprobs], list[TokensTextLogprobsPromptLogprobs]]: greedy_logprobs_params = SamplingParams( @@ -931,6 +968,7 @@ class VllmRunner: max_tokens=max_tokens, logprobs=num_logprobs, prompt_logprobs=(num_prompt_logprobs), + skip_special_tokens=skip_special_tokens, ) ''' Greedy logprobs generation for vLLM encoder/decoder models @@ -941,18 +979,20 @@ class VllmRunner: def generate_beam_search( self, - prompts: Union[list[str], list[list[int]]], + prompts: list[str], beam_width: int, max_tokens: int, + images: Optional[PromptImageInput] = None, + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, ) -> list[tuple[list[list[int]], list[str]]]: - if is_list_of(prompts, str, check="all"): - prompts = [TextPrompt(prompt=prompt) for prompt in prompts] - else: - prompts = [ - TokensPrompt(prompt_token_ids=tokens) for tokens in prompts - ] + inputs = self.get_inputs(prompts, + images=images, + videos=videos, + audios=audios) + outputs = self.model.beam_search( - prompts, + inputs, BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens)) returned_outputs = [] for output in outputs: @@ -1005,20 +1045,6 @@ def vllm_runner(): return VllmRunner -def get_tokenizer_pool_config(tokenizer_group_type): - if tokenizer_group_type is None: - return None - if tokenizer_group_type == "ray": - return TokenizerPoolConfig(pool_size=1, - pool_type="ray", - extra_config={}) - if isinstance(tokenizer_group_type, type): - return TokenizerPoolConfig(pool_size=1, - pool_type=tokenizer_group_type, - extra_config={}) - raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}") - - @pytest.fixture() def temporary_enable_log_propagate(): import logging diff --git a/tests/core/block/e2e/test_correctness.py b/tests/core/block/e2e/test_correctness.py index d44bc5617bc7e69a95e6116f8b63cdfacc9e929c..2a3d5a16030536e0fc11d691d499db3f7e127cba 100644 --- a/tests/core/block/e2e/test_correctness.py +++ b/tests/core/block/e2e/test_correctness.py @@ -197,15 +197,15 @@ def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator, ]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{ - "block_size": 8, + "block_size": 16, "max_num_batched_tokens": 2, "max_num_seqs": 2, }, { - "block_size": 8, + "block_size": 16, "max_num_batched_tokens": 3, "max_num_seqs": 2, }, { - "block_size": 8, + "block_size": 16, "max_num_batched_tokens": 256, "max_num_seqs": 10, }]) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index ac6d6aae300632b5536d9ddd836289e521569f4e..8f4c3537e1586d098c8c0d872d2e09d16de0a3a5 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -14,7 +14,8 @@ import torch from vllm.distributed import (broadcast_tensor_dict, get_pp_group, tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce) + tensor_model_parallel_all_reduce, + tensor_model_parallel_reduce_scatter) from ..utils import init_test_distributed_environment, multi_process_parallel @@ -47,6 +48,34 @@ def all_reduce_test_worker( torch.testing.assert_close(t, expected) +@ray.remote(num_gpus=1, max_calls=1) +def reduce_scatter_test_worker(monkeypatch: pytest.MonkeyPatch, tp_size: int, + pp_size: int, rank: int, + distributed_init_port: str): + # it is important to delete the CUDA_VISIBLE_DEVICES environment variable + # so that each worker can see all the GPUs + # they will be able to set the device to the correct GPU + monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False) + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + init_test_distributed_environment(tp_size, pp_size, rank, + distributed_init_port) + + num_elements = 8 + all_tensors = [ + torch.arange(num_elements, dtype=torch.float32, device="cuda") * + (r + 1) for r in range(tp_size) + ] + + index = rank % tp_size + partition_size = num_elements // tp_size + all_reduce = torch.sum(torch.stack(all_tensors, dim=0), dim=0) + expected = all_reduce[index * partition_size:(index + 1) * partition_size] + t = all_tensors[index] + t = tensor_model_parallel_reduce_scatter(t, 0) + torch.testing.assert_close(t, expected) + + @ray.remote(num_gpus=1, max_calls=1) def all_gather_test_worker( monkeypatch: pytest.MonkeyPatch, diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index efd66c255f8de790a0c5c08814880b5a8a287193..395213a3de79abca38e9520bfbf540b33daec4c4 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -161,12 +161,12 @@ TEXT_GENERATION_MODELS = { os.path.join(models_path_prefix, "deepseek-ai/DeepSeek-V2-Lite-Chat"): PPTestSettings.fast(), os.path.join(models_path_prefix, "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"): PPTestSettings.fast(), os.path.join(models_path_prefix, "tiiuae/falcon-7b"): PPTestSettings.fast(), - os.path.join(models_path_prefix, "google/gemma-2b"): PPTestSettings.fast(), + os.path.join(models_path_prefix, "google/gemma-1.1-2b-it"): PPTestSettings.fast(), os.path.join(models_path_prefix, "google/gemma-2-9b"): PPTestSettings.fast(), os.path.join(models_path_prefix, "gpt2"): PPTestSettings.fast(), os.path.join(models_path_prefix, "bigcode/starcoder"): PPTestSettings.fast(), os.path.join(models_path_prefix, "EleutherAI/gpt-j-6b"): PPTestSettings.fast(), - os.path.join(models_path_prefix, "EleutherAI/pythia-12b"): PPTestSettings.fast(), + os.path.join(models_path_prefix, "EleutherAI/pythia-1.4b"): PPTestSettings.fast(), os.path.join(models_path_prefix, "ibm/PowerLM-3b"): PPTestSettings.fast(), os.path.join(models_path_prefix, "ibm/PowerMoE-3b"): PPTestSettings.fast(), # Uses Llama @@ -195,7 +195,7 @@ TEXT_GENERATION_MODELS = { os.path.join(models_path_prefix, "microsoft/Phi-3-small-8k-instruct"): PPTestSettings.fast(), os.path.join(models_path_prefix, "microsoft/Phi-3.5-MoE-instruct"): PPTestSettings.detailed(multi_node_only=True, load_format="dummy"), # noqa: E501 os.path.join(models_path_prefix, "Qwen/Qwen-7B-Chat"): PPTestSettings.fast(), - os.path.join(models_path_prefix, "Qwen/Qwen2-7B-Instruct"): PPTestSettings.fast(), + os.path.join(models_path_prefix, "Qwen/Qwen2.5-0.5B-Instruct"): PPTestSettings.fast(), os.path.join(models_path_prefix, "Qwen/Qwen1.5-MoE-A2.7B-Chat"): PPTestSettings.fast(), os.path.join(models_path_prefix, "stabilityai/stablelm-3b-4e1t"): PPTestSettings.fast(), os.path.join(models_path_prefix, "bigcode/starcoder2-3b"): PPTestSettings.fast(), diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..19497ad9c14090aec82d499fafd4fbbd317fa17e --- /dev/null +++ b/tests/distributed/test_sequence_parallel.py @@ -0,0 +1,296 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +WARNING: This test runs in both single-node (4 GPUs) and multi-node + (2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is + important to set the distributed backend to "mp" to avoid Ray scheduling + all workers in a node other than the head node, which can cause the test + to fail. +""" +import json +import os +from dataclasses import dataclass +from typing import Literal, NamedTuple, Optional + +import pytest + +from vllm.config import TaskOption +from vllm.logger import init_logger + +from ..models.registry import HF_EXAMPLE_MODELS +from ..utils import compare_two_settings, create_new_process_for_each_test + +logger = init_logger("test_sequence_parallel") + +VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" + + +class ParallelSetup(NamedTuple): + tp_size: int + sp_enabled: bool + eager_mode: bool + chunked_prefill: bool + + +class SPTestOptions(NamedTuple): + multi_node_only: bool + load_format: Optional[str] = None + + +@dataclass +class SPTestSettings: + parallel_setups: list[ParallelSetup] + # NOTE: the length of distributed_backends and + # vllm_major_versions should be the same, and they + # are first zipped together to iterate over all + # test settings. + distributed_backends: list[str] + # vllm major version: "0" for V0, "1" for V1 + vllm_major_versions: list[str] + task: TaskOption + test_options: SPTestOptions + + def __post_init__(self): + if len(self.distributed_backends) != len(self.vllm_major_versions): + raise ValueError( + f"Length mismatch: distributed_backends " + f"({len(self.distributed_backends)}) != " + f"vllm_major_versions ({len(self.vllm_major_versions)})") + + @staticmethod + def detailed( + *, + tp_base: int = 2, + multi_node_only: bool = False, + task: TaskOption = "auto", + load_format: Optional[str] = None, + ): + return SPTestSettings( + parallel_setups=[ + ParallelSetup(tp_size=tp_base, + sp_enabled=True, + eager_mode=False, + chunked_prefill=False), + ParallelSetup(tp_size=tp_base, + sp_enabled=True, + eager_mode=False, + chunked_prefill=True), + ParallelSetup(tp_size=tp_base, + sp_enabled=True, + eager_mode=True, + chunked_prefill=False), + ParallelSetup(tp_size=tp_base, + sp_enabled=True, + eager_mode=True, + chunked_prefill=True) + ], + distributed_backends=["mp", "ray"], + vllm_major_versions=["1", "1"], + task=task, + test_options=SPTestOptions(multi_node_only=multi_node_only, + load_format=load_format), + ) + + @staticmethod + def fast( + *, + tp_base: int = 2, + task: TaskOption = "auto", + multi_node_only: bool = False, + load_format: Optional[str] = None, + ): + return SPTestSettings( + parallel_setups=[ + ParallelSetup(tp_size=tp_base, + sp_enabled=True, + eager_mode=False, + chunked_prefill=False), + ], + distributed_backends=["mp", "ray"], + vllm_major_versions=["1", "1"], + task=task, + test_options=SPTestOptions(multi_node_only=multi_node_only, + load_format=load_format), + ) + + def iter_params(self, model_id: str): + opts = self.test_options + + for parallel_setup in self.parallel_setups: + for backend, vllm_major_version in zip(self.distributed_backends, + self.vllm_major_versions): + yield (model_id, parallel_setup, backend, vllm_major_version, + self.task, opts) + + +def _compare_sp( + model_id: str, + parallel_setup: ParallelSetup, + distributed_backend: str, + vllm_major_version: str, + task: TaskOption, + test_options: SPTestOptions, + num_gpus_available: int, + *, + method: Literal["generate", "encode"], + is_multimodal: bool, +): + ( + tp_size, + sp_enabled, + eager_mode, + chunked_prefill, + ) = parallel_setup + + multi_node_only, load_format = test_options + + model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) + model_info.check_transformers_version(on_fail="skip") + + trust_remote_code = model_info.trust_remote_code + tokenizer_mode = model_info.tokenizer_mode + hf_overrides = model_info.hf_overrides + + if load_format == "dummy": + # Avoid OOM + text_overrides = { + "num_hidden_layers": 4, + "hidden_size": 512, + "intermediate_size": 800, + "num_attention_heads": 4, + "num_key_value_heads": 1, + } + + if is_multimodal: + hf_overrides.update({"text_config": text_overrides}) + else: + hf_overrides.update(text_overrides) + else: + model_info.check_available_online(on_fail="skip") + + pp_size = 1 + if num_gpus_available < tp_size * pp_size: + pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") + if VLLM_MULTI_NODE and distributed_backend == "mp": + pytest.skip("Skipping multi-node pipeline parallel test for " + "multiprocessing distributed backend") + if multi_node_only and not VLLM_MULTI_NODE: + pytest.skip("Not in multi-node setting") + + common_args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "float16", + "--max-model-len", + "2048", + "--max-num-seqs", + "8", + ] + if chunked_prefill: + common_args.append("--enable-chunked-prefill") + if eager_mode: + common_args.append("--enforce-eager") + if task != "auto": + common_args.extend(["--task", task]) + if trust_remote_code: + common_args.append("--trust-remote-code") + if tokenizer_mode: + common_args.extend(["--tokenizer-mode", tokenizer_mode]) + if load_format: + common_args.extend(["--load-format", load_format]) + if hf_overrides: + common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) + + compilation_config = { + 'level': 3, + 'custom_ops': ["+rms_norm"], + 'compile_sizes': [4, 8], + 'splitting_ops': [], + 'pass_config': { + 'enable_sequence_parallism': sp_enabled, + 'enable_noop': True, + 'enable_fusion': True, + }, + } + + tp_sp_env = tp_env = { + "VLLM_USE_V1": vllm_major_version, + } + + tp_sp_args = [ + *common_args, + "--tensor-parallel-size", + str(tp_size), + "--distributed-executor-backend", + distributed_backend, + "--compilation_config", + str(compilation_config), + ] + + tp_env = { + "VLLM_USE_V1": vllm_major_version, + } + tp_args = [ + *common_args, + "--tensor-parallel-size", + str(tp_size), + "--distributed-executor-backend", + "mp", + ] + + try: + compare_two_settings(model_id, + tp_sp_args, + tp_args, + tp_sp_env, + tp_env, + method=method) + except Exception: + testing_ray_compiled_graph = tp_sp_env is not None + if testing_ray_compiled_graph and vllm_major_version == "0": + # Ray Compiled Graph tests are flaky for V0, + # so we don't want to fail the test + logger.exception("Ray Compiled Graph tests failed") + else: + raise + + +SP_TEXT_GENERATION_MODELS = { + # [Decoder-only] + "meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.detailed(), +} + +SP_TEST_MODELS = [ + # TODO support other models + # [LANGUAGE GENERATION] + "meta-llama/Llama-3.2-1B-Instruct", +] + + +@pytest.mark.parametrize( + ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version", + "task", "test_options"), + [ + params for model_id, settings in SP_TEXT_GENERATION_MODELS.items() + for params in settings.iter_params(model_id) + if model_id in SP_TEST_MODELS + ], +) +@create_new_process_for_each_test() +def test_tp_sp_generation( + model_id: str, + parallel_setup: ParallelSetup, + distributed_backend: str, + vllm_major_version: str, + task: TaskOption, + test_options: SPTestOptions, + num_gpus_available, +): + _compare_sp(model_id, + parallel_setup, + distributed_backend, + vllm_major_version, + task, + test_options, + num_gpus_available, + method="generate", + is_multimodal=False) diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 92387b46425e6c01a8289b2eb73050db842e0d56..052d5793c1b3ab238a696c6c93a9161051148ce0 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -1,16 +1,120 @@ # SPDX-License-Identifier: Apache-2.0 +import json from argparse import ArgumentError, ArgumentTypeError +from contextlib import nullcontext +from dataclasses import dataclass, field +from typing import Literal, Optional import pytest -from vllm.config import PoolerConfig -from vllm.engine.arg_utils import EngineArgs, nullable_kvs +from vllm.config import PoolerConfig, config +from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs, + get_type, is_not_builtin, is_type, + nullable_kvs, optional_type) from vllm.utils import FlexibleArgumentParser +@pytest.mark.parametrize(("type", "value", "expected"), [ + (int, "42", 42), + (int, "None", None), + (float, "3.14", 3.14), + (float, "None", None), + (str, "Hello World!", "Hello World!"), + (str, "None", None), + (json.loads, '{"foo":1,"bar":2}', { + "foo": 1, + "bar": 2 + }), + (json.loads, "foo=1,bar=2", { + "foo": 1, + "bar": 2 + }), + (json.loads, "None", None), +]) +def test_optional_type(type, value, expected): + optional_type_func = optional_type(type) + context = nullcontext() + if value == "foo=1,bar=2": + context = pytest.warns(DeprecationWarning) + with context: + assert optional_type_func(value) == expected + + +@pytest.mark.parametrize(("type_hint", "type", "expected"), [ + (int, int, True), + (int, float, False), + (list[int], list, True), + (list[int], tuple, False), + (Literal[0, 1], Literal, True), +]) +def test_is_type(type_hint, type, expected): + assert is_type(type_hint, type) == expected + + +@pytest.mark.parametrize(("type_hints", "type", "expected"), [ + ({float, int}, int, True), + ({int, tuple[int]}, int, True), + ({int, tuple[int]}, float, False), + ({str, Literal["x", "y"]}, Literal, True), +]) +def test_contains_type(type_hints, type, expected): + assert contains_type(type_hints, type) == expected + + +@pytest.mark.parametrize(("type_hints", "type", "expected"), [ + ({int, float}, int, int), + ({int, float}, str, None), + ({str, Literal["x", "y"]}, Literal, Literal["x", "y"]), +]) +def test_get_type(type_hints, type, expected): + assert get_type(type_hints, type) == expected + + +@config +@dataclass +class DummyConfigClass: + regular_bool: bool = True + """Regular bool with default True""" + optional_bool: Optional[bool] = None + """Optional bool with default None""" + optional_literal: Optional[Literal["x", "y"]] = None + """Optional literal with default None""" + tuple_n: tuple[int, ...] = field(default_factory=lambda: (1, 2, 3)) + """Tuple with default (1, 2, 3)""" + tuple_2: tuple[int, int] = field(default_factory=lambda: (1, 2)) + """Tuple with default (1, 2)""" + list_n: list[int] = field(default_factory=lambda: [1, 2, 3]) + """List with default [1, 2, 3]""" + + +@pytest.mark.parametrize(("type_hint", "expected"), [ + (int, False), + (DummyConfigClass, True), +]) +def test_is_not_builtin(type_hint, expected): + assert is_not_builtin(type_hint) == expected + + +def test_get_kwargs(): + kwargs = get_kwargs(DummyConfigClass) + print(kwargs) + + # bools should not have their type set + assert kwargs["regular_bool"].get("type") is None + assert kwargs["optional_bool"].get("type") is None + # optional literals should have None as a choice + assert kwargs["optional_literal"]["choices"] == ["x", "y", "None"] + # tuples should have the correct nargs + assert kwargs["tuple_n"]["nargs"] == "+" + assert kwargs["tuple_2"]["nargs"] == 2 + # lists should work + assert kwargs["list_n"]["type"] is int + assert kwargs["list_n"]["nargs"] == "+" + + @pytest.mark.parametrize(("arg", "expected"), [ - (None, None), + (None, dict()), ("image=16", { "image": 16 }), @@ -24,6 +128,10 @@ from vllm.utils import FlexibleArgumentParser }), ]) def test_limit_mm_per_prompt_parser(arg, expected): + """This functionality is deprecated and will be removed in the future. + This argument should be passed as JSON string instead. + + TODO: Remove with nullable_kvs.""" parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) if arg is None: args = parser.parse_args([]) @@ -53,12 +161,20 @@ def test_compilation_config(): assert args.compilation_config.level == 3 # set to string form of a dict - args = parser.parse_args(["--compilation-config", "{'level': 3}"]) - assert args.compilation_config.level == 3 + args = parser.parse_args([ + "--compilation-config", + "{'level': 3, 'cudagraph_capture_sizes': [1, 2, 4, 8]}", + ]) + assert (args.compilation_config.level == 3 and + args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]) # set to string form of a dict - args = parser.parse_args(["--compilation-config={'level': 3}"]) - assert args.compilation_config.level == 3 + args = parser.parse_args([ + "--compilation-config=" + "{'level': 3, 'cudagraph_capture_sizes': [1, 2, 4, 8]}", + ]) + assert (args.compilation_config.level == 3 and + args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]) def test_prefix_cache_default(): diff --git a/tests/entrypoints/llm/test_chat.py b/tests/entrypoints/llm/test_chat.py index 4878e847b3d6965a288efd8909e3a9f666c99e38..fbb0efd690cee3f4d784461df578fa12353b0fd1 100644 --- a/tests/entrypoints/llm/test_chat.py +++ b/tests/entrypoints/llm/test_chat.py @@ -91,3 +91,31 @@ def test_chat_multi_image(image_urls: list[str]): }] outputs = llm.chat(messages) assert len(outputs) >= 0 + + +def test_llm_chat_tokenization_no_double_bos(): + """ + LLM.chat() should not add special tokens when using chat templates. + Check we get a single BOS token for llama chat. + """ + llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", enforce_eager=True) + messages = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": "Hello!" + }, + ] + outputs = llm.chat(messages) + assert len(outputs) == 1 + prompt_token_ids = getattr(outputs[0], "prompt_token_ids", None) + assert prompt_token_ids is not None + + bos_token = llm.get_tokenizer().bos_token_id + + # Ensure we have a single BOS + assert prompt_token_ids[0] == bos_token + assert prompt_token_ids[1] != bos_token, "Double BOS" diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index d320ff17357bbfcaf1119362d0554c69f60aa065..8bae2b7d5969e077f6af768b6b1fcce0a28ff504 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -308,7 +308,7 @@ def test_disable_guided_decoding_fallback(sample_regex, llm): with pytest.raises( ValueError, match="xgrammar does not support advanced JSON schema features " - "like enums, patterns or numeric ranges."): + "like string length, item limits, or property bounds."): llm.generate(prompts="This should fail", sampling_params=sampling_params, use_tqdm=True) @@ -386,4 +386,118 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str): assert generated_text is not None print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") output_json = json.loads(generated_text) - jsonschema.validate(instance=output_json, schema=json_schema) \ No newline at end of file + jsonschema.validate(instance=output_json, schema=json_schema) + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) +def test_guided_number_range_json_completion(llm, + guided_decoding_backend: str): + sample_output_schema = { + "type": "object", + "properties": { + "age": { + "type": "integer", + "minimum": 18, + "maximum": 99 + }, + "score": { + "type": "number", + "minimum": 0.0, + "maximum": 100.0 + }, + "zipcode": { + "type": "string", + "pattern": r"^\d{5}(-\d{4})?$" + }, + }, + "required": ["age", "score", "zipcode"], + } + sampling_params = SamplingParams( + temperature=1.0, + max_tokens=1000, + guided_decoding=GuidedDecodingParams(json=sample_output_schema, + backend=guided_decoding_backend), + ) + outputs = llm.generate( + prompts=[ + "Create a JSON object for a user with age, score, and zipcode." + ] * 2, + sampling_params=sampling_params, + use_tqdm=True, + ) + + assert outputs is not None + + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + + generated_text = output.outputs[0].text + assert generated_text is not None + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + output_json = json.loads(generated_text) + jsonschema.validate(instance=output_json, schema=sample_output_schema) + assert 18 <= output_json["age"] <= 99 + assert 0.0 <= output_json["score"] <= 100.0 + assert (re.fullmatch(r"^\d{5}(-\d{4})?$", output_json["zipcode"]) + is not None) + + +@pytest.mark.skip_global_cleanup +def test_guidance_no_additional_properties(llm): + schema = { + 'type': 'object', + 'properties': { + 'a1': { + 'type': 'string' + }, + 'a2': { + 'type': 'string' + }, + 'a3': { + 'type': 'string' + } + }, + 'required': ['a1', 'a2', 'a3'], + } + + prompt = ( + "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a " + "helpful assistant.<|im_end|>\n<|im_start|>user\nPlease generate a " + "large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20" + "<|im_end|>\n<|im_start|>assistant\n") + + def generate_with_backend(backend): + guided_params = GuidedDecodingParams(json=schema, backend=backend) + sampling_params = SamplingParams(temperature=0, + max_tokens=256, + guided_decoding=guided_params) + + outputs = llm.generate(prompts=prompt, sampling_params=sampling_params) + assert outputs is not None + generated_text = outputs[0].outputs[0].text + assert generated_text is not None + parsed_json = json.loads(generated_text) + assert isinstance(parsed_json, dict) + jsonschema.validate(instance=parsed_json, schema=schema) + return parsed_json + + base_generated = generate_with_backend('guidance:disable-any-whitespace') + assert "a1" in base_generated + assert "a2" in base_generated + assert "a3" in base_generated + # by default additional keys are generated + assert "a4" in base_generated + assert "a5" in base_generated + assert "a6" in base_generated + + generated = generate_with_backend( + 'guidance:no-additional-properties,disable-any-whitespace') + assert "a1" in generated + assert "a2" in generated + assert "a3" in generated + assert "a4" not in generated + assert "a5" not in generated + assert "a6" not in generated diff --git a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py index eca5d184f5d6061788c0ca9668dbddf32ffa7276..642c204b9ff001f0d813e7edd4581be3283027f8 100644 --- a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py +++ b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py @@ -150,6 +150,7 @@ def test_wer_correctness(model_name, expected_wer, n_examples=-1, max_concurrent_request=None): + # TODO refactor to use `ASRDataset` with RemoteOpenAIServer(model_name, ['--enforce-eager']) as remote_server: dataset = load_hf_dataset(dataset_repo) diff --git a/tests/entrypoints/openai/test_audio.py b/tests/entrypoints/openai/test_audio.py index 2b8469f8e84686257b6c635f2eb9aa9a8cf51979..f8a8070668a7b9995d796abd5c672190975aa0f1 100644 --- a/tests/entrypoints/openai/test_audio.py +++ b/tests/entrypoints/openai/test_audio.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +import json + import openai import pytest import os @@ -27,7 +29,7 @@ def server(): "--enforce-eager", "--trust-remote-code", "--limit-mm-per-prompt", - f"audio={MAXIMUM_AUDIOS}", + json.dumps({"audio": MAXIMUM_AUDIOS}), ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -102,6 +104,35 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI, assert message.content is not None and len(message.content) >= 0 +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]]) +async def test_error_on_invalid_audio_url_type(client: openai.AsyncOpenAI, + model_name: str, + audio_url: str): + messages = [{ + "role": + "user", + "content": [ + { + "type": "audio_url", + "audio_url": audio_url + }, + { + "type": "text", + "text": "What's happening in this audio?" + }, + ], + }] + + # audio_url should be a dict {"url": "some url"}, not directly a string + with pytest.raises(openai.BadRequestError): + _ = await client.chat.completions.create(model=model_name, + messages=messages, + max_completion_tokens=10, + temperature=0.0) + + @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]]) diff --git a/tests/entrypoints/openai/test_embedding.py b/tests/entrypoints/openai/test_embedding.py index 29dbb80926860101f94c09a2f407b0f4668fa8ab..9f277ec25c282ce8aa9d23c4c452dcdf443f5c4a 100644 --- a/tests/entrypoints/openai/test_embedding.py +++ b/tests/entrypoints/openai/test_embedding.py @@ -12,11 +12,13 @@ import requests from vllm.entrypoints.openai.protocol import EmbeddingResponse from vllm.transformers_utils.tokenizer import get_tokenizer -from ...models.embedding.utils import check_embeddings_close +from ...models.embedding.utils import correctness_test from ...utils import RemoteOpenAIServer, models_path_prefix + MODEL_NAME = os.path.join(models_path_prefix, "intfloat/multilingual-e5-small") DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501 +DTYPE = "bfloat16" @pytest.fixture(scope="module") @@ -26,7 +28,7 @@ def server(): "embed", # use half precision for speed and memory savings in CI environment "--dtype", - "bfloat16", + DTYPE, "--enforce-eager", "--max-model-len", "512", @@ -44,9 +46,17 @@ async def client(server): yield async_client +@pytest.fixture(scope="module") +def hf_model(hf_runner): + with hf_runner(MODEL_NAME, dtype=DTYPE, + is_sentence_transformer=True) as hf_model: + yield hf_model + + @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str): +async def test_single_embedding(hf_model, client: openai.AsyncOpenAI, + model_name: str): input_texts = [ "The chef prepared a delicious meal.", ] @@ -67,6 +77,9 @@ async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str): assert embeddings.usage.prompt_tokens == 11 assert embeddings.usage.total_tokens == 11 + vllm_outputs = [d.embedding for d in embeddings.data] + correctness_test(hf_model, input_texts, vllm_outputs) + # test using token IDs input_tokens = [1, 1, 1, 1, 1] embedding_response = await client.embeddings.create( @@ -87,7 +100,8 @@ async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_batch_embedding(client: openai.AsyncOpenAI, model_name: str): +async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI, + model_name: str): # test list[str] input_texts = [ "The cat sat on the mat.", "A feline was resting on a rug.", @@ -108,6 +122,9 @@ async def test_batch_embedding(client: openai.AsyncOpenAI, model_name: str): assert embeddings.usage.prompt_tokens == 33 assert embeddings.usage.total_tokens == 33 + vllm_outputs = [d.embedding for d in embeddings.data] + correctness_test(hf_model, input_texts, vllm_outputs) + # test list[list[int]] input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24], [25, 32, 64, 77]] @@ -182,7 +199,7 @@ async def test_conversation_embedding(server: RemoteOpenAIServer, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_batch_base64_embedding(client: openai.AsyncOpenAI, +async def test_batch_base64_embedding(hf_model, client: openai.AsyncOpenAI, model_name: str): input_texts = [ "Hello my name is", @@ -193,6 +210,7 @@ async def test_batch_base64_embedding(client: openai.AsyncOpenAI, model=model_name, encoding_format="float") float_data = [d.embedding for d in responses_float.data] + correctness_test(hf_model, input_texts, float_data) responses_base64 = await client.embeddings.create(input=input_texts, model=model_name, @@ -203,24 +221,13 @@ async def test_batch_base64_embedding(client: openai.AsyncOpenAI, np.frombuffer(base64.b64decode(data.embedding), dtype="float32").tolist()) - check_embeddings_close( - embeddings_0_lst=float_data, - embeddings_1_lst=base64_data, - name_0="float", - name_1="base64", - ) + correctness_test(hf_model, input_texts, base64_data) # Default response is float32 decoded from base64 by OpenAI Client responses_default = await client.embeddings.create(input=input_texts, model=model_name) default_data = [d.embedding for d in responses_default.data] - - check_embeddings_close( - embeddings_0_lst=float_data, - embeddings_1_lst=default_data, - name_0="float", - name_1="default", - ) + correctness_test(hf_model, input_texts, default_data) @pytest.mark.asyncio diff --git a/tests/entrypoints/openai/test_embedding_dimensions.py b/tests/entrypoints/openai/test_embedding_dimensions.py index 79d43a2231f8254c1c4964853ae9e22a25857644..9f5a8c6839bc550eae827149342def6bb3ea74e1 100644 --- a/tests/entrypoints/openai/test_embedding_dimensions.py +++ b/tests/entrypoints/openai/test_embedding_dimensions.py @@ -3,80 +3,121 @@ Run `pytest tests/entrypoints/openai/test_embedding_dimensions.py`. """ -from typing import NamedTuple +from typing import Optional import openai import pytest from vllm.entrypoints.openai.protocol import EmbeddingResponse +from ...conftest import HfRunner +from ...models.embedding.utils import EmbedModelInfo, correctness_test from ...utils import RemoteOpenAIServer - -class ModelInfo(NamedTuple): - name: str - is_matryoshka: bool - - MODELS = [ - ModelInfo(name="BAAI/bge-m3", is_matryoshka=False), - ModelInfo(name="jinaai/jina-embeddings-v3", is_matryoshka=True), + EmbedModelInfo("intfloat/multilingual-e5-small", is_matryoshka=False), + EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5", + is_matryoshka=True, + matryoshka_dimensions=[256]), ] input_texts = [ "The chef prepared a delicious meal.", -] * 3 +] -@pytest.mark.asyncio -@pytest.mark.parametrize("model", MODELS) -async def test_validating_dimensions(model: ModelInfo): +@pytest.fixture(scope="module", params=MODELS) +def model_info(request): + return request.param + + +@pytest.fixture(scope="module", params=["bfloat16"]) +def dtype(request): + return request.param + + +@pytest.fixture(scope="module") +def server(model_info, dtype: str): args = [ "--task", "embed", # use half precision for speed and memory savings in CI environment "--dtype", - "bfloat16", + dtype, "--enforce-eager", "--max-model-len", - "512", - "--trust_remote_code" + "512" ] - with RemoteOpenAIServer(model.name, args) as remote_server: - client = remote_server.get_async_client() - - async def make_request(dimensions): - embedding_response = await client.embeddings.create( - model=model.name, - input=input_texts, - dimensions=dimensions, - encoding_format="float", - ) - embeddings = EmbeddingResponse.model_validate( - embedding_response.model_dump(mode="json")) - - assert embeddings.id is not None - assert len(embeddings.data) == 3 - assert len(embeddings.data[0].embedding) > 0 - assert embeddings.usage.completion_tokens == 0 - assert embeddings.usage.prompt_tokens > 0 - assert embeddings.usage.total_tokens > 0 - - if dimensions is not None: - assert len(embeddings.data[0].embedding) == dimensions - - if model.is_matryoshka: - for dimensions in [None, 16]: - await make_request(dimensions) + if model_info.name == "Snowflake/snowflake-arctic-embed-m-v1.5": + # Manually enable Matryoshka Embeddings + args.extend([ + "--trust_remote_code", "--hf_overrides", + '{"matryoshka_dimensions":[256]}' + ]) + + with RemoteOpenAIServer(model_info.name, args) as remote_server: + yield remote_server + + +@pytest.fixture(scope="module") +def hf_model(hf_runner, model_info, dtype: str): + with hf_runner(model_info.name, dtype=dtype, + is_sentence_transformer=True) as hf_model: + yield hf_model + + +@pytest.mark.asyncio +async def test_matryoshka(model_info: EmbedModelInfo, + server: RemoteOpenAIServer, hf_model: HfRunner): + client = server.get_async_client() + + async def make_request_and_correctness_test(dimensions): + prompts = input_texts * 3 + + embedding_response = await client.embeddings.create( + model=model_info.name, + input=prompts, + dimensions=dimensions, + encoding_format="float", + ) + embeddings = EmbeddingResponse.model_validate( + embedding_response.model_dump(mode="json")) + + assert embeddings.id is not None + assert len(embeddings.data) == 3 + assert len(embeddings.data[0].embedding) > 0 + assert embeddings.usage.completion_tokens == 0 + assert embeddings.usage.prompt_tokens > 0 + assert embeddings.usage.total_tokens > 0 + + if dimensions is not None: + assert len(embeddings.data[0].embedding) == dimensions + + vllm_outputs = [d.embedding for d in embeddings.data] + correctness_test(hf_model, prompts, vllm_outputs, dimensions) + + if model_info.is_matryoshka: + valid_dimensions: list[Optional[int]] = [None] + if model_info.matryoshka_dimensions is not None: + valid_dimensions += model_info.matryoshka_dimensions[:2] + + for dimensions in valid_dimensions: + await make_request_and_correctness_test(dimensions) + + invalid_dimensions: list[Optional[int]] = [-1] + if model_info.matryoshka_dimensions is not None: + assert 5 not in model_info.matryoshka_dimensions + invalid_dimensions.append(5) + + for dimensions in invalid_dimensions: with pytest.raises(openai.BadRequestError): - for dimensions in [-1]: - await make_request(dimensions) + await make_request_and_correctness_test(dimensions) - else: - for dimensions in [None]: - await make_request(dimensions) + else: + for dimensions in [None]: + await make_request_and_correctness_test(dimensions) + for dimensions in [-1, 16]: with pytest.raises(openai.BadRequestError): - for dimensions in [-1, 16]: - await make_request(dimensions) + await make_request_and_correctness_test(dimensions) diff --git a/tests/entrypoints/openai/test_lora_resolvers.py b/tests/entrypoints/openai/test_lora_resolvers.py new file mode 100644 index 0000000000000000000000000000000000000000..c96151349eb3f7d3f99f9b0edce4b64f33fef6ee --- /dev/null +++ b/tests/entrypoints/openai/test_lora_resolvers.py @@ -0,0 +1,209 @@ +# SPDX-License-Identifier: Apache-2.0 + +from contextlib import suppress +from dataclasses import dataclass, field +from http import HTTPStatus +from typing import Optional +from unittest.mock import MagicMock + +import pytest + +from vllm.config import MultiModalConfig +from vllm.engine.multiprocessing.client import MQLLMEngineClient +from vllm.entrypoints.openai.protocol import CompletionRequest, ErrorResponse +from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion +from vllm.entrypoints.openai.serving_models import (BaseModelPath, + OpenAIServingModels) +from vllm.lora.request import LoRARequest +from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry +from vllm.transformers_utils.tokenizer import get_tokenizer + +MODEL_NAME = "openai-community/gpt2" +BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] + +MOCK_RESOLVER_NAME = "mock_test_resolver" + + +@dataclass +class MockHFConfig: + model_type: str = "any" + + +@dataclass +class MockModelConfig: + """Minimal mock ModelConfig for testing.""" + model: str = MODEL_NAME + tokenizer: str = MODEL_NAME + trust_remote_code: bool = False + tokenizer_mode: str = "auto" + max_model_len: int = 100 + tokenizer_revision: Optional[str] = None + multimodal_config: MultiModalConfig = field( + default_factory=MultiModalConfig) + hf_config: MockHFConfig = field(default_factory=MockHFConfig) + logits_processor_pattern: Optional[str] = None + diff_sampling_param: Optional[dict] = None + allowed_local_media_path: str = "" + encoder_config = None + generation_config: str = "auto" + + def get_diff_sampling_param(self): + return self.diff_sampling_param or {} + + +class MockLoRAResolver(LoRAResolver): + + async def resolve_lora(self, base_model_name: str, + lora_name: str) -> Optional[LoRARequest]: + if lora_name == "test-lora": + return LoRARequest(lora_name="test-lora", + lora_int_id=1, + lora_local_path="/fake/path/test-lora") + elif lora_name == "invalid-lora": + return LoRARequest(lora_name="invalid-lora", + lora_int_id=2, + lora_local_path="/fake/path/invalid-lora") + return None + + +@pytest.fixture(autouse=True) +def register_mock_resolver(): + """Fixture to register and unregister the mock LoRA resolver.""" + resolver = MockLoRAResolver() + LoRAResolverRegistry.register_resolver(MOCK_RESOLVER_NAME, resolver) + yield + # Cleanup: remove the resolver after the test runs + if MOCK_RESOLVER_NAME in LoRAResolverRegistry.resolvers: + del LoRAResolverRegistry.resolvers[MOCK_RESOLVER_NAME] + + +@pytest.fixture +def mock_serving_setup(): + """Provides a mocked engine and serving completion instance.""" + mock_engine = MagicMock(spec=MQLLMEngineClient) + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine.errored = False + + def mock_add_lora_side_effect(lora_request: LoRARequest): + """Simulate engine behavior when adding LoRAs.""" + if lora_request.lora_name == "test-lora": + # Simulate successful addition + return + elif lora_request.lora_name == "invalid-lora": + # Simulate failure during addition (e.g. invalid format) + raise ValueError(f"Simulated failure adding LoRA: " + f"{lora_request.lora_name}") + + mock_engine.add_lora.side_effect = mock_add_lora_side_effect + mock_engine.generate.reset_mock() + mock_engine.add_lora.reset_mock() + + mock_model_config = MockModelConfig() + models = OpenAIServingModels(engine_client=mock_engine, + base_model_paths=BASE_MODEL_PATHS, + model_config=mock_model_config) + + serving_completion = OpenAIServingCompletion(mock_engine, + mock_model_config, + models, + request_logger=None) + + return mock_engine, serving_completion + + +@pytest.mark.asyncio +async def test_serving_completion_with_lora_resolver(mock_serving_setup, + monkeypatch): + monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true") + + mock_engine, serving_completion = mock_serving_setup + + lora_model_name = "test-lora" + req_found = CompletionRequest( + model=lora_model_name, + prompt="Generate with LoRA", + ) + + # Suppress potential errors during the mocked generate call, + # as we are primarily checking for add_lora and generate calls + with suppress(Exception): + await serving_completion.create_completion(req_found) + + mock_engine.add_lora.assert_called_once() + called_lora_request = mock_engine.add_lora.call_args[0][0] + assert isinstance(called_lora_request, LoRARequest) + assert called_lora_request.lora_name == lora_model_name + + mock_engine.generate.assert_called_once() + called_lora_request = mock_engine.generate.call_args[1]['lora_request'] + assert isinstance(called_lora_request, LoRARequest) + assert called_lora_request.lora_name == lora_model_name + + +@pytest.mark.asyncio +async def test_serving_completion_resolver_not_found(mock_serving_setup, + monkeypatch): + monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true") + + mock_engine, serving_completion = mock_serving_setup + + non_existent_model = "non-existent-lora-adapter" + req = CompletionRequest( + model=non_existent_model, + prompt="what is 1+1?", + ) + + response = await serving_completion.create_completion(req) + + mock_engine.add_lora.assert_not_called() + mock_engine.generate.assert_not_called() + + assert isinstance(response, ErrorResponse) + assert response.code == HTTPStatus.NOT_FOUND.value + assert non_existent_model in response.message + + +@pytest.mark.asyncio +async def test_serving_completion_resolver_add_lora_fails( + mock_serving_setup, monkeypatch): + monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true") + + mock_engine, serving_completion = mock_serving_setup + + invalid_model = "invalid-lora" + req = CompletionRequest( + model=invalid_model, + prompt="what is 1+1?", + ) + + response = await serving_completion.create_completion(req) + + # Assert add_lora was called before the failure + mock_engine.add_lora.assert_called_once() + called_lora_request = mock_engine.add_lora.call_args[0][0] + assert isinstance(called_lora_request, LoRARequest) + assert called_lora_request.lora_name == invalid_model + + # Assert generate was *not* called due to the failure + mock_engine.generate.assert_not_called() + + # Assert the correct error response + assert isinstance(response, ErrorResponse) + assert response.code == HTTPStatus.BAD_REQUEST.value + assert invalid_model in response.message + + +@pytest.mark.asyncio +async def test_serving_completion_flag_not_set(mock_serving_setup): + mock_engine, serving_completion = mock_serving_setup + + lora_model_name = "test-lora" + req_found = CompletionRequest( + model=lora_model_name, + prompt="Generate with LoRA", + ) + + await serving_completion.create_completion(req_found) + + mock_engine.add_lora.assert_not_called() + mock_engine.generate.assert_not_called() diff --git a/tests/entrypoints/openai/test_openai_schema.py b/tests/entrypoints/openai/test_openai_schema.py new file mode 100644 index 0000000000000000000000000000000000000000..1ccb803a328d608bf1fbf0cd7731c4dd935ef4d7 --- /dev/null +++ b/tests/entrypoints/openai/test_openai_schema.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import schemathesis +from schemathesis import GenerationConfig + +from ...utils import RemoteOpenAIServer + +schemathesis.experimental.OPEN_API_3_1.enable() + +MODEL_NAME = "HuggingFaceTB/SmolVLM-256M-Instruct" +MAXIMUM_IMAGES = 2 + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--task", + "generate", + "--max-model-len", + "2048", + "--max-num-seqs", + "5", + "--enforce-eager", + "--trust-remote-code", + "--limit-mm-per-prompt", + f"image={MAXIMUM_IMAGES}", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.fixture(scope="module") +def get_schema(server): + # avoid generating null (\x00) bytes in strings during test case generation + return schemathesis.openapi.from_uri( + f"{server.url_root}/openapi.json", + generation_config=GenerationConfig(allow_x00=False), + ) + + +schema = schemathesis.from_pytest_fixture("get_schema") + + +@schema.parametrize() +@schema.override(headers={"Content-Type": "application/json"}) +async def test_openapi_stateless(case): + #No need to verify SSL certificate for localhost + await case.call_and_validate(verify=False) diff --git a/tests/entrypoints/openai/test_transcription_validation.py b/tests/entrypoints/openai/test_transcription_validation.py index 29571bcd7649b15e447df8da23c1ca32ba74de92..5c48df3cebbc254dfb85b4de620a20d2b31d6176 100644 --- a/tests/entrypoints/openai/test_transcription_validation.py +++ b/tests/entrypoints/openai/test_transcription_validation.py @@ -192,3 +192,36 @@ async def test_stream_options(winning_call): else: continuous = continuous and hasattr(chunk, 'usage') assert final and continuous + + +@pytest.mark.asyncio +async def test_sampling_params(mary_had_lamb): + """ + Compare sampling with params and greedy sampling to assert results + are different when extreme sampling parameters values are picked. + """ + model_name = "openai/whisper-small" + server_args = ["--enforce-eager"] + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + transcription = await client.audio.transcriptions.create( + model=model_name, + file=mary_had_lamb, + language="en", + temperature=0.8, + extra_body=dict(seed=42, + repetition_penalty=1.9, + top_k=12, + top_p=0.4, + min_p=0.5, + frequency_penalty=1.8, + presence_penalty=2.0)) + + greedy_transcription = await client.audio.transcriptions.create( + model=model_name, + file=mary_had_lamb, + language="en", + temperature=0.0, + extra_body=dict(seed=42)) + + assert greedy_transcription.text != transcription.text diff --git a/tests/entrypoints/openai/test_video.py b/tests/entrypoints/openai/test_video.py index 79801e0f4cebae60a3defaafdd2dec99e0fcd2e6..0d1a88d568f601aba0ed458bdffe716de84e8fc5 100644 --- a/tests/entrypoints/openai/test_video.py +++ b/tests/entrypoints/openai/test_video.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 import os +import json + import openai import pytest import pytest_asyncio @@ -39,7 +41,7 @@ def server(): "--enforce-eager", "--trust-remote-code", "--limit-mm-per-prompt", - f"video={MAXIMUM_VIDEOS}", + json.dumps({"video": MAXIMUM_VIDEOS}), ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -114,6 +116,35 @@ async def test_single_chat_session_video(client: openai.AsyncOpenAI, assert message.content is not None and len(message.content) >= 0 +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) +async def test_error_on_invalid_video_url_type(client: openai.AsyncOpenAI, + model_name: str, + video_url: str): + messages = [{ + "role": + "user", + "content": [ + { + "type": "video_url", + "video_url": video_url + }, + { + "type": "text", + "text": "What's in this video?" + }, + ], + }] + + # video_url should be a dict {"url": "some url"}, not directly a string + with pytest.raises(openai.BadRequestError): + _ = await client.chat.completions.create(model=model_name, + messages=messages, + max_completion_tokens=10, + temperature=0.0) + + @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index 4cf232f8e80987721c88fa875b478fddb94a043f..713b0150ae5f7d14a07e387cc913eb2b756a001f 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +import json + import openai import pytest import os @@ -44,7 +46,7 @@ def server(): "--enforce-eager", "--trust-remote-code", "--limit-mm-per-prompt", - f"image={MAXIMUM_IMAGES}", + json.dumps({"image": MAXIMUM_IMAGES}), ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -144,6 +146,36 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI, assert message.content is not None and len(message.content) >= 0 +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +async def test_error_on_invalid_image_url_type(client: openai.AsyncOpenAI, + model_name: str, + image_url: str): + content_text = "What's in this image?" + messages = [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": image_url + }, + { + "type": "text", + "text": content_text + }, + ], + }] + + # image_url should be a dict {"url": "some url"}, not directly a string + with pytest.raises(openai.BadRequestError): + _ = await client.chat.completions.create(model=model_name, + messages=messages, + max_completion_tokens=10, + temperature=0.0) + + @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) diff --git a/tests/entrypoints/openai/test_vision_embedding.py b/tests/entrypoints/openai/test_vision_embedding.py index 53a4f884e2a23f65c4a495fa9312b42625b8939b..bc9cb9e614ecd122a614f05e395fef81dbefb526 100644 --- a/tests/entrypoints/openai/test_vision_embedding.py +++ b/tests/entrypoints/openai/test_vision_embedding.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 import os +import json + import pytest import requests from PIL import Image @@ -45,7 +47,7 @@ def server(): "--enforce-eager", "--trust-remote-code", "--limit-mm-per-prompt", - f"image={MAXIMUM_IMAGES}", + json.dumps({"image": MAXIMUM_IMAGES}), "--chat-template", str(vlm2vec_jinja_path), ] diff --git a/tests/kernels/conftest.py b/tests/kernels/attention/conftest.py similarity index 100% rename from tests/kernels/conftest.py rename to tests/kernels/attention/conftest.py diff --git a/tests/kernels/test_attention.py b/tests/kernels/attention/test_attention.py similarity index 99% rename from tests/kernels/test_attention.py rename to tests/kernels/attention/test_attention.py index 763de25fbec3f5a26bee5733528340a67a84ba8f..0cea5abbeddbf623d2e416e7e52f348f4fdb9378 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/attention/test_attention.py @@ -6,13 +6,12 @@ from typing import Optional import pytest import torch +from tests.kernels.allclose_default import get_default_atol, get_default_rtol from tests.kernels.utils import opcheck from vllm import _custom_ops as ops from vllm.platforms import current_platform from vllm.utils import get_max_shared_memory_bytes -from .allclose_default import get_default_atol, get_default_rtol - if not current_platform.is_rocm(): from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask diff --git a/tests/kernels/test_blocksparse_attention.py b/tests/kernels/attention/test_blocksparse_attention.py similarity index 99% rename from tests/kernels/test_blocksparse_attention.py rename to tests/kernels/attention/test_blocksparse_attention.py index 2aa86a4aef1895f7a47f1f4af9e5a691580a28e0..376df68aea3b9f3631190b3f4dd86750b9ff0167 100644 --- a/tests/kernels/test_blocksparse_attention.py +++ b/tests/kernels/attention/test_blocksparse_attention.py @@ -6,14 +6,13 @@ from typing import Optional import pytest import torch +from tests.kernels.allclose_default import get_default_atol, get_default_rtol from vllm import _custom_ops as ops from vllm.attention.ops.blocksparse_attention.interface import ( LocalStridedBlockSparseAttn) from vllm.platforms import current_platform from vllm.utils import get_max_shared_memory_bytes -from .allclose_default import get_default_atol, get_default_rtol - FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. # - 512 as a buffer diff --git a/tests/kernels/test_cache.py b/tests/kernels/attention/test_cache.py similarity index 93% rename from tests/kernels/test_cache.py rename to tests/kernels/attention/test_cache.py index 2a1fd243c03f2eec6c7f93a06d96e6eb10c1a457..1d2e5c33ea2a6f1ce5103bb5697f0a5e46371639 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/attention/test_cache.py @@ -16,6 +16,7 @@ NUM_LAYERS = [1] # Arbitrary values for testing NUM_HEADS = [8] # Arbitrary values for testing HEAD_SIZES = [64, 80, 120, 256] BLOCK_SIZES = [8, 16, 32] +CACHE_LAYOUTS = ["NHD", "HND"] # Parameters for MLA tests. KV_LORA_RANKS = [512] @@ -221,6 +222,7 @@ def test_reshape_and_cache( @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) +@pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS) @torch.inference_mode() def test_reshape_and_cache_flash( kv_cache_factory_flashinfer, @@ -233,17 +235,21 @@ def test_reshape_and_cache_flash( seed: int, device: str, kv_cache_dtype: str, + kv_cache_layout: str, ) -> None: current_platform.seed_everything(seed) torch.set_default_device(device) + # fp8 conversion requires continugous memory buffer. Reduce the number of + # blocks and tokens to consume less memory. + num_tokens = num_tokens // 2 + num_blocks = num_blocks // 2 # Create a random slot mapping. num_slots = block_size * num_blocks slot_mapping_lst = random.sample(range(num_slots), num_tokens) slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device) - qkv = torch.randn(num_tokens, 3, num_heads, @@ -262,27 +268,35 @@ def test_reshape_and_cache_flash( kv_cache_dtype, dtype, device=device, + cache_layout=kv_cache_layout, ) - key_cache, value_cache = key_caches[0].contiguous( - ), value_caches[0].contiguous() + key_cache, value_cache = key_caches[0], value_caches[0] del key_caches del value_caches k_scale = (key.amax() / 64.0).to(torch.float32) v_scale = (value.amax() / 64.0).to(torch.float32) + def permute_and_compact(x): + y = x if kv_cache_layout == "NHD" else x.permute(0, 2, 1, 3) + return y.contiguous() + + key_cache_compact = permute_and_compact(key_cache) + value_cache_compact = permute_and_compact(value_cache) + # Clone the KV caches. if kv_cache_dtype == "fp8": - cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - ops.convert_fp8(cloned_key_cache, key_cache, k_scale.item(), - kv_cache_dtype) - cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - ops.convert_fp8(cloned_value_cache, value_cache, v_scale.item(), + cloned_key_cache = torch.empty_like(key_cache_compact, + dtype=torch.float16) + ops.convert_fp8(cloned_key_cache, key_cache_compact, k_scale.item(), kv_cache_dtype) + cloned_value_cache = torch.empty_like(value_cache_compact, + dtype=torch.float16) + ops.convert_fp8(cloned_value_cache, value_cache_compact, + v_scale.item(), kv_cache_dtype) else: - cloned_key_cache = key_cache.clone() - cloned_value_cache = value_cache.clone() - + cloned_key_cache = key_cache_compact.clone() + cloned_value_cache = value_cache_compact.clone() # Call the reshape_and_cache kernel. opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash, (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, @@ -290,16 +304,20 @@ def test_reshape_and_cache_flash( cond=(head_size == HEAD_SIZES[0])) ops.reshape_and_cache_flash(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, k_scale, v_scale) + key_cache_compact = permute_and_compact(key_cache) + value_cache_compact = permute_and_compact(value_cache) if kv_cache_dtype == "fp8": - result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) + result_key_cache = torch.empty_like(key_cache_compact, + dtype=torch.float16) ops.convert_fp8(result_key_cache, - key_cache, + key_cache_compact, k_scale.item(), kv_dtype=kv_cache_dtype) - result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) + result_value_cache = torch.empty_like(value_cache_compact, + dtype=torch.float16) ops.convert_fp8(result_value_cache, - value_cache, + value_cache_compact, v_scale.item(), kv_dtype=kv_cache_dtype) @@ -311,8 +329,12 @@ def test_reshape_and_cache_flash( for i in range(num_tokens): block_idx = block_indicies_lst[i] block_offset = block_offsets_lst[i] - cloned_key_cache[block_idx, block_offset, :, :] = key[i] - cloned_value_cache[block_idx, block_offset, :, :] = value[i] + if kv_cache_layout == "NHD": + cloned_key_cache[block_idx, block_offset, :, :] = key[i] + cloned_value_cache[block_idx, block_offset, :, :] = value[i] + else: + cloned_key_cache[block_idx, :, block_offset, :] = key[i] + cloned_value_cache[block_idx, :, block_offset, :] = value[i] if kv_cache_dtype == "fp8": torch.testing.assert_close(result_key_cache, @@ -324,9 +346,9 @@ def test_reshape_and_cache_flash( atol=0.001, rtol=0.1) else: - torch.testing.assert_close(key_cache, cloned_key_cache) - torch.testing.assert_close(value_cache, cloned_value_cache) - + torch.testing.assert_close(key_cache_compact, cloned_key_cache) + torch.testing.assert_close(value_cache_compact, cloned_value_cache) + @pytest.mark.parametrize("direction", COPYING_DIRECTION) @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) diff --git a/tests/kernels/test_cascade_flash_attn.py b/tests/kernels/attention/test_cascade_flash_attn.py similarity index 100% rename from tests/kernels/test_cascade_flash_attn.py rename to tests/kernels/attention/test_cascade_flash_attn.py diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/attention/test_encoder_decoder_attn.py similarity index 100% rename from tests/kernels/test_encoder_decoder_attn.py rename to tests/kernels/attention/test_encoder_decoder_attn.py diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/attention/test_flash_attn.py similarity index 99% rename from tests/kernels/test_flash_attn.py rename to tests/kernels/attention/test_flash_attn.py index bc72ef32270a3160de735fefc813073bf2402d17..68eb4d4be6cd7e82b8ae4ee8244525c1f074dd9a 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/attention/test_flash_attn.py @@ -151,7 +151,7 @@ def test_flash_attn_with_paged_kv( v_descale = None if q_dtype is not None: # QKV are drawn from N(0, 1): no need for a fp8 scaling factor - maybe_quantized_query = query.to(q_dtype) + maybe_quantized_query = q.to(q_dtype) maybe_quantized_key_cache = key_cache.to(q_dtype) maybe_quantized_value_cache = value_cache.to(q_dtype) diff --git a/tests/kernels/test_flashmla.py b/tests/kernels/attention/test_flashmla.py similarity index 100% rename from tests/kernels/test_flashmla.py rename to tests/kernels/attention/test_flashmla.py diff --git a/tests/kernels/test_lightning_attn.py b/tests/kernels/attention/test_lightning_attn.py similarity index 100% rename from tests/kernels/test_lightning_attn.py rename to tests/kernels/attention/test_lightning_attn.py diff --git a/tests/kernels/test_merge_attn_states.py b/tests/kernels/attention/test_merge_attn_states.py similarity index 100% rename from tests/kernels/test_merge_attn_states.py rename to tests/kernels/attention/test_merge_attn_states.py diff --git a/tests/kernels/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py similarity index 100% rename from tests/kernels/test_mha_attn.py rename to tests/kernels/attention/test_mha_attn.py diff --git a/tests/kernels/test_mla_decode_cpu.py b/tests/kernels/attention/test_mla_decode_cpu.py similarity index 100% rename from tests/kernels/test_mla_decode_cpu.py rename to tests/kernels/attention/test_mla_decode_cpu.py diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/attention/test_prefix_prefill.py similarity index 100% rename from tests/kernels/test_prefix_prefill.py rename to tests/kernels/attention/test_prefix_prefill.py diff --git a/tests/kernels/attention/test_rocm_attention_selector.py b/tests/kernels/attention/test_rocm_attention_selector.py new file mode 100644 index 0000000000000000000000000000000000000000..4cf7bcb01d4d7641d9fe12a3ad53c76a1674689a --- /dev/null +++ b/tests/kernels/attention/test_rocm_attention_selector.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend +from vllm.platforms.rocm import RocmPlatform +from vllm.utils import STR_BACKEND_ENV_VAR + + +@pytest.fixture(autouse=True) +def clear_cache(): + """Clear lru cache to ensure each test case runs without caching. + """ + _cached_get_attn_backend.cache_clear() + + +def test_selector(monkeypatch: pytest.MonkeyPatch): + with monkeypatch.context() as m: + m.setenv(STR_BACKEND_ENV_VAR, "ROCM_FLASH") + + # Set the current platform to ROCm using monkeypatch + monkeypatch.setattr("vllm.attention.selector.current_platform", + RocmPlatform()) + + # Test standard ROCm attention + backend = get_attn_backend(16, torch.float16, torch.float16, 16, False) + assert (backend.get_name() == "ROCM_FLASH" + or backend.get_name() == "TRITON_ATTN_VLLM_V1") + + # MLA test for deepseek related + + # change the attention backend to triton MLA + m.setenv(STR_BACKEND_ENV_VAR, "TRITON_MLA") + backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, + False, True) + assert backend.get_name() == "TRITON_MLA" + + # If attention backend is None + # If use_mla is true + # The selected backend is triton MLA + m.setenv(STR_BACKEND_ENV_VAR, None) + backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, + False, True) + assert backend.get_name() == "TRITON_MLA" + + # change the attention backend to AITER MLA + m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA") + backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, + False, True) + assert backend.get_name() == "ROCM_AITER_MLA" + + # If attention backend is None + # If use_mla is true + # If VLLM_ROCM_USE_AITER is enabled + # The selected backend is ROCM_AITER_MLA + m.setenv(STR_BACKEND_ENV_VAR, None) + m.setenv("VLLM_ROCM_USE_AITER", "1") + backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, + False, True) + assert backend.get_name() == "ROCM_AITER_MLA" diff --git a/tests/kernels/test_triton_decode_attention.py b/tests/kernels/attention/test_triton_decode_attention.py similarity index 100% rename from tests/kernels/test_triton_decode_attention.py rename to tests/kernels/attention/test_triton_decode_attention.py diff --git a/tests/kernels/attention/untest_attention_selector.py b/tests/kernels/attention/untest_attention_selector.py new file mode 100644 index 0000000000000000000000000000000000000000..b0414244c2151c8139b8d6faf0416b1507997327 --- /dev/null +++ b/tests/kernels/attention/untest_attention_selector.py @@ -0,0 +1,252 @@ +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import patch + +import pytest +import torch + +from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend +from vllm.platforms.cpu import CpuPlatform +from vllm.platforms.cuda import CudaPlatform +from vllm.platforms.rocm import RocmPlatform +from vllm.utils import STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, STR_INVALID_VAL + + +@pytest.fixture(autouse=True) +def clear_cache(): + """Clear lru cache to ensure each test case runs without caching. + """ + _cached_get_attn_backend.cache_clear() + + +# Define MLA and non-MLA backends separately +DEVICE_MLA_BACKENDS = { + "cuda": ["TRITON_MLA", "FLASHMLA"], + "hip": ["TRITON_MLA", "ROCM_AITER_MLA"], + "cpu": [], +} + +DEVICE_REGULAR_ATTN_BACKENDS = { + "cuda": ["XFORMERS", "FLASHINFER"], + "hip": ["ROCM_FLASH"], + "cpu": ["TORCH_SDPA"], +} + +DEVICE_MLA_BLOCK_SIZES = { + "cuda": [16, 64], # CUDA supports both standard and extended block sizes + "hip": [16, 1], # HIP requires special handling for block_size=1 + "cpu": [16] # CPU uses fixed block size from test cases +} + + +def generate_params(): + params = [] + for use_mla in [True, False]: + for device in ["cuda", "hip", "cpu"]: + backends = DEVICE_MLA_BACKENDS[ + device] if use_mla else DEVICE_REGULAR_ATTN_BACKENDS[device] + for name in backends: + block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [ + 16 + ] + for block_size in block_sizes: + params.append( + pytest.param( + device, + name, + use_mla, + block_size, + id= + f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}" + )) + return params + + +@pytest.mark.parametrize("device, name, use_mla, block_size", + generate_params()) +@pytest.mark.parametrize("use_v1", [True, False]) +def test_env( + device: str, + name: str, + use_mla: bool, + block_size: int, + use_v1: bool, + monkeypatch: pytest.MonkeyPatch, +): + """Test attention backend selection with valid device-backend pairs.""" + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1" if use_v1 else "0") + m.setenv(STR_BACKEND_ENV_VAR, name) + m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0") + + if device == "cpu": + with patch("vllm.attention.selector.current_platform", + CpuPlatform()): + backend = get_attn_backend(16, torch.float16, torch.float16, + block_size, False) + assert backend.get_name() == "TORCH_SDPA" + + elif device == "hip": + with patch("vllm.attention.selector.current_platform", + RocmPlatform()): + if use_mla: + # Validate HIP MLA backend-block_size combinations + valid_combination = ( + (name == "TRITON_MLA" and block_size != 1) + or (name == "ROCM_AITER_MLA" and block_size == 1)) + + if valid_combination: + backend = get_attn_backend(16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla) + assert backend.get_name() == name + else: + with pytest.raises(ValueError) as exc_info: + get_attn_backend(16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla) + assert f"The selected backend, {name}" in str( + exc_info.value) + else: + backend = get_attn_backend(16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla) + expected = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH" + assert backend.get_name() == expected + + elif device == "cuda": + with patch("vllm.attention.selector.current_platform", + CudaPlatform()): + if use_mla: + if name == "FLASHMLA" and block_size == 64: + from vllm.attention.backends.flashmla import ( + is_flashmla_supported) + + # only on cuda platforms with specific capability. + is_supported, _ = is_flashmla_supported() + + if not is_supported: + # if platform is not supported then skip this case. + pytest.skip() + else: + backend = get_attn_backend(16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla) + expected = f"{name}_VLLM_V1" if use_v1 else name + assert backend.get_name() == expected + else: + backend = get_attn_backend(16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla) + expected = ("TRITON_MLA_VLLM_V1" + if use_v1 else "TRITON_MLA") + assert backend.get_name() == expected + elif name == "FLASHINFER": + backend = get_attn_backend(16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla) + expected = "FLASHINFER_VLLM_V1" if use_v1 else name + assert backend.get_name() == expected + else: + backend = get_attn_backend(16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla) + expected = "FLASH_ATTN_VLLM_V1" if use_v1 else name + assert backend.get_name() == expected + + +def test_flash_attn(monkeypatch: pytest.MonkeyPatch): + """Test FlashAttn validation.""" + # TODO: When testing for v1, pipe in `use_v1` as an argument to + # get_attn_backend + + with monkeypatch.context() as m: + m.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL) + + # Unsupported CUDA arch + monkeypatch.setattr(torch.cuda, "get_device_capability", lambda: + (7, 5)) + backend = get_attn_backend(16, torch.float16, None, 16, False) + assert backend.get_name() != STR_FLASH_ATTN_VAL + + # Reset the monkeypatch for subsequent tests + monkeypatch.undo() + + # Unsupported data type + backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False) + assert backend.get_name() != STR_FLASH_ATTN_VAL + + # Unsupported kv cache data type + backend = get_attn_backend(16, torch.float16, "fp8", 16, False) + assert backend.get_name() != STR_FLASH_ATTN_VAL + + # Unsupported block size + backend = get_attn_backend(16, torch.float16, None, 8, False) + assert backend.get_name() != STR_FLASH_ATTN_VAL + + # flash-attn is not installed + import sys + original_module = sys.modules.get('vllm_flash_attn') + monkeypatch.setitem(sys.modules, 'vllm_flash_attn', None) + backend = get_attn_backend(16, torch.float16, None, 16, False) + assert backend.get_name() != STR_FLASH_ATTN_VAL + + # Restore the original module if it existed + if original_module is not None: + monkeypatch.setitem(sys.modules, 'vllm_flash_attn', + original_module) + else: + monkeypatch.delitem(sys.modules, 'vllm_flash_attn', raising=False) + + # Unsupported head size + backend = get_attn_backend(17, torch.float16, None, 16, False) + assert backend.get_name() != STR_FLASH_ATTN_VAL + + # Attention-free models should bypass env and use PlaceholderAttention + backend = get_attn_backend(16, torch.float16, torch.float16, 16, True) + assert backend.get_name() != STR_FLASH_ATTN_VAL + + +@pytest.mark.parametrize("use_v1", [True, False]) +def test_invalid_env(use_v1: bool, monkeypatch: pytest.MonkeyPatch): + + with monkeypatch.context() as m, patch( + "vllm.attention.selector.current_platform", CudaPlatform()): + m.setenv("VLLM_USE_V1", "1" if use_v1 else "0") + m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL) + + # Test with head size 32 + backend = get_attn_backend(32, torch.float16, None, 16, False) + EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else "FLASH_ATTN" + assert backend.get_name() == EXPECTED + + # when block size == 16, backend will fall back to XFORMERS + # this behavior is not yet supported on V1. + if use_v1: + # TODO: support fallback on V1! + # https://github.com/vllm-project/vllm/issues/14524 + pass + else: + backend = get_attn_backend(16, torch.float16, None, 16, False) + assert backend.get_name() == "XFORMERS" diff --git a/tests/kernels/untest_flashinfer.py b/tests/kernels/attention/untest_flashinfer.py similarity index 100% rename from tests/kernels/untest_flashinfer.py rename to tests/kernels/attention/untest_flashinfer.py diff --git a/tests/kernels/test_activation.py b/tests/kernels/core/test_activation.py similarity index 97% rename from tests/kernels/test_activation.py rename to tests/kernels/core/test_activation.py index cf0f21ce06514fbdb3286ab51e222571fe5dbf09..79f838a954e70f8fe26488883358808ac400ad0d 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/core/test_activation.py @@ -5,6 +5,7 @@ import random import pytest import torch +from tests.kernels.allclose_default import get_default_atol, get_default_rtol from tests.kernels.utils import opcheck from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul, GeluAndMul, MulAndSilu, @@ -12,8 +13,6 @@ from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul, SiluAndMul) from vllm.platforms import current_platform -from .allclose_default import get_default_atol, get_default_rtol - DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing D = [512, 13824] # Arbitrary values for testing diff --git a/tests/kernels/test_fused_quant_layernorm.py b/tests/kernels/core/test_fused_quant_layernorm.py similarity index 100% rename from tests/kernels/test_fused_quant_layernorm.py rename to tests/kernels/core/test_fused_quant_layernorm.py diff --git a/tests/kernels/test_layernorm.py b/tests/kernels/core/test_layernorm.py similarity index 100% rename from tests/kernels/test_layernorm.py rename to tests/kernels/core/test_layernorm.py diff --git a/tests/kernels/core/test_opcheck.py b/tests/kernels/core/test_opcheck.py new file mode 100644 index 0000000000000000000000000000000000000000..c9a9679c5d80f4dfa67a6589d481063706711bfb --- /dev/null +++ b/tests/kernels/core/test_opcheck.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Tests for miscellaneous utilities +""" + +import torch + +from tests.kernels.utils import opcheck + + +def test_convert_fp8_opcheck(): + data = torch.randn((256, 256), dtype=torch.float32, device="cuda") + result = torch.empty_like(data, dtype=torch.float8_e4m3fn) + opcheck(torch.ops._C_cache_ops.convert_fp8, (result, data, 1.0, "fp8")) + + +# TODO: Add this back, currently fails with +# csrc/cuda_utils_kernels.cu:15 'invalid argument' +# @pytest.mark.skipif(not current_platform.is_cuda(), +# reason="Only supported for CUDA") +# def test_cuda_utils_opcheck(): +# opcheck(torch.ops._C_cuda_utils.get_device_attribute, (0, 0)) +# opcheck( +# torch.ops._C_cuda_utils. +# get_max_shared_memory_per_block_device_attribute, (0, )) diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/core/test_pos_encoding.py similarity index 99% rename from tests/kernels/test_pos_encoding.py rename to tests/kernels/core/test_pos_encoding.py index 44fad27e4005bf4def079fd67c9e3c95d5d448ac..648263c6b210ab308a7f1f1a5869cf51df49e1ec 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/core/test_pos_encoding.py @@ -6,11 +6,10 @@ from typing import Callable, Optional import pytest import torch +from tests.kernels.allclose_default import get_default_atol, get_default_rtol from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.platforms import current_platform -from .allclose_default import get_default_atol, get_default_rtol - IS_NEOX_STYLE = [True, False] DTYPES = [torch.half, torch.bfloat16, torch.float] HEAD_SIZES = [64, 80, 112, 120, 256] diff --git a/tests/kernels/test_rotary_embedding.py b/tests/kernels/core/test_rotary_embedding.py similarity index 100% rename from tests/kernels/test_rotary_embedding.py rename to tests/kernels/core/test_rotary_embedding.py diff --git a/tests/kernels/test_uva.py b/tests/kernels/core/test_uva.py similarity index 100% rename from tests/kernels/test_uva.py rename to tests/kernels/core/test_uva.py diff --git a/tests/kernels/untest_permute_cols.py b/tests/kernels/core/untest_permute_cols.py similarity index 100% rename from tests/kernels/untest_permute_cols.py rename to tests/kernels/core/untest_permute_cols.py diff --git a/tests/kernels/test_mamba_mixer2.py b/tests/kernels/mamba/test_mamba_mixer2.py similarity index 100% rename from tests/kernels/test_mamba_mixer2.py rename to tests/kernels/mamba/test_mamba_mixer2.py diff --git a/tests/kernels/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py similarity index 95% rename from tests/kernels/test_mamba_ssm_ssd.py rename to tests/kernels/mamba/test_mamba_ssm_ssd.py index 8f23a9b216e98a2eed32876d6f8339e3e14c8061..ee908105f557fa961a18864e3e7b1e6989758c0d 100644 --- a/tests/kernels/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -5,6 +5,8 @@ import torch import torch.nn.functional as F from einops import rearrange, repeat +from vllm.model_executor.layers.mamba.mamba2_metadata import ( + _seq_idx_to_chunk_indices_offsets) from vllm.model_executor.layers.mamba.ops.ssd_combined import ( mamba_chunk_scan_combined) from vllm.platforms import current_platform @@ -160,14 +162,14 @@ def generate_continous_batched_examples(example_lens_by_batch, # get the metadata cu_seqlens = torch.tensor((0, ) + spec, device=device).cumsum(dim=0) - sed_idx = torch.zeros(cu_seqlens[-1], + seq_idx = torch.zeros(cu_seqlens[-1], dtype=torch.int32, device=cu_seqlens.device) for i, (srt, end) in enumerate(zip( cu_seqlens, cu_seqlens[1:], )): - sed_idx[srt:end] = i + seq_idx[srt:end] = i # for cont batch if IND_E is None: @@ -177,7 +179,7 @@ def generate_continous_batched_examples(example_lens_by_batch, IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)] yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)], - cu_seqlens, sed_idx.unsqueeze(0), (A, dt2, X2, B2, C2)) + cu_seqlens, seq_idx.unsqueeze(0), (A, dt2, X2, B2, C2)) @pytest.mark.parametrize("itype", @@ -266,12 +268,15 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, exhausted: dict = {} # map: eg -> boolean indicating example is exhausted states = None - for Y_min, cu_seqlens, sed_idx, (A, dt, X, B, + for Y_min, cu_seqlens, seq_idx, (A, dt, X, B, C) in generate_continous_batched_examples( cases, num_examples, seqlen, last_taken, exhausted, n_heads, d_head, itype): + chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets( + seq_idx, chunk_size) + Y, new_states = mamba_chunk_scan_combined( X, dt, @@ -281,7 +286,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, chunk_size, D=None, cu_seqlens=cu_seqlens, - seq_idx=sed_idx, + seq_idx=seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, return_varlen_states=True, initial_states=states, ) diff --git a/tests/kernels/untest_causal_conv1d.py b/tests/kernels/mamba/untest_causal_conv1d.py similarity index 100% rename from tests/kernels/untest_causal_conv1d.py rename to tests/kernels/mamba/untest_causal_conv1d.py diff --git a/tests/kernels/untest_mamba_ssm.py b/tests/kernels/mamba/untest_mamba_ssm.py similarity index 100% rename from tests/kernels/untest_mamba_ssm.py rename to tests/kernels/mamba/untest_mamba_ssm.py diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..975cd418a171f28d01fc36f7fa2bba16d3707e58 --- /dev/null +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -0,0 +1,364 @@ +# SPDX-License-Identifier: Apache-2.0 +import dataclasses +from typing import Optional + +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8 +from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts, + fused_topk) +from vllm.platforms import current_platform + +NUM_EXPERTS = [40, 64] +TOP_KS = [6, 8] + +MNK_FACTORS = [ + (2, 1024, 1024), + (2, 1024, 1536), + (2, 3072, 1024), + (2, 3072, 1536), + (64, 1024, 1024), + (64, 1024, 1536), + (64, 3072, 1024), + (64, 3072, 1536), + (224, 1024, 1024), + (224, 1024, 1536), + (224, 3072, 1024), + (224, 3072, 1536), +] + + +@dataclasses.dataclass +class MOETensors: + a: torch.Tensor + w1: torch.Tensor + w2: torch.Tensor + ab_strides1: torch.Tensor + c_strides1: torch.Tensor + ab_strides2: torch.Tensor + c_strides2: torch.Tensor + + @staticmethod + def make_moe_tensors(m: int, k: int, n: int, e: int, + dtype: torch.dtype) -> "MOETensors": + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + return MOETensors(a=a, + w1=w1, + w2=w2, + ab_strides1=ab_strides1, + c_strides1=c_strides1, + ab_strides2=ab_strides2, + c_strides2=c_strides2) + + +@dataclasses.dataclass +class MOETensors8Bit(MOETensors): + # quantized + a_q: Optional[torch.Tensor] = None # a -> a_q + w1_q: Optional[torch.Tensor] = None # w1 -> w1_q + w2_q: Optional[torch.Tensor] = None # w2 -> w2_q + a_scale: Optional[torch.Tensor] = None + w1_scale: Optional[torch.Tensor] = None + w2_scale: Optional[torch.Tensor] = None + # dequantized + a_d: Optional[torch.Tensor] = None # a -> a_q -> a_d + w1_d: Optional[torch.Tensor] = None # w1 -> w1_q -> w1_d + w2_d: Optional[torch.Tensor] = None # w2 -> w2_q -> w2_d + + @staticmethod + def make_moe_tensors_8bit(m: int, k: int, n: int, e: int, + per_act_token: bool, + per_out_channel: bool) -> "MOETensors8Bit": + dtype = torch.half + q_dtype = torch.float8_e4m3fn + + moe_tensors_fp16 = MOETensors.make_moe_tensors(m, k, n, e, dtype) + + # a -> a_q, w1 -> w1_q, w2 -> w2_q + n_b_scales = 2 * n if per_out_channel else 1 + k_b_scales = k if per_out_channel else 1 + # Get the right scale for tests. + _, a_scale = ops.scaled_fp8_quant( + moe_tensors_fp16.a, use_per_token_if_dynamic=per_act_token) + a_q, _ = ops.scaled_fp8_quant(moe_tensors_fp16.a, + a_scale, + use_per_token_if_dynamic=per_act_token) + w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype) + w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype) + + w1_scale = torch.empty((e, n_b_scales, 1), + device="cuda", + dtype=torch.float32) + w2_scale = torch.empty((e, k_b_scales, 1), + device="cuda", + dtype=torch.float32) + for expert in range(e): + w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( + moe_tensors_fp16.w1[expert], + use_per_token_if_dynamic=per_out_channel) + w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( + moe_tensors_fp16.w2[expert], + use_per_token_if_dynamic=per_out_channel) + + # a_q -> a_d, w1_q -> w1_d, w2_q -> w2_d + a_d = a_q.float().mul(a_scale).to(dtype) + w1_d = torch.empty_like(moe_tensors_fp16.w1) + w2_d = torch.empty_like(moe_tensors_fp16.w2) + for expert in range(e): + w1_d[expert] = (w1_q[expert].float() * w1_scale[expert]).half() + w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half() + + return MOETensors8Bit(a=moe_tensors_fp16.a, + w1=moe_tensors_fp16.w1, + w2=moe_tensors_fp16.w2, + ab_strides1=moe_tensors_fp16.ab_strides1, + c_strides1=moe_tensors_fp16.c_strides1, + ab_strides2=moe_tensors_fp16.ab_strides2, + c_strides2=moe_tensors_fp16.c_strides2, + a_q=a_q, + w1_q=w1_q, + w2_q=w2_q, + a_scale=a_scale, + w1_scale=w1_scale, + w2_scale=w2_scale, + a_d=a_d, + w1_d=w1_d, + w2_d=w2_d) + + +def run_with_expert_maps(num_experts: int, num_local_experts: int, + **cutlass_moe_kwargs): + + def slice_experts(): + slice_params = [ + "w1_q", "w2_q", "ab_strides1", "ab_strides2", "c_strides1", + "c_strides2", "w1_scale", "w2_scale" + ] + full_tensors = { + k: v + for k, v in cutlass_moe_kwargs.items() + if k in slice_params and k in cutlass_moe_kwargs + } + + for i in range(0, num_experts, num_local_experts): + s, e = i, i + num_local_experts + + # make expert map + expert_map = [-1] * num_experts + expert_map[s:e] = list(range(num_local_experts)) + expert_map = torch.tensor(expert_map, + dtype=torch.int32, + device="cuda") + + # update cutlass moe arg with expert_map + cutlass_moe_kwargs["expert_map"] = expert_map + # update cutlass moe arg tensors + for k, t in full_tensors.items(): + cutlass_moe_kwargs[k] = t[s:e] + + yield cutlass_moe_kwargs + + out_tensor = torch.zeros_like(cutlass_moe_kwargs["a"]) + for kwargs in slice_experts(): + out_tensor = out_tensor + cutlass_moe_fp8(**kwargs) + + return out_tensor + + +def run_8_bit(moe_tensors: MOETensors8Bit, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_local_experts: Optional[int] = None) -> torch.Tensor: + assert not any([ + t is None for t in [ + moe_tensors.w1_q, moe_tensors.w2_q, moe_tensors.w1_scale, + moe_tensors.w2_scale, moe_tensors.a_scale + ] + ]) + + kwargs = { + 'a': moe_tensors.a, + 'w1_q': moe_tensors.w1_q.transpose(1, 2), # type: ignore[union-attr] + 'w2_q': moe_tensors.w2_q.transpose(1, 2), # type: ignore[union-attr] + 'topk_weights': topk_weights, + 'topk_ids_': topk_ids, + 'ab_strides1': moe_tensors.ab_strides1, + 'c_strides1': moe_tensors.c_strides1, + 'ab_strides2': moe_tensors.ab_strides2, + 'c_strides2': moe_tensors.c_strides2, + 'w1_scale': moe_tensors.w1_scale, + 'w2_scale': moe_tensors.w2_scale, + 'a1_scale': moe_tensors.a_scale + } + + num_experts = moe_tensors.w1.size(0) + with_ep = num_local_experts is not None or num_local_experts == num_experts + if not with_ep: + return cutlass_moe_fp8(**kwargs) + + assert num_local_experts is not None + return run_with_expert_maps( + num_experts, + num_local_experts, # type: ignore[arg-type] + **kwargs) + + +@pytest.mark.parametrize("m,n,k", MNK_FACTORS) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.skipif( + (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( + current_platform.get_device_capability()), + reason="Grouped gemm is not supported on this GPU type.") +def test_cutlass_moe_8_bit_no_graph( + m: int, + n: int, + k: int, + e: int, + topk: int, + per_act_token: bool, + per_out_ch: bool, +): + current_platform.seed_everything(7) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + + mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, + per_out_ch) + + score = torch.randn((m, e), device="cuda", dtype=torch.half) + topk_weights, topk_ids = fused_topk(mt.a, + score, + topk, + renormalize=False) + + # Note that we are using the dequantized versions of the tensors. + # Using a, w1 and w2 directly results in minor output differences. + triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights, + topk_ids) + + cutlass_output = run_8_bit(mt, topk_weights, topk_ids) + + torch.testing.assert_close(triton_output, + cutlass_output, + atol=5e-2, + rtol=1e-2) + + +@pytest.mark.parametrize("m,n,k", MNK_FACTORS) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.skipif( + (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( + current_platform.get_device_capability()), + reason="Grouped gemm is not supported on this GPU type.") +def test_cutlass_moe_8_bit_cuda_graph( + m: int, + n: int, + k: int, + e: int, + topk: int, + per_act_token: bool, + per_out_ch: bool, +): + current_platform.seed_everything(7) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + + dtype = torch.half + + mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, + per_out_ch) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + topk_weights, topk_ids = fused_topk(mt.a, + score, + topk, + renormalize=False) + + # Note that we are using the dequantized versions of the tensors. + # Using a, w1 and w2 directly results in minor output differences. + triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights, + topk_ids) + + stream = torch.cuda.Stream() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + cutlass_output = run_8_bit(mt, topk_weights, topk_ids) + + torch.cuda.synchronize() + graph.replay() + torch.cuda.synchronize() + + torch.testing.assert_close(triton_output, + cutlass_output, + atol=9e-2, + rtol=1e-2) + + +@pytest.mark.parametrize("m", [64]) +@pytest.mark.parametrize("n", [1024]) +@pytest.mark.parametrize("k", [4096]) +@pytest.mark.parametrize("e", [16]) +@pytest.mark.parametrize("topk", [1, 8]) +@pytest.mark.parametrize("per_act_token", [True]) +@pytest.mark.parametrize("per_out_channel", [True]) +@pytest.mark.parametrize("ep_size", [1, 2, 4, 8, 16]) +@pytest.mark.skipif( + (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( + current_platform.get_device_capability()), + reason="Grouped gemm is not supported on this GPU type.") +def test_cutlass_moe_8_bit_EP( + m: int, + n: int, + k: int, + e: int, + topk: int, + per_act_token: bool, + per_out_channel: bool, + ep_size: int, +): + current_platform.seed_everything(7) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + + mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, + per_out_channel) + + score = torch.randn((m, e), device="cuda", dtype=torch.half) + topk_weights, topk_ids = fused_topk(mt.a, + score, + topk, + renormalize=False) + + # Note that we are using the dequantized versions of the tensors. + # Using a, w1 and w2 directly results in minor output differences. + triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights, + topk_ids) + + assert e % ep_size == 0, "Cannot distribute experts evenly" + cutlass_output = run_8_bit(mt, + topk_weights, + topk_ids, + num_local_experts=e // ep_size) + + torch.testing.assert_close(triton_output, + cutlass_output, + atol=5e-2, + rtol=1e-2) diff --git a/tests/kernels/test_moe.py b/tests/kernels/moe/test_moe.py similarity index 72% rename from tests/kernels/test_moe.py rename to tests/kernels/moe/test_moe.py index 6778702e23810ed212c56422c3527a4a2e23a67e..30bcdf1b10c8949adc939f0035d0b3db19ab02f2 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -11,16 +11,14 @@ from transformers import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock import vllm.model_executor.layers.fused_moe # noqa -from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev, - torch_moe, torch_moe_single) -from vllm import _custom_ops as ops +from tests.kernels.utils import (opcheck, stack_and_dev, torch_moe, + torch_moe_single) from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_topk, moe_align_block_size) +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( fused_moe as iterative_moe) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( - marlin_quantize) + awq_marlin_quantize, marlin_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( quantize_weights) from vllm.model_executor.models.mixtral import MixtralMoE @@ -289,14 +287,17 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, @pytest.mark.skipif(current_platform.is_rocm(), reason="Currently, there is not supported on ROCm.") -@pytest.mark.parametrize("m", [1, 33, 64, 222]) -@pytest.mark.parametrize("n", [128, 2048]) -@pytest.mark.parametrize("k", [128, 1024]) -@pytest.mark.parametrize("e", NUM_EXPERTS) -@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("m", [1, 33, 123]) +@pytest.mark.parametrize("n", [128, 1024]) +@pytest.mark.parametrize("k", [256, 2048]) +@pytest.mark.parametrize("e", [4, 12]) +@pytest.mark.parametrize("topk", [2, 3]) +@pytest.mark.parametrize("ep_size", [1, 4]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("group_size", [-1, 32, 128]) @pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("num_bits", [4, 8]) +@pytest.mark.parametrize("has_zp", [True, False]) @pytest.mark.parametrize("is_k_full", [True, False]) @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") def test_fused_marlin_moe( @@ -305,9 +306,12 @@ def test_fused_marlin_moe( k: int, e: int, topk: int, + ep_size: int, + dtype: torch.dtype, group_size: int, act_order: bool, num_bits: int, + has_zp: bool, is_k_full: bool, ): current_platform.seed_everything(7) @@ -318,75 +322,110 @@ def test_fused_marlin_moe( return if group_size in (k, n): return + if has_zp: + return else: if not is_k_full: return - quant_type = (scalar_types.uint4b8 - if num_bits == 4 else scalar_types.uint8b128) - dtype = torch.float16 + if has_zp: + # we don't build kernel for int8 with zero + if num_bits == 8: + return + quant_type = scalar_types.uint4 if num_bits == 4 else scalar_types.uint8 + else: + quant_type = scalar_types.uint4b8 \ + if num_bits == 4 else scalar_types.uint8b128 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + if ep_size > 1: + local_e = e // ep_size + e_ids = torch.randperm(e, device="cuda", dtype=torch.int32)[:local_e] + e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32) + e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32) + w1 = w1[e_ids] + w2 = w2[e_ids] + else: + e_map = None + w_ref1_l = [] qweight1_l = [] scales1_l = [] + zeros1_l = [] g_idx1_l = [] sort_indices1_l = [] for i in range(w1.shape[0]): - test_perm = torch.randperm(k) - w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize( - w1[i].transpose(1, 0), quant_type, group_size, act_order, - test_perm) - w_ref1_l.append(w_ref1) - qweight1_l.append(qweight1) - scales1_l.append(scales1) - g_idx1_l.append(g_idx1) - sort_indices1_l.append(sort_indices1) + if has_zp: + w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize( + w1[i].transpose(1, 0), quant_type, group_size) + + w_ref1_l.append(w_ref1.T) + qweight1_l.append(qweight1) + scales1_l.append(scales1) + zeros1_l.append(zeros1) + else: + test_perm = torch.randperm(k) + quant_res = marlin_quantize(w1[i].transpose(1, 0), quant_type, + group_size, act_order, test_perm) + w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = quant_res + + w_ref1_l.append(w_ref1.T) + qweight1_l.append(qweight1) + scales1_l.append(scales1) + g_idx1_l.append(g_idx1) + sort_indices1_l.append(sort_indices1) w_ref1 = stack_and_dev(w_ref1_l) qweight1 = stack_and_dev(qweight1_l).contiguous() scales1 = stack_and_dev(scales1_l) - g_idx1 = stack_and_dev(g_idx1_l) - sort_indices1 = stack_and_dev(sort_indices1_l) + g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None + zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None + sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None w_ref2_l = [] qweight2_l = [] scales2_l = [] + zeros2_l = [] g_idx2_l = [] sort_indices2_l = [] for i in range(w2.shape[0]): - test_perm = torch.randperm(n) - w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize( - w2[i].transpose(1, 0), quant_type, group_size, act_order, - test_perm) - w_ref2_l.append(w_ref2) - qweight2_l.append(qweight2) - scales2_l.append(scales2) - g_idx2_l.append(g_idx2) - sort_indices2_l.append(sort_indices2) + if has_zp: + w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize( + w2[i].transpose(1, 0), quant_type, group_size) + + w_ref2_l.append(w_ref2.T) + qweight2_l.append(qweight2) + scales2_l.append(scales2) + zeros2_l.append(zeros2) + else: + test_perm = torch.randperm(n) + quant_res = marlin_quantize(w2[i].transpose(1, 0), quant_type, + group_size, act_order, test_perm) + w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = quant_res + + w_ref2_l.append(w_ref2.T) + qweight2_l.append(qweight2) + scales2_l.append(scales2) + g_idx2_l.append(g_idx2) + sort_indices2_l.append(sort_indices2) w_ref2 = stack_and_dev(w_ref2_l) qweight2 = stack_and_dev(qweight2_l).contiguous() scales2 = stack_and_dev(scales2_l) - g_idx2 = stack_and_dev(g_idx2_l) - sort_indices2 = stack_and_dev(sort_indices2_l) + g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None + zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None + sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None score = torch.randn((m, e), device="cuda", dtype=dtype) topk_weights, topk_ids = fused_topk(a, score, topk, False) - triton_output = fused_moe( - a, - w_ref1.transpose(1, 2).contiguous(), - w_ref2.transpose(1, 2).contiguous(), - score, - topk, - renormalize=False, - ) + torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map) + marlin_output = torch.ops.vllm.fused_marlin_moe( a, qweight1, @@ -396,113 +435,92 @@ def test_fused_marlin_moe( score, topk_weights, topk_ids, + global_num_experts=e, + expert_map=e_map, g_idx1=g_idx1, g_idx2=g_idx2, sort_indices1=sort_indices1, sort_indices2=sort_indices2, + w1_zeros=zeros1, + w2_zeros=zeros2, num_bits=num_bits, - is_k_full=is_k_full, - ) + is_k_full=is_k_full) - assert compute_max_diff(marlin_output, triton_output) < 4e-2 - - if ops.supports_moe_ops: - token_expert_indicies = torch.empty(m, - topk, - dtype=torch.int32, - device=a.device) - - opcheck(torch.ops._moe_C.topk_softmax, ( - topk_weights, - topk_ids, - token_expert_indicies, - score.float(), - )) - - block_size_m = 4 - - sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, - e) - - max_workspace_size = ((m + 255) // 256) * (max(2 * n, k) // 64) * 16 - workspace = torch.zeros(max_workspace_size, - dtype=torch.int, - device="cuda", - requires_grad=False) - - zp = torch.empty((0, 0), - dtype=dtype, - device="cuda", - requires_grad=False) - opcheck(torch.ops._moe_C.marlin_gemm_moe, - (a, qweight1, sorted_token_ids, topk_weights, topk_ids, - scales1, zp, g_idx1, sort_indices1, workspace, quant_type.id, - m, 2 * n, k, True, e, topk, block_size_m, True, False)) - + torch.testing.assert_close(marlin_output, torch_output, atol=2e-2, rtol=0) @pytest.mark.skipif(current_platform.is_rocm(), reason="Currently, there is not supported on ROCm.") @pytest.mark.skip("This test is here for the sake of debugging, " "don't run it in automated tests.") -@pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) -@pytest.mark.parametrize("n", [128, 2048, 256, 1024]) -@pytest.mark.parametrize("k", [128, 1024, 512]) -@pytest.mark.parametrize("e", [8, 64]) -@pytest.mark.parametrize("topk", [2, 6]) -@pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) +@pytest.mark.parametrize("m", [1, 33, 123]) +@pytest.mark.parametrize("n", [128, 1024]) +@pytest.mark.parametrize("k", [256, 2048]) +@pytest.mark.parametrize("e", [4, 12]) +@pytest.mark.parametrize("topk", [2, 3]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("group_size", [-1, 32, 128]) @pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("num_bits", [4, 8]) +@pytest.mark.parametrize("has_zp", [True, False]) @pytest.mark.parametrize("is_k_full", [True, False]) -@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") -def test_single_marlin_moe_multiply( - m: int, - n: int, - k: int, - e: int, - topk: int, - group_size: int, - act_order: bool, - num_bits: int, - is_k_full: bool, -): - +def test_single_marlin_moe_multiply(m: int, n: int, k: int, e: int, topk: int, + dtype: torch.dtype, group_size: int, + act_order: bool, num_bits: int, + has_zp: bool, is_k_full: bool): # Filter act_order if act_order: if group_size == -1: return - if group_size == k: + if group_size in (k, n): + return + if has_zp: return else: if not is_k_full: return - quant_type = (scalar_types.uint4b8 - if num_bits == 4 else scalar_types.uint8b128) - dtype = torch.float16 + if has_zp: + quant_type = scalar_types.uint4 if num_bits == 4 else scalar_types.uint8 + else: + quant_type = scalar_types.uint4b8 \ + if num_bits == 4 else scalar_types.uint8b128 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10 w_ref_l = [] - qweights_l = [] + qweight_l = [] scales_l = [] + zeros_l = [] g_idx_l = [] sort_indices_l = [] for i in range(w.shape[0]): - test_perm = torch.randperm(k) - w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize( - w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm) - w_ref_l.append(w_ref) - qweights_l.append(qweight) - scales_l.append(scales) - g_idx_l.append(g_idx) - sort_indices_l.append(sort_indices) + if has_zp: + w_ref, qweight, scales, zeros = awq_marlin_quantize( + w[i].transpose(1, 0), quant_type, group_size) + + w_ref_l.append(w_ref.T) + qweight_l.append(qweight) + scales_l.append(scales) + zeros_l.append(zeros) + else: + test_perm = torch.randperm(k) + w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize( + w[i].transpose(1, 0), quant_type, group_size, act_order, + test_perm) + + w_ref_l.append(w_ref.T) + qweight_l.append(qweight) + scales_l.append(scales) + g_idx_l.append(g_idx) + sort_indices_l.append(sort_indices) w_ref = stack_and_dev(w_ref_l) - qweight = stack_and_dev(qweights_l).contiguous() + qweight = stack_and_dev(qweight_l).contiguous() scales = stack_and_dev(scales_l) - g_idx = stack_and_dev(g_idx_l) - sort_indices = stack_and_dev(sort_indices_l) + g_idx = stack_and_dev(g_idx_l) if g_idx_l else None + zeros = stack_and_dev(zeros_l) if zeros_l else None + sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None score = torch.randn((m, e), device="cuda", dtype=dtype) marlin_output = torch.ops.vllm.single_marlin_moe( @@ -514,13 +532,14 @@ def test_single_marlin_moe_multiply( renormalize=False, g_idx=g_idx, sort_indices=sort_indices, + w_zeros=zeros, num_bits=num_bits, is_k_full=is_k_full, ) - torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) + torch_output = torch_moe_single(a, w_ref, score, topk) - assert compute_max_diff(marlin_output, torch_output) < 1e-2 + torch.testing.assert_close(marlin_output, torch_output, atol=2e-2, rtol=0) def test_moe_align_block_size_opcheck(): diff --git a/tests/kernels/test_triton_moe_ptpc_fp8.py b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py similarity index 100% rename from tests/kernels/test_triton_moe_ptpc_fp8.py rename to tests/kernels/moe/test_triton_moe_ptpc_fp8.py diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index 498da6001ae93e834ffca9f25e3f9a9a2bed7acc..764924f26783db5115b0624aae5a66ba210298a9 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -87,3 +87,63 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \ ref_out = (as_float32_tensor(x) * ref_iscale).clamp( fp8_traits_min, fp8_traits_max).to(FP8_DTYPE) return ref_out, ref_scale.view((1, )) + + +def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor, + As: torch.Tensor, Bs: torch.Tensor, block_size, + output_dtype): + """This function performs matrix multiplication with block-wise + quantization using native torch. + It is agnostic to the input data type and can be used for both int8 and + fp8 data types. + + It takes two input tensors `A` and `B` (int8) with scales `As` and + `Bs` (float32). + The output is returned in the specified `output_dtype`. + """ + A = A.to(torch.float32) + B = B.to(torch.float32) + assert A.shape[-1] == B.shape[-1] + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1] + assert A.shape[:-1] == As.shape[:-1] + + M = A.numel() // A.shape[-1] + N, K = B.shape + origin_C_shape = A.shape[:-1] + (N, ) + A = A.reshape(M, A.shape[-1]) + As = As.reshape(M, As.shape[-1]) + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + assert n_tiles == Bs.shape[0] + assert k_tiles == Bs.shape[1] + + C_shape = (M, N) + C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) + + A_tiles = [ + A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) + ] + B_tiles = [[ + B[ + j * block_n:min((j + 1) * block_n, N), + i * block_k:min((i + 1) * block_k, K), + ] for i in range(k_tiles) + ] for j in range(n_tiles)] + C_tiles = [ + C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles) + ] + As_tiles = [As[:, i:i + 1] for i in range(k_tiles)] + + for i in range(k_tiles): + for j in range(n_tiles): + a = A_tiles[i] + b = B_tiles[j][i] + c = C_tiles[j] + s = As_tiles[i] * Bs[j][i] + c[:, :] += torch.matmul(a, b.t()) * s + + C = C.reshape(origin_C_shape).to(output_dtype) + return C diff --git a/tests/kernels/test_allspark_gemm.py b/tests/kernels/quantization/test_allspark_gemm.py similarity index 100% rename from tests/kernels/test_allspark_gemm.py rename to tests/kernels/quantization/test_allspark_gemm.py diff --git a/tests/kernels/test_awq_marlin.py b/tests/kernels/quantization/test_awq_marlin.py similarity index 100% rename from tests/kernels/test_awq_marlin.py rename to tests/kernels/quantization/test_awq_marlin.py diff --git a/tests/kernels/test_awq_triton.py b/tests/kernels/quantization/test_awq_triton.py similarity index 100% rename from tests/kernels/test_awq_triton.py rename to tests/kernels/quantization/test_awq_triton.py diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py similarity index 99% rename from tests/kernels/test_block_fp8.py rename to tests/kernels/quantization/test_block_fp8.py index c450048bf6651b16d44a4f1a5af0394f76e75042..c57e39f4250646db4247ed0368ad3727dd40e0e5 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -6,6 +6,7 @@ import itertools import pytest import torch +from tests.kernels.quant_utils import native_w8a8_block_matmul from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe @@ -18,8 +19,6 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, w8a8_block_fp8_matmul) from vllm.platforms import current_platform -from .utils_block import native_w8a8_block_matmul - dg_available = False try: import deep_gemm diff --git a/tests/kernels/test_block_int8.py b/tests/kernels/quantization/test_block_int8.py similarity index 99% rename from tests/kernels/test_block_int8.py rename to tests/kernels/quantization/test_block_int8.py index 9447f9d691650eff3b98974033250e678a889854..104f23fd7cd2f63fd101facb494b9587f2846510 100644 --- a/tests/kernels/test_block_int8.py +++ b/tests/kernels/quantization/test_block_int8.py @@ -6,6 +6,7 @@ import itertools import pytest import torch +from tests.kernels.quant_utils import native_w8a8_block_matmul from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe @@ -13,8 +14,6 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import ( w8a8_block_int8_matmul) from vllm.platforms import current_platform -from .utils_block import native_w8a8_block_matmul - if current_platform.get_device_capability() < (7, 0): pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True) diff --git a/tests/kernels/test_cutlass_2of4_sparse.py b/tests/kernels/quantization/test_cutlass_2of4_sparse.py similarity index 99% rename from tests/kernels/test_cutlass_2of4_sparse.py rename to tests/kernels/quantization/test_cutlass_2of4_sparse.py index 2890e15d6cbaf65ebc56ff96f62fc07af9725114..d67d2dbb8998101e063df58f576ca5c49c010cf4 100644 --- a/tests/kernels/test_cutlass_2of4_sparse.py +++ b/tests/kernels/quantization/test_cutlass_2of4_sparse.py @@ -7,13 +7,12 @@ Run `pytest tests/kernels/test_semi_structured.py`. import pytest import torch +from tests.kernels.utils import baseline_scaled_mm, to_fp8, to_int8 from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( sparse_cutlass_supported) from vllm.platforms import current_platform -from .utils import baseline_scaled_mm, to_fp8, to_int8 - CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/quantization/test_cutlass_scaled_mm.py similarity index 99% rename from tests/kernels/test_cutlass.py rename to tests/kernels/quantization/test_cutlass_scaled_mm.py index 07eaeee32f259c0b5cc6887f891ee9d9ac09e5d5..c63965b7a4156894dfa1fd7e4d9cb0ca57d33234 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/quantization/test_cutlass_scaled_mm.py @@ -8,13 +8,11 @@ import random import pytest import torch -from tests.kernels.utils import opcheck +from tests.kernels.utils import baseline_scaled_mm, opcheck, to_fp8, to_int8 from vllm import _custom_ops as ops from vllm.platforms import current_platform from vllm.utils import cdiv -from .utils import baseline_scaled_mm, to_fp8, to_int8 - MNK_FACTORS = [ (1, 256, 128), (1, 16384, 1024), diff --git a/tests/kernels/test_gguf.py b/tests/kernels/quantization/test_gguf.py similarity index 100% rename from tests/kernels/test_gguf.py rename to tests/kernels/quantization/test_gguf.py diff --git a/tests/kernels/test_int8_kernel.py b/tests/kernels/quantization/test_int8_kernel.py similarity index 100% rename from tests/kernels/test_int8_kernel.py rename to tests/kernels/quantization/test_int8_kernel.py diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/quantization/test_int8_quant.py similarity index 100% rename from tests/kernels/test_int8_quant.py rename to tests/kernels/quantization/test_int8_quant.py diff --git a/tests/kernels/test_nvfp4_quant.py b/tests/kernels/quantization/test_nvfp4_quant.py similarity index 100% rename from tests/kernels/test_nvfp4_quant.py rename to tests/kernels/quantization/test_nvfp4_quant.py diff --git a/tests/kernels/test_nvfp4_scaled_mm.py b/tests/kernels/quantization/test_nvfp4_scaled_mm.py similarity index 100% rename from tests/kernels/test_nvfp4_scaled_mm.py rename to tests/kernels/quantization/test_nvfp4_scaled_mm.py diff --git a/tests/kernels/quantization/test_rocm_skinny_gemms.py b/tests/kernels/quantization/test_rocm_skinny_gemms.py new file mode 100644 index 0000000000000000000000000000000000000000..622079c394457c5ef6eb6f0199ffd57ff0c81684 --- /dev/null +++ b/tests/kernels/quantization/test_rocm_skinny_gemms.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +import vllm._custom_ops as ops +from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant +from vllm.platforms import current_platform + +DTYPES = [torch.bfloat16, torch.float16] +M = [16, 32, 64, 128, 256, 512, 1024, 4096, 8192] +K = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] # k % 8 == 0 +N = [1, 2, 3, 4] +SEEDS = [0] + + +@pytest.mark.parametrize("n", [1]) # only test for batch size 1 +@pytest.mark.parametrize("k", K) +@pytest.mark.parametrize("m", M) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("rows_per_block", [2, 4, 8, 16]) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.skipif(not current_platform.is_rocm(), + reason="only test for rocm") +@torch.inference_mode() +def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed): + torch.manual_seed(seed) + A = torch.rand(n, k, dtype=dtype, device="cuda") + B = torch.rand(m, k, dtype=dtype, device="cuda") + + ref_out = torch.matmul(A, B.t()) + out = ops.LLMM1(B, A, rows_per_block) + + assert torch.allclose(out, ref_out, rtol=0.01) + + +@pytest.mark.parametrize("n", N) # only test for batch size <= 4 +@pytest.mark.parametrize("k", K + [9216, 10240, 16384]) +@pytest.mark.parametrize("m", [8] + M) # m >= 8 +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.skipif(not current_platform.is_rocm(), + reason="only test for rocm") +def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed): + torch.manual_seed(seed) + cu_count = current_platform.get_cu_count() + + A = torch.rand(n, k, dtype=dtype, device="cuda") + B = torch.rand(m, k, dtype=dtype, device="cuda") + + ref_out = torch.matmul(A, B.t()) + out = ops.wvSplitK(B, A, cu_count) + + assert torch.allclose(out, ref_out, rtol=0.01) + + +@pytest.mark.parametrize("n", N) # only test for batch size <= 4 +@pytest.mark.parametrize("k", K[1:] + [14336, 24576, 32768]) # k % 16 == 0 +@pytest.mark.parametrize("m", M + [28672]) # m >= 16 +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.skipif(not current_platform.is_rocm(), + reason="only test for rocm") +def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed): + torch.manual_seed(seed) + + A = torch.rand(n, k, device="cuda") + B = torch.rand(m, k, device="cuda") + + A, scale_a = ref_dynamic_per_tensor_fp8_quant(A) + B, scale_b = ref_dynamic_per_tensor_fp8_quant(B) + + ref_out = torch._scaled_mm(A, + B.t(), + out_dtype=dtype, + scale_a=scale_a, + scale_b=scale_b) + out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, + current_platform.get_cu_count()) + + assert torch.allclose(out, ref_out, rtol=0.01) diff --git a/tests/kernels/test_triton_scaled_mm.py b/tests/kernels/quantization/test_triton_scaled_mm.py similarity index 100% rename from tests/kernels/test_triton_scaled_mm.py rename to tests/kernels/quantization/test_triton_scaled_mm.py diff --git a/tests/kernels/untest_aqlm.py b/tests/kernels/quantization/untest_aqlm.py similarity index 100% rename from tests/kernels/untest_aqlm.py rename to tests/kernels/quantization/untest_aqlm.py diff --git a/tests/kernels/untest_awq.py b/tests/kernels/quantization/untest_awq.py similarity index 100% rename from tests/kernels/untest_awq.py rename to tests/kernels/quantization/untest_awq.py diff --git a/tests/kernels/untest_fp8_quant.py b/tests/kernels/quantization/untest_fp8_quant.py similarity index 100% rename from tests/kernels/untest_fp8_quant.py rename to tests/kernels/quantization/untest_fp8_quant.py diff --git a/tests/kernels/untest_ggml.py b/tests/kernels/quantization/untest_ggml.py similarity index 100% rename from tests/kernels/untest_ggml.py rename to tests/kernels/quantization/untest_ggml.py diff --git a/tests/kernels/untest_gptq.py b/tests/kernels/quantization/untest_gptq.py similarity index 100% rename from tests/kernels/untest_gptq.py rename to tests/kernels/quantization/untest_gptq.py diff --git a/tests/kernels/test_machete_mm.py b/tests/kernels/quantization/untest_machete_mm.py similarity index 100% rename from tests/kernels/test_machete_mm.py rename to tests/kernels/quantization/untest_machete_mm.py diff --git a/tests/kernels/untest_marlin_gemm.py b/tests/kernels/quantization/untest_marlin_gemm.py similarity index 100% rename from tests/kernels/untest_marlin_gemm.py rename to tests/kernels/quantization/untest_marlin_gemm.py diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py deleted file mode 100644 index d823818044f48b26c6ef1318a0d46c9ce83bdcfd..0000000000000000000000000000000000000000 --- a/tests/kernels/test_attention_selector.py +++ /dev/null @@ -1,138 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from unittest.mock import patch - -import pytest -import torch - -from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend -from vllm.platforms.cpu import CpuPlatform -from vllm.platforms.cuda import CudaPlatform -from vllm.platforms.rocm import RocmPlatform - -from vllm.utils import STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, STR_INVALID_VAL -from vllm.platforms import current_platform - - -@pytest.fixture(autouse=True) -def clear_cache(): - """Clear lru cache to ensure each test case runs without caching. - """ - _cached_get_attn_backend.cache_clear() - - -@pytest.mark.parametrize( - "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"] if not current_platform.is_rocm() else ["ROCM_FLASH"]) -@pytest.mark.parametrize("use_v1", [True, False]) -@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"]) -def test_env( - name: str, - use_v1: bool, - device: str, - monkeypatch: pytest.MonkeyPatch, -): - """Test that the attention selector can be set via environment variable. - Note that we do not test FlashAttn because it is the default backend. - """ - - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1" if use_v1 else "0") - m.setenv(STR_BACKEND_ENV_VAR, name) - - if device == "cpu": - with patch("vllm.attention.selector.current_platform", - CpuPlatform()): - backend = get_attn_backend(16, torch.float16, torch.float16, - 16, False) - assert backend.get_name() == "TORCH_SDPA" - elif device == "hip": - with patch("vllm.attention.selector.current_platform", - RocmPlatform()): - backend = get_attn_backend(16, torch.float16, torch.float16, - 16, False) - EXPECTED = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH" - assert backend.get_name() == EXPECTED - else: - if name in ["XFORMERS", "FLASHINFER"]: - with patch("vllm.attention.selector.current_platform", - CudaPlatform()): - backend = get_attn_backend(16, torch.float16, - torch.float16, 16, False) - EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else name - assert backend.get_name() == EXPECTED - - -def test_flash_attn(monkeypatch: pytest.MonkeyPatch): - """Test FlashAttn validation.""" - # TODO: When testing for v1, pipe in `use_v1` as an argument to - # get_attn_backend - - with monkeypatch.context() as m: - m.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL) - - # Unsupported CUDA arch - monkeypatch.setattr(torch.cuda, "get_device_capability", lambda: - (7, 5)) - backend = get_attn_backend(16, torch.float16, None, 16, False) - assert backend.get_name() != STR_FLASH_ATTN_VAL - - # Reset the monkeypatch for subsequent tests - monkeypatch.undo() - - # Unsupported data type - backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False) - assert backend.get_name() != STR_FLASH_ATTN_VAL - - # Unsupported kv cache data type - backend = get_attn_backend(16, torch.float16, "fp8", 16, False) - assert backend.get_name() != STR_FLASH_ATTN_VAL - - # Unsupported block size - backend = get_attn_backend(16, torch.float16, None, 8, False) - assert backend.get_name() != STR_FLASH_ATTN_VAL - - # flash-attn is not installed - import sys - original_module = sys.modules.get('vllm_flash_attn') - monkeypatch.setitem(sys.modules, 'vllm_flash_attn', None) - backend = get_attn_backend(16, torch.float16, None, 16, False) - assert backend.get_name() != STR_FLASH_ATTN_VAL - - # Restore the original module if it existed - if original_module is not None: - monkeypatch.setitem(sys.modules, 'vllm_flash_attn', - original_module) - else: - monkeypatch.delitem(sys.modules, 'vllm_flash_attn', raising=False) - - # Unsupported head size - backend = get_attn_backend(17, torch.float16, None, 16, False) - assert backend.get_name() != STR_FLASH_ATTN_VAL - - # Attention-free models should bypass env and use PlaceholderAttention - backend = get_attn_backend(16, torch.float16, torch.float16, 16, True) - assert backend.get_name() != STR_FLASH_ATTN_VAL - - -@pytest.mark.parametrize("use_v1", [True, False]) -def test_invalid_env(use_v1: bool, monkeypatch: pytest.MonkeyPatch): - - with monkeypatch.context() as m, patch( - "vllm.attention.selector.current_platform", CudaPlatform()): - m.setenv("VLLM_USE_V1", "1" if use_v1 else "0") - m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL) - - # Test with head size 32 - backend = get_attn_backend(32, torch.float16, None, 16, False) - EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else "FLASH_ATTN" - assert backend.get_name() == EXPECTED - - # when block size == 16, backend will fall back to XFORMERS - # this behavior is not yet supported on V1. - if use_v1: - # TODO: support fallback on V1! - # https://github.com/vllm-project/vllm/issues/14524 - pass - else: - backend = get_attn_backend(16, torch.float16, None, 16, False) - assert backend.get_name() == "XFORMERS" diff --git a/tests/kernels/test_cutlass_mla_decode.py b/tests/kernels/test_cutlass_mla_decode.py new file mode 100644 index 0000000000000000000000000000000000000000..87e4bd4b096b3bd2ebe832554de84cb3dac54256 --- /dev/null +++ b/tests/kernels/test_cutlass_mla_decode.py @@ -0,0 +1,93 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch +import torch.nn.functional as F +from torch import Tensor + +import vllm._custom_ops as ops +from vllm.platforms import current_platform + +if not current_platform.has_device_capability(100): + pytest.skip( + reason="Cutlass MLA Requires compute capability of 10 or above.", + allow_module_level=True) + + +def ref_mla( + out: Tensor, # (bs, num_heads, v_head_dim) + query: Tensor, # (bs, num_heads, head_dim) + kv_cache: Tensor, # (num_blocks, block_size, head_dim) + scale: float, + block_tables: Tensor, # (bs, max_num_blocks) + seq_lens: Tensor, # (bs,) +): + bs, num_heads, v_head_dim = out.shape + head_dim = query.shape[2] + + for i in range(bs): + # gather and flatten KV-cache + kv = kv_cache[ + block_tables[i]] # (max_num_blocks, block_size, head_dim) + kv = kv.view(1, -1, + head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim) + v = kv[:, :, :v_head_dim] + + q = query[i].view(num_heads, 1, head_dim) + o = F.scaled_dot_product_attention(q, + kv, + v, + scale=scale, + enable_gqa=True) + out[i] = o.view(num_heads, v_head_dim) + + return out + + +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("mean_seq_len", [128, 1024, 4096]) +@pytest.mark.parametrize("bs", [1, 2, 4]) +@pytest.mark.parametrize("varlen", [False, True]) +@pytest.mark.parametrize("block_size", [16, 64, 128]) +def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int, + varlen: bool, block_size: int): + torch.set_default_dtype(dtype) + torch.set_default_device('cuda') + torch.manual_seed(42) + + d = 576 + h_q = 128 + dv = 512 + + q_nope_dim = 128 + q_pe_dim = 64 + scale = (q_nope_dim + q_pe_dim)**(-0.5) + if varlen: + seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2) + seq_lens = seq_lens.clip(2).to(torch.int32) + else: + seq_lens = torch.full((bs, ), mean_seq_len, dtype=torch.int32) + max_seq_len = seq_lens.max().item() + block_num = (max_seq_len + block_size - 1) // block_size + + # Pad block_num so that small blocks can be packed into full 128-sized + # CUTLASS tiles. One 128-wide tile can hold (128 // block_size) small + # blocks. + pack_factor = 128 // block_size + block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor + + q = torch.randn(bs, h_q, d) + block_table = torch.randint(0, + bs * block_num, (bs, block_num), + dtype=torch.int32) + + kv_cache = torch.randn(block_table.numel(), block_size, d) + + out_ref = q.new_zeros(bs, h_q, dv) + ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens) + out_ans = torch.zeros_like(out_ref) + q_nope = q[:, :, :dv].clone() + q_pe = q[:, :, dv:].clone() + ops.cutlass_mla_decode(out_ans, q_nope, q_pe, kv_cache, seq_lens, + block_table, scale) + + torch.testing.assert_close(out_ans, out_ref, atol=1e-2, rtol=1e-2) diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py deleted file mode 100644 index 3cfed6ae8538fd75a453e1107bbd0ed2327b85a5..0000000000000000000000000000000000000000 --- a/tests/kernels/test_cutlass_moe.py +++ /dev/null @@ -1,244 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -import pytest -import torch - -from vllm import _custom_ops as ops -from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8 -from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts, - fused_topk) -from vllm.platforms import current_platform - -NUM_EXPERTS = [40, 64] -TOP_KS = [6, 8] - - -def run(a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, - w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - ab_strides1: torch.Tensor, c_strides1: torch.Tensor, - ab_strides2: torch.Tensor, c_strides2: torch.Tensor): - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - return cutlass_moe_fp8(a, - w1_q, - w2_q, - w1_scale, - w2_scale, - topk_weights, - topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - a1_scale=a_scale) - - -@pytest.mark.parametrize("m", [2, 64, 224]) -@pytest.mark.parametrize("n", [1024, 3072]) -@pytest.mark.parametrize("k", [1024, 1536]) -@pytest.mark.parametrize("e", NUM_EXPERTS) -@pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("per_act_token", [True, False]) -@pytest.mark.parametrize("per_out_ch", [True, False]) -@pytest.mark.skipif( - (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( - current_platform.get_device_capability()), - reason="Grouped gemm is not supported on this GPU type.") -def test_cutlass_moe_no_graph( - m: int, - n: int, - k: int, - e: int, - topk: int, - per_act_token: bool, - per_out_ch: bool, -): - current_platform.seed_everything(7) - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - - dtype = torch.half - - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - - # Get the right scale for tests. - _, a_scale1 = ops.scaled_fp8_quant( - a, use_per_token_if_dynamic=per_act_token) - a_q, _ = ops.scaled_fp8_quant(a, - a_scale1, - use_per_token_if_dynamic=per_act_token) - - a_d = a_q.float().mul(a_scale1).to(dtype) - - n_b_scales = 2 * n if per_out_ch else 1 - k_b_scales = k if per_out_ch else 1 - - w1_q = torch.empty((e, 2 * n, k), - device="cuda", - dtype=torch.float8_e4m3fn) - w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn) - w1_scale = torch.empty((e, n_b_scales, 1), - device="cuda", - dtype=torch.float32) - w2_scale = torch.empty((e, k_b_scales, 1), - device="cuda", - dtype=torch.float32) - - ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) - ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) - c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - - for expert in range(e): - w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( - w1[expert], use_per_token_if_dynamic=per_out_ch) - w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( - w2[expert], use_per_token_if_dynamic=per_out_ch) - w1_q = w1_q.transpose(1, 2) - w2_q = w2_q.transpose(1, 2) - - ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) - ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) - c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - - w1_d = torch.empty_like(w1) - w2_d = torch.empty_like(w2) - for expert in range(e): - w1_d[expert] = (w1_q[expert].t().float() * w1_scale[expert]).half() - w2_d[expert] = (w2_q[expert].t().float() * w2_scale[expert]).half() - - score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) - - triton_output = fused_experts(a_d, w1_d, w2_d, topk_weights, topk_ids) - - cutlass_output = cutlass_moe_fp8(a, - w1_q, - w2_q, - w1_scale, - w2_scale, - topk_weights, - topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - a1_scale=a_scale1) - - #print(triton_output) - #print(cutlass_output) - #print("*") - - torch.testing.assert_close(triton_output, - cutlass_output, - atol=5e-2, - rtol=1e-2) - - -@pytest.mark.parametrize("m", [2, 64, 224]) -@pytest.mark.parametrize("n", [1024, 3072]) -@pytest.mark.parametrize("k", [1024, 1536]) -@pytest.mark.parametrize("e", NUM_EXPERTS) -@pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("per_act_token", [True, False]) -@pytest.mark.parametrize("per_out_ch", [True, False]) -@pytest.mark.skipif( - (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( - current_platform.get_device_capability()), - reason="Grouped gemm is not supported on this GPU type.") -def test_cutlass_moe_cuda_graph( - m: int, - n: int, - k: int, - e: int, - topk: int, - per_act_token: bool, - per_out_ch: bool, -): - current_platform.seed_everything(7) - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - - dtype = torch.half - - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - - # Get the right scale for tests. - _, a_scale1 = ops.scaled_fp8_quant( - a, use_per_token_if_dynamic=per_act_token) - a_q, _ = ops.scaled_fp8_quant(a, - a_scale1, - use_per_token_if_dynamic=per_act_token) - - a_d = a_q.float().mul(a_scale1).to(dtype) - - n_b_scales = 2 * n if per_out_ch else 1 - k_b_scales = k if per_out_ch else 1 - - w1_q = torch.empty((e, 2 * n, k), - device="cuda", - dtype=torch.float8_e4m3fn) - w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn) - w1_scale = torch.empty((e, n_b_scales, 1), - device="cuda", - dtype=torch.float32) - w2_scale = torch.empty((e, k_b_scales, 1), - device="cuda", - dtype=torch.float32) - - ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) - ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) - c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - - for expert in range(e): - w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( - w1[expert], use_per_token_if_dynamic=per_out_ch) - w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( - w2[expert], use_per_token_if_dynamic=per_out_ch) - w1_q = w1_q.transpose(1, 2) - w2_q = w2_q.transpose(1, 2) - - ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) - ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) - c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - - w1_d = torch.empty_like(w1) - w2_d = torch.empty_like(w2) - for expert in range(e): - w1_d[expert] = (w1_q[expert].t().float() * w1_scale[expert]).half() - w2_d[expert] = (w2_q[expert].t().float() * w2_scale[expert]).half() - - score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) - - triton_output = fused_experts(a_d, w1_d, w2_d, topk_weights, topk_ids) - - stream = torch.cuda.Stream() - graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph, stream=stream): - cutlass_output = run(a, a_scale1, w1_q, w2_q, w1_scale, w2_scale, - topk_weights, topk_ids, ab_strides1, - c_strides1, ab_strides2, c_strides2) - torch.cuda.synchronize() - graph.replay() - torch.cuda.synchronize() - - #print(triton_output) - #print(cutlass_output) - #print("*") - - torch.testing.assert_close(triton_output, - cutlass_output, - atol=9e-2, - rtol=1e-2) diff --git a/tests/kernels/test_rocm_attention_selector.py b/tests/kernels/test_rocm_attention_selector.py deleted file mode 100644 index 90b483b4a41a08ed46d25e2df3ec0a9cf57153d0..0000000000000000000000000000000000000000 --- a/tests/kernels/test_rocm_attention_selector.py +++ /dev/null @@ -1,34 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import pytest -import torch - -from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend -from vllm.platforms.rocm import RocmPlatform -from vllm.utils import STR_BACKEND_ENV_VAR - - -@pytest.fixture(autouse=True) -def clear_cache(): - """Clear lru cache to ensure each test case runs without caching. - """ - _cached_get_attn_backend.cache_clear() - - -def test_selector(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as m: - m.setenv(STR_BACKEND_ENV_VAR, "ROCM_FLASH") - - # Set the current platform to ROCm using monkeypatch - monkeypatch.setattr("vllm.attention.selector.current_platform", - RocmPlatform()) - - # Test standard ROCm attention - backend = get_attn_backend(16, torch.float16, torch.float16, 16, False) - assert (backend.get_name() == "ROCM_FLASH" - or backend.get_name() == "TRITON_ATTN_VLLM_V1") - - # mla test for deepseek related - backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, - False, True) - assert backend.get_name() == "TRITON_MLA" diff --git a/tests/kernels/test_triton_flash_attention.py b/tests/kernels/test_triton_flash_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..cf2bdc908e420135d17cecbcb39808c67b6364b7 --- /dev/null +++ b/tests/kernels/test_triton_flash_attention.py @@ -0,0 +1,499 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the triton_flash_attention kernel + +Run `pytest tests/kernels/test_triton_flash_attention.py`. +""" +import pytest +import torch + +from vllm.attention.ops.triton_flash_attention import (SUPPORTED_LAYOUTS, + MetaData, + compute_alibi_tensor, + scale_fp8, + triton_attention_rocm) +from vllm.platforms import current_platform + + +class ReferenceAttention: + + def __init__(self, Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, use_alibi, dtype, + input_metadata): + self.Z = Z + self.HQ = HQ + self.HK = HK + self.N_CTX_Q = N_CTX_Q + self.N_CTX_K = N_CTX_K + self.D_HEAD = D_HEAD + self.use_alibi = use_alibi + self.dtype = dtype + self.input_metadata = input_metadata + + def fwd(self, q, k, v): + scores = torch.einsum('bhqd,bhkd->bhqk', q, + k).float() * self.input_metadata.sm_scale + if self.input_metadata.causal: + mask = torch.tril(torch.ones(self.N_CTX_Q, + self.N_CTX_K, + device="cuda"), + diagonal=self.N_CTX_K - self.N_CTX_Q) + scores[:, :, mask == 0] = float("-inf") + + if self.input_metadata.bias is not None: + scores += self.input_metadata.bias + + if self.use_alibi: + scores += compute_alibi_tensor(self.input_metadata.alibi_slopes, + self.N_CTX_Q, self.N_CTX_K) + + p = torch.softmax(scores, dim=-1) + if self.input_metadata.causal: + # If N_CTX_Q > N_CTX_K, there's at least one row of all -infs going + # into softmax. This creates a row of NaNs as -inf - -inf == NaN. + # So we fix this by converting the NaNs to 0s, which is what they + # should be out of the softmax. + nan_mask = torch.isnan(p) + p[nan_mask == 1] = 0 + ref_out = torch.einsum('bhqk,bhkd->bhqd', p.to(self.dtype), v) + # compare + if self.input_metadata.layout == 'bshd': + ref_out = ref_out.transpose(1, 2).clone() + return ref_out + + def fwd_fp8(self, q_quantized, k_quantized, v_quantized): + q = (q_quantized.to(torch.float16) * self.input_metadata.q_descale).to( + self.dtype) + k = (k_quantized.to(torch.float16) * self.input_metadata.k_descale).to( + self.dtype) + v = (v_quantized.to(torch.float16) * self.input_metadata.v_descale).to( + self.dtype) + result = self.fwd(q, k, v) + if self.input_metadata.o_scale is not None: + result, _ = scale_fp8(result, self.input_metadata.o_scale) + return result + + def fwd_fp8_kv(self, q, k_quantized, v_quantized): + k_descale, v_descale = (self.input_metadata.k_descale, + self.input_metadata.v_descale) + k_dequantized = (k_quantized.to(torch.float32) * + k_descale.to(torch.float32)).to(self.dtype) + v_dequantized = (v_quantized.to(torch.float32) * + v_descale.to(torch.float32)).to(self.dtype) + return self.fwd(q, k_dequantized, v_dequantized) + + def varlen_fwd(self, q, k, v, is_mqa=False): + ref_out = torch.empty_like(q) + if is_mqa: + # Make KV look like HQ/HK "groups" of HK. Later, we will reshape so + # the size aligns with Q. + k_ref = k.view(k.shape[0], k.shape[1], 1, + k.shape[2]).expand(-1, -1, self.HQ // self.HK, -1) + v_ref = v.view(v.shape[0], v.shape[1], 1, + v.shape[2]).expand(-1, -1, self.HQ // self.HK, -1) + else: + k_ref = k + v_ref = v + + for i in range(0, self.input_metadata.num_contexts): + start_q, start_k = self.input_metadata.cu_seqlens_q[ + i], self.input_metadata.cu_seqlens_k[i] + end_q, end_k = self.input_metadata.cu_seqlens_q[ + i + 1], self.input_metadata.cu_seqlens_k[i + 1] + k_curr = k_ref[start_k:end_k] + v_curr = v_ref[start_k:end_k] + if is_mqa: + k_curr = k_curr.reshape(k_curr.shape[0], -1, k_curr.shape[3]) + v_curr = v_curr.reshape(v_curr.shape[0], -1, v_curr.shape[3]) + scores = torch.einsum('qhd,khd->qhk', q[start_q:end_q], + k_curr).float() + p = torch.softmax(scores * self.input_metadata.sm_scale, + dim=-1).half() + ref_out[start_q:end_q] = torch.einsum('qhk,khd->qhd', p, v_curr) + return ref_out + + +def quantize_input(q, k, v, fp8_kv=False, use_o_scale=False): + q_descale = None + if not fp8_kv: + q, q_descale = scale_fp8(q) + k, k_descale = scale_fp8(k) + v, v_descale = scale_fp8(v) + + # In real world use case, the p scale would be a parameter trained by the + # model. + p_scale = None + + o_scale = torch.rand(1, device="cuda", + requires_grad=False) if use_o_scale else None + + return q, k, v, q_descale, k_descale, v_descale, p_scale, o_scale + + +def input_helper( + Z, + HQ, + HK, + N_CTX_Q, + N_CTX_K, + D_HEAD, + dtype, + layout=None, + use_alibi=None, + causal=None, + is_fp8=False, + fp8_kv=False, + use_o_scale=False, + use_bias=False, +): + assert layout in SUPPORTED_LAYOUTS, "Got unsupported layout." + + current_platform.seed_everything(0) + + # Initialize q, k, v + if layout == 'bhsd': + q_tensor_shape = (Z, HQ, N_CTX_Q, D_HEAD) + k_tensor_shape = (Z, HK, N_CTX_K, D_HEAD) + elif layout == 'bshd': + q_tensor_shape = (Z, N_CTX_Q, HQ, D_HEAD) + k_tensor_shape = (Z, N_CTX_K, HK, D_HEAD) + + if use_alibi: + # for n heads the set of slopes is the geometric sequence that starts + # 2^(-8/n) + alibi_slopes = torch.tensor( + [2**(-8 / HQ * i) for i in range(1, HQ + 1)], + dtype=torch.float32, + device="cuda").repeat(Z, 1) + else: + alibi_slopes = None + + if use_bias: + bias = torch.randn((1, HQ, N_CTX_Q, N_CTX_K), + dtype=dtype, + device="cuda", + requires_grad=False) + else: + bias = None + + q = torch.randn(q_tensor_shape, + dtype=dtype, + device="cuda", + requires_grad=False) + k = torch.randn(k_tensor_shape, + dtype=dtype, + device="cuda", + requires_grad=False) + v = torch.randn(k_tensor_shape, + dtype=dtype, + device="cuda", + requires_grad=False) + + if is_fp8: + (q, k, v, q_descale, k_descale, v_descale, p_scale, + o_scale) = quantize_input(q, + k, + v, + use_o_scale=use_o_scale, + fp8_kv=fp8_kv) + else: + q_descale = k_descale = v_descale = p_scale = o_scale = None + + input_metadata = MetaData(sm_scale=D_HEAD**-0.5, + max_seqlens_q=N_CTX_Q, + max_seqlens_k=N_CTX_K, + layout=layout, + alibi_slopes=alibi_slopes, + alibi_batch=Z, + alibi_nheads=HQ, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + p_scale=p_scale, + o_scale=o_scale, + bias=bias, + seqlen_q=N_CTX_Q, + seqlen_k=N_CTX_K) + return q, k, v, input_metadata + + +def varlen_input_helper(Z, + HQ, + HK, + N_CTX_Q, + N_CTX_K, + D_HEAD, + dtype, + equal_seqlens=False): + current_platform.seed_everything(0) + + # Random sequence lengths. Using N_CTX as kind of max of sum of individual + # seqs + if not equal_seqlens: + max_seqlens_q = N_CTX_Q // Z + max_seqlens_k = N_CTX_K // Z + seqlens_q = torch.randint(1, + max_seqlens_q + 1, (Z, ), + dtype=torch.int32) + seqlens_k = torch.randint(1, + max_seqlens_k + 1, (Z, ), + dtype=torch.int32) + else: + seqlens_q = torch.full((Z, ), N_CTX_Q // Z) + seqlens_k = torch.full((Z, ), N_CTX_K // Z) + + # Calculate cumulative sequence lengths + cu_seqlens_q = torch.cat([ + torch.tensor([0], dtype=torch.int32), + seqlens_q.cumsum(dim=0, dtype=torch.int32) + ]) + cu_seqlens_k = torch.cat([ + torch.tensor([0], dtype=torch.int32), + seqlens_k.cumsum(dim=0, dtype=torch.int32) + ]) + cu_seqlens_q = cu_seqlens_q.to(device="cuda") + cu_seqlens_k = cu_seqlens_k.to(device="cuda") + + # Initialize q, k, v with variable lengths + total_q = cu_seqlens_q[-1].item() + total_k = cu_seqlens_k[-1].item() + q = torch.randn((total_q, HQ, D_HEAD), dtype=dtype, + device="cuda").normal_(mean=0., std=0.5).requires_grad_() + k = torch.randn((total_k, HK, D_HEAD), dtype=dtype, + device="cuda").normal_(mean=0., std=0.5).requires_grad_() + v = torch.randn((total_k, HK, D_HEAD), dtype=dtype, + device="cuda").normal_(mean=0., std=0.5).requires_grad_() + sm_scale = D_HEAD**-0.5 + input_metadata = MetaData(sm_scale=sm_scale) + input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) + return q, k, v, input_metadata + + +@pytest.mark.parametrize('Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD', [ + (1, 48, 12, 1, 1, 64), + (4, 4, 4, 128, 128, 65), + (16, 48, 48, 1, 1, 128), + (64, 48, 24, 3, 3, 128), + (4, 4, 4, 113, 123, 1), +]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('use_alibi', [True, False]) +@pytest.mark.parametrize('layout', ['bshd']) +def test_op_fwd(Z, + HQ, + HK, + N_CTX_Q, + N_CTX_K, + D_HEAD, + causal, + use_alibi, + layout, + dtype=torch.float16): + current_platform.seed_everything(0) + q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, + dtype, layout, use_alibi, causal) + + o = torch.empty_like(q) + + # triton implementation + tri_out, _ = triton_attention_rocm(q, k, v, o, input_metadata) + + # Transpose here if layout is bshd so we have same reference code for all + # layouts + if layout == 'bshd': + q = q.transpose(1, 2).clone() + k = k.transpose(1, 2).clone() + v = v.transpose(1, 2).clone() + # Replicate K and V if using MQA/GQA + if HQ != HK: + k = k.view(k.shape[0], k.shape[1], -1, k.shape[2], + k.shape[3]).expand(-1, -1, HQ // HK, -1, + -1).reshape(k.shape[0], -1, k.shape[2], + k.shape[3]) + v = v.view(v.shape[0], v.shape[1], -1, v.shape[2], + v.shape[3]).expand(-1, -1, HQ // HK, -1, + -1).reshape(v.shape[0], -1, v.shape[2], + v.shape[3]) + + ref_impl = ReferenceAttention(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, + use_alibi, dtype, input_metadata) + ref_out = ref_impl.fwd(q, k, v) + + torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) + + +@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ + (4, 48, 1, 1, 64), + (4, 48, 1, 1, 128), + (4, 48, 3, 3, 128), + (4, 4, 128, 128, 65), +]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('layout', ['bhsd']) +@pytest.mark.parametrize('use_o_scale', [True, False]) +@pytest.mark.skipif(torch.cuda.get_device_capability() < (9, 0), + reason="Triton FP8 requires CUDA 9.0 or higher") +def test_op_fwd_fp8(Z, + H, + N_CTX_Q, + N_CTX_K, + D_HEAD, + causal, + layout, + use_o_scale, + dtype=torch.float32): + current_platform.seed_everything(0) + + # Disable grad to save memory it won't run into OOM on CI machine. + # q, k, v, input_metadata = input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, + # dtype, layout) + + q_quantized, k_quantized, v_quantized, input_metadata = input_helper( + Z, + H, + H, + N_CTX_Q, + N_CTX_K, + D_HEAD, + dtype, + causal=causal, + layout=layout, + is_fp8=True, + use_o_scale=use_o_scale) + + o = torch.empty_like(q_quantized) if use_o_scale else None + + tri_out, _ = triton_attention_rocm(q_quantized, k_quantized, v_quantized, + o, input_metadata) + + ref_impl = ReferenceAttention(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, + dtype, input_metadata) + ref_out = ref_impl.fwd_fp8(q_quantized, k_quantized, v_quantized) + + # compare + torch.testing.assert_close(ref_out.to(torch.float32), + tri_out.to(torch.float32), + atol=7e-2, + rtol=2e-1) + + +@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ + (4, 48, 1, 1, 64), + (4, 48, 1, 1, 128), + (4, 48, 3, 3, 128), + (4, 4, 128, 128, 65), + (4, 4, 113, 123, 1), +]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('layout', ['bhsd']) +def test_op_fwd_fp8_kv(Z, + H, + N_CTX_Q, + N_CTX_K, + D_HEAD, + causal, + layout, + dtype=torch.float32): + current_platform.seed_everything(0) + + q, k_quantized, v_quantized, input_metadata = input_helper(Z, + H, + H, + N_CTX_Q, + N_CTX_K, + D_HEAD, + dtype, + causal=causal, + layout=layout, + is_fp8=True, + fp8_kv=True) + + o = torch.empty_like(q) + + tri_out, _ = triton_attention_rocm(q, k_quantized, v_quantized, o, + input_metadata) + + ref_impl = ReferenceAttention(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, + dtype, input_metadata) + ref_out = ref_impl.fwd_fp8_kv(q, k_quantized, v_quantized) + + torch.testing.assert_close(ref_out, tri_out, atol=3e-2, rtol=8e-1) + + +@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ + (4, 48, 1, 1, 64), + (4, 48, 1, 1, 128), + (4, 48, 3, 3, 128), + (4, 4, 128, 128, 65), +]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('use_bias', [True]) +@pytest.mark.parametrize('dtype', [torch.bfloat16]) +def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype): + current_platform.seed_everything(0) + q, k, v, input_metadata = input_helper(Z, + H, + H, + N_CTX_Q, + N_CTX_K, + D_HEAD, + dtype, + layout='bhsd', + causal=causal, + use_bias=use_bias) + o = torch.empty_like(q) + + # triton implementation + tri_out, _ = triton_attention_rocm(q, k, v, o, input_metadata) + + ref_impl = ReferenceAttention(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, + dtype, input_metadata) + ref_out = ref_impl.fwd(q, k, v) + + # compare + torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) + + +# NOTE: Uses thd layout, so also tests thd. +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(1, 48, 256, 64), + (4, 48, 512, 64), + (16, 48, 512, 64), + (64, 48, 128, 128)]) +@pytest.mark.parametrize('causal', [True, False]) +def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): + + q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, N_CTX, + D_HEAD, dtype) + + tri_out = torch.empty_like(q) + triton_attention_rocm(q, k, v, tri_out, input_metadata) + + ref_impl = ReferenceAttention(Z, H, H, N_CTX, N_CTX, D_HEAD, False, dtype, + input_metadata) + ref_out = ref_impl.varlen_fwd(q, k, v, is_mqa=False) + + torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) + + +# NOTE: Uses thd layout, so also tests thd. +@pytest.mark.parametrize('Z, HQ, HK, N_CTX, D_HEAD', [(2, 48, 24, 128, 64), + (4, 48, 12, 256, 64), + (4, 48, 4, 512, 64), + (4, 64, 16, 128, 128)]) +@pytest.mark.parametrize('causal', [False]) +def test_op_varlen_mqa_fwd(Z, + HQ, + HK, + N_CTX, + D_HEAD, + causal, + dtype=torch.float16): + q, k, v, input_metadata = varlen_input_helper(Z, HQ, HK, N_CTX, N_CTX, + D_HEAD, dtype) + + tri_out = torch.empty_like(q) + triton_attention_rocm(q, k, v, tri_out, input_metadata) + + ref_impl = ReferenceAttention(Z, HQ, HK, N_CTX, N_CTX, D_HEAD, False, + dtype, input_metadata) + ref_out = ref_impl.varlen_fwd(q, k, v, is_mqa=True) + + torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) diff --git a/tests/kernels/test_utils.py b/tests/kernels/test_utils.py deleted file mode 100644 index 58b0c78a580b2e19f2389a84f04217f3b56bf90b..0000000000000000000000000000000000000000 --- a/tests/kernels/test_utils.py +++ /dev/null @@ -1,25 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -""" -Tests for miscellaneous utilities -""" - -import pytest -import torch - -from tests.kernels.utils import opcheck -from vllm.platforms import current_platform - - -# def test_convert_fp8_opcheck(): -# data = torch.randn((256, 256), dtype=torch.float32, device="cuda") -# result = torch.empty_like(data, dtype=torch.float8_e4m3fn) -# opcheck(torch.ops._C_cache_ops.convert_fp8, (result, data, 1.0, "fp8")) - - -@pytest.mark.skipif(not current_platform.is_cuda(), - reason="Only supported for CUDA") -def test_cuda_utils_opcheck(): - opcheck(torch.ops._C_cuda_utils.get_device_attribute, (0, 0)) - opcheck( - torch.ops._C_cuda_utils. - get_max_shared_memory_per_block_device_attribute, (0, )) diff --git a/tests/kernels/untest_machete_gemm.py b/tests/kernels/untest_machete_gemm.py deleted file mode 100644 index 0dfa79e9af8ec9bde34e7b49d0c20e615d941205..0000000000000000000000000000000000000000 --- a/tests/kernels/untest_machete_gemm.py +++ /dev/null @@ -1,282 +0,0 @@ -"""Tests for the machete kernel. - -Run `pytest tests/kernels/test_machete_gemm.py`. -""" - -import math -from typing import Optional, Tuple - -import pytest -import torch - -from tests.kernels.utils import opcheck -from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - pack_rows, quantize_weights) -from vllm.platforms import current_platform -from vllm.scalar_type import ScalarType, scalar_types - -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] - -MNK_SHAPES = [ - (1, 128, 128), - (1, 512, 1024), - (1, 4096, 4096), - (13, 8192, 4096), - (26, 4096, 8192), - (1, 4096, 4096), - (257, 128, 4096), - (257, 4224, 4160), - (257, 4096, 4096), - (64, 4096, 4096), - (1024, 4096, 8192), - (1024, 8192, 4096), -] - -ACT_TYPES = [torch.float16, torch.bfloat16] -WTYPE_ZEROPOINTS = [ - # GPTQ style - (scalar_types.uint4b8, False), - (scalar_types.uint8b128, False), - # AWQ style - (scalar_types.uint4, True), - (scalar_types.uint8, True), -] - -# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel -# unit tests to a common utility function. Currently the use of -# `is_quant_method_supported` conflates kernels with quantization methods -# an assumption which is breaking down as quantizations methods can have -# have kernels and some kernels support multiple quantization methods. -IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90) - - -def rand_data(shape, dtype=torch.float16): - return 10 * (torch.rand(shape, dtype=dtype, device="cuda") - 0.3) - - -def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor): - return zps if zps is None else -1 * s * (zps.to(s.dtype)) - - -def machete_quantize_and_pack(w: torch.Tensor, - wtype: ScalarType, - group_size: int, - zero_points: bool = False): - assert wtype.is_integer(), "TODO: support floating point weights" - - w_ref, w_q, w_s, w_zp = quantize_weights( - w, - wtype, - group_size, - zero_points=zero_points, - # to match how the kernel applies zps - ref_zero_points_after_scales=True) - - w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape) - w_q = w_q.t().contiguous().t() # convert to col major - w_q_machete = ops.machete_prepack_B(w_q, wtype) - - opcheck(torch.ops._C.machete_prepack_B, (w_q, wtype)) - - return w_ref, w_q_machete, w_s, w_zp - - -def machete_gemm_test_helper(a: torch.Tensor, b: torch.Tensor, - wtype: ScalarType, group_size: int, - zero_points: bool): - w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack( - b, wtype, group_size, zero_points) - - output_ref = torch.matmul(a, w_ref) - - output = ops.machete_gemm( - a=a, - b_q=w_q_packed, - b_type=wtype, - b_scales=w_s, - b_zeros=maybe_convert_zeropoints(w_zp, w_s), - b_group_size=group_size, - ) - - # Relax atol as our reduction dim becomes larger (more rounding error) - # Relax atol when we have zeropoints since the way machete applies - # zeropoints (after scales) causes noise around 0 - atol = 1 if zero_points else min(5e-2 * math.sqrt(a.shape[1]), 1) - torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol) - - -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") -@pytest.mark.parametrize("shape", - MNK_SHAPES, - ids=lambda x: "x".join(str(v) for v in x)) -@pytest.mark.parametrize("atype", ACT_TYPES, ids=lambda x: str(x)) -@pytest.mark.parametrize("wtype_zeropoints", WTYPE_ZEROPOINTS) -@pytest.mark.parametrize("group_size", [128, None]) -def test_machete_all_schedules(shape, atype: torch.dtype, - wtype_zeropoints: Tuple[ScalarType, bool], - group_size: Optional[int]): - m, n, k = shape - wtype, zero_points = wtype_zeropoints - - if group_size is not None and k % group_size != 0: - return - - print(f"MNK = {m} {n} {k}") - - # Normalize group_size - if group_size is None: - group_size = k - assert group_size <= k - - a = rand_data((m, k), atype) - w = rand_data((k, n), atype) - - w_ref, w_q_machete, w_s, w_zp = machete_quantize_and_pack( - w, wtype, group_size, zero_points) - - output_ref = torch.matmul(a, w_ref) - - for schedule in ops.machete_supported_schedules(wtype): - print(f"Testing schedule {schedule}") - output = ops.machete_gemm( - a, - b_q=w_q_machete, - b_type=wtype, - b_scales=w_s, - b_zeros=maybe_convert_zeropoints(w_zp, w_s), - b_group_size=group_size, - schedule=schedule, - ) - - opcheck(torch.ops._C.machete_gemm, - (a, w_q_machete, wtype, w_s, maybe_convert_zeropoints( - w_zp, w_s), group_size, None, None, None, schedule)) - - # Relax atol as our reduction dim becomes larger (more rounding error) - # Relax atol when we have zeropoints since the way machete applies - # zeropoints (after scales) causes noise around 0 - atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1) - torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol),\ - f"Schedule failed {schedule}" - - -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") -@pytest.mark.parametrize("shape", - MNK_SHAPES, - ids=lambda x: "x".join(str(v) for v in x)) -@pytest.mark.parametrize("atype", ACT_TYPES, ids=lambda x: str(x)) -@pytest.mark.parametrize("wtype_zeropoints", WTYPE_ZEROPOINTS) -@pytest.mark.parametrize("group_size", [128, None]) -def test_machete_heuristic(shape, atype: torch.dtype, - wtype_zeropoints: Tuple[ScalarType, bool], - group_size: Optional[int]): - m, n, k = shape - wtype, zero_points = wtype_zeropoints - - if group_size is not None and k % group_size != 0: - return - - # Normalize group_size - if group_size is None: - group_size = k - assert group_size <= k - - a = rand_data((m, k), atype) - b = rand_data((k, n), atype) - - machete_gemm_test_helper(a, b, wtype, group_size, zero_points) - - -# Test working on other devices -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_machete_devices(device: str): - m, n, k = 512, 4096, 4096 - wtype = scalar_types.uint4b8 - group_size = 128 - zero_points = False - - print(f"MNK = {m} {n} {k}, device = {device}") - - a = rand_data((m, k), torch.float16).to(device) - b = rand_data((k, n), torch.float16).to(device) - - machete_gemm_test_helper(a, b, wtype, group_size, zero_points) - - -# Test working with a subset of A and B -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") -def test_machete_subset(): - big_m, big_n, big_k = 1024, 1024, 1024 - m, n, k = 512, 512, 512 - wtype = scalar_types.uint4b8 - group_size = 128 - zero_points = False - - whole_a = rand_data((big_m, big_k), torch.float16) - whole_b = rand_data((big_k, big_n), torch.float16) - - a = whole_a[0:m, 0:k] - b = whole_b[0:k, 0:n] - - machete_gemm_test_helper(a, b, wtype, group_size, zero_points) - - -# Test to make sure cuda graphs work -class MacheteLayer(torch.nn.Module): - - def __init__(self, **kwargs): - super().__init__() - self.kwargs = kwargs - - def forward(self, a): - return ops.machete_gemm(**self.kwargs) - - -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") -def test_machete_cuda_graph(): - m, n, k = 512, 4096, 4096 - - a = rand_data((m, k), torch.float16) - b = rand_data((k, n), torch.float16) - wtype = scalar_types.uint4b8 - group_size = 128 - zero_points = False - - w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack( - b, wtype, group_size, zero_points) - - # Construct a trivial model with a single layer that calls a machete kernel - model = MacheteLayer( - a=a, - b_q=w_q_packed, - b_type=wtype, - b_scales=w_s, - b_zeros=maybe_convert_zeropoints(w_zp, w_s), - b_group_size=group_size, - ) - - output_ref = torch.matmul(a, w_ref) - - # Run the model with a cuda graph - stream = torch.cuda.Stream() - with torch.cuda.stream(stream): - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - output = model(a) - output.zero_() - g.replay() - - # Relax atol as our reduction dim becomes larger (more rounding error) - # Relax atol when we have zeropoints since the way machete applies - # zeropoints (after scales) causes noise around 0 - atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1) - torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol) diff --git a/tests/kernels/utils_block.py b/tests/kernels/utils_block.py deleted file mode 100644 index c16cba50967eba7c694f28992c143ec56788d092..0000000000000000000000000000000000000000 --- a/tests/kernels/utils_block.py +++ /dev/null @@ -1,63 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import torch - - -def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor, - As: torch.Tensor, Bs: torch.Tensor, block_size, - output_dtype): - """This function performs matrix multiplication with block-wise - quantization using native torch. - It is agnostic to the input data type and can be used for both int8 and - fp8 data types. - - It takes two input tensors `A` and `B` (int8) with scales `As` and - `Bs` (float32). - The output is returned in the specified `output_dtype`. - """ - A = A.to(torch.float32) - B = B.to(torch.float32) - assert A.shape[-1] == B.shape[-1] - assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 - assert len(block_size) == 2 - block_n, block_k = block_size[0], block_size[1] - assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1] - assert A.shape[:-1] == As.shape[:-1] - - M = A.numel() // A.shape[-1] - N, K = B.shape - origin_C_shape = A.shape[:-1] + (N, ) - A = A.reshape(M, A.shape[-1]) - As = As.reshape(M, As.shape[-1]) - n_tiles = (N + block_n - 1) // block_n - k_tiles = (K + block_k - 1) // block_k - assert n_tiles == Bs.shape[0] - assert k_tiles == Bs.shape[1] - - C_shape = (M, N) - C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) - - A_tiles = [ - A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) - ] - B_tiles = [[ - B[ - j * block_n:min((j + 1) * block_n, N), - i * block_k:min((i + 1) * block_k, K), - ] for i in range(k_tiles) - ] for j in range(n_tiles)] - C_tiles = [ - C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles) - ] - As_tiles = [As[:, i:i + 1] for i in range(k_tiles)] - - for i in range(k_tiles): - for j in range(n_tiles): - a = A_tiles[i] - b = B_tiles[j][i] - c = C_tiles[j] - s = As_tiles[i] * Bs[j][i] - c[:, :] += torch.matmul(a, b.t()) * s - - C = C.reshape(origin_C_shape).to(output_dtype) - return C diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index 2cb1309c5dc9f30c7b5749c26bd94ec08194e709..84691b955ff53144fd6ebbd2887ab9e797474d0c 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -48,6 +48,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: ] sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256, + skip_special_tokens=False, stop=["[/assistant]"]) outputs = llm.generate( prompts, diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 576d95a471547115deb3bad3b9790dadb8465a0f..52b0834cacb8598b050d4476099edfccfb46d2ea 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -31,6 +31,8 @@ DEVICES = ([ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] if current_platform.is_cuda_alike() else ["cpu"]) +DEFAULT_DTYPE = torch.get_default_dtype() + @pytest.fixture(scope="function", autouse=True) def use_v0_only(monkeypatch: pytest.MonkeyPatch): @@ -125,8 +127,10 @@ def test_replace_submodules(dist_init, dummy_model): model = dummy_model manager = LoRAModelManager( model, 1, 1, 1, - LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8), - torch.device(DEVICES[0])) + LoRAConfig(max_lora_rank=8, + max_cpu_loras=8, + max_loras=8, + lora_dtype=DEFAULT_DTYPE), torch.device(DEVICES[0])) model = manager.model assert isinstance(model.get_submodule("dense1"), ColumnParallelLinearWithLoRA) @@ -155,7 +159,8 @@ def test_lora_model_manager(dist_init, dummy_model, device): 2, LoRAConfig(max_lora_rank=8, max_cpu_loras=3, - max_loras=2), + max_loras=2, + lora_dtype=DEFAULT_DTYPE), device=device) assert all(x is None for x in manager.lora_index_to_id) assert manager.add_adapter(model_lora1) @@ -221,7 +226,8 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device): 2, LoRAConfig(max_lora_rank=8, max_cpu_loras=3, - max_loras=2), + max_loras=2, + lora_dtype=DEFAULT_DTYPE), device=device) assert all(x is None for x in manager.lora_index_to_id) assert manager.add_adapter(model_lora1) @@ -316,7 +322,8 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device): 2, LoRAConfig(max_lora_rank=8, max_cpu_loras=2, - max_loras=2), + max_loras=2, + lora_dtype=DEFAULT_DTYPE), device=device) assert all(x is None for x in manager.lora_index_to_id) @@ -424,7 +431,10 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device): @pytest.mark.parametrize("device", DEVICES) def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, sql_lora_files, device): - lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) + lora_config = LoRAConfig(max_lora_rank=8, + max_cpu_loras=4, + max_loras=4, + lora_dtype=DEFAULT_DTYPE) worker_adapter_manager = LRUCacheWorkerLoRAManager( 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size - lora_config.lora_extra_vocab_size, lora_config, device, @@ -504,7 +514,10 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, sql_lora_files, device): # Should remove every LoRA not specified in the request. - lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) + lora_config = LoRAConfig(max_lora_rank=8, + max_cpu_loras=4, + max_loras=4, + lora_dtype=DEFAULT_DTYPE) worker_adapter_manager = WorkerLoRAManager( 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size - lora_config.lora_extra_vocab_size, lora_config, device, @@ -600,7 +613,8 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device): 2, LoRAConfig(max_lora_rank=8, max_cpu_loras=2, - max_loras=2), + max_loras=2, + lora_dtype=DEFAULT_DTYPE), device=device) model = manager.model diff --git a/tests/lora/test_minicpmv_tp.py b/tests/lora/test_minicpmv_tp.py index f53cb9e4194dc1e08b57d4428e3918782d289ead..8ee8009c6dd1035ea52e4670a00d6eb01b399b28 100644 --- a/tests/lora/test_minicpmv_tp.py +++ b/tests/lora/test_minicpmv_tp.py @@ -68,8 +68,12 @@ def test_minicpmv_lora(minicpmv_lora_files): max_loras=2, max_lora_rank=8, enforce_eager=True, + max_model_len=2048, + limit_mm_per_prompt={ + "image": 2, + "video": 0 + }, trust_remote_code=True, - enable_chunked_prefill=True, ) output1 = do_sample(llm, minicpmv_lora_files, lora_id=1) for i in range(len(EXPECTED_OUTPUT)): @@ -93,9 +97,11 @@ def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files): max_loras=4, max_lora_rank=64, tensor_parallel_size=4, + limit_mm_per_prompt={ + "image": 2, + "video": 0 + }, trust_remote_code=True, - enforce_eager=True, - enable_chunked_prefill=True, ) output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1) for i in range(len(EXPECTED_OUTPUT)): @@ -117,8 +123,11 @@ def test_minicpmv_tp4_fully_sharded_loras(minicpmv_lora_files): max_lora_rank=8, tensor_parallel_size=4, trust_remote_code=True, + limit_mm_per_prompt={ + "image": 1, + "video": 0 + }, fully_sharded_loras=True, - enable_chunked_prefill=True, ) output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1) for i in range(len(EXPECTED_OUTPUT)): diff --git a/tests/lora/test_resolver.py b/tests/lora/test_resolver.py new file mode 100644 index 0000000000000000000000000000000000000000..8ebc2ae98fc4341c162610d7566db0ebb012170c --- /dev/null +++ b/tests/lora/test_resolver.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import pytest + +from vllm.lora.request import LoRARequest +from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry + + +class DummyLoRAResolver(LoRAResolver): + """A dummy LoRA resolver for testing.""" + + async def resolve_lora(self, base_model_name: str, + lora_name: str) -> Optional[LoRARequest]: + if lora_name == "test_lora": + return LoRARequest( + lora_name=lora_name, + lora_path=f"/dummy/path/{base_model_name}/{lora_name}", + lora_int_id=abs(hash(lora_name))) + return None + + +def test_resolver_registry_registration(): + """Test basic resolver registration functionality.""" + registry = LoRAResolverRegistry + resolver = DummyLoRAResolver() + + # Register a new resolver + registry.register_resolver("dummy", resolver) + assert "dummy" in registry.get_supported_resolvers() + + # Get registered resolver + retrieved_resolver = registry.get_resolver("dummy") + assert retrieved_resolver is resolver + + +def test_resolver_registry_duplicate_registration(): + """Test registering a resolver with an existing name.""" + registry = LoRAResolverRegistry + resolver1 = DummyLoRAResolver() + resolver2 = DummyLoRAResolver() + + registry.register_resolver("dummy", resolver1) + registry.register_resolver("dummy", resolver2) + + assert registry.get_resolver("dummy") is resolver2 + + +def test_resolver_registry_unknown_resolver(): + """Test getting a non-existent resolver.""" + registry = LoRAResolverRegistry + + with pytest.raises(KeyError, match="not found"): + registry.get_resolver("unknown_resolver") + + +@pytest.mark.asyncio +async def test_dummy_resolver_resolve(): + """Test the dummy resolver's resolve functionality.""" + dummy_resolver = DummyLoRAResolver() + base_model_name = "base_model_test" + lora_name = "test_lora" + + # Test successful resolution + result = await dummy_resolver.resolve_lora(base_model_name, lora_name) + assert isinstance(result, LoRARequest) + assert result.lora_name == lora_name + assert result.lora_path == f"/dummy/path/{base_model_name}/{lora_name}" + + # Test failed resolution + result = await dummy_resolver.resolve_lora(base_model_name, + "nonexistent_lora") + assert result is None diff --git a/tests/lora/test_tokenizer_group.py b/tests/lora/test_tokenizer_group.py index 46d63b28c79bee71e3ed4b7ea6f4ef2da74d2ea2..2c0a5482eeefb85d7b21c505f889139f517f14de 100644 --- a/tests/lora/test_tokenizer_group.py +++ b/tests/lora/test_tokenizer_group.py @@ -1,22 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 +import os import pytest from transformers import AutoTokenizer, PreTrainedTokenizerBase from vllm.lora.request import LoRARequest from vllm.transformers_utils.tokenizer import get_lora_tokenizer -from vllm.transformers_utils.tokenizer_group import get_tokenizer_group -import os -from ..utils import RemoteOpenAIServer, models_path_prefix -from ..conftest import get_tokenizer_pool_config + +from vllm.transformers_utils.tokenizer_group import TokenizerGroup +from ..utils import models_path_prefix @pytest.mark.asyncio @pytest.mark.parametrize("tokenizer_group_type", [None, "ray"]) async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type): reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files) - tokenizer_group = get_tokenizer_group( - get_tokenizer_pool_config(tokenizer_group_type), + tokenizer_group = TokenizerGroup( tokenizer_id=os.path.join(models_path_prefix,"gpt2"), enable_lora=True, max_num_seqs=1, @@ -61,8 +60,7 @@ def test_get_lora_tokenizer(sql_lora_files, tmp_path): @pytest.mark.parametrize("max_num_seqs", [1, 2]) @pytest.mark.parametrize("max_loras", [1, 2]) def test_lora_tokenizers(enable_lora, max_num_seqs, max_loras): - tokenizer_group = get_tokenizer_group( - get_tokenizer_pool_config(None), + tokenizer_group = TokenizerGroup( tokenizer_id="gpt2", enable_lora=enable_lora, max_num_seqs=max_num_seqs, diff --git a/tests/lora/test_utils.py b/tests/lora/test_utils.py index 34a26e9edf36ac03129da4bbd946d8c319cead68..67f3866beff55c9b6ebfb3601bf2c129cd934fc2 100644 --- a/tests/lora/test_utils.py +++ b/tests/lora/test_utils.py @@ -9,7 +9,6 @@ from torch import nn from vllm.lora.utils import (get_adapter_absolute_path, parse_fine_tuned_lora_name, replace_submodule) -from vllm.utils import LRUCache def test_parse_fine_tuned_lora_name_valid(): @@ -40,6 +39,18 @@ def test_parse_fine_tuned_lora_name_valid(): False, False, ), + ( + "language_model.layers.9.mlp.down_proj.lora_A.weight", + "language_model.layers.9.mlp.down_proj", + True, + False, + ), + ( + "language_model.layers.9.mlp.down_proj.lora_B.weight", + "language_model.layers.9.mlp.down_proj", + False, + False, + ), } for name, module_name, is_lora_a, is_bias in fixture: assert (module_name, is_lora_a, @@ -85,114 +96,6 @@ def test_replace_submodule(): assert dict(model.named_modules())["seq1.dense2"] == dense2 -class TestLRUCache(LRUCache): - - def _on_remove(self, key, value): - if not hasattr(self, "_remove_counter"): - self._remove_counter = 0 - self._remove_counter += 1 - - -def test_lru_cache(): - cache = TestLRUCache(3) - - cache.put(1, 1) - assert len(cache) == 1 - - cache.put(1, 1) - assert len(cache) == 1 - - cache.put(2, 2) - assert len(cache) == 2 - - cache.put(3, 3) - assert len(cache) == 3 - assert set(cache.cache) == {1, 2, 3} - - cache.put(4, 4) - assert len(cache) == 3 - assert set(cache.cache) == {2, 3, 4} - assert cache._remove_counter == 1 - assert cache.get(2) == 2 - - cache.put(5, 5) - assert set(cache.cache) == {2, 4, 5} - assert cache._remove_counter == 2 - - assert cache.pop(5) == 5 - assert len(cache) == 2 - assert set(cache.cache) == {2, 4} - assert cache._remove_counter == 3 - - cache.pop(10) - assert len(cache) == 2 - assert set(cache.cache) == {2, 4} - assert cache._remove_counter == 3 - - cache.get(10) - assert len(cache) == 2 - assert set(cache.cache) == {2, 4} - assert cache._remove_counter == 3 - - cache.put(6, 6) - assert len(cache) == 3 - assert set(cache.cache) == {2, 4, 6} - assert 2 in cache - assert 4 in cache - assert 6 in cache - - cache.remove_oldest() - assert len(cache) == 2 - assert set(cache.cache) == {2, 6} - assert cache._remove_counter == 4 - - cache.clear() - assert len(cache) == 0 - assert cache._remove_counter == 6 - - cache._remove_counter = 0 - - cache[1] = 1 - assert len(cache) == 1 - - cache[1] = 1 - assert len(cache) == 1 - - cache[2] = 2 - assert len(cache) == 2 - - cache[3] = 3 - assert len(cache) == 3 - assert set(cache.cache) == {1, 2, 3} - - cache[4] = 4 - assert len(cache) == 3 - assert set(cache.cache) == {2, 3, 4} - assert cache._remove_counter == 1 - assert cache[2] == 2 - - cache[5] = 5 - assert set(cache.cache) == {2, 4, 5} - assert cache._remove_counter == 2 - - del cache[5] - assert len(cache) == 2 - assert set(cache.cache) == {2, 4} - assert cache._remove_counter == 3 - - cache.pop(10) - assert len(cache) == 2 - assert set(cache.cache) == {2, 4} - assert cache._remove_counter == 3 - - cache[6] = 6 - assert len(cache) == 3 - assert set(cache.cache) == {2, 4, 6} - assert 2 in cache - assert 4 in cache - assert 6 in cache - - # Unit tests for get_adapter_absolute_path @patch('os.path.isabs') def test_get_adapter_absolute_path_absolute(mock_isabs): diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index ac2e0f3542e789752b73d72b10e02c614faca987..2d9cf1d48fd5fefad6ccb234df6e286bb1d9c53c 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -11,6 +11,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( dispatch_fused_experts_func, dispatch_topk_func, torch_vllm_inplace_fused_experts, torch_vllm_outplace_fused_experts, vllm_topk_softmax) +from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + is_rocm_aiter_moe_enabled) from vllm.model_executor.layers.layernorm import ( RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm, rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm) @@ -100,11 +102,10 @@ def test_enabled_ops_invalid(env: str): def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) topk_func = dispatch_topk_func() - + is_rocm_aiter_moe_enabled.cache_clear() if current_platform.is_rocm() and int(use_rocm_aiter): from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( rocm_aiter_topk_softmax) - assert topk_func == rocm_aiter_topk_softmax else: assert topk_func == vllm_topk_softmax @@ -116,11 +117,11 @@ def test_fused_experts_dispatch(use_rocm_aiter: str, inplace: bool, monkeypatch): monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) + is_rocm_aiter_moe_enabled.cache_clear() fused_experts_func = dispatch_fused_experts_func(inplace) if current_platform.is_rocm() and int(use_rocm_aiter): from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( rocm_aiter_fused_experts) - assert fused_experts_func == rocm_aiter_fused_experts elif inplace: assert fused_experts_func == torch_vllm_inplace_fused_experts diff --git a/tests/models/decoder_only/audio_language/test_granite_speech.py b/tests/models/decoder_only/audio_language/test_granite_speech.py new file mode 100644 index 0000000000000000000000000000000000000000..7c14845ec54d465832ccc623e8bcfda77f6b4eeb --- /dev/null +++ b/tests/models/decoder_only/audio_language/test_granite_speech.py @@ -0,0 +1,143 @@ +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Sequence +from typing import Optional + +import pytest +from transformers import AutoModelForSpeechSeq2Seq + +from vllm.lora.request import LoRARequest +from vllm.sequence import SampleLogprobs + +from ....conftest import HfRunner, PromptAudioInput, VllmRunner, _AudioAssets +from ...registry import HF_EXAMPLE_MODELS +from ...utils import check_logprobs_close + +HF_AUDIO_PROMPT = "<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|><|audio|>can you transcribe the speech into a written format?<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>" # noqa: E501 + + +def vllm_to_hf_output( + vllm_output: tuple[list[int], str, Optional[SampleLogprobs]], +) -> tuple[list[int], str, Optional[SampleLogprobs]]: + """Sanitize hf output to be comparable with vllm output.""" + output_ids, output_str, out_logprobs = vllm_output + + hf_output_str = output_str + "<|end_of_text|>" + + return output_ids, hf_output_str, out_logprobs + + +MODEL_NAME = "ibm-granite/granite-speech-3.3-8b" +# Audio lora co-exists directly in the model directory, but +# currently still needs to be passed directly to vLLM. +audio_lora_path = MODEL_NAME +models = [MODEL_NAME] + + +def run_test( + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + inputs: Sequence[tuple[list[str], PromptAudioInput]], + model: str, + *, + max_model_len: int, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + """Inference result should be the same between hf and vllm. + + All the audio fixtures for the test are from AUDIO_ASSETS. + For huggingface runner, we provide the audio as input. + For vllm runner, we provide MultiModalDataDict objects + and corresponding MultiModalConfig as input. + Note, the text input is also adjusted to abide by vllm contract. + The text output is sanitized to be able to compare with hf. + """ + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + # max_model_len should be greater than image_feature_size + with vllm_runner( + model, + task="generate", + max_model_len=max_model_len, + max_num_seqs=1, + dtype=dtype, + limit_mm_per_prompt={"audio": 1}, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enable_lora=True, + max_lora_rank=64, + enforce_eager=True, + ) as vllm_model: + lora_request = LoRARequest("audio", 1, audio_lora_path) + vllm_outputs_per_case = [ + vllm_model.generate_greedy_logprobs(prompts, + max_tokens, + num_logprobs=num_logprobs, + audios=audios, + lora_request=lora_request) + for prompts, audios in inputs + ] + + with hf_runner(model, dtype=dtype, + auto_cls=AutoModelForSpeechSeq2Seq) as hf_model: + + hf_processor = hf_model.processor + eos_token_id = hf_processor.tokenizer.eos_token_id + + hf_outputs_per_case = [ + hf_model.generate_greedy_logprobs_limit(prompts, + max_tokens, + num_logprobs=num_logprobs, + audios=[audios], + eos_token_id=eos_token_id) + for prompts, audios in inputs + ] + + for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, + vllm_outputs_per_case): + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=[ + vllm_to_hf_output(output) for output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_model_len", [2048]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [10]) +def test_models(hf_runner, vllm_runner, model: str, audio_assets: _AudioAssets, + dtype: str, max_model_len: int, max_tokens: int, + num_logprobs: int) -> None: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + + audio, sr = audio_assets[0].audio_and_sample_rate + # This model expects 16k sample rate, which our test audio + # already is; if this changes, it may break this test, + # so we check it directly + assert sr == 16000 + run_test( + hf_runner, + vllm_runner, + [ + ([HF_AUDIO_PROMPT], [audio]), + ], + model, + dtype=dtype, + max_model_len=max_model_len, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py index fe466f7524fa29fa679fc9d3bc4d7b1a2af00e82..12980cb89c093a16ed4c1271b5791d254445be7b 100644 --- a/tests/models/decoder_only/audio_language/test_ultravox.py +++ b/tests/models/decoder_only/audio_language/test_ultravox.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional +import json +from typing import Any, Optional import numpy as np import pytest @@ -9,10 +10,11 @@ import os import pytest_asyncio from transformers import AutoModel, AutoTokenizer -from vllm.multimodal.audio import resample_audio +from vllm.multimodal.audio import resample_audio_librosa from vllm.sequence import SampleLogprobs -from ....conftest import HfRunner, VllmRunner + +from ....conftest import HfRunner, VllmRunner, _AudioAssets from ....utils import RemoteOpenAIServer, models_path_prefix from ...registry import HF_EXAMPLE_MODELS from ...utils import check_logprobs_close @@ -32,31 +34,34 @@ CHUNKED_PREFILL_KWARGS = { } -@pytest.fixture(scope="session") -def audio_assets(): - from vllm.assets.audio import AudioAsset - return [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")] - - @pytest.fixture(scope="module", params=("mary_had_lamb", "winning_call")) def audio(request): from vllm.assets.audio import AudioAsset return AudioAsset(request.param) +def params_kwargs_to_cli_args(params_kwargs: dict[str, Any]) -> list[str]: + """Convert kwargs to CLI args.""" + args = [] + for key, value in params_kwargs.items(): + if isinstance(value, bool): + if value: + args.append(f"--{key.replace('_','-')}") + else: + args.append(f"--{key.replace('_','-')}={value}") + return args + + @pytest.fixture(params=[ pytest.param({}, marks=pytest.mark.cpu_model), pytest.param(CHUNKED_PREFILL_KWARGS), ]) -def server(request, audio_assets): +def server(request, audio_assets: _AudioAssets): args = [ - "--dtype=bfloat16", "--max-model-len=4096", "--enforce-eager", - f"--limit-mm-per-prompt=audio={len(audio_assets)}", - "--trust-remote-code" - ] + [ - f"--{key.replace('_','-')}={value}" - for key, value in request.param.items() - ] + "--dtype", "bfloat16", "--max-model-len", "4096", "--enforce-eager", + "--limit-mm-per-prompt", + json.dumps({"audio": len(audio_assets)}), "--trust-remote-code" + ] + params_kwargs_to_cli_args(request.param) with RemoteOpenAIServer(MODEL_NAME, args, @@ -137,9 +142,9 @@ def run_test( [hf_prompt], max_tokens, num_logprobs=num_logprobs, - audios=[(resample_audio(audio[0], - orig_sr=audio[1], - target_sr=16000), 16000)]) + audios=[(resample_audio_librosa(audio[0], + orig_sr=audio[1], + target_sr=16000), 16000)]) for _, hf_prompt, audio in prompts_and_audios ] @@ -222,8 +227,9 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int, pytest.param({}, marks=pytest.mark.cpu_model), pytest.param(CHUNKED_PREFILL_KWARGS), ]) -def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str, - max_tokens: int, num_logprobs: int, +def test_models_with_multiple_audios(vllm_runner, audio_assets: _AudioAssets, + dtype: str, max_tokens: int, + num_logprobs: int, vllm_kwargs: dict) -> None: vllm_prompt = _get_prompt(len(audio_assets), @@ -242,7 +248,7 @@ def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str, @pytest.mark.asyncio -async def test_online_serving(client, audio_assets): +async def test_online_serving(client, audio_assets: _AudioAssets): """Exercises online serving with/without chunked prefill enabled.""" messages = [{ diff --git a/tests/models/decoder_only/language/test_hybrid.py b/tests/models/decoder_only/language/test_hybrid.py index 21adf439afbd942e7fdd0e8a5296cdce76e0678f..885a0054adfc1148662d858853f28380d1e969f8 100644 --- a/tests/models/decoder_only/language/test_hybrid.py +++ b/tests/models/decoder_only/language/test_hybrid.py @@ -7,76 +7,85 @@ from tests.utils import multi_gpu_test from vllm.engine.arg_utils import EngineArgs from vllm.sampling_params import SamplingParams -from ...utils import check_outputs_equal from ....utils import models_path_prefix - -# This test is for the hybrid models -MODELS = [os.path.join(models_path_prefix, "ai21labs/Jamba-tiny-dev"), os.path.join(models_path_prefix, "Zyphra/Zamba2-1.2B-instruct")] -# Bamba at Fp32 is too big for the CI (L4 GPU). -# MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"] - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [96]) +from ...utils import check_logprobs_close, check_outputs_equal + +# NOTE: The first model in each list is taken as the primary model, +# meaning that it will be used in all tests in this file +# The rest of the models will only be tested by test_models + +SSM_MODELS = [ + os.path.join(models_path_prefix, "state-spaces/mamba-130m-hf"), + os.path.join(models_path_prefix, "tiiuae/falcon-mamba-tiny-dev"), + # TODO: Compare to a Mamba2 model. The HF transformers implementation of + # Mamba2 is buggy for Codestral as it doesn't handle n_groups. + # See https://github.com/huggingface/transformers/pull/35943 + # "mistralai/Mamba-Codestral-7B-v0.1", +] + +HYBRID_MODELS = [ + os.path.join(models_path_prefix, "ai21labs/Jamba-tiny-dev"), + # NOTE: Running Plamo2 in transformers implementation requires to install + # causal-conv1d package, which is not listed as a test dependency as it's + # not compatible with pip-compile. + os.path.join(models_path_prefix, "pfnet/plamo-2-1b"), + os.path.join(models_path_prefix, "Zyphra/Zamba2-1.2B-instruct"), + os.path.join(models_path_prefix, "ibm-ai-platform/Bamba-9B"), +] + +# Avoid OOM +MAX_NUM_SEQS = 4 + + +@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) def test_models( hf_runner, vllm_runner, example_prompts, model: str, - dtype: str, max_tokens: int, + num_logprobs: int, ) -> None: + with hf_runner(model) as hf_model: + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) - # numeric error produces different generation - if "Bamba" in model: - example_prompts.pop(3) - - model_kwargs = { - "use_mamba_kernels": False, # mamba kernels are not installed so HF - # don't use them - } - if "Zamba2" in model: - # Zamba2 HF implementation automatically checks if mamba kernels are - # installed - model_kwargs = {} + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) - with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model: - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - - with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - - for i in range(len(example_prompts)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_outputs[i] - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [96]) +@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) def test_batching( vllm_runner, example_prompts, model: str, - dtype: str, max_tokens: int, + num_logprobs: int, ) -> None: - # To pass the small model tests, we need full precision. for_loop_outputs = [] - with vllm_runner(model, dtype=dtype) as vllm_model: + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: for prompt in example_prompts: - for_loop_outputs.append( - vllm_model.generate_greedy([prompt], max_tokens)[0]) + single_output, = vllm_model.generate_greedy_logprobs([prompt], + max_tokens, + num_logprobs) + for_loop_outputs.append(single_output) - batched_outputs = vllm_model.generate_greedy(example_prompts, - max_tokens) + batched_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) - check_outputs_equal( + check_logprobs_close( outputs_0_lst=for_loop_outputs, outputs_1_lst=batched_outputs, name_0="for_loop_vllm", @@ -84,74 +93,35 @@ def test_batching( ) -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float16"]) -@pytest.mark.parametrize("max_tokens", [10]) -def test_mamba_prefill_chunking_with_parallel_sampling( - hf_runner, vllm_runner, example_prompts, model: str, dtype: str, - max_tokens: int) -> None: - # Tests prefill chunking in conjunction with n>1, in this case, - # prefill is populated with decoding tokens and we test that it - # doesn't fail This test might fail if cache is not allocated - # correctly for n > 1 decoding steps inside a - # chunked prefill forward pass (where we have both prefills - # and decoding together ) - sampling_params = SamplingParams(n=3, - temperature=1, - seed=0, - max_tokens=max_tokens) - with vllm_runner( - model, - dtype=dtype, - enable_chunked_prefill=True, - max_num_batched_tokens=30, - max_num_seqs=10 # forces prefill chunks with decoding - ) as vllm_model: - vllm_model.generate(example_prompts, sampling_params) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [7]) -def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts, - model: str, dtype: str, - max_tokens: int) -> None: - # numeric error during prefill chunking produces different generation - # compared to w/o prefill chunking for those examples, removed them for now - if "Jamba" in model: - example_prompts.pop(7) - example_prompts.pop(2) - example_prompts.pop(1) - elif "Bamba" in model: - example_prompts.pop(6) - example_prompts.pop(3) - example_prompts.pop(2) - dtype = "half" # use a different dtype for Bamba - elif "Zamba2" in model: - example_prompts.pop(7) - dtype = "half" - - model_kwargs = { - "use_mamba_kernels": False, # mamba kernels are not installed so HF - # don't use them - } - if "Zamba2" in model: - # Zamba2 HF implementation automatically checks if mamba kernels are - # installed - model_kwargs = {} - - with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model: - non_chunked = hf_model.generate_greedy(example_prompts, max_tokens) +@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) +def test_chunked_prefill( + vllm_runner, + example_prompts, + model: str, + max_tokens: int, + num_logprobs: int, + chunked_prefill_token_size: int, +) -> None: + max_num_seqs = chunked_prefill_token_size + max_num_batched_tokens = chunked_prefill_token_size with vllm_runner(model, - dtype=dtype, enable_chunked_prefill=True, - max_num_batched_tokens=5, - max_num_seqs=2) as vllm_model: - chunked = vllm_model.generate_greedy(example_prompts, - max_tokens=max_tokens) + max_num_batched_tokens=max_num_batched_tokens, + max_num_seqs=max_num_seqs) as vllm_model: + chunked = vllm_model.generate_greedy_logprobs(example_prompts, + max_tokens, num_logprobs) - check_outputs_equal( + with vllm_runner(model, + enable_chunked_prefill=False, + max_num_seqs=max_num_seqs) as vllm_model: + non_chunked = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + check_logprobs_close( outputs_0_lst=chunked, outputs_1_lst=non_chunked, name_0="chunked", @@ -159,64 +129,59 @@ def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts, ) -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [15]) -def test_parallel_sampling( +@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) +@pytest.mark.parametrize("max_tokens", [10]) +def test_chunked_prefill_with_parallel_sampling( vllm_runner, example_prompts, model: str, - dtype: str, max_tokens: int, ) -> None: - - with vllm_runner(model, dtype=dtype) as vllm_model: - for_loop_outputs = [] - for _ in range(10): - for_loop_outputs.append( - # using example_prompts index 1 instead of 0 since with 0 the - # logprobs get really close and the test doesn't pass - vllm_model.generate_greedy([example_prompts[1]], max_tokens) - [0]) - sampling_params = SamplingParams(n=10, - temperature=0.001, - seed=0, - max_tokens=max_tokens) - n_lt_1_outputs = vllm_model.generate([example_prompts[1]], - sampling_params) - token_ids, texts = n_lt_1_outputs[0] - n_lt_1_outputs = [(token_id, text) - for token_id, text in zip(token_ids, texts)] - - check_outputs_equal( - outputs_0_lst=n_lt_1_outputs, - outputs_1_lst=for_loop_outputs, - name_0="vllm_n_lt_1_outputs", - name_1="vllm", - ) + """ + Tests chunked prefill in conjunction with n > 1. + + In this case, prefill is populated with decoding tokens and + we test that it doesn't fail. + + This test might fail if cache is not allocated correctly for n > 1 + decoding steps inside a chunked prefill forward pass + (where we have both prefill and decode together) + """ + sampling_params = SamplingParams(n=3, + temperature=1, + seed=0, + max_tokens=max_tokens) + with vllm_runner( + model, + enable_chunked_prefill=True, + # forces prefill chunks with decoding + max_num_batched_tokens=MAX_NUM_SEQS * 3, + max_num_seqs=MAX_NUM_SEQS, + ) as vllm_model: + vllm_model.generate(example_prompts, sampling_params) -@pytest.mark.skip(reason="RE-ENABLE: test is currently failing on main.") -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @pytest.mark.parametrize("max_tokens", [20]) def test_mamba_cache_cg_padding( vllm_runner, example_prompts, model: str, - dtype: str, max_tokens: int, ) -> None: - # This test is for verifying that mamba cache is padded to CG captured - # batch size. If it's not, a torch RuntimeError will be raised because - # tensor dimensions aren't compatible - vllm_config = EngineArgs(model=model).create_engine_config() + """ + This test is for verifying that mamba cache is padded to CG captured + batch size. If it's not, a torch RuntimeError will be raised because + tensor dimensions aren't compatible. + """ + vllm_config = EngineArgs(model=model, + trust_remote_code=True).create_engine_config() while len(example_prompts) == vllm_config.pad_for_cudagraph( len(example_prompts)): example_prompts.append(example_prompts[0]) try: - with vllm_runner(model, dtype=dtype) as vllm_model: + with vllm_runner(model) as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) except RuntimeError: pytest.fail( @@ -225,28 +190,24 @@ def test_mamba_cache_cg_padding( "Could be related to mamba cache not padded correctly") -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @pytest.mark.parametrize("max_tokens", [20]) def test_models_preemption_recompute( - hf_runner, vllm_runner, example_prompts, model: str, - dtype: str, max_tokens: int, ) -> None: - # Tests that outputs are identical with and w/o preemtions (recompute) - assert dtype == "float" - - with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_model.model.llm_engine.scheduler[ - 0].ENABLE_ARTIFICIAL_PREEMPT = True + """ + Tests that outputs are identical with and w/o preemptions (recompute). + """ + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + scheduler = vllm_model.model.llm_engine.scheduler[0] + scheduler.ENABLE_ARTIFICIAL_PREEMPT = True preempt_vllm_outputs = vllm_model.generate_greedy( example_prompts, max_tokens) - vllm_model.model.llm_engine.scheduler[ - 0].ENABLE_ARTIFICIAL_PREEMPT = False + scheduler.ENABLE_ARTIFICIAL_PREEMPT = False vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) check_outputs_equal( @@ -257,40 +218,43 @@ def test_models_preemption_recompute( ) -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks( vllm_runner, - model: str, - dtype: str, example_prompts, + model: str, ) -> None: - # This test is for verifying that the hybrid inner state management doesn't - # collapse in case where the number of incoming requests and - # finished_requests_ids is larger than the maximum mamba block capacity. - # This could generally happen due to the fact that hybrid does support - # statelessness mechanism where it can cleanup new incoming requests in - # a single step. + """ + This test is for verifying that the hybrid inner state management doesn't + collapse in case where the number of incoming requests and + finished_requests_ids is larger than the maximum mamba block capacity. + + This could generally happen due to the fact that hybrid does support + statelessness mechanism where it can cleanup new incoming requests in + a single step. + """ try: - with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model: + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: vllm_model.generate_greedy([example_prompts[0]] * 100, 10) except ValueError: pytest.fail("Hybrid inner state wasn't cleaned up properly between" "steps finished requests registered unnecessarily ") -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) def test_state_cleanup( vllm_runner, - model: str, - dtype: str, example_prompts, + model: str, ) -> None: - # This test is for verifying that the Hybrid state is cleaned up between - # steps, If its not cleaned, an error would be expected. + """ + This test is for verifying that the Hybrid state is cleaned up between + steps. + + If its not cleaned, an error would be expected. + """ try: - with vllm_runner(model, dtype=dtype) as vllm_model: + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: for _ in range(10): vllm_model.generate_greedy([example_prompts[0]] * 100, 1) except ValueError: @@ -298,28 +262,14 @@ def test_state_cleanup( "could be related to finished_requests_ids") -@pytest.mark.skip(reason="RE-ENABLE: test is currently failing on main.") -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -def test_multistep( +@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) +@pytest.mark.parametrize("max_tokens", [64]) +def test_multistep_correctness( vllm_runner, - model: str, - dtype: str, example_prompts, + model: str, + max_tokens: int, ) -> None: - # This test is verifying that multistep works correctly - #on mamba-like models - with vllm_runner(model, num_scheduler_steps=8, - max_num_seqs=2) as vllm_model: - vllm_model.generate_greedy([example_prompts[0]] * 10, 1) - - -@pytest.mark.skip(reason="RE-ENABLE: test is currently failing on main.") -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [64]) -def test_multistep_correctness(vllm_runner, model: str, dtype: str, - max_tokens: int, example_prompts) -> None: with vllm_runner(model, num_scheduler_steps=8, max_num_seqs=2) as vllm_model: vllm_outputs_multistep = vllm_model.generate_greedy( @@ -339,18 +289,21 @@ def test_multistep_correctness(vllm_runner, model: str, dtype: str, @multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @pytest.mark.parametrize("max_tokens", [64]) def test_hybrid_distributed_produces_identical_generation( - vllm_runner, model: str, dtype: str, max_tokens: int, - example_prompts) -> None: - - with vllm_runner(model, dtype=dtype, tensor_parallel_size=2) as vllm_model: + vllm_runner, + example_prompts, + model: str, + max_tokens: int, +) -> None: + with vllm_runner(model, tensor_parallel_size=2, + max_num_seqs=2) as vllm_model: vllm_outputs_tp_2 = vllm_model.generate_greedy(example_prompts, max_tokens) - with vllm_runner(model, dtype=dtype, tensor_parallel_size=1) as vllm_model: + with vllm_runner(model, tensor_parallel_size=1, + max_num_seqs=2) as vllm_model: vllm_outputs_tp_1 = vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/tests/models/decoder_only/language/test_mistral.py b/tests/models/decoder_only/language/test_mistral.py index ccf6cc1edf912ab66aa0dea3bb7323205c822513..a692115848b0f77b2620806b6d0d8c9216e9b1b1 100644 --- a/tests/models/decoder_only/language/test_mistral.py +++ b/tests/models/decoder_only/language/test_mistral.py @@ -11,8 +11,8 @@ import jsonschema.exceptions import pytest import os -from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa - MistralToolParser) +from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( + MistralToolCall, MistralToolParser) from vllm.sampling_params import GuidedDecodingParams, SamplingParams from ...utils import check_logprobs_close @@ -196,7 +196,6 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, ) -@pytest.mark.skip("RE-ENABLE: test is currently failing on main.") @pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @@ -248,10 +247,8 @@ def test_mistral_symbolic_languages(vllm_runner, model: str, assert "�" not in outputs[0].outputs[0].text.strip() -@pytest.mark.skip("RE-ENABLE: test is currently failing on main.") +@pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("model", - MISTRAL_FORMAT_MODELS) # v1 can't do func calling def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None: with vllm_runner(model, dtype=dtype, @@ -272,7 +269,8 @@ def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None: parsed_message = tool_parser.extract_tool_calls(model_output, None) assert parsed_message.tools_called - assert parsed_message.tool_calls[0].id == "0UAqFzWsD" + + assert MistralToolCall.is_valid_id(parsed_message.tool_calls[0].id) assert parsed_message.tool_calls[ 0].function.name == "get_current_weather" assert parsed_message.tool_calls[ @@ -283,28 +281,38 @@ def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None: @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("guided_backend", ["outlines", "lm-format-enforcer", "xgrammar"]) -def test_mistral_guided_decoding(vllm_runner, model: str, - guided_backend: str) -> None: - with vllm_runner(model, dtype='bfloat16', - tokenizer_mode="mistral") as vllm_model: +def test_mistral_guided_decoding( + monkeypatch: pytest.MonkeyPatch, + vllm_runner, + model: str, + guided_backend: str, +) -> None: + with monkeypatch.context() as m: + # Guided JSON not supported in xgrammar + V1 yet + m.setenv("VLLM_USE_V1", "0") - guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA, - backend=guided_backend) - params = SamplingParams(max_tokens=512, - temperature=0.7, - guided_decoding=guided_decoding) - - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": - "user", - "content": - f"Give an example JSON for an employee profile that " - f"fits this schema: {SAMPLE_JSON_SCHEMA}" - }] - outputs = vllm_model.model.chat(messages, sampling_params=params) + with vllm_runner( + model, + dtype='bfloat16', + tokenizer_mode="mistral", + guided_decoding_backend=guided_backend, + ) as vllm_model: + guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA) + params = SamplingParams(max_tokens=512, + temperature=0.7, + guided_decoding=guided_decoding) + + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": + "user", + "content": + f"Give an example JSON for an employee profile that " + f"fits this schema: {SAMPLE_JSON_SCHEMA}" + }] + outputs = vllm_model.model.chat(messages, sampling_params=params) generated_text = outputs[0].outputs[0].text json_response = json.loads(generated_text) diff --git a/tests/models/decoder_only/language/test_models.py b/tests/models/decoder_only/language/test_models.py index d6f2f2707bf2249d6ab7024a426e5b98579239f1..9a786f113de79ff16f4d85707980a3a8cd3acb08 100644 --- a/tests/models/decoder_only/language/test_models.py +++ b/tests/models/decoder_only/language/test_models.py @@ -10,6 +10,8 @@ import torch from vllm.platforms import current_platform +from ....utils import large_gpu_mark +from ...registry import HF_EXAMPLE_MODELS from ...utils import check_logprobs_close from ....utils import models_path_prefix @@ -27,7 +29,7 @@ REQUIRES_V0 = ["microsoft/phi-2", "stabilityai/stablelm-3b-4e1t"] AITER_MODEL_LIST = [ "meta-llama/Llama-3.2-1B-Instruct", "openbmb/MiniCPM3-4B", - "Qwen/Qwen-7B", + "Qwen/Qwen-7B-Chat", "Qwen/Qwen2.5-0.5B-Instruct", "ehristoforu/Falcon3-MoE-2x7B-Insruct", ] @@ -62,7 +64,8 @@ AITER_MODEL_LIST = [ pytest.param( os.path.join(models_path_prefix, "openbmb/MiniCPM3-4B"), # fused_moe not supported on CPU - marks=[pytest.mark.core_model], + marks=[pytest.mark.core_model, + large_gpu_mark(min_gb=32)], ), pytest.param( os.path.join(models_path_prefix, "facebook/opt-125m"), # opt @@ -73,7 +76,7 @@ AITER_MODEL_LIST = [ marks=[pytest.mark.core_model], ), pytest.param( - os.path.join(models_path_prefix, "Qwen/Qwen-7B"), # qwen (text-only) + os.path.join(models_path_prefix, "Qwen/Qwen-7B-Chat"), # qwen (text-only) ), pytest.param( os.path.join(models_path_prefix, "Qwen/Qwen2.5-0.5B-Instruct"), # qwen2 @@ -83,17 +86,21 @@ AITER_MODEL_LIST = [ pytest.param(os.path.join(models_path_prefix, "bigcode/starcoder2-3b")), # starcoder2 pytest.param( os.path.join(models_path_prefix, "ehristoforu/Falcon3-MoE-2x7B-Insruct"), # mixtral - marks=[pytest.mark.cpu_model], + marks=[pytest.mark.cpu_model, + large_gpu_mark(min_gb=48)], ) ]) -@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize( "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) def test_models(hf_runner, vllm_runner, example_prompts, model: str, - dtype: str, max_tokens: int, num_logprobs: int, - use_rocm_aiter: bool, monkeypatch) -> None: + max_tokens: int, num_logprobs: int, use_rocm_aiter: bool, + monkeypatch) -> None: + + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") if model in REQUIRES_V0: monkeypatch.setenv("VLLM_USE_V1", "0") @@ -107,15 +114,17 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, # in parts of the operators pytest.skip(f"Skipping '{model}' model test with AITER kernel.") - with hf_runner(model, dtype=dtype) as hf_model: - if model.startswith("THUDM/chatglm3"): - hf_model.model.get_output_embeddings = lambda: \ - hf_model.model.transformer.output_layer - + with hf_runner(model) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) - with vllm_runner(model, dtype=dtype) as vllm_model: + with vllm_runner( + model, + tokenizer_name=model_info.tokenizer or model, + tokenizer_mode=model_info.tokenizer_mode, + trust_remote_code=model_info.trust_remote_code, + max_num_seqs=2, + ) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 91f8a8ec1efd6efdaa9d875bedc26987b657cb76..8f6a62307506489fb2b774620de1cd06e8d18753 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -142,6 +142,23 @@ VLM_TEST_SETTINGS = { image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), + "qwen2_5_omni": VLMTestInfo( + models=["Qwen/Qwen2.5-Omni-7B"], + test_type=( + VLMTestType.IMAGE, + VLMTestType.MULTI_IMAGE, + VLMTestType.VIDEO + ), + prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "<|vision_bos|><|IMAGE|><|vision_eos|>", # noqa: E501 + video_idx_to_prompt=lambda idx: "<|vision_bos|><|VIDEO|><|vision_eos|>", # noqa: E501 + max_model_len=4096, + max_num_seqs=2, + auto_cls=AutoModelForVision2Seq, + vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output, + image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], + marks=[pytest.mark.core_model, pytest.mark.cpu_model], + ), #### Extended model tests "aria": VLMTestInfo( models=[os.path.join(models_path_prefix, "rhymes-ai/Aria")], @@ -321,6 +338,18 @@ VLM_TEST_SETTINGS = { use_tokenizer_eos=True, patch_hf_runner=model_utils.internvl_patch_hf_runner, ), + "kimi_vl": VLMTestInfo( + models=["moonshotai/Kimi-VL-A3B-Instruct"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|im_user|>user<|im_middle|>{img_prompt}<|im_end|><|im_assistant|>assistant<|im_middle|>", # noqa: E501 + img_idx_to_prompt=lambda _: "<|media_start|>image<|media_content|><|media_pad|><|media_end|>", # noqa: E501 + max_model_len=8192, + max_num_seqs=2, + dtype="bfloat16", + tensor_parallel_size=1, + vllm_output_post_proc=model_utils.kimiv_vl_vllm_to_hf_output, + marks=[large_gpu_mark(min_gb=48)], + ), "llama4": VLMTestInfo( models=["meta-llama/Llama-4-Scout-17B-16E-Instruct"], prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{img_prompt}<|eot|><|header_start|>assistant<|header_end|>\n\n", # noqa: E501 diff --git a/tests/models/decoder_only/vision_language/test_phi4mm.py b/tests/models/decoder_only/vision_language/test_phi4mm.py index 3cd830015076d38f5e4143065b2af4020a9c28ee..11460a1a8d2b53b51b75db4c0475b8e4588a78c5 100644 --- a/tests/models/decoder_only/vision_language/test_phi4mm.py +++ b/tests/models/decoder_only/vision_language/test_phi4mm.py @@ -181,7 +181,7 @@ def run_test( ], ) @pytest.mark.parametrize("dtype", [target_dtype]) -@pytest.mark.parametrize("max_model_len", [4096]) +@pytest.mark.parametrize("max_model_len", [12800]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, @@ -225,7 +225,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, ], ) @pytest.mark.parametrize("dtype", [target_dtype]) -@pytest.mark.parametrize("max_model_len", [10000]) +@pytest.mark.parametrize("max_model_len", [25600]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, @@ -258,7 +258,7 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, @pytest.mark.parametrize("model", models) @pytest.mark.parametrize("dtype", [target_dtype]) -@pytest.mark.parametrize("max_model_len", [10000]) +@pytest.mark.parametrize("max_model_len", [12800]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) def test_vision_speech_models(hf_runner, vllm_runner, model, dtype: str, diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py index 3520345c9679c5d7139d80069e8a6fbf1841f965..49305332726e4529952ff2abf700cc154b37bc71 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py @@ -68,6 +68,17 @@ def qwen2_vllm_to_hf_output( return output_ids, hf_output_str, out_logprobs +def kimiv_vl_vllm_to_hf_output( + vllm_output: RunnerOutput, + model: str) -> tuple[list[int], str, Optional[SampleLogprobs]]: + """Sanitize vllm output [kimi_vl models] to be comparable with hf output.""" + output_ids, output_str, out_logprobs = vllm_output + + hf_output_str = output_str + "<|im_end|>[EOS]" + + return output_ids, hf_output_str, out_logprobs + + def llava_image_vllm_to_hf_output(vllm_output: RunnerOutput, model: str) -> RunnerOutput: config = AutoConfig.from_pretrained(model) diff --git a/tests/models/embedding/language/test_gritlm.py b/tests/models/embedding/language/test_gritlm.py index 0bf2e418af05b6288d56bca3ce263a9368c27498..3cbb50e08142be9bff96eba1e7576227c7f8ca34 100644 --- a/tests/models/embedding/language/test_gritlm.py +++ b/tests/models/embedding/language/test_gritlm.py @@ -59,24 +59,25 @@ def test_find_array(monkeypatch: pytest.MonkeyPatch): def server_embedding(): # GritLM embedding implementation is only supported by XFormers backend. args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)] - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: - yield remote_server + with pytest.MonkeyPatch.context() as m: + m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS") + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server @pytest.fixture(scope="module") def server_generate(): args = ["--task", "generate", "--max_model_len", str(MAX_MODEL_LEN)] - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: - yield remote_server + with pytest.MonkeyPatch.context() as m: + m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS") + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server @pytest_asyncio.fixture -async def client_embedding(monkeypatch: pytest.MonkeyPatch, - server_embedding: RemoteOpenAIServer): - with monkeypatch.context() as m: - m.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS") - async with server_embedding.get_async_client() as async_client: - yield async_client +async def client_embedding(server_embedding: RemoteOpenAIServer): + async with server_embedding.get_async_client() as async_client: + yield async_client @pytest_asyncio.fixture diff --git a/tests/models/embedding/language/test_jina.py b/tests/models/embedding/language/test_jina.py index 881d0a75b15843bf0983640b5fb2279814fbb311..1e234368f3b317a5569bfabef841a4cded4793c1 100644 --- a/tests/models/embedding/language/test_jina.py +++ b/tests/models/embedding/language/test_jina.py @@ -153,14 +153,24 @@ def test_matryoshka( with vllm_runner(model, task="embed", dtype=dtype, max_model_len=None) as vllm_model: - vllm_outputs = vllm_model.encode( - example_prompts, - pooling_params=PoolingParams(dimensions=dimensions)) - - check_embeddings_close( - embeddings_0_lst=hf_outputs, - embeddings_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - tol=1e-2, - ) + matryoshka_dimensions = ( + vllm_model.model.llm_engine.model_config.matryoshka_dimensions) + assert matryoshka_dimensions is not None + + if dimensions not in matryoshka_dimensions: + with pytest.raises(ValueError): + vllm_model.encode( + example_prompts, + pooling_params=PoolingParams(dimensions=dimensions)) + else: + vllm_outputs = vllm_model.encode( + example_prompts, + pooling_params=PoolingParams(dimensions=dimensions)) + + check_embeddings_close( + embeddings_0_lst=hf_outputs, + embeddings_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + tol=1e-2, + ) diff --git a/tests/models/embedding/language/test_snowflake_arctic_embed.py b/tests/models/embedding/language/test_snowflake_arctic_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..2b884fceec80ce332d2ae4a451aac62d97c552cf --- /dev/null +++ b/tests/models/embedding/language/test_snowflake_arctic_embed.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Compare the embedding outputs of HF and vLLM models. + +Run `pytest tests/models/embedding/language/test_snowflake_arctic_embed.py`. +""" +import pytest + +from tests.models.embedding.utils import EmbedModelInfo + +from ..utils import check_embeddings_close + +EMBEDDING_PROMPTS = [ + 'what is snowflake?', 'Where can I get the best tacos?', 'The Data Cloud!', + 'Mexico City of Course!' +] + +MODELS = [ + EmbedModelInfo("Snowflake/snowflake-arctic-embed-xs", + is_matryoshka=False, + architecture="BertModel", + enable_test=True), + EmbedModelInfo("Snowflake/snowflake-arctic-embed-s", + is_matryoshka=False, + architecture="BertModel", + enable_test=False), + EmbedModelInfo("Snowflake/snowflake-arctic-embed-m", + is_matryoshka=False, + architecture="BertModel", + enable_test=False), + EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long", + is_matryoshka=False, + architecture="NomicBertModel", + enable_test=True), + EmbedModelInfo("Snowflake/snowflake-arctic-embed-l", + is_matryoshka=False, + architecture="BertModel", + enable_test=False), + EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5", + is_matryoshka=True, + architecture="BertModel", + enable_test=True), + EmbedModelInfo("Snowflake/snowflake-arctic-embed-l-v2.0", + is_matryoshka=True, + architecture="XLMRobertaModel", + enable_test=True), + EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v2.0", + is_matryoshka=True, + architecture="GteModel", + enable_test=True), +] + + +@pytest.mark.parametrize("model_info", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model_info: EmbedModelInfo, + dtype: str, + monkeypatch, +) -> None: + if not model_info.enable_test: + # A model family has many models with the same architecture, + # and we don't need to test each one. + pytest.skip("Skipping test.") + + example_prompts = example_prompts + EMBEDDING_PROMPTS + + vllm_extra_kwargs = { + "hf_overrides": { + "is_matryoshka": model_info.is_matryoshka + } + } + + with hf_runner(model_info.name, dtype=dtype, + is_sentence_transformer=True) as hf_model: + hf_outputs = hf_model.encode(example_prompts) + + with vllm_runner(model_info.name, + task="embed", + dtype=dtype, + max_model_len=None, + **vllm_extra_kwargs) as vllm_model: + + assert (vllm_model.model.llm_engine.model_config.is_matryoshka == + model_info.is_matryoshka) + + if model_info.architecture: + assert (model_info.architecture + in vllm_model.model.llm_engine.model_config.architectures) + + vllm_outputs = vllm_model.encode(example_prompts) + + check_embeddings_close( + embeddings_0_lst=hf_outputs, + embeddings_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + tol=1e-2, + ) diff --git a/tests/models/embedding/utils.py b/tests/models/embedding/utils.py index 5aeeb51785402a3df05e59a0c62c7d039cae562c..6d4df2c265c4d7f168f1c9467a698ab090a653fc 100644 --- a/tests/models/embedding/utils.py +++ b/tests/models/embedding/utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from collections.abc import Sequence +from typing import NamedTuple, Optional import torch import torch.nn.functional as F @@ -37,3 +38,29 @@ def matryoshka_fy(tensor, dimensions): tensor = tensor[..., :dimensions] tensor = F.normalize(tensor, p=2, dim=1) return tensor + + +class EmbedModelInfo(NamedTuple): + name: str + is_matryoshka: bool + matryoshka_dimensions: Optional[list[int]] = None + architecture: str = "" + enable_test: bool = True + + +def correctness_test(hf_model, + inputs, + vllm_outputs: Sequence[list[float]], + dimensions: Optional[int] = None): + + hf_outputs = hf_model.encode(inputs) + if dimensions: + hf_outputs = matryoshka_fy(hf_outputs, dimensions) + + check_embeddings_close( + embeddings_0_lst=hf_outputs, + embeddings_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + tol=1e-2, + ) diff --git a/tests/models/encoder_decoder/vision_language/test_florence2.py b/tests/models/encoder_decoder/vision_language/test_florence2.py index bd836b903dde23eb23b8f411ac9a668d658ece9c..e73ad10ce91042293d1d9fbc6944d48e2ca49b90 100644 --- a/tests/models/encoder_decoder/vision_language/test_florence2.py +++ b/tests/models/encoder_decoder/vision_language/test_florence2.py @@ -15,12 +15,12 @@ from ...utils import check_logprobs_close from ....utils import models_path_prefix MODELS = [os.path.join(models_path_prefix, "microsoft/Florence-2-base")] -# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer -# Therefore, we borrow the BartTokenizer from the original Bart model -TOKENIZER = os.path.join(models_path_prefix, "facebook/bart-base") +# Florence-2 model repo's tokenizer config is missing some special tokens. +# Therefore, we use a converted tokenizer from a forked repo +TOKENIZER = os.path.join(models_path_prefix, "Isotr0py/Florence-2-tokenizer") HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": - "", # special task token + "", # special task token which will output special tokens "cherry_blossom": "Describe in detail what is shown in the image.", }) @@ -47,7 +47,6 @@ def hf_to_vllm_output(hf_output: tuple[list[int], str, output_ids, output_str, out_logprobs = hf_output output_str = output_str.replace("", "").replace("", "") - output_ids = [ids for ids in output_ids if ids not in [0, 2]] return output_ids, output_str, out_logprobs @@ -73,8 +72,11 @@ def run_test( enforce_eager=True) as vllm_model: vllm_outputs_per_case = [ vllm_model.generate_encoder_decoder_greedy_logprobs( - prompts, max_tokens, num_logprobs=num_logprobs) - for prompts in inputs + prompts, + max_tokens, + num_logprobs=num_logprobs, + skip_special_tokens=False, + ) for prompts in inputs ] hf_inputs = [get_hf_images_prompts(prompts) for prompts in inputs] @@ -95,6 +97,7 @@ def run_test( outputs_1_lst=vllm_outputs, name_0="hf", name_1="vllm", + num_outputs_0_skip_tokens=1, ) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 8ec7d0887bd4ca56bf6e03746882e1a331f817a7..5a4215a70d2492bf6940619355bef083c56f50b7 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -254,10 +254,12 @@ def _test_processing_correctness_mistral( "adept/fuyu-8b", "google/gemma-3-4b-it", "THUDM/glm-4v-9b", + "ibm-granite/granite-speech-3.3-8b", "h2oai/h2ovl-mississippi-800m", "OpenGVLab/InternVL2-1B", "HuggingFaceM4/Idefics3-8B-Llama3", "HuggingFaceTB/SmolVLM2-2.2B-Instruct", + "moonshotai/Kimi-VL-A3B-Instruct", "meta-llama/Llama-4-Scout-17B-16E-Instruct", "llava-hf/llava-1.5-7b-hf", "llava-hf/llava-v1.6-mistral-7b-hf", @@ -273,12 +275,14 @@ def _test_processing_correctness_mistral( "nvidia/NVLM-D-72B", "google/paligemma-3b-mix-224", "google/paligemma2-3b-ft-docci-448", + "microsoft/Phi-4-multimodal-instruct", "mistralai/Pixtral-12B-2409", "mistral-community/pixtral-12b", "Qwen/Qwen-VL-Chat", "Qwen/Qwen2-VL-2B-Instruct", "Qwen/Qwen2.5-VL-3B-Instruct", "Qwen/Qwen2-Audio-7B-Instruct", + "Qwen/Qwen2.5-Omni-7B", "Skywork/Skywork-R1V-38B", "fixie-ai/ultravox-v0_5-llama-3_2-1b", "openai/whisper-large-v3", diff --git a/tests/models/multimodal/processing/test_phi4mm.py b/tests/models/multimodal/processing/test_phi4mm.py new file mode 100644 index 0000000000000000000000000000000000000000..797986adba4afc7d07b12b29ce61a970ea250220 --- /dev/null +++ b/tests/models/multimodal/processing/test_phi4mm.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for phi4mm's multimodal preprocessing kwargs.""" +import pytest + +from vllm.multimodal import MULTIMODAL_REGISTRY + +from ....conftest import _ImageAssets +from ...utils import build_model_context + + +@pytest.mark.parametrize("model_id", ["microsoft/Phi-4-multimodal-instruct"]) +# yapf: disable +@pytest.mark.parametrize( + ("mm_processor_kwargs", "expected_toks_per_img"), + [ + ({"dynamic_hd": 4}, 1329), + ({"dynamic_hd": 16}, 4433), + # the default num_crops of phi-4-multimodal is 36 + ({}, 9585), + ]) +# yapf: enable +@pytest.mark.parametrize("num_imgs", [1, 2]) +@pytest.mark.parametrize("kwargs_on_init", [True, False]) +def test_processor_override( + image_assets: _ImageAssets, + model_id: str, + mm_processor_kwargs: dict[str, int], + expected_toks_per_img: int, + num_imgs: int, + kwargs_on_init: bool, +): + """Ensure Phi4MMMultiModalProcessor handles dynamic_hd properly.""" + # Avoid initializing CUDA early + from vllm.model_executor.models.phi4mm import _IMAGE_PLACEHOLDER_TOKEN_ID + + ctx = build_model_context( + model_id, + mm_processor_kwargs=mm_processor_kwargs if kwargs_on_init else None, + limit_mm_per_prompt={"image": num_imgs}, + ) + processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) + hf_processor_mm_kwargs = {} if kwargs_on_init else mm_processor_kwargs + + # Build the image str / prompt based on the number of images we pass + img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)]) + prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n" + + image_size = ctx.get_hf_config( + ).embd_layer["image_embd_layer"]["crop_size"] + dummy_image_size = (image_size * 7, image_size * 7) + dummy_image = image_assets[0].pil_image.resize(dummy_image_size) + mm_data = {"image": [dummy_image] * num_imgs} + + processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs) + + # Ensure we have the right number of placeholders per num_crops size + img_tok_count = processed_inputs["prompt_token_ids"].count( + _IMAGE_PLACEHOLDER_TOKEN_ID) + assert img_tok_count == expected_toks_per_img * num_imgs diff --git a/tests/models/registry.py b/tests/models/registry.py index bab20a62076ede9f33e7ca0a3b8dcaada52105d3..20581243ef268f6a52b36e77663ddf92ca785996 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -121,9 +121,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat", trust_remote_code=True), "BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B"), - "BloomForCausalLM": _HfExamplesInfo("bigscience/bloomz-1b1"), + "BloomForCausalLM": _HfExamplesInfo("bigscience/bloom-560m", + {"1b": "bigscience/bloomz-1b1"}), "ChatGLMModel": _HfExamplesInfo("THUDM/chatglm3-6b", - trust_remote_code=True), + trust_remote_code=True, + max_transformers_version="4.48"), "ChatGLMForConditionalGeneration": _HfExamplesInfo("thu-coai/ShieldLM-6B-chatglm3", # noqa: E501 trust_remote_code=True), "CohereForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r-v01", @@ -141,24 +143,26 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501 "Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501 "FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"), - "GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"), + "GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"), "Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"), - "Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it", - min_transformers_version="4.50"), + "Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"), "GlmForCausalLM": _HfExamplesInfo("THUDM/glm-4-9b-chat-hf"), "Glm4ForCausalLM": _HfExamplesInfo( "THUDM/GLM-4-32B-0414", is_available_online=False, min_transformers_version="4.52.dev0" ), - "GPT2LMHeadModel": _HfExamplesInfo("gpt2"), - "GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder"), - "GPTJForCausalLM": _HfExamplesInfo("EleutherAI/gpt-j-6b"), - "GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-160m"), + "GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2", + {"alias": "gpt2"}), + "GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder", + {"tiny": "bigcode/tiny_starcoder_py"}), # noqa: E501 + "GPTJForCausalLM": _HfExamplesInfo("Milos/slovak-gpt-j-405M", + {"6b": "EleutherAI/gpt-j-6b"}), + "GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-70m", + {"1b": "EleutherAI/pythia-1.4b"}), "GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"), "GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"), - "GraniteMoeSharedForCausalLM": _HfExamplesInfo("ibm-research/moe-7b-1b-active-shared-experts", # noqa: E501 - min_transformers_version="4.49"), # noqa: E501 + "GraniteMoeSharedForCausalLM": _HfExamplesInfo("ibm-research/moe-7b-1b-active-shared-experts"), # noqa: E501 "Grok1ModelForCausalLM": _HfExamplesInfo("hpcai-tech/grok-1", trust_remote_code=True), "InternLMForCausalLM": _HfExamplesInfo("internlm/internlm-chat-7b", @@ -186,7 +190,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "MiniMaxText01ForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01", trust_remote_code=True), "MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"), - "MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1"), # noqa: E501 + "MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1", # noqa: E501 + {"falcon3": "ehristoforu/Falcon3-MoE-2x7B-Insruct"}), # noqa: E501 "QuantMixtralForCausalLM": _HfExamplesInfo("mistral-community/Mixtral-8x22B-v0.1-AWQ"), # noqa: E501 "MptForCausalLM": _HfExamplesInfo("mpt", is_available_online=False), "MPTForCausalLM": _HfExamplesInfo("mosaicml/mpt-7b"), @@ -194,7 +199,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "OlmoForCausalLM": _HfExamplesInfo("allenai/OLMo-1B-hf"), "Olmo2ForCausalLM": _HfExamplesInfo("shanearora/OLMo-7B-1124-hf"), "OlmoeForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924-Instruct"), - "OPTForCausalLM": _HfExamplesInfo("facebook/opt-iml-max-1.3b"), + "OPTForCausalLM": _HfExamplesInfo("facebook/opt-125m", + {"1b": "facebook/opt-iml-max-1.3b"}), "OrionForCausalLM": _HfExamplesInfo("OrionStarAI/Orion-14B-Chat", trust_remote_code=True), "PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"), @@ -204,10 +210,12 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { trust_remote_code=True), "PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct", trust_remote_code=True), + "Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b", + trust_remote_code=True), "QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat", trust_remote_code=True), - "Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-7B-Instruct", - extras={"2.5": "Qwen/Qwen2.5-7B-Instruct"}), # noqa: E501 + "Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-0.5B-Instruct", + extras={"2.5": "Qwen/Qwen2.5-0.5B-Instruct"}), # noqa: E501 "Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"), "Qwen3ForCausalLM": _HfExamplesInfo( "Qwen/Qwen3-8B", @@ -233,8 +241,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat", is_available_online=False, trust_remote_code=True), - "Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct", - min_transformers_version="4.49"), + "Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct"), # [Encoder-decoder] "BartModel": _HfExamplesInfo("facebook/bart-base"), "BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"), @@ -245,11 +252,15 @@ _EMBEDDING_EXAMPLE_MODELS = { "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"), "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), "GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"), + "GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0", + trust_remote_code=True), "InternLM2ForRewardModel": _HfExamplesInfo("internlm/internlm2-1_8b-reward", trust_remote_code=True), "JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501 "LlamaModel": _HfExamplesInfo("llama", is_available_online=False), "MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"), + "NomicBertModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-long", # noqa: E501 + trust_remote_code=True), "Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"), "Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"), "Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B"), @@ -273,6 +284,7 @@ _CROSS_ENCODER_EXAMPLE_MODELS = { "BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501 "RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base"), # noqa: E501 "XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3"), # noqa: E501 + "ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base"), # noqa: E501 } _MULTIMODAL_EXAMPLE_MODELS = { @@ -286,10 +298,11 @@ _MULTIMODAL_EXAMPLE_MODELS = { extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501 max_transformers_version="4.48", # noqa: E501 transformers_version_reason="HF model is not compatible.", # noqa: E501 - hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501 + hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501 "FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"), - "Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it", - min_transformers_version="4.50"), + "Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"), + "GraniteSpeechForConditionalGeneration": _HfExamplesInfo("ibm-granite/granite-speech-3.3-8b", # noqa: E501 + min_transformers_version="4.52.0"), # noqa: E501 "GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b", trust_remote_code=True, hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501 @@ -302,6 +315,9 @@ _MULTIMODAL_EXAMPLE_MODELS = { trust_remote_code=True), "Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501 {"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}), # noqa: E501 + "KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501 + extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501 + trust_remote_code=True), "Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501 min_transformers_version="4.51"), "LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf", @@ -322,7 +338,6 @@ _MULTIMODAL_EXAMPLE_MODELS = { extras={"2.6": "openbmb/MiniCPM-V-2_6"}, # noqa: E501 trust_remote_code=True), "Mistral3ForConditionalGeneration": _HfExamplesInfo("mistralai/Mistral-Small-3.1-24B-Instruct-2503", # noqa: E501 - min_transformers_version="4.50", # noqa: E501 extras={"fp8": "nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic"}), # noqa: E501 "MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924", max_transformers_version="4.48", @@ -348,8 +363,9 @@ _MULTIMODAL_EXAMPLE_MODELS = { hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]}), # noqa: E501 "Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501 "Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501 - "Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501 - min_transformers_version="4.49"), # noqa: E501 + "Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct"), # noqa: E501 + "Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B", # noqa: E501 + min_transformers_version="4.52"), # noqa: E501 "SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"), "SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501 "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501 @@ -358,7 +374,7 @@ _MULTIMODAL_EXAMPLE_MODELS = { # Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer # Therefore, we borrow the BartTokenizer from the original Bart model "Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501 - tokenizer="facebook/bart-base", + tokenizer="Isotr0py/Florence-2-tokenizer", trust_remote_code=True), # noqa: E501 "MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501 "Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct"), # noqa: E501 @@ -379,6 +395,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { trust_remote_code=True, speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B", tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501 + "Eagle3LlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", # noqa: E501 + trust_remote_code=True, + speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", + tokenizer="meta-llama/Llama-3.1-8B-Instruct"), } _TRANSFORMERS_MODELS = { diff --git a/tests/models/test_bitblas.py b/tests/models/test_bitblas.py new file mode 100644 index 0000000000000000000000000000000000000000..ae4a52214ad0c7e7f52b57da55e7479256d2ee05 --- /dev/null +++ b/tests/models/test_bitblas.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Compare the outputs of a GPTQ model to a bitblas model. + +Note: GPTQ and bitblas do not have bitwise correctness. +As a result, in this test, we just confirm that the top selected tokens of the +bitblas/GPTQ models are in the top 3 selections of each other. + +Note: bitblas internally uses locks to synchronize the threads. This can +result in very slight nondeterminism for bitblas. As a result, we re-run the +test up to 3 times to see if we pass. + +Run `pytest tests/models/test_bitblas.py`. +""" +from dataclasses import dataclass + +import pytest + +from .utils import check_logprobs_close + + +@dataclass +class ModelPair: + model_bitblas: str + model_gptq: str + + +model_pairs = [ + ModelPair(model_bitblas="hxbgsyxh/opt-125m-4bit-128g-bitblas", + model_gptq="hxbgsyxh/opt-125m-4bit-128g"), +] + + +@pytest.mark.flaky(reruns=2) +@pytest.mark.skipif(True, reason="BitBLAS takes too much time for tuning.") +@pytest.mark.parametrize("model_pair", model_pairs) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models( + vllm_runner, + example_prompts, + model_pair: ModelPair, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + with vllm_runner(model_pair.model_bitblas, + dtype=dtype, + quantization="bitblas") as bitblas_model: + bitblas_outputs = bitblas_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + with vllm_runner(model_pair.model_gptq, dtype=dtype, + quantization="gptq") as gptq_model: + gptq_outputs = gptq_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + check_logprobs_close( + outputs_0_lst=gptq_outputs, + outputs_1_lst=bitblas_outputs, + name_0="gptq", + name_1="bitblas", + ) diff --git a/tests/models/test_gptq_bitblas.py b/tests/models/test_gptq_bitblas.py new file mode 100644 index 0000000000000000000000000000000000000000..d28442120ea6931737f6752ef84a782eabb7c166 --- /dev/null +++ b/tests/models/test_gptq_bitblas.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Compare the outputs of a GPTQ model to a bitblas model. + +Note: GPTQ and bitblas do not have bitwise correctness. +As a result, in this test, we just confirm that the top selected tokens of the +bitblas/GPTQ models are in the top 3 selections of each other. + +Note: bitblas internally uses locks to synchronize the threads. This can +result in very slight nondeterminism for bitblas. As a result, we re-run the +test up to 3 times to see if we pass. + +Run `pytest tests/models/test_bitblas.py`. +""" +from dataclasses import dataclass + +import pytest + +from .utils import check_logprobs_close + + +@dataclass +class ModelPair: + model_gptq: str + + +model_pairs = [ + ModelPair(model_gptq="hxbgsyxh/opt-125m-4bit-128g"), +] + + +@pytest.mark.flaky(reruns=2) +@pytest.mark.skipif(True, reason="BitBLAS takes too much time for tuning.") +@pytest.mark.parametrize("model_pair", model_pairs) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models( + vllm_runner, + example_prompts, + model_pair: ModelPair, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + with vllm_runner(model_pair.model_gptq, + dtype=dtype, + quantization="bitblas") as bitblas_model: + bitblas_outputs = bitblas_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + with vllm_runner(model_pair.model_gptq, dtype=dtype, + quantization="gptq") as gptq_model: + gptq_outputs = gptq_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + check_logprobs_close( + outputs_0_lst=gptq_outputs, + outputs_1_lst=bitblas_outputs, + name_0="gptq", + name_1="gptq_bitblas", + ) diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index cd2b8f00d521b6870f704fe0215dcfe638c55e6f..446c4efbf6af0176c944815bb0589fc5940e05f5 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -24,10 +24,7 @@ def test_can_initialize(model_arch): def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig: hf_config.update(model_info.hf_overrides) - if hasattr(hf_config, "text_config"): - text_config: PretrainedConfig = hf_config.text_config - else: - text_config = hf_config + text_config = hf_config.get_text_config() text_config.update({ "num_layers": 1, diff --git a/tests/models/test_oot_registration.py b/tests/models/test_oot_registration.py index f1ed8a04cfa08ae4bd0d62e155e8a3b91f237d51..b45a87d94b8687defc5ac1a71ec600969a5bc894 100644 --- a/tests/models/test_oot_registration.py +++ b/tests/models/test_oot_registration.py @@ -18,10 +18,9 @@ def test_plugin( m.setenv("VLLM_USE_V1", "0") m.setenv("VLLM_PLUGINS", "") - with pytest.raises(Exception) as excinfo: + match = "Cannot find model module" + with pytest.raises(ValueError, match=match): LLM(model=dummy_opt_path, load_format="dummy") - error_msg = "has no vLLM implementation and the Transformers implementation is not compatible with vLLM" # noqa: E501 - assert (error_msg in str(excinfo.value)) @create_new_process_for_each_test() diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index d026e0f4d489a84100920ad547505a09cb2ba56d..aa20bdcaac63957b60954f5c75db70e04b80621a 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -266,16 +266,23 @@ def test_compressed_tensors_w8a8_dynamic_per_token( reason="WNA16 is not supported on ROCm.") @pytest.mark.parametrize( "wNa16_args", - [ - (os.path.join(models_path_prefix,"nm-testing/tinyllama-oneshot-w4a16-channel-v2"), "channel", None, 8), - (os.path.join(models_path_prefix,"nm-testing/tinyllama-oneshot-w4a16-group128-v2"), "group", 128, 8), - (os.path.join(models_path_prefix,"nm-testing/tinyllama-oneshot-w8a16-per-channel"), "channel", None, 4), - ], + [(os.path.join(models_path_prefix, "nm-testing/tinyllama-oneshot-w4a16-channel-v2"), "channel", None, 8, + True, False), + (os.path.join(models_path_prefix, "nm-testing/tinyllama-oneshot-w4a16-group128-v2"), "group", 128, 8, True, + False), + (os.path.join(models_path_prefix, "nm-testing/tinyllama-oneshot-w8a16-per-channel"), "channel", None, 4, + True, False), + (os.path.join(models_path_prefix, "nm-testing/TinyLlama-1.1B-Chat-v1.0-awq-group128-asym256"), "group", 128, + 8, False, False), + (os.path.join(models_path_prefix, "nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-Channel"), + "channel", None, 8, False, False), + (os.path.join(models_path_prefix, "nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-ActOrder"), + "group", 128, 8, False, True)], ) @pytest.mark.skipif(not current_platform.is_cuda(), reason="The tests are skipped on non-CUDA platform.") def test_compressed_tensors_wNa16(vllm_runner, wNa16_args): - model, strategy, group, pack_factor = wNa16_args + model, strategy, group, pack_factor, symmetric, has_g_idx = wNa16_args with vllm_runner(model) as llm: def check_model(model): @@ -291,6 +298,8 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args): if group is None else group) assert qkv_proj.scheme.pack_factor == pack_factor + assert qkv_proj.scheme.symmetric == symmetric + assert qkv_proj.scheme.has_g_idx == has_g_idx llm.apply_model(check_model) diff --git a/tests/samplers/test_beam_search.py b/tests/samplers/test_beam_search.py index fffe63a8c9e0525ca2b4471de2ca301d4fd0395f..a4de8a1c70b0cb92def2e9d2d65b8d7726a5ffae 100644 --- a/tests/samplers/test_beam_search.py +++ b/tests/samplers/test_beam_search.py @@ -7,6 +7,9 @@ Run `pytest tests/samplers/test_beam_search.py`. import pytest import os from ..utils import models_path_prefix +from transformers import AutoModelForSeq2SeqLM + +from vllm.assets.audio import AudioAsset @pytest.fixture(autouse=True) @@ -21,7 +24,8 @@ def v1(run_with_both_engines): # 3. Use the model "huggyllama/llama-7b". MAX_TOKENS = [64] BEAM_WIDTHS = [4] -MODELS = [os.path.join(models_path_prefix, "TinyLlama/TinyLlama-1.1B-Chat-v1.0")] +MM_BEAM_WIDTHS = [2] +MODELS = [os.path.join(models_path_prefix, "TinyLlama/TinyLlama-1.1B-Chat-v1.0")] @pytest.mark.skip_v1 # FIXME: This fails on V1 right now. @@ -50,15 +54,90 @@ def test_beam_search_single_input( for i in range(len(example_prompts)): hf_output_ids, hf_output_texts = hf_outputs[i] vllm_output_ids, vllm_output_texts = vllm_outputs[i] - for i, (hf_text, + for j, (hf_text, vllm_text) in enumerate(zip(hf_output_texts, vllm_output_texts)): - print(f">>>{i}-th hf output:") + print(f">>>{j}-th hf output:") print(hf_text) - print(f">>>{i}-th vllm output:") + print(f">>>{j}-th vllm output:") print(vllm_text) assert len(hf_output_ids) == len(vllm_output_ids) for j in range(len(hf_output_ids)): assert hf_output_ids[j] == vllm_output_ids[j], ( f"Test{i} output{j}:\nHF: {hf_output_ids}\n" f"vLLM: {vllm_output_ids}") + + +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", MAX_TOKENS) +@pytest.mark.parametrize("beam_width", MM_BEAM_WIDTHS) +def test_beam_search_passes_multimodal_data( + hf_runner, + vllm_runner, + dtype: str, + max_tokens: int, + beam_width: int, +) -> None: + """Ensure that beam search passes multimodal data through correctly.""" + # NOTE - this test is primarily to check that mm data is passed to beams + # correctly. As such, we just need to check one extra modality to make + # sure things pass through properly. + audios = [AudioAsset("mary_had_lamb").audio_and_sample_rate] + model = "Qwen/Qwen2-Audio-7B-Instruct" + audio_seq = "<|audio_bos|><|AUDIO|><|audio_eos|>" + prompts = [ + f"<|im_start|>user\n{audio_seq}Can you transcribe this?<|im_end|>\n<|im_start|>assistant\n" #noqa: E501 + ] + + with hf_runner(model, dtype=dtype, + auto_cls=AutoModelForSeq2SeqLM) as hf_model: + audio_token_id = hf_model.config.audio_token_index + eos_token_id = hf_model.tokenizer.eos_token_id # <|im_end|> + hf_outputs = hf_model.generate_beam_search( + prompts, + beam_width=beam_width, + max_tokens=max_tokens, + audios=audios, + ) + + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_beam_search( + prompts, + beam_width=beam_width, + max_tokens=max_tokens, + audios=audios, + ) + + seq_with_no_audio_toks = lambda seq: [ + tok for tok in seq if tok != audio_token_id + ] + + for i in range(len(prompts)): + hf_output_ids, hf_output_texts = hf_outputs[i] + vllm_output_ids, vllm_output_texts = vllm_outputs[i] + + for j, (hf_text, + vllm_text) in enumerate(zip(hf_output_texts, + vllm_output_texts)): + print(f">>>{j}-th hf output [NOTE: special tokens are filtered]:") + print(hf_text) + print(f">>>{j}-th vllm output:") + print(vllm_text) + assert len(hf_output_ids) == len(vllm_output_ids) + + for j in range(len(hf_output_ids)): + # Compare everything except for the audio tokens; we do this since + # the IDs returned from the transformers helper expands the audio + # token to match features, while the vLLM helper maintains the + # single audio token in the input text + filtered_hf_output_ids = seq_with_no_audio_toks(hf_output_ids[j]) + filtered_vllm_output_ids = seq_with_no_audio_toks( + vllm_output_ids[j]) + + # HF output IDs may contain the end of sequence + if len(filtered_hf_output_ids + ) == len(filtered_vllm_output_ids) + 1: + assert filtered_hf_output_ids[-1] == eos_token_id + filtered_hf_output_ids = filtered_hf_output_ids[:-1] + + assert filtered_hf_output_ids == filtered_vllm_output_ids diff --git a/tests/spec_decode/test_scorer.py b/tests/spec_decode/test_scorer.py index 986b2b71846d9ef8e41f1442f8347c15d403342b..3f0c10a4df598a900b50c560aea2c54a19e8f5d7 100644 --- a/tests/spec_decode/test_scorer.py +++ b/tests/spec_decode/test_scorer.py @@ -64,9 +64,8 @@ def test_scorer(model_name: str, batch_size: int, max_propose_len: int, scorer_worker = create_worker(Worker, model_name, block_size, num_gpu_blocks, seed) scorer_worker.model_runner.disable_logprobs = True # accessed by mqa_scorer - scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor = True - scorer_worker.model_runner.model.sampler.\ - should_modify_greedy_probs_inplace = True + scorer_worker.model_runner.sampler.include_gpu_probs_tensor = True + scorer_worker.model_runner.sampler.should_modify_greedy_probs_inplace = True vocab_size = scorer_worker.vocab_size diff --git a/tests/test_config.py b/tests/test_config.py index b2ae3f065db5b33b4b31bc524801903b6d3f6d00..df3d720f1e70407e0c88cd5869013d47f369e90e 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,16 +1,38 @@ # SPDX-License-Identifier: Apache-2.0 -from dataclasses import asdict +from dataclasses import MISSING, Field, asdict, dataclass, field import pytest import os -from vllm.config import ModelConfig, PoolerConfig +from vllm.config import ModelConfig, PoolerConfig, get_field from vllm.model_executor.layers.pooler import PoolingType from vllm.platforms import current_platform from utils import models_path_prefix +def test_get_field(): + + @dataclass + class TestConfig: + a: int + b: dict = field(default_factory=dict) + c: str = "default" + + with pytest.raises(ValueError): + get_field(TestConfig, "a") + + b = get_field(TestConfig, "b") + assert isinstance(b, Field) + assert b.default is MISSING + assert b.default_factory is dict + + c = get_field(TestConfig, "c") + assert isinstance(c, Field) + assert c.default == "default" + assert c.default_factory is MISSING + + @pytest.mark.parametrize( ("model_id", "expected_runner_type", "expected_task"), [ diff --git a/tests/test_utils.py b/tests/test_utils.py index 05d2a3d4ed3959760591e7aab9e2fe189f2ad0e3..3a5ad46c6333dc1500d8a7c2ae03b305c4e050a5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -13,11 +13,11 @@ import torch from vllm_test_utils.monitor import monitor from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config -from vllm.utils import (FlexibleArgumentParser, MemorySnapshot, - PlaceholderModule, StoreBoolean, bind_kv_cache, - deprecate_kwargs, get_open_port, memory_profiling, - merge_async_iterators, sha256, supports_kw, - swap_dict_values) +from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache, + MemorySnapshot, PlaceholderModule, StoreBoolean, + bind_kv_cache, deprecate_kwargs, get_open_port, + memory_profiling, merge_async_iterators, sha256, + supports_kw, swap_dict_values) from .utils import create_new_process_for_each_test, error_on_warning from .utils import models_path_prefix @@ -419,6 +419,129 @@ def test_bind_kv_cache_pp(): assert ctx['layers.0.self_attn'].kv_cache[1] is kv_cache[1][0] +class TestLRUCache(LRUCache): + + def _on_remove(self, key, value): + if not hasattr(self, "_remove_counter"): + self._remove_counter = 0 + self._remove_counter += 1 + + +def test_lru_cache(): + cache = TestLRUCache(3) + assert cache.stat() == CacheInfo(hits=0, total=0) + assert cache.stat(delta=True) == CacheInfo(hits=0, total=0) + + cache.put(1, 1) + assert len(cache) == 1 + + cache.put(1, 1) + assert len(cache) == 1 + + cache.put(2, 2) + assert len(cache) == 2 + + cache.put(3, 3) + assert len(cache) == 3 + assert set(cache.cache) == {1, 2, 3} + + cache.put(4, 4) + assert len(cache) == 3 + assert set(cache.cache) == {2, 3, 4} + assert cache._remove_counter == 1 + + assert cache.get(2) == 2 + assert cache.stat() == CacheInfo(hits=1, total=1) + assert cache.stat(delta=True) == CacheInfo(hits=1, total=1) + + assert cache[2] == 2 + assert cache.stat() == CacheInfo(hits=2, total=2) + assert cache.stat(delta=True) == CacheInfo(hits=1, total=1) + + cache.put(5, 5) + assert set(cache.cache) == {2, 4, 5} + assert cache._remove_counter == 2 + + assert cache.pop(5) == 5 + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + assert cache.get(-1) is None + assert cache.stat() == CacheInfo(hits=2, total=3) + assert cache.stat(delta=True) == CacheInfo(hits=0, total=1) + + cache.pop(10) + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.get(10) + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.put(6, 6) + assert len(cache) == 3 + assert set(cache.cache) == {2, 4, 6} + assert 2 in cache + assert 4 in cache + assert 6 in cache + + cache.remove_oldest() + assert len(cache) == 2 + assert set(cache.cache) == {2, 6} + assert cache._remove_counter == 4 + + cache.clear() + assert len(cache) == 0 + assert cache._remove_counter == 6 + assert cache.stat() == CacheInfo(hits=0, total=0) + assert cache.stat(delta=True) == CacheInfo(hits=0, total=0) + + cache._remove_counter = 0 + + cache[1] = 1 + assert len(cache) == 1 + + cache[1] = 1 + assert len(cache) == 1 + + cache[2] = 2 + assert len(cache) == 2 + + cache[3] = 3 + assert len(cache) == 3 + assert set(cache.cache) == {1, 2, 3} + + cache[4] = 4 + assert len(cache) == 3 + assert set(cache.cache) == {2, 3, 4} + assert cache._remove_counter == 1 + assert cache[2] == 2 + + cache[5] = 5 + assert set(cache.cache) == {2, 4, 5} + assert cache._remove_counter == 2 + + del cache[5] + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.pop(10) + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache[6] = 6 + assert len(cache) == 3 + assert set(cache.cache) == {2, 4, 6} + assert 2 in cache + assert 4 in cache + assert 6 in cache + + def test_placeholder_module_error_handling(): placeholder = PlaceholderModule("placeholder_1234") diff --git a/tests/tokenization/test_cached_tokenizer.py b/tests/tokenization/test_cached_tokenizer.py index 576a8fa482b124d346823c3c7347d5874ee20e19..38933011ec692f5c6bab99c58587e1954b331dc7 100644 --- a/tests/tokenization/test_cached_tokenizer.py +++ b/tests/tokenization/test_cached_tokenizer.py @@ -1,26 +1,46 @@ # SPDX-License-Identifier: Apache-2.0 - +import pickle from copy import deepcopy import os +import pytest from transformers import AutoTokenizer - -from vllm.transformers_utils.tokenizer import get_cached_tokenizer from ..utils import models_path_prefix -def test_cached_tokenizer(): - reference_tokenizer = AutoTokenizer.from_pretrained(os.path.join(models_path_prefix, "gpt2")) +from vllm.transformers_utils.tokenizer import (AnyTokenizer, + get_cached_tokenizer) + + +@pytest.mark.parametrize("model_id", [os.path.join(models_path_prefix, "gpt2"), os.path.join(models_path_prefix, "THUDM/chatglm3-6b")]) +def test_cached_tokenizer(model_id: str): + reference_tokenizer = AutoTokenizer.from_pretrained(model_id, + trust_remote_code=True) reference_tokenizer.add_special_tokens({"cls_token": ""}) reference_tokenizer.add_special_tokens( {"additional_special_tokens": [""]}) + cached_tokenizer = get_cached_tokenizer(deepcopy(reference_tokenizer)) + _check_consistency(cached_tokenizer, reference_tokenizer) + + pickled_tokenizer = pickle.dumps(cached_tokenizer) + unpickled_tokenizer = pickle.loads(pickled_tokenizer) + _check_consistency(unpickled_tokenizer, reference_tokenizer) + + +def _check_consistency(target: AnyTokenizer, expected: AnyTokenizer): + assert isinstance(target, type(expected)) + + # Cached attributes + assert target.all_special_ids == expected.all_special_ids + assert target.all_special_tokens == expected.all_special_tokens + assert (target.all_special_tokens_extended == + expected.all_special_tokens_extended) + assert target.get_vocab() == expected.get_vocab() + assert len(target) == len(expected) + + # Other attributes + assert getattr(target, "padding_side", + None) == getattr(expected, "padding_side", None) - assert reference_tokenizer.encode("prompt") == cached_tokenizer.encode( - "prompt") - assert set(reference_tokenizer.all_special_ids) == set( - cached_tokenizer.all_special_ids) - assert set(reference_tokenizer.all_special_tokens) == set( - cached_tokenizer.all_special_tokens) - assert set(reference_tokenizer.all_special_tokens_extended) == set( - cached_tokenizer.all_special_tokens_extended) + assert target.encode("prompt") == expected.encode("prompt") diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index 9addeca6767223ab48832080e18820455645e6c6..873ef9ae03e9834d944d2ad971f805fed7107417 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -3,19 +3,26 @@ from collections.abc import Generator from typing import Any, Optional -import pytest import os -from transformers import AutoTokenizer +import pytest +from transformers import (AutoTokenizer, PreTrainedTokenizer, + PreTrainedTokenizerFast) from vllm.inputs import token_inputs from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup -from vllm.transformers_utils.detokenizer import (Detokenizer, - detokenize_incrementally) -from vllm.transformers_utils.tokenizer_group import get_tokenizer_group - +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer +from vllm.v1.engine import EngineCoreRequest +from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer, + IncrementalDetokenizer, + SlowIncrementalDetokenizer) from ..utils import models_path_prefix +SPECIAL_TOKS_TRUTH = [ + "Some text with adjacent special tokens <|padding|><|padding|>other text", # noqa +] + TRUTH = [ "Hello here, this is a simple test", "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving", # noqa @@ -25,7 +32,8 @@ TRUTH = [ # incomplete UTF-8 characters # see https://github.com/vllm-project/vllm/pull/9625 "ပုံပြင်လေးပြောပြပါ်", -] +] + SPECIAL_TOKS_TRUTH + TOKENIZERS = [ os.path.join(models_path_prefix, "facebook/opt-125m"), os.path.join(models_path_prefix, "gpt2"), @@ -41,26 +49,37 @@ TOKENIZERS = [ ] -def _run_incremental_decode(tokenizer, all_input_ids, - skip_special_tokens: bool, starting_index: int): - decoded_text = "" - offset = 0 - token_offset = 0 - prev_tokens = None - for i in range(starting_index, len(all_input_ids)): - new_tokens, text, offset, token_offset = detokenize_incrementally( - tokenizer, - all_input_ids[:i + 1], - prev_tokens, - offset, - token_offset, - skip_special_tokens=skip_special_tokens) - decoded_text += text - if prev_tokens is None: - prev_tokens = new_tokens - else: - prev_tokens += new_tokens - return decoded_text +def _run_incremental_decode(tokenizer, + all_input_ids, + skip_special_tokens: bool, + starting_index: int, + spaces_between_special_tokens: bool = True, + fast: Optional[bool] = None): + + prompt_token_ids = all_input_ids[:starting_index] + + params = SamplingParams( + skip_special_tokens=skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + ) + request = EngineCoreRequest("", prompt_token_ids, None, None, None, params, + None, 0.0, None) + + if fast is None: + detokenizer = IncrementalDetokenizer.from_new_request( + tokenizer, request) + elif fast: + detokenizer = FastIncrementalDetokenizer(tokenizer, request) + else: + detokenizer = SlowIncrementalDetokenizer(tokenizer, request) + + output_text = "" + for i, token_id in enumerate(all_input_ids[starting_index:]): + detokenizer.update([token_id], False) + finished = i == len(all_input_ids) - 1 + output_text += detokenizer.get_next_output_text(finished, delta=True) + + return output_text, detokenizer.output_token_ids @pytest.fixture @@ -88,11 +107,13 @@ def test_mistral_edge_case(tokenizer, truth): starting_index = 0 all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids - decoded_text = _run_incremental_decode(tokenizer, - all_input_ids, - skip_special_tokens=True, - starting_index=starting_index) + decoded_text, out_ids = _run_incremental_decode( + tokenizer, + all_input_ids, + skip_special_tokens=True, + starting_index=starting_index) assert decoded_text == truth + assert out_ids == all_input_ids[starting_index:] @pytest.fixture @@ -109,45 +130,91 @@ def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]: @pytest.mark.parametrize("with_prompt", [True, False]) @pytest.mark.parametrize("tokenizer_name", TOKENIZERS) @pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True) -def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens): +@pytest.mark.parametrize("spaces_between_special_tokens", (True, False)) +@pytest.mark.parametrize("fast", (True, False)) +def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens, + spaces_between_special_tokens, fast): + if fast and not isinstance(tokenizer, PreTrainedTokenizerFast): + pytest.skip() + + if skip_special_tokens and not spaces_between_special_tokens: + pytest.skip() + + if not fast and isinstance(tokenizer, PreTrainedTokenizerFast): + # Fix up inconsistency in fast/slow tokenizer behaviour. + tokenizer.add_special_tokens({ + "additional_special_tokens": [ + at for at in + tokenizer._tokenizer.get_added_tokens_decoder().values() + if at.special + ] + }) + + extra_decode_args = {} if not isinstance(tokenizer, PreTrainedTokenizer) \ + else {"spaces_between_special_tokens": spaces_between_special_tokens} + + truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids + if tokenizer.bos_token_id is not None: + truth_tokens.insert(0, tokenizer.bos_token_id) + truth_tokens.append(tokenizer.eos_token_id) + + new_truth = tokenizer.decode(truth_tokens, + skip_special_tokens=skip_special_tokens, + **extra_decode_args) + if with_prompt: - truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids - prompt_input_ids = truth_tokens[:len(truth) // 2] - generated_input_ids = truth_tokens[len(truth) // 2:] + num_prompt_tokens = len( + tokenizer(truth[:len(truth) // 2], + add_special_tokens=False).input_ids) + if tokenizer.bos_token_id is not None: + num_prompt_tokens += 1 + + prompt_input_ids = truth_tokens[:num_prompt_tokens] + generated_input_ids = truth_tokens[num_prompt_tokens:] all_input_ids = prompt_input_ids + generated_input_ids starting_index = len(prompt_input_ids) prompt = tokenizer.decode(prompt_input_ids, - skip_special_tokens=skip_special_tokens) - generated = truth[len(prompt):] + skip_special_tokens=skip_special_tokens, + **extra_decode_args) + + generated = new_truth[len(prompt):] else: - generated = truth + generated = new_truth starting_index = 0 - all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids - if skip_special_tokens: - if tokenizer.bos_token_id is not None: - all_input_ids = [tokenizer.bos_token_id] + all_input_ids - starting_index += 1 - all_input_ids = all_input_ids + [tokenizer.eos_token_id] + all_input_ids = truth_tokens - decoded_text = _run_incremental_decode( + decoded_text, out_ids = _run_incremental_decode( tokenizer, all_input_ids, skip_special_tokens=skip_special_tokens, - starting_index=starting_index) + starting_index=starting_index, + spaces_between_special_tokens=spaces_between_special_tokens, + fast=fast) assert decoded_text == generated + assert out_ids == all_input_ids[starting_index:] + + +@pytest.mark.parametrize("tokenizer_name", TOKENIZERS) +@pytest.mark.parametrize("fast", (True, False)) +def test_oov_decode(tokenizer, fast): + if fast and not isinstance(tokenizer, PreTrainedTokenizerFast): + pytest.skip() - decoded_text = _run_incremental_decode( + decoded_text, out_ids = _run_incremental_decode( tokenizer, [len(tokenizer)], - skip_special_tokens=skip_special_tokens, - starting_index=starting_index) + skip_special_tokens=True, + starting_index=0, + spaces_between_special_tokens=True, + fast=fast) assert decoded_text == '' + assert out_ids == [len(tokenizer)] @pytest.fixture def detokenizer(tokenizer_name: str) -> Detokenizer: - init_kwargs = dict( + tokenizer_group = TokenizerGroup( tokenizer_id=tokenizer_name, enable_lora=False, max_num_seqs=100, @@ -157,26 +224,20 @@ def detokenizer(tokenizer_name: str) -> Detokenizer: revision=None, ) - tokenizer_group = get_tokenizer_group( - None, - **init_kwargs, - ) - return Detokenizer(tokenizer_group) @pytest.fixture(name="complete_sequence_token_ids") def create_complete_sequence_token_ids(complete_sequence: str, tokenizer) -> list[int]: - complete_sequence_token_ids = tokenizer(complete_sequence).input_ids - return complete_sequence_token_ids + return tokenizer(complete_sequence, add_special_tokens=False).input_ids def create_sequence(prompt_token_ids=None): - prompt_token_ids = prompt_token_ids or [1] + prompt_token_ids = prompt_token_ids or [] return Sequence( seq_id=0, - inputs=token_inputs(prompt_token_ids, prompt=""), + inputs=token_inputs(prompt_token_ids), block_size=16, ) @@ -227,7 +288,7 @@ def test_decode_sequence_logprobs(complete_sequence: str, assert sequential_result == "".join(sequential_logprobs_text_chosen_token) assert sequential_result != "".join(sequential_logprobs_text_other_token) - if skip_special_tokens: + if not skip_special_tokens: # Text for logprobs for the chosen token should be the same as the # generated text. Note that this will only be true if we skip # special tokens. @@ -236,10 +297,23 @@ def test_decode_sequence_logprobs(complete_sequence: str, @pytest.mark.parametrize("complete_sequence", TRUTH) @pytest.mark.parametrize("tokenizer_name", TOKENIZERS) -def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int], +def test_decode_prompt_logprobs(complete_sequence: str, + complete_sequence_token_ids: list[int], detokenizer: Detokenizer): + + # We want to use skip_special_tokens=False here but Mistral tokenizers + # don't support that. + if complete_sequence not in SPECIAL_TOKS_TRUTH: + skip_special_tokens = True + elif not isinstance(detokenizer.tokenizer_group.get_lora_tokenizer(None), + MistralTokenizer): + skip_special_tokens = False + else: + pytest.skip("MistralTokenizers don't support " + "skip_special_tokens=False") + return """Verify Detokenizer decodes prompt logprobs correctly.""" - sampling_params = SamplingParams(skip_special_tokens=True, + sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens, prompt_logprobs=1) # Run sequentially. @@ -259,8 +333,10 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int], # decoded_prompt_logprobs doesn't contain the first token. token_ids = complete_sequence_token_ids tokenizer = detokenizer.get_tokenizer_for_seq(seq) - text_full = tokenizer.decode(token_ids, skip_special_tokens=True) - text_first = tokenizer.decode(token_ids[0], skip_special_tokens=True) + text_full = tokenizer.decode(token_ids, + skip_special_tokens=skip_special_tokens) + text_first = tokenizer.decode(token_ids[0], + skip_special_tokens=skip_special_tokens) text = text_full[len(text_first):] # Text for logprobs for the chosen token should be the same as the diff --git a/tests/tokenization/test_tokenizer_group.py b/tests/tokenization/test_tokenizer_group.py index 6717795a1387d63141502848c353d738e63b8b7c..3e3db33e22c9d0cbba1c7b39d6f9389234ec82ae 100644 --- a/tests/tokenization/test_tokenizer_group.py +++ b/tests/tokenization/test_tokenizer_group.py @@ -1,41 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 -import asyncio import os -import sys -from typing import Optional -from unittest.mock import patch - import pytest from transformers import AutoTokenizer, PreTrainedTokenizerBase - -from vllm.transformers_utils.tokenizer_group import (TokenizerGroup, - get_tokenizer_group) -from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import ( - RayTokenizerGroupPool) - -from ..conftest import get_tokenizer_pool_config from ..utils import models_path_prefix - -class CustomTokenizerGroup(TokenizerGroup): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._i = 0 - - def encode(self, *args, **kwargs): - self._i += 1 - return super().encode(*args, **kwargs) +from vllm.transformers_utils.tokenizer_group import TokenizerGroup @pytest.mark.asyncio -@pytest.mark.parametrize("tokenizer_group_type", - [None, "ray", CustomTokenizerGroup]) -async def test_tokenizer_group(tokenizer_group_type): +async def test_tokenizer_group(): reference_tokenizer = AutoTokenizer.from_pretrained(os.path.join(models_path_prefix, "gpt2")) - tokenizer_group = get_tokenizer_group( - get_tokenizer_pool_config(tokenizer_group_type), + tokenizer_group = TokenizerGroup( tokenizer_id=os.path.join(models_path_prefix, "gpt2"), enable_lora=False, max_num_seqs=1, @@ -50,159 +26,3 @@ async def test_tokenizer_group(tokenizer_group_type): PreTrainedTokenizerBase) assert tokenizer_group.get_lora_tokenizer( None) == await tokenizer_group.get_lora_tokenizer_async(None) - if tokenizer_group_type is CustomTokenizerGroup: - assert tokenizer_group._i > 0 - - -@pytest.mark.asyncio -@pytest.mark.parametrize("tokenizer_group_type", ["ray"]) -async def test_tokenizer_group_pool(tokenizer_group_type): - reference_tokenizer = AutoTokenizer.from_pretrained(os.path.join(models_path_prefix, "gpt2")) - tokenizer_group_pool = get_tokenizer_group( - get_tokenizer_pool_config(tokenizer_group_type), - tokenizer_id=os.path.join(models_path_prefix, "gpt2"), - enable_lora=False, - max_num_seqs=1, - max_input_length=None, - ) - # Send multiple requests to the tokenizer group pool - # (more than the pool size) - # and check that all requests are processed correctly. - num_requests = tokenizer_group_pool.pool_size * 5 - requests = [ - tokenizer_group_pool.encode_async(prompt=f"prompt {i}", - lora_request=None) - for i in range(num_requests) - ] - results = await asyncio.gather(*requests) - expected_results = [ - reference_tokenizer.encode(f"prompt {i}") for i in range(num_requests) - ] - assert results == expected_results - - -@pytest.mark.asyncio -@pytest.mark.parametrize("tokenizer_group_type", ["ray"]) -async def test_tokenizer_group_ray_pool_env_var_propagation( - tokenizer_group_type): - """Test that env vars from caller process are propagated to - tokenizer Ray actors.""" - env_var = "MY_ENV_VAR" - - class EnvVarCheckerTokenizerGroup(TokenizerGroup): - - def ping(self): - assert os.environ.get(env_var) == "1" - return super().ping() - - class EnvVarCheckerRayTokenizerGroupPool(RayTokenizerGroupPool): - _worker_cls = EnvVarCheckerTokenizerGroup - - tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type) - tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config( - tokenizer_pool_config, - tokenizer_id=os.path.join(models_path_prefix, "gpt2"), - enable_lora=False, - max_num_seqs=1, - max_input_length=None) - with pytest.raises(AssertionError): - tokenizer_pool.ping() - - with patch.dict(os.environ, {env_var: "1"}): - tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type) - tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config( - tokenizer_pool_config, - tokenizer_id=os.path.join(models_path_prefix, "gpt2"), - enable_lora=False, - max_num_seqs=1, - max_input_length=None) - tokenizer_pool.ping() - - -@pytest.mark.asyncio -@pytest.mark.parametrize("tokenizer_group_type", ["ray"]) -async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type): - """Test that Ray tokenizer pool group can recover from failures and - if that's not possible, mark itself as unhealthy.""" - - class FailingTokenizerGroup(TokenizerGroup): - - def __init__(self, - *args, - fail_at: Optional[list[int]] = None, - **kwargs): - super().__init__(*args, **kwargs) - self.i = 0 - self.fail_at = fail_at or [] - - def encode(self, *args, **kwargs): - self.i += 1 - if self.i in self.fail_at: - sys.exit(1) - return super().encode(*args, **kwargs) - - class FailingRayTokenizerGroupPool(RayTokenizerGroupPool): - _worker_cls = FailingTokenizerGroup - - # Fail at first iteration - fail_at = [1] - tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type) - tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config( - tokenizer_pool_config, - tokenizer_id=os.path.join(models_path_prefix, "gpt2"), - enable_lora=False, - max_num_seqs=1, - max_input_length=None, - fail_at=fail_at) - tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy() - - # Modify fail at to not fail at all (will be re-read when actor is - # re-initialized). - fail_at[0] = 1000 - - # We should recover successfully. - await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None) - await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None) - - # Check that we have a new actor - assert len(tokenizer_group_pool.tokenizer_actors) == len(tokenizer_actors) - assert tokenizer_group_pool.tokenizer_actors != tokenizer_actors - - # Fail at first iteration - fail_at = [1] - tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config( - tokenizer_pool_config, - tokenizer_id=os.path.join(models_path_prefix, "gpt2"), - enable_lora=False, - max_num_seqs=1, - max_input_length=None, - fail_at=fail_at) - - # We should fail after re-initialization. - with pytest.raises(RuntimeError): - await tokenizer_group_pool.encode_async(prompt="prompt", - lora_request=None) - - # check_health should raise the same thing - with pytest.raises(RuntimeError): - tokenizer_group_pool.check_health() - - # Ensure that non-ActorDiedErrors are still propagated correctly and do not - # cause a re-initialization. - fail_at = [] - tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config( - tokenizer_pool_config, - tokenizer_id=os.path.join(models_path_prefix, "gpt2"), - enable_lora=False, - max_num_seqs=1, - max_input_length=2, - fail_at=fail_at) - tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy() - - # Prompt too long error - with pytest.raises(ValueError): - await tokenizer_group_pool.encode_async(prompt="prompt" * 100, - lora_request=None) - await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None) - # Actors should stay the same. - assert tokenizer_group_pool.tokenizer_actors == tokenizer_actors diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index 44fa4e059d9af4cf83db300f09c025c6284cf87c..506fff2562f3372b9438651dd9efa42129d6409f 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -101,6 +101,20 @@ CONFIGS: dict[str, ServerConfig] = { "extended": True }, + "llama4_json": { + "model": + "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "arguments": [ + "--enforce-eager", "--no-enable-prefix-caching", "-tp", "4", + "--distributed-executor-backend", "mp", "--tool-call-parser", + "llama4_json", "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_llama4_json.jinja") + ], + "supports_parallel": + True, + "extended": + True + }, "mistral": { "model": os.path.join(models_path_prefix, "mistralai/Mistral-7B-Instruct-v0.3"), diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index a4a571b180c6bddd45fce1a9e81836cac0539587..e73e08e74b0d2c4bb38ba9211fdb7a8bc26aa953 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -37,7 +37,6 @@ def make_request(request_id, return Request( request_id=request_id, - prompt=None, prompt_token_ids=prompt_token_ids, multi_modal_inputs=multi_modal_inputs, multi_modal_hashes=mm_hashes, @@ -311,7 +310,7 @@ def test_metrics(): def stats(requests, queries, hits): return PrefixCacheStats(requests=requests, queries=queries, hits=hits) - metrics = PrefixCachingMetrics(interval=5) + metrics = PrefixCachingMetrics(max_recent_requests=5) assert metrics.hit_rate == 0.0 metrics.observe(stats(1, 20, 9)) @@ -496,8 +495,7 @@ def test_allocate_with_lookahead(): # Test case 1: Requires additional lookahead tokens kv_cache_manager = KVCacheManager(kv_cache_config=config, - max_model_len=100, - num_preallocate_tokens=0) + max_model_len=100) blocks = kv_cache_manager.allocate_slots( request, num_tokens=3, @@ -507,25 +505,19 @@ def test_allocate_with_lookahead(): # Test case 2: With precomputed blocks kv_cache_manager = KVCacheManager(kv_cache_config=config, - max_model_len=100, - num_preallocate_tokens=4) - # num_preallocate_blocks = 4 // 4 - 2 // 4 = 1 + max_model_len=100) # required_blocks = ceil((3 + 2) /4) = 2 - # total_blocks = 1 + 2 = 3 blocks = kv_cache_manager.allocate_slots( request, num_tokens=3, num_lookahead_tokens=2, ) - assert len(blocks) == 3 + assert len(blocks) == 2 # Test case 3: With precomputed blocks - # num_preallocate_blocks = 4 // 4 - 4 // 4 = 0 # required_blocks = ceil((3 + 4) / 4) = 2 - # total_blocks = 0 + 2 = 2 kv_cache_manager = KVCacheManager(kv_cache_config=config, - max_model_len=100, - num_preallocate_tokens=4) + max_model_len=100) blocks = kv_cache_manager.allocate_slots( request, num_tokens=3, diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 80dd275a90b87c5fbc391c00fc14825c9f428a40..b2e8ff61450c4e7ce93c435d6322eea07f55f3e5 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -8,7 +8,7 @@ import torch from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams -from vllm.utils import cdiv, sha256 +from vllm.utils import sha256 from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_manager import KVCacheManager, Request from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, @@ -29,7 +29,6 @@ def make_request(request_id, return Request( request_id=request_id, - prompt=None, prompt_token_ids=prompt_token_ids, multi_modal_inputs=multi_modal_inputs, multi_modal_hashes=mm_hashes, @@ -61,7 +60,6 @@ def test_prefill(hash_algo): max_model_len=8192, enable_caching=True, caching_hash_algo=hash_algo, - num_preallocate_tokens=16, ) # choose the hash function according to the parameter @@ -80,7 +78,7 @@ def test_prefill(hash_algo): assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, computed_blocks) - assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5] + assert [b.block_id for b in blocks] == [1, 2, 3, 4] # Check full block metadata parent_block_hash = None @@ -92,8 +90,8 @@ def test_prefill(hash_algo): assert manager.block_pool.blocks[block_id].ref_cnt == 1 parent_block_hash = block_hash.hash_value - # Check partial/preallocated block metadata - for block_id in (4, 5): + # Check partial block metadata + for block_id in (4, ): assert manager.block_pool.blocks[block_id].block_hash is None assert manager.block_pool.blocks[block_id].ref_cnt == 1 @@ -107,12 +105,12 @@ def test_prefill(hash_algo): assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) - assert [b.block_id for b in blocks] == [6, 7] + assert [b.block_id for b in blocks] == [5] for block in computed_blocks: assert block.ref_cnt == 2 - # At this point, we should have 3 free blocks left. - assert manager.block_pool.free_block_queue.num_free_blocks == 3 + # At this point, we should have 5 free blocks left. + assert manager.block_pool.free_block_queue.num_free_blocks == 5 manager.free(req0) manager.free(req1) @@ -120,14 +118,14 @@ def test_prefill(hash_algo): # All blocks should be available. assert manager.block_pool.free_block_queue.num_free_blocks == 10 # The order should be - # [unallocated (8, 9, 10)] - # [unique_req0 (5, 4)] - # [unique_req1 (7, 6)] + # [unallocated (6, 7, 8, 9, 10)] + # [unique_req0 (4)] + # [unique_req1 (5)] # [common (3, 2, 1)] assert [ b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks() - ] == [8, 9, 10, 5, 4, 7, 6, 3, 2, 1] + ] == [6, 7, 8, 9, 10, 4, 5, 3, 2, 1] # Cache hit in the common prefix when the original block is already free. # Incomplete 1 block (6 tokens) @@ -139,29 +137,29 @@ def test_prefill(hash_algo): assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks) - assert [b.block_id for b in blocks] == [8, 9] + assert [b.block_id for b in blocks] == [6] - # Although we only have 5 free blocks, we have 8 blocks in + # Although we only have 6 free blocks, we have 8 blocks in # the free block queue due to lazy removal. - assert manager.block_pool.free_block_queue.num_free_blocks == 5 + assert manager.block_pool.free_block_queue.num_free_blocks == 6 assert all([ b.ref_cnt == 0 for b in manager.block_pool.free_block_queue.get_all_free_blocks() ]) assert len([ b for b in manager.block_pool.free_block_queue.get_all_free_blocks() - ]) == 5 + ]) == 6 manager.free(req2) # Cache miss and eviction. - req3 = make_request("3", [99] * (16 * 9)) + req3 = make_request("3", [99] * (16 * 10)) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks) + blocks = manager.allocate_slots(req3, 16 * 10, computed_blocks) # This block ID order also checks the eviction order. - assert [b.block_id for b in blocks] == [10, 5, 4, 7, 6, 9, 8, 3, 2, 1] + assert [b.block_id for b in blocks] == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1] assert manager.block_pool.free_block_queue.num_free_blocks == 0 assert manager.block_pool.free_block_queue.free_list_head is None assert manager.block_pool.free_block_queue.free_list_tail is None @@ -178,7 +176,6 @@ def test_prefill_plp(): make_kv_cache_config(16, 11), max_model_len=8192, enable_caching=True, - num_preallocate_tokens=16, ) # the default hash function is hash hash_fn = hash @@ -197,7 +194,7 @@ def test_prefill_plp(): assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, computed_blocks) - assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5] + assert [b.block_id for b in blocks] == [1, 2, 3, 4] req0_block_hashes = [b.block_hash for b in blocks] # Check full block metadata @@ -210,8 +207,8 @@ def test_prefill_plp(): assert manager.block_pool.blocks[block_id].ref_cnt == 1 parent_block_hash = block_hash.hash_value - # Check partial/preallocated block metadata - for block_id in (4, 5): + # Check partial block metadata + for block_id in (4, ): assert manager.block_pool.blocks[block_id].block_hash is None assert manager.block_pool.blocks[block_id].ref_cnt == 1 @@ -226,12 +223,12 @@ def test_prefill_plp(): assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) - assert [b.block_id for b in blocks] == [6, 7] + assert [b.block_id for b in blocks] == [5] for block in computed_blocks: assert block.ref_cnt == 2 - # At this point, we should have 3 free blocks left. - assert manager.block_pool.free_block_queue.num_free_blocks == 3 + # At this point, we should have 5 free blocks left. + assert manager.block_pool.free_block_queue.num_free_blocks == 5 manager.free(req0) manager.free(req1) @@ -239,14 +236,14 @@ def test_prefill_plp(): # All blocks should be available. assert manager.block_pool.free_block_queue.num_free_blocks == 10 # The order should be - # [unallocated (8, 9, 10)] - # [unique_req0 (5, 4)] - # [unique_req1 (7, 6)] + # [unallocated (6, 7, 8, 9, 10)] + # [unique_req0 (4)] + # [unique_req1 (5)] # [common (3, 2, 1)] assert [ b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks() - ] == [8, 9, 10, 5, 4, 7, 6, 3, 2, 1] + ] == [6, 7, 8, 9, 10, 4, 5, 3, 2, 1] # Request #2 is a prompt-logprobs request: # NO cache hit in the common prefix; duplicates request #0 cached blocks @@ -262,7 +259,7 @@ def test_prefill_plp(): block_ids = [b.block_id for b in blocks] # Duplicate cached blocks have different ids but same hashes vs request #0 assert [b.block_hash for b in blocks] == req0_block_hashes - assert block_ids != [1, 2, 3, 4, 5] + assert block_ids != [1, 2, 3, 4] # Request #2 block hashes are valid since request #0 hashes are. # Check block reference counts. @@ -277,7 +274,6 @@ def test_decode(): make_kv_cache_config(16, 11), max_model_len=8192, enable_caching=True, - num_preallocate_tokens=16, ) # Complete 3 blocks (48 tokens) @@ -291,7 +287,7 @@ def test_decode(): assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, computed_blocks) - assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5] + assert [b.block_id for b in blocks] == [1, 2, 3, 4] # Append slots without allocating a new block. req0.num_computed_tokens = 55 @@ -299,28 +295,18 @@ def test_decode(): req0.append_output_token_ids(8) new_blocks = manager.allocate_slots(req0, 4) assert new_blocks is not None and len(new_blocks) == 0 - assert manager.req_to_blocks[req0.request_id][-2].block_hash is None + assert manager.req_to_blocks[req0.request_id][-1].block_hash is None - # Append slots without allocating a new block, but start using the - # preallocated block. + # Append slots with allocating a new block. req0.num_computed_tokens = 59 - # 6 tokens to fill the previous block, and 10 tokens to fill + # 9 tokens to fill the previous block, and 10 tokens to fill # the preallocated block. - for _ in range(5 + 10): + for _ in range(9 + 10): req0.append_output_token_ids(7) - new_blocks = manager.allocate_slots(req0, 15) - assert new_blocks is not None and len(new_blocks) == 0 + new_blocks = manager.allocate_slots(req0, 19) + assert new_blocks is not None and len(new_blocks) == 1 assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None - - # Append slots with allocating a new block. - req0.num_computed_tokens = 74 - # 6 tokens to fill the previous block, and 10 tokens to fill - # the preallocated block. - for _ in range(6 + 11): - req0.append_output_token_ids(12) - new_blocks = manager.allocate_slots(req0, 17) - # Plus one preallocated block. - assert new_blocks is not None and len(new_blocks) == 2 + assert manager.req_to_blocks[req0.request_id][-1].block_hash is None def test_evict(): @@ -328,7 +314,6 @@ def test_evict(): make_kv_cache_config(16, 11), max_model_len=8192, enable_caching=True, - num_preallocate_tokens=16, ) last_token_id = 5 * 16 + 7 @@ -337,7 +322,7 @@ def test_evict(): assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks) - assert len(blocks) == 7 # 5 full + 1 partial + 1 preallocated + assert len(blocks) == 6 # 5 full + 1 partial # 3 blocks. req1 = make_request("1", list(range(last_token_id, @@ -349,7 +334,8 @@ def test_evict(): assert len(blocks) == 3 # 3 full blocks last_token_id += 3 * 16 - assert manager.block_pool.free_block_queue.num_free_blocks == 0 + # 10 - (6 + 3) == 1 + assert manager.block_pool.free_block_queue.num_free_blocks == 1 manager.free(req0) manager.free(req1) @@ -357,7 +343,7 @@ def test_evict(): assert [ b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks() - ] == [7, 6, 5, 4, 3, 2, 1, 10, 9, 8] + ] == [10, 6, 5, 4, 3, 2, 1, 9, 8, 7] # Touch the first 2 blocks. req2 = make_request("2", list(range(2 * 16 + 3))) @@ -365,8 +351,8 @@ def test_evict(): assert [b.block_id for b in computed_blocks] == [1, 2] assert num_computed_tokens == 2 * 16 blocks = manager.allocate_slots(req2, 3, computed_blocks) - assert [b.block_id for b in blocks] == [7, 6] - assert manager.block_pool.free_block_queue.num_free_blocks == 6 + assert [b.block_id for b in blocks] == [10] + assert manager.block_pool.free_block_queue.num_free_blocks == 7 def test_hash_block_correct_reuse(): @@ -379,7 +365,6 @@ def test_hash_block_correct_reuse(): make_kv_cache_config(16, 2), max_model_len=8192, enable_caching=True, - num_preallocate_tokens=0, ) # Allocate 1 block and cache it. @@ -416,7 +401,6 @@ def test_computed_blocks_not_evicted(): make_kv_cache_config(block_size, 3), max_model_len=8192, enable_caching=True, - num_preallocate_tokens=0, ) # Allocate a block and cache it. @@ -465,7 +449,6 @@ def test_basic_prefix_caching_disabled(): make_kv_cache_config(block_size, 5), max_model_len=8192, enable_caching=False, - num_preallocate_tokens=0, ) req1 = make_request("1", list(range(10))) # 2 blocks and some more @@ -496,40 +479,6 @@ def test_basic_prefix_caching_disabled(): assert not blocks -@pytest.mark.parametrize("num_preallocate_tokens", list(range(0, 8))) -@pytest.mark.parametrize("block_size", [4]) -def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int): - """ - This tests that the preallocated blocks are correctly added. - """ - manager = KVCacheManager( - make_kv_cache_config(block_size, 11), - max_model_len=8192, - enable_caching=True, - num_preallocate_tokens=num_preallocate_tokens, - ) - num_preallocated_blocks = cdiv(num_preallocate_tokens, block_size) - - req = make_request("0", list(range(block_size * 30))) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) - assert not computed_blocks - assert num_computed_tokens == 0 - # Just ask for 1 block. - blocks = manager.allocate_slots(req, block_size, computed_blocks) - req.num_computed_tokens = block_size - assert len(blocks) == 1 + num_preallocated_blocks - - # Assume all computed, only when num_preallocate_tokens > 0, we need to - # consume the previously preallocated blocks. - if num_preallocated_blocks > 0: - manager.allocate_slots(req, block_size * (len(blocks) - 1)) - req.num_computed_tokens = block_size * len(blocks) - - # Append 1 block. - blocks = manager.allocate_slots(req, block_size) - assert len(blocks) == 1 + num_preallocated_blocks - - @pytest.mark.parametrize("hash_fn", [sha256, hash]) def test_cache_blocks(hash_fn): """ @@ -588,7 +537,6 @@ def test_mm_prefix_caching(): make_kv_cache_config(16, 11), max_model_len=8192, enable_caching=True, - num_preallocate_tokens=16, ) # Common prompt tokens (T is text tokens and P is image placeholder tokens) @@ -626,7 +574,7 @@ def test_mm_prefix_caching(): assert block_hashes[2].extra_keys == ("bbb", ) blocks = manager.allocate_slots(req0, 59, computed_blocks) - assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5] + assert [b.block_id for b in blocks] == [1, 2, 3, 4] req0.num_computed_tokens = 59 # Append slots without allocating a new block. @@ -667,7 +615,6 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, - num_preallocate_tokens=0, ) # Complete 3 blocks (48 tokens) # | Common-0 | Common-1 | Common-2 | ... | @@ -721,7 +668,6 @@ def test_reset_prefix_cache(): make_kv_cache_config(16, 11), max_model_len=8192, enable_caching=True, - num_preallocate_tokens=0, ) full_block_token_ids = [i for i in range(3) for _ in range(16)] @@ -751,3 +697,82 @@ def test_reset_prefix_cache(): assert manager.reset_prefix_cache() assert not manager.block_pool.cached_block_hash_to_block assert all([blk.block_hash is None for blk in manager.block_pool.blocks]) + + +def test_prefix_cache_stats_disabled(): + """Test that prefix_cache_stats is None when log_stats is False.""" + manager = KVCacheManager( + make_kv_cache_config(16, 11), + max_model_len=8192, + enable_caching=True, + log_stats=False, # Disable logging stats + ) + assert manager.prefix_cache_stats is None + + # Call all functions that check whether log_stats is disabled. + req = make_request("0", list(range(16))) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) + assert not computed_blocks + assert num_computed_tokens == 0 + manager.allocate_slots(req, 16, computed_blocks) + manager.reset_prefix_cache() + + # Ensure prefix_cache_stats remains None + assert manager.prefix_cache_stats is None + + +def test_eagle_enabled_removes_last_block(): + """Verify Eagle does NOT remove blocks when request + length is divisible by block size.""" + block_size = 16 + manager = KVCacheManager( + make_kv_cache_config(block_size, num_blocks=10), + max_model_len=8192, + enable_caching=True, + use_eagle=True, + ) + + # Request with 3 full blocks (48 tokens) + token_ids = [0] * (3 * block_size) + req = make_request("divisible_request", token_ids) + + # Prime the cache + computed_blocks, _ = manager.get_computed_blocks(req) + manager.allocate_slots(req, len(token_ids), computed_blocks) + manager.free(req) + + # New request with same tokens + Eagle enabled + req_eagle = make_request("eagle_divisible", token_ids) + computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) + + # Should retain 2 blocks: + # 1. Original 3 blocks → pop last hash → 2 matched blocks + # 2. last_block_hash is not None → Eagle pop is not SKIPPED + assert len(computed_blocks) == 1 + assert num_tokens == 1 * block_size # 32 tokens + + +def test_eagle_with_partial_blocks(): + """Test Eagle behavior with requests containing partial blocks.""" + block_size = 16 + manager = KVCacheManager( + make_kv_cache_config(block_size, num_blocks=10), + max_model_len=8192, + enable_caching=True, + use_eagle=True, + ) + # 2 full blocks + 5 tokens (non-divisible length) + token_ids = [0] * (2 * block_size + 5) + req = make_request("partial_block_test", token_ids) + + # Prime the cache + computed_blocks, _ = manager.get_computed_blocks(req) + manager.allocate_slots(req, len(token_ids), computed_blocks) + manager.free(req) + + # New request with Eagle enabled + req_eagle = make_request("partial_eagle", token_ids) + computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) + # Original match: 2 full blocks → Eagle removes 1 → 1 remaining + assert len(computed_blocks) == 1 + assert num_tokens == 1 * block_size diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index bc17ca32e5b6412a740cba5708d1525dd04e0f5b..ee4e95856f2338038758cb6576257e0021e5ce70 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1,10 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 from typing import Optional +from unittest.mock import Mock import pytest import torch -from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig +from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, + SchedulerConfig, SpeculativeConfig, VllmConfig) from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams from vllm.v1.core.sched.output import SchedulerOutput @@ -25,6 +27,11 @@ def create_scheduler( enable_prefix_caching: Optional[bool] = None, long_prefill_token_threshold: int = 0, disable_chunked_mm_input: bool = False, + use_kv_connector: bool = False, + num_blocks: int = 10000, + block_size: int = 16, + max_model_len: Optional[int] = None, + num_speculative_tokens: Optional[int] = None, ) -> Scheduler: '''Create scheduler under test. @@ -39,12 +46,15 @@ def create_scheduler( Returns: :class:`Scheduler` instance ''' + if max_model_len is None: + max_model_len = max_num_batched_tokens scheduler_config = SchedulerConfig( max_num_seqs=max_num_seqs, max_num_batched_tokens=max_num_batched_tokens, - max_model_len=max_num_batched_tokens, + max_model_len=max_model_len, long_prefill_token_threshold=long_prefill_token_threshold, disable_chunked_mm_input=disable_chunked_mm_input, + enable_chunked_prefill=True, ) model_config = ModelConfig( model=model, @@ -60,31 +70,42 @@ def create_scheduler( 'enable_prefix_caching': enable_prefix_caching }) cache_config = CacheConfig( - block_size=16, + block_size=block_size, gpu_memory_utilization=0.9, swap_space=0, cache_dtype="auto", **kwargs_cache, ) + kv_transfer_config = KVTransferConfig( + kv_connector="SharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={"shared_storage_path": "local_storage"}, + ) if use_kv_connector else None + + speculative_config: Optional[SpeculativeConfig] = None + if num_speculative_tokens is not None: + speculative_config = SpeculativeConfig( + model="ngram", num_speculative_tokens=num_speculative_tokens) + vllm_config = VllmConfig( scheduler_config=scheduler_config, model_config=model_config, cache_config=cache_config, + kv_transfer_config=kv_transfer_config, + speculative_config=speculative_config, ) kv_cache_config = KVCacheConfig( - num_blocks=10000, # A large number of blocks to hold all requests + num_blocks=num_blocks, # A large number of blocks to hold all requests tensors={}, kv_cache_groups=[ KVCacheGroupSpec(['layer'], - FullAttentionSpec(16, 1, 1, torch.float32, False)) + FullAttentionSpec(block_size, 1, 1, torch.float32, + False)) ], ) - cache_config.num_gpu_blocks = 10000 + cache_config.num_gpu_blocks = num_blocks return Scheduler( - scheduler_config, - model_config, - cache_config, - lora_config=None, + vllm_config=vllm_config, kv_cache_config=kv_cache_config, log_stats=True, structured_output_manager=StructuredOutputManager(vllm_config), @@ -111,7 +132,6 @@ def create_requests(num_requests: int, mm_inputs = None request = Request( request_id=f"{i}", - prompt=None, prompt_token_ids=[i] * num_tokens, sampling_params=sampling_params, multi_modal_inputs=mm_inputs, @@ -286,6 +306,7 @@ def test_no_mm_input_chunking(): model="llava-hf/llava-1.5-7b-hf", max_num_batched_tokens=1024, disable_chunked_mm_input=True, + max_model_len=2048, ) mm_positions = [[PlaceholderRange(offset=400, length=800)]] requests = create_requests(num_requests=1, @@ -414,7 +435,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): def test_stop_via_update_from_output(): """Test stopping behavior through update_from_output""" - scheduler = create_scheduler() + scheduler = create_scheduler(num_speculative_tokens=1) # Test case 1: Stop on EOS token requests = create_requests(num_requests=2, max_tokens=10) @@ -422,7 +443,6 @@ def test_stop_via_update_from_output(): req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) - scheduler.scheduled_req_ids.add(req.request_id) scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduled_cached_reqs=[], @@ -466,7 +486,7 @@ def test_stop_via_update_from_output(): assert list(requests[1].output_token_ids) == [10, 11] # Test case 2: Stop on custom stop token - scheduler = create_scheduler() + scheduler = create_scheduler(num_speculative_tokens=2) requests = create_requests(num_requests=2, max_tokens=10, stop_token_ids=[42, 43]) @@ -474,7 +494,6 @@ def test_stop_via_update_from_output(): req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) - scheduler.scheduled_req_ids.add(req.request_id) scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduled_cached_reqs=[], @@ -518,13 +537,12 @@ def test_stop_via_update_from_output(): assert list(requests[1].output_token_ids) == [13, 14] # Test case 3: Stop on max tokens - scheduler = create_scheduler() + scheduler = create_scheduler(num_speculative_tokens=2) requests = create_requests(num_requests=2, max_tokens=2) for req in requests: req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) - scheduler.scheduled_req_ids.add(req.request_id) scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduled_cached_reqs=[], @@ -568,13 +586,12 @@ def test_stop_via_update_from_output(): assert list(requests[1].output_token_ids) == [13] # Test case 4: Ignore EOS flag - scheduler = create_scheduler() + scheduler = create_scheduler(num_speculative_tokens=2) requests = create_requests(num_requests=1, max_tokens=10) requests[0].sampling_params.ignore_eos = True requests[0].num_computed_tokens = requests[0].num_tokens scheduler.requests[requests[0].request_id] = requests[0] scheduler.running.append(requests[0]) - scheduler.scheduled_req_ids.add(requests[0].request_id) scheduler_output = SchedulerOutput( scheduled_new_reqs=[], @@ -671,13 +688,14 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], @pytest.mark.parametrize( "spec_tokens,output_tokens,expected", [ - ([[1, 2, 3]], [[1, 2, 3, 4]], (3, 3)), # perfect match - ([[1, 2, 3]], [[1, 5]], (3, 1)), # early mismatch - ([[1, 2], [3]], [[1, 2, 5], [3, 4]], (3, 3)), # multiple sequences - ([[1]], [[1, 2]], (1, 1)), # single token sequence - ([[]], [[5]], (0, 0)), # empty sequence + ([[1, 2, 3]], [[1, 2, 3, 4]], (1, 3, 3, [1, 1, 1])), # perfect match + ([[1, 2, 3]], [[1, 5]], (1, 3, 1, [1, 0, 0])), # early mismatch + ([[1, 2], [3]], [[1, 2, 5], [3, 4]], + (2, 3, 3, [2, 1])), # multiple sequences + ([[1]], [[1, 2]], (1, 1, 1, [1])), # single token sequence + ([[]], [[5]], (0, 0, 0, [0])), # empty sequence ([[1, 2, 3], [4, 5, 6]], [[1, 2, 7], [4, 8]], - (6, 3)), # multiple mismatches + (2, 6, 3, [2, 1, 0])), # multiple mismatches ]) def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): """Test scheduling behavior with speculative decoding. @@ -686,7 +704,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): 1. Speculated tokens get scheduled correctly 2. Spec decoding stats properly count number of draft and accepted tokens """ - scheduler = create_scheduler() + num_spec_tokens = max(1, max(len(t) for t in spec_tokens)) + scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens) requests = create_requests(num_requests=len(spec_tokens), num_tokens=1) req_ids = [] req_to_index = {} @@ -759,5 +778,390 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): else: assert scheduler_stats.spec_decoding_stats is not None stats = scheduler_stats.spec_decoding_stats - assert stats.num_draft_tokens == expected[0] - assert stats.num_accepted_tokens == expected[1] + assert stats.num_drafts == expected[0] + assert stats.num_draft_tokens == expected[1] + assert stats.num_accepted_tokens == expected[2] + assert stats.num_accepted_tokens_per_pos == expected[3] + + +def _assert_right_scheduler_output( + output: SchedulerOutput, + num_requests: int, + expected_num_scheduled_tokens: int, +): + """Check if SchedulerOutput is correct after remote KV cache hit.""" + + # We should inject the kv_connector_metadata. + assert len(output.kv_connector_metadata.requests) == num_requests + + # Only num_tokens - matched_num_new_tokens should be scheduled. + for _, num_scheduled_tokens in output.num_scheduled_tokens.items(): + assert num_scheduled_tokens == expected_num_scheduled_tokens + + +def _assert_right_kv_cache_manager( + scheduler: Scheduler, + req_ids: list[str], + num_tokens: int, + block_size: int, + num_requests: int, + num_total_blocks: int, +): + """Check whether KVCacheManager is correct after allocate.""" + + # Make sure the request stats are right. + EXPECTED_TOTAL_BLOCKS = num_tokens // block_size + for req_id in req_ids: + blocks = scheduler.kv_cache_manager.req_to_blocks[req_id] + hashes = scheduler.kv_cache_manager.req_to_block_hashes[req_id] + assert (scheduler.kv_cache_manager.num_cached_block[req_id] == + EXPECTED_TOTAL_BLOCKS) + assert len(blocks) == EXPECTED_TOTAL_BLOCKS + assert len(hashes) == EXPECTED_TOTAL_BLOCKS + + # Make sure we actually touched all the blocks. + BLOCKS_PER_REQ = num_tokens / block_size + assert (scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == + num_total_blocks - num_requests * BLOCKS_PER_REQ) + + +def _step_until_done( + scheduler: Scheduler, + output: SchedulerOutput, + model_runner_output: ModelRunnerOutput, +): + """Loop over schedule(), update_from_output() until finished.""" + + all_finished = False + _ = scheduler.update_from_output(output, model_runner_output) + while not all_finished: + # Schedule + a few iterations until stopping. + output = scheduler.schedule() + assert len(scheduler.running) + for _, num_scheduled_tokens in output.num_scheduled_tokens.items(): + # We should be in the decode phase now. + assert num_scheduled_tokens == 1 + assert len(output.kv_connector_metadata.requests) == 0 + ecos = scheduler.update_from_output(output, model_runner_output) + all_done = True + for eco in ecos.outputs: + if eco.finish_reason is None: + all_done = False + all_finished = all_done + + +def test_kv_connector_basic(): + """ + Test whether Scheduler with KVConnector schedules tokens, allocates + memory, and cleans up requests as expected under normal operation. + """ + + # Setup Scheduler. + scheduler = create_scheduler( + enable_prefix_caching=True, + use_kv_connector=True, + ) + NUM_TOTAL_BLOCKS = ( + scheduler.kv_cache_manager.block_pool.get_num_free_blocks()) + BLOCK_SIZE = scheduler.cache_config.block_size + + # Mock External Cache Hit. + NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 + scheduler.connector.get_num_new_matched_tokens = Mock(name="method") + scheduler.connector.get_num_new_matched_tokens.return_value = ( + NUM_MATCHED_NEW_TOKENS) + + ###################################################### + # FIRST SET OF REQUESTS - External Hit Only + NUM_REQUESTS = 2 + NUM_TOKENS = NUM_MATCHED_NEW_TOKENS * 2 + MAX_TOKENS = 3 + requests = create_requests(num_requests=NUM_REQUESTS, + num_tokens=NUM_TOKENS, + max_tokens=MAX_TOKENS) + req_ids = [] + req_to_index = {} + for i, request in enumerate(requests): + scheduler.add_request(request) + req_ids.append(request.request_id) + req_to_index[request.request_id] = i + + MODEL_RUNNER_OUTPUT = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_to_index, + sampled_token_ids=[[1000]] * len(req_ids), + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + ) + + # Ensure ScheduleOutput is correct. + output = scheduler.schedule() + _assert_right_scheduler_output( + output=output, + num_requests=NUM_REQUESTS, + # Just the incremental tokens should be scheduled. + expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS, + ) + + # Ensure KVCacheManager is correct. + _assert_right_kv_cache_manager(scheduler, req_ids, NUM_TOKENS, BLOCK_SIZE, + NUM_REQUESTS, NUM_TOTAL_BLOCKS) + + # Continue Generation until done. + _step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT) + _ = scheduler.schedule() + # Confirm we clean up the memory properly. + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ + == NUM_TOTAL_BLOCKS + + ###################################################### + # SECOND SET OF REQUESTS - Local And External Hit + NUM_TOKENS_PREFIX = NUM_TOKENS + # We will get a local prefix cache hit for the first + # NUM_TOKENS_PREFIX tokens since they are used above. + NUM_TOKENS = NUM_TOKENS_PREFIX * 2 + requests = create_requests(num_requests=NUM_REQUESTS, + num_tokens=NUM_TOKENS, + max_tokens=MAX_TOKENS) + req_ids = [] + req_to_index = {} + for i, request in enumerate(requests): + scheduler.add_request(request) + req_ids.append(request.request_id) + req_to_index[request.request_id] = i + + MODEL_RUNNER_OUTPUT = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_to_index, + sampled_token_ids=[[1000]] * len(req_ids), + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + ) + + # We should get a local cache hit of NUM_TOKENS_PREFIX and + # a remote KV cache hit of NUM_MATCHED_NEW_TOKENS. + output = scheduler.schedule() + _assert_right_scheduler_output( + output=output, + num_requests=NUM_REQUESTS, + # Just the incremental tokens after local + remote cache hit. + expected_num_scheduled_tokens=(NUM_TOKENS - NUM_TOKENS_PREFIX - + NUM_MATCHED_NEW_TOKENS)) + + # Ensure KVCacheManager is correct. + _assert_right_kv_cache_manager(scheduler, req_ids, NUM_TOKENS, BLOCK_SIZE, + NUM_REQUESTS, NUM_TOTAL_BLOCKS) + + # Continue Generation until done. + _step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT) + _ = scheduler.schedule() + # Confirm we clean up the memory properly. + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ + == NUM_TOTAL_BLOCKS + + +def test_kv_connector_unable_to_allocate(): + """ + Test whether scheduler with KVConnector is able to handle + unable to allocate (run out of blocks in allocate_slots(). + """ + + # Setup Scheduler With Mock External Cache Hit. + BLOCK_SIZE = 4 + NUM_BLOCKS = 10 + scheduler = create_scheduler( + enable_prefix_caching=True, + use_kv_connector=True, + block_size=BLOCK_SIZE, + num_blocks=NUM_BLOCKS, + ) + NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 + scheduler.connector.get_num_new_matched_tokens = Mock(name="method") + scheduler.connector.get_num_new_matched_tokens.return_value = ( + NUM_MATCHED_NEW_TOKENS) + + # Create two requests. The second request will not be able to + # allocate slots because it will not have enough blocks. + NUM_REQUESTS = 2 + NUM_TOKENS = (NUM_BLOCKS // 2 + 1) * BLOCK_SIZE + MAX_TOKENS = 2 + requests = create_requests(num_requests=NUM_REQUESTS, + num_tokens=NUM_TOKENS, + max_tokens=MAX_TOKENS) + req_ids = [] + req_to_index = {} + for i, request in enumerate(requests): + scheduler.add_request(request) + req_ids.append(request.request_id) + req_to_index[request.request_id] = i + + MODEL_RUNNER_OUTPUT = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_to_index, + sampled_token_ids=[[1000]] * len(req_ids), + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + ) + + # Just one request should be running. + output = scheduler.schedule() + _assert_right_scheduler_output(output, + num_requests=1, + expected_num_scheduled_tokens=NUM_TOKENS - + NUM_MATCHED_NEW_TOKENS) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 1 + + # All memory should be freed, with one request waiting. + _step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT) + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ + == NUM_BLOCKS - 1 + assert len(scheduler.running) == 0 + assert len(scheduler.waiting) == 1 + + # Just one request should be running. + output = scheduler.schedule() + _assert_right_scheduler_output(output, + num_requests=1, + expected_num_scheduled_tokens=NUM_TOKENS - + NUM_MATCHED_NEW_TOKENS) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 0 + + # All memory should be freed, with no requests waiting / running. + _step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT) + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ + == NUM_BLOCKS - 1 + assert len(scheduler.running) == 0 + assert len(scheduler.waiting) == 0 + + +def test_kv_connector_handles_preemption(): + """ + Test whether scheduler with KVConnector is able to handle + unable to allocate (run out of blocks in allocate_slots(). + """ + + # Setup Scheduler With Mock External Cache Hit. + BLOCK_SIZE = 2 + # NOTE: there is 1 null block, so this is 6 blocks. + NUM_BLOCKS = 7 + scheduler = create_scheduler( + enable_prefix_caching=True, + use_kv_connector=True, + block_size=BLOCK_SIZE, + num_blocks=NUM_BLOCKS, + ) + + NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE + scheduler.connector.get_num_new_matched_tokens = Mock(name="method") + scheduler.connector.get_num_new_matched_tokens.return_value = ( + NUM_MATCHED_NEW_TOKENS) + + # Create two requests. + # Both can be scheduled at first, but the second request + # will be preempted and re-scheduled. + NUM_REQUESTS = 2 + NUM_TOKENS = BLOCK_SIZE * 2 + 1 + MAX_TOKENS = BLOCK_SIZE * 2 + requests = create_requests(num_requests=NUM_REQUESTS, + num_tokens=NUM_TOKENS, + max_tokens=MAX_TOKENS) + req_ids = [] + req_to_index = {} + for i, request in enumerate(requests): + scheduler.add_request(request) + req_ids.append(request.request_id) + req_to_index[request.request_id] = i + + MODEL_RUNNER_OUTPUT = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_to_index, + sampled_token_ids=[[1000]] * len(req_ids), + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + ) + + # All can be scheduled - 1st token. + output = scheduler.schedule() + _assert_right_scheduler_output( + output, + # 2 remote kv cache hits. + num_requests=2, + expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS) + assert len(scheduler.running) == 2 + _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) + + # All can be scheduled - 2nd token. + output = scheduler.schedule() + _assert_right_scheduler_output( + output, + # no connector_metadata + num_requests=0, + expected_num_scheduled_tokens=1) + assert len(scheduler.running) == 2 + _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) + + # This will generate a new block and cause a preemption - 3rd token. + output = scheduler.schedule() + _assert_right_scheduler_output( + output, + # no connector_metadata + num_requests=0, + expected_num_scheduled_tokens=1) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 1 + _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 1 + + # Only 1 can be scheduled - 4th (and last token). + output = scheduler.schedule() + _assert_right_scheduler_output( + output, + # no connector_metadata + num_requests=0, + expected_num_scheduled_tokens=1) + assert len(scheduler.waiting) == 1 + assert len(scheduler.running) == 1 + _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) + assert len(scheduler.running) == 0 + assert len(scheduler.waiting) == 1 + # All memory should be freed since nothing is running. + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ + == NUM_BLOCKS - 1 + + # Restarts the preempted request - generate 3rd token. + # This will have a local and remote cache hit. + output = scheduler.schedule() + _assert_right_scheduler_output( + output, + # 1 remote kv_cache hit! + num_requests=1, + # Only 1 block was preempted and there is a single + # remote hit. So only single new token is scheduled. + expected_num_scheduled_tokens=1, + ) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 0 + _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 0 + + # Only 1 can be scheduled - 4th (and last token). + output = scheduler.schedule() + _assert_right_scheduler_output( + output, + # no connector_metadata + num_requests=0, + expected_num_scheduled_tokens=1) + assert len(scheduler.running) == 1 + _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) + assert len(scheduler.running) == 0 + # All memory should be freed since nothing is running. + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ + == NUM_BLOCKS - 1 diff --git a/tests/v1/e2e/test_cascade_attention.py b/tests/v1/e2e/test_cascade_attention.py index a8079dcce5e2f3efe7f66a5e46c4349031351487..48c265560348c80e540d71c1e05f3d9906b398a9 100644 --- a/tests/v1/e2e/test_cascade_attention.py +++ b/tests/v1/e2e/test_cascade_attention.py @@ -1,13 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 +import pytest + from vllm import LLM, SamplingParams +from ...utils import fork_new_process_for_each_test + -def test_cascade_attention(example_system_message, monkeypatch): +@fork_new_process_for_each_test +@pytest.mark.parametrize("attn_backend", + ["FLASH_ATTN_VLLM_V1", "FLASHINFER_VLLM_V1"]) +def test_cascade_attention(example_system_message, monkeypatch, attn_backend): prompt = "\n: Implement fibonacci sequence in Python.\n:" with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") + m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) llm = LLM(model="Qwen/Qwen2-1.5B-Instruct") sampling_params = SamplingParams(temperature=0.0, max_tokens=100) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 673714980592a1eec4732a14a7f60a807735e228..2fad37d6801bb5d731a1afd9ae13a949e9aa88bc 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -44,18 +44,20 @@ def test_prompts(): @pytest.fixture def sampling_config(): - # Only support greedy for now return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False) @pytest.fixture def model_name(): - return "meta-llama/Meta-Llama-3-8B-Instruct" + return "meta-llama/Llama-3.1-8B-Instruct" -@pytest.fixture def eagle_model_name(): - return "yuhuili/EAGLE-LLaMA3-Instruct-8B" + return "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" + + +def eagle3_model_name(): + return "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" def test_ngram_correctness( @@ -102,12 +104,13 @@ def test_ngram_correctness( del spec_llm +@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"]) def test_eagle_correctness( monkeypatch: pytest.MonkeyPatch, test_prompts: list[list[dict[str, Any]]], sampling_config: SamplingParams, model_name: str, - eagle_model_name: str, + use_eagle3: bool, ): ''' Compare the outputs of a original LLM and a speculative LLM @@ -116,18 +119,22 @@ def test_eagle_correctness( with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") - ref_llm = LLM(model=model_name, max_model_len=1024) + ref_llm = LLM(model=model_name, max_model_len=2048) ref_outputs = ref_llm.chat(test_prompts, sampling_config) del ref_llm + spec_model_name = eagle3_model_name( + ) if use_eagle3 else eagle_model_name() spec_llm = LLM( model=model_name, + trust_remote_code=True, speculative_config={ - "method": "eagle", - "model": eagle_model_name, + "method": "eagle3" if use_eagle3 else "eagle", + "model": spec_model_name, "num_speculative_tokens": 3, + "max_model_len": 2048, }, - max_model_len=1024, + max_model_len=2048, ) spec_outputs = spec_llm.chat(test_prompts, sampling_config) matches = 0 @@ -140,7 +147,7 @@ def test_eagle_correctness( print(f"ref_output: {ref_output.outputs[0].text}") print(f"spec_output: {spec_output.outputs[0].text}") - # Heuristic: expect at least 70% of the prompts to match exactly + # Heuristic: expect at least 66% of the prompts to match exactly # Upon failure, inspect the outputs to check for inaccuracy. - assert matches > int(0.7 * len(ref_outputs)) + assert matches > int(0.66 * len(ref_outputs)) del spec_llm diff --git a/tests/v1/engine/conftest.py b/tests/v1/engine/conftest.py index 8872f0388dd249dd17e14cb34a8664a4c974bd56..f8addd920d577667be09aa7b2526efc60697e8ce 100644 --- a/tests/v1/engine/conftest.py +++ b/tests/v1/engine/conftest.py @@ -47,7 +47,7 @@ def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors: tokenizer=tokenizer, tokenizer_group=init_tokenizer_from_configs( vllm_config.model_config, vllm_config.scheduler_config, - vllm_config.parallel_config, vllm_config.lora_config), + vllm_config.lora_config), vllm_config=vllm_config, full_tokens=[tokenizer(text).input_ids for text in FULL_STRINGS], prompt_tokens=prompt_tokens, diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index 9e0acf11c23d3c96598acd9714a8bc48716d5a16..b44a1c9ab4473f1ab4c8664be17392eaf676b750 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -3,17 +3,20 @@ import asyncio from contextlib import ExitStack from typing import Optional +from unittest.mock import MagicMock import os import pytest from vllm import SamplingParams from vllm.assets.image import ImageAsset +from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.inputs import PromptType from vllm.platforms import current_platform from vllm.sampling_params import RequestOutputKind from vllm.v1.engine.async_llm import AsyncLLM +from vllm.v1.metrics.loggers import LoggingStatLogger from ...utils import models_path_prefix if not current_platform.is_cuda(): @@ -218,3 +221,33 @@ async def test_finished_flag(monkeypatch: pytest.MonkeyPatch, n: int, # Assert only the last output has the finished flag set assert all(not out.finished for out in outputs[:-1]) assert outputs[-1].finished + + +class MockLoggingStatLogger(LoggingStatLogger): + + def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): + super().__init__(vllm_config, engine_index) + self.log = MagicMock() + + +@pytest.mark.asyncio +async def test_customize_loggers(monkeypatch): + """Test that we can customize the loggers. + If a customized logger is provided at the init, it should + be used directly. + """ + + with monkeypatch.context() as m, ExitStack() as after: + m.setenv("VLLM_USE_V1", "1") + + engine = AsyncLLM.from_engine_args( + TEXT_ENGINE_ARGS, + stat_loggers=[MockLoggingStatLogger], + ) + after.callback(engine.shutdown) + + await engine.do_log_stats() + + assert len(engine.stat_loggers) == 1 + assert len(engine.stat_loggers[0]) == 1 + engine.stat_loggers[0][0].log.assert_called_once() diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 99071e4abe5cd6e77b2c2beacf2f1235440054bf..c20acc506d7bd9f88ce708bf7d463263d74f4c21 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -1,10 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 import copy -import threading import time import uuid -from concurrent.futures import Future +from concurrent.futures import Future, ThreadPoolExecutor import os import pytest @@ -34,8 +33,7 @@ PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids def make_request() -> EngineCoreRequest: return EngineCoreRequest( - request_id=uuid.uuid4(), - prompt=PROMPT, + request_id=str(uuid.uuid4()), prompt_token_ids=PROMPT_TOKENS, mm_inputs=None, mm_hashes=None, @@ -246,33 +244,33 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): self, kv_cache_configs: list[KVCacheConfig]) -> None: super().initialize_from_config(kv_cache_configs) - # This executor actually can only run 1 batch at a time - self.semaphore = threading.Semaphore(1) + # Create a thread pool with a single worker + self.thread_pool = ThreadPoolExecutor(max_workers=1) def execute_model( self, scheduler_output, ) -> Future[ModelRunnerOutput]: """Make execute_model non-blocking.""" - future: Future[ModelRunnerOutput] = Future() - def _thread_wrapper(scheduler_output, future): - with self.semaphore: - output = self.collective_rpc("execute_model", - args=(scheduler_output, )) - # Make a copy because output[0] may be reused - # by the next batch. - output = copy.deepcopy(output[0]) - future.set_result(output) + def _execute(): + output = self.collective_rpc("execute_model", + args=(scheduler_output, )) + # Make a copy because output[0] may be reused + # by the next batch. + return copy.deepcopy(output[0]) - threading.Thread(target=_thread_wrapper, - args=(scheduler_output, future)).start() - return future + # Use the thread pool instead of creating a new thread + return self.thread_pool.submit(_execute) @property def max_concurrent_batches(self) -> int: return 2 + def shutdown(self): + if hasattr(self, 'thread_pool'): + self.thread_pool.shutdown(wait=False) + with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") @@ -301,14 +299,77 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): # Schedule Batch 1: (10, req0) assert engine_core.step_with_batch_queue() is None assert engine_core.batch_queue.qsize() == 1 + scheduler_output = engine_core.batch_queue.queue[-1][1] + assert scheduler_output.num_scheduled_tokens[0] == 10 + # num_computed_tokens should have been updated immediately. + assert engine_core.scheduler.requests[ + req0.request_id].num_computed_tokens == 10 + + # Schedule Batch 2: (2, req0), (8, req1) assert engine_core.step_with_batch_queue() is None assert engine_core.batch_queue.qsize() == 2 + scheduler_output = engine_core.batch_queue.queue[-1][1] + assert scheduler_output.num_scheduled_tokens[0] == 2 + assert scheduler_output.num_scheduled_tokens[1] == 8 + # num_computed_tokens should have been updated immediately. + assert engine_core.scheduler.requests[0].num_computed_tokens == 12 + assert engine_core.scheduler.requests[1].num_computed_tokens == 8 + assert engine_core.scheduler.get_num_unfinished_requests() == 2 - # Loop through both requests. - while engine_core.scheduler.get_num_unfinished_requests() == 2: - engine_core.step_with_batch_queue() + # Batch queue is full. Finish Batch 1. + engine_core.step_with_batch_queue() + + # Schedule Batch 3: (4, req1). Note that req0 cannot be scheduled + # because it is in the decoding stage now. + engine_core.step_with_batch_queue() + assert engine_core.batch_queue.qsize() == 2 + scheduler_output = engine_core.batch_queue.queue[-1][1] + assert scheduler_output.num_scheduled_tokens[1] == 4 - # Reaching here when got the result of the first request. - while engine_core.scheduler.get_num_unfinished_requests() == 1: - engine_core.step_with_batch_queue() + # Batch queue is full. Finish Batch 2. Get first token of req0. + output = engine_core.step_with_batch_queue() + assert output is not None + assert len(output.outputs) == 1 + assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13 + + # Schedule Batch 4: (1, req0). + engine_core.step_with_batch_queue() + assert engine_core.batch_queue.qsize() == 2 + scheduler_output = engine_core.batch_queue.queue[-1][1] + assert scheduler_output.num_scheduled_tokens[0] == 1 + + # Batch queue is full. Finish Batch 3. Get first token of req1. + output = engine_core.step_with_batch_queue() + assert output is not None + assert len(output.outputs) == 1 + assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13 + + # Schedule Batch 5: (1, req1). + engine_core.step_with_batch_queue() + assert engine_core.batch_queue.qsize() == 2 + scheduler_output = engine_core.batch_queue.queue[-1][1] + assert scheduler_output.num_scheduled_tokens[1] == 1 + + # Loop until req0 is finished. + step = 0 + req_id = 0 + expected_num_tokens = [ + engine_core.scheduler.requests[0].num_tokens + 1, + engine_core.scheduler.requests[1].num_tokens + 1, + ] + while engine_core.scheduler.get_num_unfinished_requests() == 2: + output = engine_core.step_with_batch_queue() + if step % 2 == 0: + # Even steps consumes an output. + assert output is not None + assert len(output.outputs) == 1 + if req_id in engine_core.scheduler.requests: + assert engine_core.scheduler.requests[ + req_id].num_tokens == expected_num_tokens[req_id] + expected_num_tokens[req_id] += 1 + req_id = (req_id + 1) % 2 + else: + # Odd steps schedules a new batch. + assert output is None + step += 1 diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 859867b2ed6e76e7385a69ca5f5884dece01d837..2b2d59bec8b15b56122eb01556cf020e2ce0430e 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -37,7 +37,6 @@ PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids def make_request(params: SamplingParams) -> EngineCoreRequest: return EngineCoreRequest( request_id=str(uuid.uuid4()), - prompt=PROMPT, prompt_token_ids=PROMPT_TOKENS, mm_inputs=None, mm_hashes=None, diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index 54ada51e43f67956617a3196d8564245cd9faf03..98ece27e897ff2ac7c5c0e8ae414207eeb7047e4 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -51,7 +51,6 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind, # Make N requests. requests = [ EngineCoreRequest(request_id=f"request-{idx}", - prompt=prompt, prompt_token_ids=prompt_tokens, arrival_time=0, mm_inputs=None, @@ -65,14 +64,13 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind, output_kind=request_output_kind, stop=[], include_stop_str_in_output=False, - )) for idx, (prompt, prompt_tokens) in enumerate( - zip(dummy_test_vectors.prompt_strings, - dummy_test_vectors.prompt_tokens)) + )) + for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) ] # Add requests to the detokenizer. - for request in requests: - output_processor.add_request(request) + for request, prompt in zip(requests, dummy_test_vectors.prompt_strings): + output_processor.add_request(request, prompt) gen_strings = {} gen_tokens = {} @@ -399,7 +397,6 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind, ] requests = [ EngineCoreRequest(request_id=request_id_list[idx], - prompt=prompt, prompt_token_ids=prompt_tokens, arrival_time=0, mm_inputs=None, @@ -415,14 +412,13 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind, include_stop_str_in_output=False, logprobs=num_sample_logprobs, prompt_logprobs=num_prompt_logprobs, - )) for idx, (prompt, prompt_tokens) in enumerate( - zip(dummy_test_vectors.prompt_strings, - dummy_test_vectors.prompt_tokens)) + )) + for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) ] # Add requests to the detokenizer. - for request in requests: - output_processor.add_request(request) + for request, prompt in zip(requests, dummy_test_vectors.prompt_strings): + output_processor.add_request(request, prompt) gen_tokens = {} gen_logprobs = {} @@ -563,7 +559,6 @@ def test_stop_token(include_stop_str_in_output: bool, request_id = "request-0" request = EngineCoreRequest( request_id=request_id, - prompt=prompt_string, prompt_token_ids=prompt_tokens, arrival_time=0, mm_inputs=None, @@ -584,7 +579,7 @@ def test_stop_token(include_stop_str_in_output: bool, )) # Add request to the detokenizer. - output_processor.add_request(request) + output_processor.add_request(request, prompt_string) # Loop over engine core steps; run output processor gen_string = "" @@ -660,7 +655,6 @@ def test_stop_string(include_stop_str_in_output: bool, requests = [ EngineCoreRequest( request_id=request_id_list[idx], - prompt=prompt, prompt_token_ids=prompt_tokens, arrival_time=0, mm_inputs=None, @@ -676,14 +670,13 @@ def test_stop_string(include_stop_str_in_output: bool, include_stop_str_in_output=include_stop_str_in_output, logprobs=num_sample_logprobs, prompt_logprobs=None, - )) for idx, (prompt, prompt_tokens) in enumerate( - zip(dummy_test_vectors.prompt_strings, - dummy_test_vectors.prompt_tokens)) + )) + for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) ] # Add requests to the detokenizer. - for request in requests: - output_processor.add_request(request) + for request, prompt in zip(requests, dummy_test_vectors.prompt_strings): + output_processor.add_request(request, prompt) gen_strings = {} gen_tokens = {} @@ -775,7 +768,6 @@ def test_iteration_stats(dummy_test_vectors): requests = [ EngineCoreRequest( request_id=f"request-{idx}", - prompt=prompt, prompt_token_ids=prompt_tokens, arrival_time=0, mm_inputs=None, @@ -784,15 +776,13 @@ def test_iteration_stats(dummy_test_vectors): eos_token_id=None, lora_request=None, sampling_params=SamplingParams(), - ) for idx, (prompt, prompt_tokens) in enumerate( - zip(dummy_test_vectors.prompt_strings, - dummy_test_vectors.prompt_tokens)) + ) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) ] # Add all requests except one to the OutputProcessor. num_active = len(dummy_test_vectors.generation_tokens) - 1 for request in requests[:num_active]: - output_processor.add_request(request) + output_processor.add_request(request, None) inactive_request = requests[num_active] # First iteration has 2 prefills. @@ -818,7 +808,7 @@ def test_iteration_stats(dummy_test_vectors): assert iteration_stats.num_generation_tokens == num_active # Add a new request - prefill and 2 decodes in this step. - output_processor.add_request(inactive_request) + output_processor.add_request(inactive_request, None) num_active += 1 outputs = engine_core.get_outputs()[:num_active] iteration_stats = IterationStats() @@ -922,3 +912,84 @@ async def test_request_output_collector(): # Cumulative logprobs should be the last one. cumulative_logprob_expected = 1.0 * num_to_put assert output.outputs[0].cumulative_logprob == cumulative_logprob_expected + + +@pytest.mark.asyncio +async def test_cumulative_output_collector_n(): + """Test collector correctly handles multiple outputs by index.""" + collector = RequestOutputCollector(RequestOutputKind.CUMULATIVE) + outputs = [ + RequestOutput( + request_id="my-request-id", + prompt=None, + prompt_token_ids=[1, 2, 3], + prompt_logprobs=None, + outputs=[ + CompletionOutput( + index=0, + text="a", + token_ids=[0], + cumulative_logprob=None, + logprobs=None, + finish_reason=None, + ), + CompletionOutput( + index=1, + text="b", + token_ids=[1], + cumulative_logprob=None, + logprobs=None, + finish_reason=None, + ), + ], + finished=False, + ), + RequestOutput( + request_id="my-request-id", + prompt=None, + prompt_token_ids=[1, 2, 3], + prompt_logprobs=None, + outputs=[ + CompletionOutput( + index=0, + text="ab", + token_ids=[0, 1], + cumulative_logprob=None, + logprobs=None, + finish_reason=None, + ), + CompletionOutput( + index=2, + text="c", + token_ids=[2], + cumulative_logprob=None, + logprobs=None, + finish_reason=None, + ), + ], + finished=False, + ), + ] + for output in outputs: + collector.put(output) + + # Get the output and check that the text and token_ids are correct. + result = await collector.get() + # We are expecting + # [{index: 0, text: "ab"}, {index: 1, text: "b"}, {index: 2, text: "c"}] + assert len(result.outputs) == 3 + # First is the one where index is 0 + first = [k for k in result.outputs if k.index == 0] + assert len(first) == 1 + assert first[0].text == "ab" + + # Second is the one where index is 1 + second = [k for k in result.outputs if k.index == 1] + assert len(second) == 1 + assert second[0].text == "b" + assert second[0].token_ids == [1] + + # Third is the one where index is 2 + third = [k for k in result.outputs if k.index == 2] + assert len(third) == 1 + assert third[0].text == "c" diff --git a/tests/v1/engine/utils.py b/tests/v1/engine/utils.py index 1ee93c72cd2636361073a9f9dc53e18fe033e3f3..4a23e0c1b212e91e8e8604a8c5c625b0d028af34 100644 --- a/tests/v1/engine/utils.py +++ b/tests/v1/engine/utils.py @@ -8,8 +8,7 @@ import torch from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from vllm.engine.arg_utils import EngineArgs -from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( - BaseTokenizerGroup) +from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.v1.engine import EngineCoreOutput, FinishReason from vllm.v1.outputs import LogprobsLists, LogprobsTensors @@ -296,7 +295,7 @@ def generate_dummy_prompt_logprobs_tensors( class DummyOutputProcessorTestVectors: """Dummy test vectors for output processor tests""" tokenizer: GeneralTokenizerType - tokenizer_group: BaseTokenizerGroup + tokenizer_group: TokenizerGroup vllm_config: EngineArgs full_tokens: list[list[int]] # Prompt + generated tokens prompt_tokens: list[list[int]] diff --git a/tests/v1/entrypoints/conftest.py b/tests/v1/entrypoints/conftest.py index 6d4278b4c87191be861ef15b2bccbf03b86d6914..d84b2b22db121f21113dd583748f41a5e2ec8c3e 100644 --- a/tests/v1/entrypoints/conftest.py +++ b/tests/v1/entrypoints/conftest.py @@ -47,6 +47,14 @@ def sample_json_schema(): "type": "string", } }, + "grade": { + "type": "string", + "pattern": "^[A-D]$" # Regex pattern + }, + "email": { + "type": "string", + "pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$" + }, "work_history": { "type": "array", "items": { @@ -56,17 +64,20 @@ def sample_json_schema(): "type": "string" }, "duration": { - "type": "number" + "type": "number", + "minimum": 0.0, + "maximum": 100.0, # Numeric range }, "position": { "type": "string" } }, - "required": ["company", "position"] + "required": ["company", "duration", "position"] } } }, - "required": ["name", "age", "skills", "work_history"] + "required": + ["name", "age", "skills", "grade", "email", "work_history"] } @@ -78,27 +89,18 @@ def unsupported_json_schema(): "properties": { "score": { "type": "integer", - "minimum": 0, - "maximum": 100 # Numeric range - }, - "grade": { - "type": "string", - "pattern": "^[A-D]$" # Regex pattern - }, - "email": { - "type": "string", - "pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$" + "multipleOf": 5 # Numeric multiple }, "tags": { "type": "array", "items": { "type": "string", - "pattern": - "^[a-z]{1,10}$" # Combining length and pattern restrictions + "minLength": 10, + "maxLength": 20 } } }, - "required": ["score", "grade", "email", "tags"] + "required": ["score", "tags"] } diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index b179dc3b4747c0bf2948f27c932351378a5a2572..19960c13c856db945260d057f4b47c07e65d120b 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -13,6 +13,7 @@ from pydantic import BaseModel from vllm.entrypoints.llm import LLM from vllm.outputs import RequestOutput +from vllm.platforms import current_platform from vllm.sampling_params import GuidedDecodingParams, SamplingParams PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [ @@ -63,10 +64,13 @@ def test_structured_output( ): monkeypatch.setenv("VLLM_USE_V1", "1") + # Don't use eager execution on TPUs because we want to test for no + # recompilation at runtime + enforce_eager = bool(not current_platform.is_tpu()) # Use a single LLM instance for several scenarios to # speed up the test suite. llm = LLM(model=model_name, - enforce_eager=True, + enforce_eager=enforce_eager, max_model_len=1024, guided_decoding_backend=guided_decoding_backend, tokenizer_mode=tokenizer_mode) @@ -346,6 +350,7 @@ def test_structured_output( temperature=1.0, max_tokens=1000, guided_decoding=GuidedDecodingParams(json=json_schema)) + outputs = llm.generate( prompts="Generate a description of a frog using 50 characters.", sampling_params=sampling_params, @@ -364,6 +369,106 @@ def test_structured_output( output_json = json.loads(generated_text) jsonschema.validate(instance=output_json, schema=json_schema) + # + # Test 11: Generate structured output using structural_tag format + # + structural_tag_config = { + "type": + "structural_tag", + "structures": [{ + "begin": "", + "schema": { + "type": "object", + "properties": { + "city": { + "type": "string" + } + } + }, + "end": "" + }], + "triggers": ["{parameters}{end_tag} +where + +start_tag => ` a JSON dict with the function argument name + as key and function argument value as value. +end_tag => `` + +Here is an example, +{"example_name": "example_value"} + +Reminder: +- Function calls MUST follow the specified format +- Required parameters MUST be specified +- Only call one function at a time +- Put the entire function call reply on one line +- Always add your sources when using search results to answer the user query + +You are a helpful assistant. + +Given the previous instructions, what is the weather in New York City? +""" + + # Change this once other backends support structural_tag + if guided_decoding_backend.startswith("xgrammar"): + outputs = llm.generate(prompts=prompt, + sampling_params=sampling_params, + use_tqdm=True) + assert outputs is not None + else: + outputs = [] + + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + generated_text = output.outputs[0].text + assert generated_text is not None + + # Search for function call pattern in the response + function_call_pattern = r'(.*?)' + matches = re.findall(function_call_pattern, generated_text) + + if not matches: + print(f"Warning: No function calls found in response: " + f"{generated_text!r}") + continue + + # Take the first function call if multiple are found + json_str = matches[0] + try: + json_content = json.loads(json_str) + assert "city" in json_content + assert isinstance(json_content["city"], str) + print(f"Found valid function call: {generated_text!r}") + except (json.JSONDecodeError, AssertionError) as e: + pytest.fail("Invalid function call format: " + f"{generated_text!r}\nError: {str(e)}") + @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("model_name, tokenizer_mode", @@ -386,13 +491,21 @@ def test_structured_output_auto_mode( max_tokens=1000, guided_decoding=GuidedDecodingParams(json=unsupported_json_schema)) + prompts = ("Give an example JSON object for a grade " + "that fits this schema: " + f"{unsupported_json_schema}") # This would fail with the default of "xgrammar", but in "auto" # we will handle fallback automatically. - outputs = llm.generate(prompts=("Give an example JSON object for a grade " - "that fits this schema: " - f"{unsupported_json_schema}"), + outputs = llm.generate(prompts=prompts, sampling_params=sampling_params, use_tqdm=True) + # Make sure `auto` backend handling doesn't mess up sampling_params + # and that we can reuse it without error. + outputs.extend( + llm.generate(prompts=prompts, + sampling_params=sampling_params, + use_tqdm=True)) + assert outputs is not None for output in outputs: assert output is not None @@ -404,3 +517,59 @@ def test_structured_output_auto_mode( # Parse to verify it is valid JSON parsed_json = json.loads(generated_text) assert isinstance(parsed_json, dict) + + +@pytest.mark.skip_global_cleanup +def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_V1", "1") + + backend = 'guidance:no-additional-properties,disable-any-whitespace' + llm = LLM(model="Qwen/Qwen2.5-1.5B-Instruct", + max_model_len=1024, + guided_decoding_backend=backend) + + schema = { + 'type': 'object', + 'properties': { + 'a1': { + 'type': 'string' + }, + 'a2': { + 'type': 'string' + }, + 'a3': { + 'type': 'string' + } + }, + 'required': ['a1', 'a2', 'a3'], + } + + prompt = ( + "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a " + "helpful assistant.<|im_end|>\n<|im_start|>user\nPlease generate a " + "large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20" + "<|im_end|>\n<|im_start|>assistant\n") + + def generate_with_backend(backend): + guided_params = GuidedDecodingParams(json=schema, backend=backend) + sampling_params = SamplingParams(temperature=0, + max_tokens=256, + guided_decoding=guided_params) + + outputs = llm.generate(prompts=prompt, sampling_params=sampling_params) + assert outputs is not None + generated_text = outputs[0].outputs[0].text + assert generated_text is not None + parsed_json = json.loads(generated_text) + assert isinstance(parsed_json, dict) + jsonschema.validate(instance=parsed_json, schema=schema) + return parsed_json + + generated = generate_with_backend( + 'guidance:no-additional-properties,disable-any-whitespace') + assert "a1" in generated + assert "a2" in generated + assert "a3" in generated + assert "a4" not in generated + assert "a5" not in generated + assert "a6" not in generated diff --git a/tests/v1/shutdown/test_delete.py b/tests/v1/shutdown/test_delete.py new file mode 100644 index 0000000000000000000000000000000000000000..ed368fe828d07181f115f19db101d33f4b95356f --- /dev/null +++ b/tests/v1/shutdown/test_delete.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Test that we handle a startup Error and shutdown.""" + +import pytest + +from tests.utils import wait_for_gpu_memory_to_clear +from tests.v1.shutdown.utils import (SHUTDOWN_TEST_THRESHOLD_BYTES, + SHUTDOWN_TEST_TIMEOUT_SEC) +from vllm import LLM, SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.sampling_params import RequestOutputKind +from vllm.utils import cuda_device_count_stateless +from vllm.v1.engine.async_llm import AsyncLLM + +MODELS = ["meta-llama/Llama-3.2-1B"] + + +@pytest.mark.asyncio +@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("tensor_parallel_size", [2, 1]) +@pytest.mark.parametrize("send_one_request", [False, True]) +async def test_async_llm_delete(model: str, tensor_parallel_size: int, + send_one_request: bool) -> None: + """Test that AsyncLLM frees GPU memory upon deletion. + AsyncLLM always uses an MP client. + + Args: + model: model under test + tensor_parallel_size: degree of tensor parallelism + send_one_request: send one request to engine before deleting + """ + if cuda_device_count_stateless() < tensor_parallel_size: + pytest.skip(reason="Not enough CUDA devices") + + engine_args = AsyncEngineArgs(model=model, + enforce_eager=True, + tensor_parallel_size=tensor_parallel_size) + + # Instantiate AsyncLLM; make request to complete any deferred + # initialization; then delete instance + async_llm = AsyncLLM.from_engine_args(engine_args) + if send_one_request: + async for _ in async_llm.generate( + "Hello my name is", + request_id="abc", + sampling_params=SamplingParams( + max_tokens=1, output_kind=RequestOutputKind.DELTA)): + pass + del async_llm + + # Confirm all the processes are cleaned up. + wait_for_gpu_memory_to_clear( + devices=list(range(tensor_parallel_size)), + threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES, + ) + + +@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("tensor_parallel_size", [2, 1]) +@pytest.mark.parametrize("enable_multiprocessing", [True]) +@pytest.mark.parametrize("send_one_request", [False, True]) +def test_llm_delete(monkeypatch, model: str, tensor_parallel_size: int, + enable_multiprocessing: bool, + send_one_request: bool) -> None: + """Test that LLM frees GPU memory upon deletion. + TODO(andy) - LLM without multiprocessing. + + Args: + model: model under test + tensor_parallel_size: degree of tensor parallelism + enable_multiprocessing: enable workers in separate process(es) + send_one_request: send one request to engine before deleting + """ + if cuda_device_count_stateless() < tensor_parallel_size: + pytest.skip(reason="Not enough CUDA devices") + + with monkeypatch.context() as m: + MP_VALUE = "1" if enable_multiprocessing else "0" + m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE) + + # Instantiate LLM; make request to complete any deferred + # initialization; then delete instance + llm = LLM(model=model, + enforce_eager=True, + tensor_parallel_size=tensor_parallel_size) + if send_one_request: + llm.generate("Hello my name is", + sampling_params=SamplingParams(max_tokens=1)) + del llm + + # Confirm all the processes are cleaned up. + wait_for_gpu_memory_to_clear( + devices=list(range(tensor_parallel_size)), + threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES, + ) diff --git a/tests/v1/shutdown/test_forward_error.py b/tests/v1/shutdown/test_forward_error.py new file mode 100644 index 0000000000000000000000000000000000000000..9fedbe4f9a01a6396edb53510972f1b8ac17e681 --- /dev/null +++ b/tests/v1/shutdown/test_forward_error.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Test that we handle an Error in model forward and shutdown.""" + +import asyncio + +import pytest + +from tests.utils import wait_for_gpu_memory_to_clear +from tests.v1.shutdown.utils import (SHUTDOWN_TEST_THRESHOLD_BYTES, + SHUTDOWN_TEST_TIMEOUT_SEC) +from vllm import LLM, AsyncEngineArgs, SamplingParams +from vllm.distributed import get_tensor_model_parallel_rank +from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.utils import cuda_device_count_stateless +from vllm.v1.engine.async_llm import AsyncLLM +from vllm.v1.engine.exceptions import EngineDeadError + +MODELS = ["meta-llama/Llama-3.2-1B"] + + +def evil_forward(self, *args, **kwargs): + """Evil forward method that raise an exception after 10 calls.""" + NUMBER_OF_GOOD_PASSES = 10 + + if not hasattr(self, "num_calls"): + self.num_calls = 0 + + if (self.num_calls == NUMBER_OF_GOOD_PASSES + and get_tensor_model_parallel_rank() == 0): + raise Exception("Simulated illegal memory access on Rank 0!") + self.num_calls += 1 + + return self.model(*args, **kwargs) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("tensor_parallel_size", [2, 1]) +@pytest.mark.parametrize("model", MODELS) +async def test_async_llm_model_error(monkeypatch, tensor_parallel_size: int, + model: str) -> None: + """Test that AsyncLLM propagates a forward pass error and frees memory. + + AsyncLLM always uses an MP client. + """ + if cuda_device_count_stateless() < tensor_parallel_size: + pytest.skip(reason="Not enough CUDA devices") + + # Monkeypatch an error in the model. + monkeypatch.setattr(LlamaForCausalLM, "forward", evil_forward) + + engine_args = AsyncEngineArgs(model=model, + enforce_eager=True, + tensor_parallel_size=tensor_parallel_size) + async_llm = AsyncLLM.from_engine_args(engine_args) + + async def generate(request_id: str): + generator = async_llm.generate("Hello my name is", + request_id=request_id, + sampling_params=SamplingParams()) + try: + async for _ in generator: + pass + except Exception as e: + return e + + NUM_REQS = 3 + tasks = [generate(f"request-{idx}") for idx in range(NUM_REQS)] + outputs = await asyncio.gather(*tasks) + + # Every request should get an EngineDeadError. + for output in outputs: + assert isinstance(output, EngineDeadError) + + # AsyncLLM should be errored. + assert async_llm.errored + + # We should not be able to make another request. + with pytest.raises(EngineDeadError): + async for _ in async_llm.generate("Hello my name is", + request_id="abc", + sampling_params=SamplingParams()): + raise Exception("We should not get here.") + + # Confirm all the processes are cleaned up. + wait_for_gpu_memory_to_clear( + devices=list(range(tensor_parallel_size)), + threshold_bytes=2 * 2**30, + timeout_s=60, + ) + + # NOTE: shutdown is handled by the API Server if an exception + # occurs, so it is expected that we would need to call this. + async_llm.shutdown() + + +@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC) +@pytest.mark.parametrize("enable_multiprocessing", [True]) +@pytest.mark.parametrize("tensor_parallel_size", [2, 1]) +@pytest.mark.parametrize("model", MODELS) +def test_llm_model_error(monkeypatch, tensor_parallel_size: int, + enable_multiprocessing: bool, model: str) -> None: + """Test that LLM propagates a forward pass error and frees memory. + TODO(andy) - LLM without multiprocessing; LLM with multiprocessing + and >1 rank + """ + if cuda_device_count_stateless() < tensor_parallel_size: + pytest.skip(reason="Not enough CUDA devices") + + with monkeypatch.context() as m: + + MP_VALUE = "1" if enable_multiprocessing else "0" + m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE) + + # Monkeypatch an error in the model. + m.setattr(LlamaForCausalLM, "forward", evil_forward) + + llm = LLM(model=model, + enforce_eager=True, + tensor_parallel_size=tensor_parallel_size) + + with pytest.raises( + EngineDeadError if enable_multiprocessing else Exception): + llm.generate("Hello my name is Robert and I") + + # Confirm all the processes are cleaned up. + wait_for_gpu_memory_to_clear( + devices=list(range(tensor_parallel_size)), + threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES, + ) diff --git a/tests/v1/shutdown/test_processor_error.py b/tests/v1/shutdown/test_processor_error.py new file mode 100644 index 0000000000000000000000000000000000000000..0fe48da475c6a49142690541b2549e43cb711f10 --- /dev/null +++ b/tests/v1/shutdown/test_processor_error.py @@ -0,0 +1,69 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Test error handling in Processor. Should not impact other reqs.""" + +import asyncio + +import pytest + +from tests.v1.shutdown.utils import SHUTDOWN_TEST_TIMEOUT_SEC +from vllm import SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.inputs.data import TokensPrompt +from vllm.sampling_params import RequestOutputKind +from vllm.v1.engine.async_llm import AsyncLLM +from vllm.v1.engine.exceptions import EngineGenerateError + +MODELS = ["meta-llama/Llama-3.2-1B"] + + +@pytest.mark.asyncio +@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC) +@pytest.mark.parametrize("model", MODELS) +async def test_async_llm_processor_error(model: str) -> None: + """Test that AsyncLLM propagates a processor error. + Test empty tokens prompt (failure) and non-empty prompt (no failure.) + AsyncLLM always uses an MP client. + """ + engine_args = AsyncEngineArgs(model=model, enforce_eager=True) + async_llm = AsyncLLM.from_engine_args(engine_args) + + async def generate(request_id: str): + # [] is not allowed and will raise a ValueError in Processor. + generator = async_llm.generate(TokensPrompt([]), + request_id=request_id, + sampling_params=SamplingParams()) + try: + async for _ in generator: + pass + except Exception as e: + return e + + NUM_REQS = 3 + tasks = [generate(f"request-{idx}") for idx in range(NUM_REQS)] + outputs = await asyncio.gather(*tasks) + + # Every request should have get an EngineGenerateError. + for output in outputs: + with pytest.raises(EngineGenerateError): + raise output + + # AsyncLLM should be errored. + assert not async_llm.errored + + # This should be no problem. + EXPECTED_TOKENS = 5 + outputs = [] + async for out in async_llm.generate( + "Hello my name is", + request_id="abc", + sampling_params=SamplingParams( + max_tokens=EXPECTED_TOKENS, + output_kind=RequestOutputKind.DELTA)): + outputs.append(out) + + generated_tokens = [] + for out in outputs: + generated_tokens.extend(out.outputs[0].token_ids) + assert len(generated_tokens) == EXPECTED_TOKENS + + async_llm.shutdown() diff --git a/tests/v1/shutdown/test_startup_error.py b/tests/v1/shutdown/test_startup_error.py new file mode 100644 index 0000000000000000000000000000000000000000..1bba19102ec611f3bbd8d4c392974e1e1d3a5bde --- /dev/null +++ b/tests/v1/shutdown/test_startup_error.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Test that we handle a startup Error and shutdown.""" + +import pytest + +from tests.utils import wait_for_gpu_memory_to_clear +from tests.v1.shutdown.utils import (SHUTDOWN_TEST_THRESHOLD_BYTES, + SHUTDOWN_TEST_TIMEOUT_SEC) +from vllm import LLM +from vllm.distributed import get_tensor_model_parallel_rank +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.utils import cuda_device_count_stateless +from vllm.v1.engine.async_llm import AsyncLLM + +MODELS = ["meta-llama/Llama-3.2-1B"] + + +def evil_method(self, *args, **kwargs): + """Evil method that raises an exception.""" + + if get_tensor_model_parallel_rank() == 0: + raise Exception("Simulated Error in startup!") + + return self.model(*args, **kwargs, intermediate_tensors=None) + + +@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("tensor_parallel_size", [2, 1]) +@pytest.mark.parametrize("failing_method", ["forward", "load_weights"]) +def test_async_llm_startup_error(monkeypatch, model: str, + tensor_parallel_size: int, + failing_method: str) -> None: + """Test that AsyncLLM propagates an __init__ error & frees memory. + Test profiling (forward()) and load weights failures. + AsyncLLM always uses an MP client. + """ + if cuda_device_count_stateless() < tensor_parallel_size: + pytest.skip(reason="Not enough CUDA devices") + + # Monkeypatch an error in the model. + monkeypatch.setattr(LlamaForCausalLM, failing_method, evil_method) + + engine_args = AsyncEngineArgs(model=model, + enforce_eager=True, + tensor_parallel_size=tensor_parallel_size) + + # Confirm we get an exception. + with pytest.raises(Exception, match="initialization failed"): + _ = AsyncLLM.from_engine_args(engine_args) + + # Confirm all the processes are cleaned up. + wait_for_gpu_memory_to_clear( + devices=list(range(tensor_parallel_size)), + threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES, + ) + + +@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("tensor_parallel_size", [2, 1]) +@pytest.mark.parametrize("enable_multiprocessing", [True]) +@pytest.mark.parametrize("failing_method", ["forward", "load_weights"]) +def test_llm_startup_error(monkeypatch, model: str, tensor_parallel_size: int, + enable_multiprocessing: bool, + failing_method: str) -> None: + """Test that LLM propagates an __init__ error and frees memory. + Test profiling (forward()) and load weights failures. + TODO(andy) - LLM without multiprocessing. + """ + if model != "meta-llama/Llama-3.2-1B": + pytest.skip(reason="Only test meta-llama/Llama-3.2-1B") + if cuda_device_count_stateless() < tensor_parallel_size: + pytest.skip(reason="Not enough CUDA devices") + + with monkeypatch.context() as m: + + MP_VALUE = "1" if enable_multiprocessing else "0" + m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE) + + # Monkeypatch an error in the model. + monkeypatch.setattr(LlamaForCausalLM, failing_method, evil_method) + + with pytest.raises( + Exception, + match="initialization failed" + if enable_multiprocessing else "Simulated Error in startup!"): + _ = LLM(model=model, + enforce_eager=True, + tensor_parallel_size=tensor_parallel_size) + + # Confirm all the processes are cleaned up. + wait_for_gpu_memory_to_clear( + devices=list(range(tensor_parallel_size)), + threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES, + ) diff --git a/tests/v1/shutdown/utils.py b/tests/v1/shutdown/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8f7c0380d407f12620fd01b066f05e945fce7d0b --- /dev/null +++ b/tests/v1/shutdown/utils.py @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Shutdown test utils""" + +SHUTDOWN_TEST_TIMEOUT_SEC = 120 +SHUTDOWN_TEST_THRESHOLD_BYTES = 2 * 2**30 diff --git a/tests/v1/spec_decode/test_max_len.py b/tests/v1/spec_decode/test_max_len.py new file mode 100644 index 0000000000000000000000000000000000000000..f577fb4ab3295991b7d52607ad66074b63f09574 --- /dev/null +++ b/tests/v1/spec_decode/test_max_len.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Test whether spec decoding handles the max model length properly.""" + +import pytest + +from vllm import LLM, SamplingParams + +_PROMPTS = [ + "1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1", + "Repeat the following sentence 10 times: Consistency is key to mastering any skill.", # noqa: E501 + "Who won the Turing Award in 2018, and for what contribution? Describe in detail.", # noqa: E501 +] + + +@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10]) +def test_ngram_max_len( + monkeypatch: pytest.MonkeyPatch, + num_speculative_tokens: int, +): + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + llm = LLM( + model="facebook/opt-125m", + max_model_len=100, + enforce_eager=True, # For faster initialization. + speculative_config={ + "method": "ngram", + "prompt_lookup_max": 5, + "prompt_lookup_min": 3, + "num_speculative_tokens": num_speculative_tokens, + }, + ) + sampling_params = SamplingParams(max_tokens=100, ignore_eos=True) + llm.generate(_PROMPTS, sampling_params) + + +@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10]) +def test_eagle_max_len( + monkeypatch: pytest.MonkeyPatch, + num_speculative_tokens: int, +): + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + llm = LLM( + model="meta-llama/Meta-Llama-3-8B-Instruct", + enforce_eager=True, # For faster initialization. + speculative_config={ + "method": "eagle", + "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", + "num_speculative_tokens": num_speculative_tokens, + }, + max_model_len=100, + ) + sampling_params = SamplingParams(max_tokens=100, ignore_eos=True) + llm.generate(_PROMPTS, sampling_params) diff --git a/tests/v1/spec_decode/test_ngram.py b/tests/v1/spec_decode/test_ngram.py index a81b4897e5d650b47661d8e7fa5398f229221437..50548219fff042b26f79a98df6e602314e69acef 100644 --- a/tests/v1/spec_decode/test_ngram.py +++ b/tests/v1/spec_decode/test_ngram.py @@ -2,6 +2,7 @@ import numpy as np +from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig from vllm.v1.spec_decode.ngram_proposer import (NgramProposer, _find_subarray_kmp, _kmp_lps_array) @@ -39,50 +40,50 @@ def test_find_subarray_kmp(): def test_ngram_proposer(): - proposer = NgramProposer() + + def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: + # Dummy model config. Just to set max_model_len. + model_config = ModelConfig(model="facebook/opt-125m", + task="generate", + max_model_len=100, + tokenizer="facebook/opt-125m", + tokenizer_mode="auto", + dtype="auto", + seed=None, + trust_remote_code=False) + return NgramProposer( + vllm_config=VllmConfig(model_config=model_config, + speculative_config=SpeculativeConfig. + from_dict({ + "prompt_lookup_min": min_n, + "prompt_lookup_max": max_n, + "num_speculative_tokens": k, + "method": "ngram", + }))) # No match. - result = proposer.propose( - context_token_ids=np.array([1, 2, 3, 4, 5]), - min_n=2, - max_n=2, - k=2, - ) + result = ngram_proposer( + 2, 2, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 5])) assert result is None # No match for 4-gram. - result = proposer.propose( - context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]), - min_n=4, - max_n=4, - k=2, - ) + result = ngram_proposer( + 4, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3])) assert result is None # No match for 4-gram but match for 3-gram. - result = proposer.propose( - context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]), - min_n=3, - max_n=4, - k=2, - ) + result = ngram_proposer( + 3, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3])) assert np.array_equal(result, np.array([4, 1])) # Match for both 4-gram and 3-gram. # In this case, the proposer should return the 4-gram match. - result = proposer.propose( - context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]), - min_n=3, - max_n=4, - k=2, - ) + result = ngram_proposer(3, 4, 2).propose( + context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4])) assert np.array_equal(result, np.array([1, 2])) # Not [5, 1] # Match for 2-gram and 3-gram, but not 4-gram. - result = proposer.propose( - context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]), - min_n=2, - max_n=4, - k=2, - ) + result = ngram_proposer( + 2, 4, + 2).propose(context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4])) assert np.array_equal(result, np.array([1, 2])) # Not [5, 2] diff --git a/tests/v1/structured_output/test_utils.py b/tests/v1/structured_output/test_utils.py index 0929f990162897aa3369c95dea59ff5b99df0a85..1cefe8726df73c96c4e73a91801f65c8f5d66d70 100644 --- a/tests/v1/structured_output/test_utils.py +++ b/tests/v1/structured_output/test_utils.py @@ -2,17 +2,13 @@ import pytest -from vllm.v1.structured_output.utils import ( +from vllm.v1.structured_output.backend_xgrammar import ( has_xgrammar_unsupported_json_features) @pytest.fixture def unsupported_string_schemas(): return [ - { - "type": "string", - "pattern": "^[a-zA-Z]+$" - }, { "type": "string", "format": "email" @@ -23,22 +19,6 @@ def unsupported_string_schemas(): @pytest.fixture def unsupported_integer_schemas(): return [ - { - "type": "integer", - "minimum": 0 - }, - { - "type": "integer", - "maximum": 120 - }, - { - "type": "integer", - "exclusiveMinimum": 120 - }, - { - "type": "integer", - "exclusiveMaximum": 120 - }, { "type": "integer", "multipleOf": 120 @@ -49,22 +29,6 @@ def unsupported_integer_schemas(): @pytest.fixture def unsupported_number_schemas(): return [ - { - "type": "number", - "minimum": 0 - }, - { - "type": "number", - "maximum": 120 - }, - { - "type": "number", - "exclusiveMinimum": 120 - }, - { - "type": "number", - "exclusiveMaximum": 120 - }, { "type": "number", "multipleOf": 120 @@ -156,13 +120,28 @@ def supported_schema(): "type": "string", "enum": ["sedan", "suv", "truck"] }, + "car_brand": { + "type": "string", + "pattern": "^[a-zA-Z]+$" + }, "short_description": { "type": "string", "maxLength": 50 }, + "mileage": { + "type": "number", + "minimum": 0, + "maximum": 1000000 + }, + "model_year": { + "type": "integer", + "exclusiveMinimum": 1900, + "exclusiveMaximum": 2100 + }, "long_description": { "type": "string", - "minLength": 50 + "minLength": 50, + "maxLength": 2000 }, "address": { "type": "object", diff --git a/tests/v1/test_async_llm_dp.py b/tests/v1/test_async_llm_dp.py index f0e031969e733694ba223f9420666052d3baa290..ce4c4d198db580606f1d03e956eafbe87fdc20d3 100644 --- a/tests/v1/test_async_llm_dp.py +++ b/tests/v1/test_async_llm_dp.py @@ -101,9 +101,9 @@ async def test_load(output_kind: RequestOutputKind): # the engines only synchronize stopping every N steps so # allow a small amount of time here. for _ in range(10): - if core_client.num_engines_running == 0: + if not core_client.engines_running: break await asyncio.sleep(0.5) - assert core_client.num_engines_running == 0 + assert not core_client.engines_running assert not core_client.reqs_in_flight diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index bc0e0cbd85e1aafdd4907418ea4cf33303f8829a..b55018ae8ef033d1d18a3b4eea8f2db5e76cd8e1 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -1,10 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 from collections import UserDict from dataclasses import dataclass +from typing import Optional +import msgspec import numpy as np import torch +from vllm.multimodal.inputs import (MultiModalBatchedField, + MultiModalFieldElem, MultiModalKwargs, + MultiModalKwargsItem, + MultiModalSharedField, NestedTensors) from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder @@ -26,6 +32,7 @@ class MyType: large_f_contig_tensor: torch.Tensor small_non_contig_tensor: torch.Tensor large_non_contig_tensor: torch.Tensor + empty_tensor: torch.Tensor def test_encode_decode(): @@ -41,6 +48,10 @@ def test_encode_decode(): torch.rand((1, 10), dtype=torch.float32), torch.rand((3, 5, 4000), dtype=torch.float64), torch.tensor(1984), # test scalar too + # Make sure to test bf16 which numpy doesn't support. + torch.rand((3, 5, 1000), dtype=torch.bfloat16), + torch.tensor([float("-inf"), float("inf")] * 1024, + dtype=torch.bfloat16), ], numpy_array=np.arange(512), unrecognized=UnrecognizedType(33), @@ -48,9 +59,10 @@ def test_encode_decode(): large_f_contig_tensor=torch.rand(1024, 4).t(), small_non_contig_tensor=torch.rand(2, 4)[:, 1:3], large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20], + empty_tensor=torch.empty(0), ) - encoder = MsgpackEncoder() + encoder = MsgpackEncoder(size_threshold=256) decoder = MsgpackDecoder(MyType) encoded = encoder.encode(obj) @@ -58,7 +70,7 @@ def test_encode_decode(): # There should be the main buffer + 4 large tensor buffers # + 1 large numpy array. "large" is <= 512 bytes. # The two small tensors are encoded inline. - assert len(encoded) == 6 + assert len(encoded) == 8 decoded: MyType = decoder.decode(encoded) @@ -70,7 +82,7 @@ def test_encode_decode(): encoded2 = encoder.encode_into(obj, preallocated) - assert len(encoded2) == 6 + assert len(encoded2) == 8 assert encoded2[0] is preallocated decoded2: MyType = decoder.decode(encoded2) @@ -78,6 +90,97 @@ def test_encode_decode(): assert_equal(decoded2, obj) +class MyRequest(msgspec.Struct): + mm: Optional[list[MultiModalKwargs]] + + +def test_multimodal_kwargs(): + d = { + "foo": + torch.zeros(20000, dtype=torch.float16), + "bar": [torch.zeros(i * 1000, dtype=torch.int8) for i in range(3)], + "baz": [ + torch.rand((256), dtype=torch.float16), + [ + torch.rand((1, 12), dtype=torch.float32), + torch.rand((3, 5, 7), dtype=torch.float64), + ], [torch.rand((4, 4), dtype=torch.float16)] + ], + } + + # pack mm kwargs into a mock request so that it can be decoded properly + req = MyRequest(mm=[MultiModalKwargs(d)]) + + encoder = MsgpackEncoder() + decoder = MsgpackDecoder(MyRequest) + + encoded = encoder.encode(req) + + assert len(encoded) == 6 + + total_len = sum(memoryview(x).cast("B").nbytes for x in encoded) + + # expected total encoding length, should be 44559, +-20 for minor changes + assert total_len >= 44539 and total_len <= 44579 + decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] + assert all(nested_equal(d[k], decoded[k]) for k in d) + + +def test_multimodal_items_by_modality(): + e1 = MultiModalFieldElem("audio", "a0", + torch.zeros(1000, dtype=torch.bfloat16), + MultiModalBatchedField()) + e2 = MultiModalFieldElem( + "video", + "v0", + [torch.zeros(1000, dtype=torch.int8) for _ in range(4)], + MultiModalBatchedField(), + ) + e3 = MultiModalFieldElem("image", "i0", torch.zeros(1000, + dtype=torch.int32), + MultiModalSharedField(4)) + e4 = MultiModalFieldElem("image", "i1", torch.zeros(1000, + dtype=torch.int32), + MultiModalBatchedField()) + audio = MultiModalKwargsItem.from_elems([e1]) + video = MultiModalKwargsItem.from_elems([e2]) + image = MultiModalKwargsItem.from_elems([e3, e4]) + mm = MultiModalKwargs.from_items([audio, video, image]) + + # pack mm kwargs into a mock request so that it can be decoded properly + req = MyRequest([mm]) + + encoder = MsgpackEncoder() + decoder = MsgpackDecoder(MyRequest) + + encoded = encoder.encode(req) + + assert len(encoded) == 8 + + total_len = sum(memoryview(x).cast("B").nbytes for x in encoded) + + # expected total encoding length, should be 14255, +-20 for minor changes + assert total_len >= 14235 and total_len <= 14275 + decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] + + # check all modalities were recovered and do some basic sanity checks + assert len(decoded.modalities) == 3 + images = decoded.get_items("image") + assert len(images) == 1 + assert len(images[0].items()) == 2 + assert list(images[0].keys()) == ["i0", "i1"] + + # check the tensor contents and layout in the main dict + assert all(nested_equal(mm[k], decoded[k]) for k in mm) + + +def nested_equal(a: NestedTensors, b: NestedTensors): + if isinstance(a, torch.Tensor): + return torch.equal(a, b) + else: + return all(nested_equal(x, y) for x, y in zip(a, b)) + + def assert_equal(obj1: MyType, obj2: MyType): assert torch.equal(obj1.tensor1, obj2.tensor1) assert obj1.a_string == obj2.a_string @@ -92,3 +195,4 @@ def assert_equal(obj1: MyType, obj2: MyType): obj2.small_non_contig_tensor) assert torch.equal(obj1.large_non_contig_tensor, obj2.large_non_contig_tensor) + assert torch.equal(obj1.empty_tensor, obj2.empty_tensor) diff --git a/tests/v1/tpu/test_basic.py b/tests/v1/tpu/test_basic.py index 8164952fe3823b14325f02684f048d2717de90b1..a4571a554572cb3714515509f4c4c8ec909d5259 100644 --- a/tests/v1/tpu/test_basic.py +++ b/tests/v1/tpu/test_basic.py @@ -22,6 +22,7 @@ MODELS = [ ] TENSOR_PARALLEL_SIZES = [1] +MAX_NUM_REQS = [16, 1024] # TODO: Enable when CI/CD will have a multi-tpu instance # TENSOR_PARALLEL_SIZES = [1, 4] @@ -32,12 +33,14 @@ TENSOR_PARALLEL_SIZES = [1] @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES) +@pytest.mark.parametrize("max_num_seqs", MAX_NUM_REQS) def test_basic( vllm_runner: type[VllmRunner], monkeypatch: pytest.MonkeyPatch, model: str, max_tokens: int, tensor_parallel_size: int, + max_num_seqs: int, ) -> None: prompt = "The next numbers of the sequence " + ", ".join( str(i) for i in range(1024)) + " are:" @@ -51,9 +54,9 @@ def test_basic( # Note: max_num_batched_tokens == 1024 is needed here to # actually test chunked prompt max_num_batched_tokens=1024, - max_model_len=8196, + max_model_len=8192, gpu_memory_utilization=0.7, - max_num_seqs=16, + max_num_seqs=max_num_seqs, tensor_parallel_size=tensor_parallel_size) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/tests/v1/tpu/test_multimodal.py b/tests/v1/tpu/test_multimodal.py new file mode 100644 index 0000000000000000000000000000000000000000..eb62e0e4b201a73a7b3e510b944bb79112dc54a6 --- /dev/null +++ b/tests/v1/tpu/test_multimodal.py @@ -0,0 +1,91 @@ +# SPDX-License-Identifier: Apache-2.0 + +import openai +import pytest + +from vllm import envs +from vllm.multimodal.utils import encode_image_base64, fetch_image +from vllm.platforms import current_platform + +from ...entrypoints.openai.test_vision import TEST_IMAGE_URLS +from ...utils import RemoteOpenAIServer + +if not envs.VLLM_USE_V1: + pytest.skip( + "Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.", + allow_module_level=True, + ) + + +@pytest.fixture(scope="session") +def base64_encoded_image() -> dict[str, str]: + return { + image_url: encode_image_base64(fetch_image(image_url)) + for image_url in TEST_IMAGE_URLS + } + + +@pytest.mark.asyncio +@pytest.mark.skipif(not current_platform.is_tpu(), + reason="This test needs a TPU") +@pytest.mark.parametrize("model_name", ["llava-hf/llava-1.5-7b-hf"]) +async def test_basic_vision(model_name: str, base64_encoded_image: dict[str, + str]): + + def whats_in_this_image_msg(b64): + return [{ + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's in this image?" + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{b64}" + }, + }, + ], + }] + + server_args = [ + "--max-model-len", + "1024", + "--max-num-seqs", + "16", + "--gpu-memory-utilization", + "0.95", + "--trust-remote-code", + "--max-num-batched-tokens", + "576", + # NOTE: max-num-batched-tokens>=mm_item_size + "--disable_chunked_mm_input", + "--chat-template", + "examples/template_llava.jinja" + ] + + # Server will pre-compile on first startup (takes a long time). + with RemoteOpenAIServer(model_name, server_args, + max_wait_seconds=600) as remote_server: + client: openai.AsyncOpenAI = remote_server.get_async_client() + + # Other requests now should be much faster + for image_url in TEST_IMAGE_URLS: + image_base64 = base64_encoded_image[image_url] + chat_completion_from_base64 = await client.chat.completions\ + .create( + model=model_name, + messages=whats_in_this_image_msg(image_base64), + max_completion_tokens=24, + temperature=0.0) + result = chat_completion_from_base64 + assert result + choice = result.choices[0] + assert choice.finish_reason == "length" + + message = choice.message + message = result.choices[0].message + assert message.content is not None and len(message.content) >= 10 + assert message.role == "assistant" diff --git a/tests/v1/tpu/test_sampler.py b/tests/v1/tpu/test_sampler.py index 0147da53351715c84db39379034aece7e4e4f387..c6b492b5a3cc2997a898067f1c403aaf7117bdf2 100644 --- a/tests/v1/tpu/test_sampler.py +++ b/tests/v1/tpu/test_sampler.py @@ -1,4 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import random + import pytest from vllm import LLM, envs @@ -39,3 +41,23 @@ def test_sampler_different(model_name: str): # Unsupported `seed` param. sampling_params = SamplingParams(temperature=0.3, seed=42) output2 = llm.generate(prompts, sampling_params) + + # Batch-case with TopK/P + for B in [4, 16]: + p = prompts * B + sampling_params = [ + SamplingParams( + temperature=0.1, + min_p=0.8, + max_tokens=64, + # Vary number of ks + top_k=random.randint(4, 12), + top_p=random.random()) for _ in range(B) + ] + # Make sure first two reqs have the same K/P + sampling_params[0] = sampling_params[1] + output = llm.generate(p, sampling_params) + # There are natural numerical instabilities that make it difficult + # to have deterministic results over many tokens, tests the first ~20 + # tokens match. + assert output[0].outputs[0].text[:20] == output[1].outputs[0].text[:20] diff --git a/tests/v1/tpu/test_topk_topp_sampler.py b/tests/v1/tpu/test_topk_topp_sampler.py index dce0303e68d558da1c2de3e23f4f69330c59cead..ff9217f8f3cab55125685b4ccbb1f39643408d1e 100644 --- a/tests/v1/tpu/test_topk_topp_sampler.py +++ b/tests/v1/tpu/test_topk_topp_sampler.py @@ -5,7 +5,8 @@ import pytest import torch from vllm.platforms import current_platform -from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_tpu +from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p, + apply_top_k_top_p_tpu) if not current_platform.is_tpu(): pytest.skip("This test needs a TPU.", allow_module_level=True) @@ -16,6 +17,25 @@ VOCAB_SIZE = 128 * 1024 TOLERANCE = 1e-6 +def test_topk_equivalence_to_native_impl(): + with torch.device(xm.xla_device()): + xm.set_rng_state(seed=33) + + logits = torch.rand((BATCH_SIZE, VOCAB_SIZE)) + + # Random top-k values between 1 and 10. + k = torch.randint(1, 10, (BATCH_SIZE, )) + + # Set k=vocab_size for ~50% of requests in the batch (top-k disabled). + k.masked_fill_(torch.randint(0, 2, (BATCH_SIZE, ), dtype=bool), + VOCAB_SIZE) + + result_tpu = apply_top_k_top_p_tpu(logits=logits.clone(), k=k, p=None) + + result_native = apply_top_k_top_p(logits=logits.clone(), k=k, p=None) + assert torch.allclose(result_native, result_tpu) + + def test_topp_result_sums_past_p(): with torch.device(xm.xla_device()): xm.set_rng_state(seed=33) diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py index 8ea8c890613a34218b959049688d261f9d18574a..319b38b4ca09d43671739e90cbbb2166b0cd8b2e 100644 --- a/tests/v1/tpu/worker/test_tpu_model_runner.py +++ b/tests/v1/tpu/worker/test_tpu_model_runner.py @@ -77,7 +77,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: NewRequestData( req_id=req_id, prompt_token_ids=[1, 2, 3], - prompt="test", mm_inputs=[], mm_hashes=[], mm_positions=[], @@ -294,8 +293,28 @@ def test_update_states_request_unscheduled(model_runner): def test_get_paddings(): + # Bucketed padding min_token_size, max_token_size, padding_gap = 16, 512, 64 expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512] + actual_paddings = _get_token_paddings(min_token_size, max_token_size, + padding_gap) + + # Bucketed padding with max_token_size not a power of two. + max_token_size = 317 + expected_paddings = [16, 32, 64, 128, 192, 256, 320] + actual_paddings = _get_token_paddings(min_token_size, max_token_size, + padding_gap) + assert actual_paddings == expected_paddings + + # Exponential padding. + max_token_size, padding_gap = 1024, 0 + expected_paddings = [16, 32, 64, 128, 256, 512, 1024] + actual_paddings = _get_token_paddings(min_token_size, max_token_size, + padding_gap) + assert actual_paddings == expected_paddings + # Exponential padding with max_token_size not a power of two. + max_token_size = 317 + expected_paddings = [16, 32, 64, 128, 256, 512] actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap) assert actual_paddings == expected_paddings diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 2486c26c6071af1db5c8300f6e2681ab779b6915..915ec2914a825b6c459100f0e687b065fcef8b0e 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -195,7 +195,6 @@ def _construct_cached_request_state(req_id_suffix: int): return CachedRequestState( req_id=f"req_id_{req_id_suffix}", prompt_token_ids=prompt_token_ids, - prompt=None, sampling_params=_create_sampling_params(), mm_inputs=[], mm_positions=[], diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index dd95a7f53064ea792c6d565b7c170eb7da5d1892..68e34cfacc5886c578759aaeb993511c143cf3f7 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -50,7 +50,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: NewRequestData( req_id=req_id, prompt_token_ids=[1, 2, 3], - prompt="test", mm_inputs=[], mm_hashes=[], mm_positions=[], diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 3f54a1a4eca49259e8c39821eb683fc1ad415800..9240b34d5383f99131a1a82fff216f9f79daddf1 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1616,6 +1616,26 @@ def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, ssm_states, pad_slot_id) +# ROCm skinny gemms +def LLMM1(a: torch.Tensor, b: torch.Tensor, + rows_per_block: int) -> torch.Tensor: + return torch.ops._rocm_C.LLMM1(a, b, rows_per_block) + + +def wvSplitK(a: torch.Tensor, b: torch.Tensor, cu_count: int) -> torch.Tensor: + return torch.ops._rocm_C.wvSplitK(a, b, cu_count) + + +def wvSplitKQ(a: torch.Tensor, b: torch.Tensor, out_dtype: torch.dtype, + scale_a: torch.Tensor, scale_b: torch.Tensor, + cu_count: int) -> torch.Tensor: + out = torch.empty((b.shape[0], a.shape[0]), + dtype=out_dtype, + device=b.device) + torch.ops._rocm_C.wvSplitKQ(a, b, out, scale_a, scale_b, cu_count) + return out + + # moe def moe_sum(input: torch.Tensor, output: torch.Tensor): torch.ops._moe_C.moe_sum(input, output) @@ -1665,6 +1685,29 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, token_expert_indicies, gating_output) +def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor], + b_qweight: torch.Tensor, b_scales: torch.Tensor, + b_qzeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_past_padded: torch.Tensor, + topk_weights: torch.Tensor, moe_block_size: int, + top_k: int, mul_topk_weights: bool, is_ep: bool, + b_q_type: ScalarType, size_m: int, size_n: int, + size_k: int, is_k_full: bool, use_atomic_add: bool, + use_fp32_reduce: bool, + is_zp_float: bool) -> torch.Tensor: + return torch.ops._moe_C.moe_wna16_marlin_gemm( + input, output, b_qweight, b_scales, b_qzeros, g_idx, perm, workspace, + sorted_token_ids, expert_ids, num_tokens_past_padded, topk_weights, + moe_block_size, top_k, mul_topk_weights, is_ep, b_q_type.id, size_m, + size_n, size_k, is_k_full, use_atomic_add, use_fp32_reduce, + is_zp_float) + + if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): @register_fake("_moe_C::marlin_gemm_moe") @@ -1683,6 +1726,29 @@ if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): dtype=a.dtype, device=a.device) + @register_fake("_moe_C::moe_wna16_marlin_gemm") + def moe_wna16_marlin_gemm_fake(input: torch.Tensor, + output: Optional[torch.Tensor], + b_qweight: torch.Tensor, + b_scales: torch.Tensor, + b_qzeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_past_padded: torch.Tensor, + topk_weights: torch.Tensor, + moe_block_size: int, top_k: int, + mul_topk_weights: bool, is_ep: bool, + b_q_type: ScalarType, size_m: int, + size_n: int, size_k: int, is_k_full: bool, + use_atomic_add: bool, use_fp32_reduce: bool, + is_zp_float: bool) -> torch.Tensor: + return torch.empty((size_m * top_k, size_n), + dtype=input.dtype, + device=input.device) + def reshape_and_cache( key: torch.Tensor, @@ -1904,3 +1970,12 @@ def flash_mla_with_kvcache( num_splits, ) return out, softmax_lse + + +# def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor, +# q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, +# seq_lens: torch.Tensor, page_table: torch.Tensor, +# scale: float) -> torch.Tensor: +# torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache, +# seq_lens, page_table, scale) +# return out diff --git a/vllm/assets/video.py b/vllm/assets/video.py index 32b0b86ba36f4be4578ed35b7e376bb4d8cf0506..133e18b68e25b9e26160d61a12b4397a11c1ea03 100644 --- a/vllm/assets/video.py +++ b/vllm/assets/video.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from functools import lru_cache -from typing import Literal +from typing import Literal, Optional import cv2 import numpy as np @@ -10,8 +10,15 @@ import numpy.typing as npt from huggingface_hub import hf_hub_download from PIL import Image +from vllm.utils import PlaceholderModule + from .base import get_cache_dir +try: + import librosa +except ImportError: + librosa = PlaceholderModule("librosa") # type: ignore[assignment] + @lru_cache def download_video_asset(filename: str) -> str: @@ -85,3 +92,12 @@ class VideoAsset: video_path = download_video_asset(self.name) ret = video_to_ndarrays(video_path, self.num_frames) return ret + + def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray: + """ + Read audio data from the video asset, used in Qwen2.5-Omni examples. + + See also: examples/offline_inference/qwen2_5_omni/only_thinker.py + """ + video_path = download_video_asset(self.name) + return librosa.load(video_path, sr=sampling_rate)[0] diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 82d60f9da7da64e9a3a1713aebc04bc9c5e6634c..f3d6ffaeb8f45959dd4a3de23d2370f3d2f7c246 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -77,6 +77,10 @@ class AttentionBackend(ABC): ) -> Tuple[int, ...]: raise NotImplementedError + @staticmethod + def get_kv_cache_stride_order() -> Tuple[int, ...]: + raise NotImplementedError + @staticmethod @abstractmethod def swap_blocks( @@ -237,6 +241,7 @@ class AttentionLayer(Protocol): _v_scale: torch.Tensor _k_scale_float: float _v_scale_float: float + _prob_scale: torch.Tensor def forward( self, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 7b2b326684eff6367d60c3b04ab6151821a0f44f..b3687a9012b9ee08d06ecb653f3eb680461f730d 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -22,13 +22,13 @@ from vllm.attention.backends.utils import ( compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args, is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set, is_block_tables_empty) +from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, + get_flash_attn_version) from vllm.logger import init_logger from vllm.multimodal import MultiModalPlaceholderMap from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.vllm_flash_attn import (flash_attn_varlen_func, flash_attn_with_kvcache) -from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8, - get_flash_attn_version) if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, @@ -691,7 +691,7 @@ class FlashAttentionImpl(AttentionImpl): assert output is not None, "Output tensor must be provided." # NOTE(woosuk): FlashAttention2 does not support FP8 KV cache. - if self.vllm_flash_attn_version < 3 or output.dtype != torch.bfloat16: + if not flash_attn_supports_fp8() or output.dtype != torch.bfloat16: assert ( layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0), ( "key/v_scale is only supported in FlashAttention 3 with " diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index a5ea0db3be48daba577ba6031581b980757a88d1..e8198683c0a8cdb41ab85a69cdb71620c3d5172d 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import dataclasses +import os from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass @@ -37,7 +38,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, is_block_tables_empty) from vllm.attention.layer import Attention from vllm.attention.ops.paged_attn import PagedAttention -from vllm.config import VllmConfig, get_current_vllm_config +from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, make_tensor_with_pad) @@ -48,6 +49,9 @@ if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) +FLASHINFER_KV_CACHE_LAYOUT: str = os.getenv("FLASHINFER_KV_CACHE_LAYOUT", + "NHD").upper() + class FlashInferBackend(AttentionBackend): @@ -80,6 +84,14 @@ class FlashInferBackend(AttentionBackend): ) -> Tuple[int, ...]: return (num_blocks, 2, block_size, num_kv_heads, head_size) + @staticmethod + def get_kv_cache_stride_order() -> Tuple[int, ...]: + cache_layout = FLASHINFER_KV_CACHE_LAYOUT + assert (cache_layout in ("NHD", "HND")) + stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, + 2, 4) + return stride_order + @staticmethod def swap_blocks( src_kv_cache: torch.Tensor, @@ -128,12 +140,10 @@ def get_per_layer_parameters( to use during `plan`. """ - layers = vllm_config.compilation_config.static_forward_context + layers = get_layers_from_vllm_config(vllm_config, Attention) per_layer_params: Dict[str, PerLayerParameters] = {} for key, layer in layers.items(): - assert isinstance(layer, Attention) - impl = layer.impl assert isinstance(impl, FlashInferImpl) @@ -187,7 +197,8 @@ class FlashInferState(AttentionState): # Global hyperparameters shared by all attention layers self.global_hyperparameters: Optional[PerLayerParameters] = None - self.vllm_config = get_current_vllm_config() + self.vllm_config = self.runner.vllm_config + self._kv_cache_layout = None def _get_workspace_buffer(self): if self._workspace_buffer is None: @@ -197,10 +208,15 @@ class FlashInferState(AttentionState): device=self.runner.device) return self._workspace_buffer + def get_kv_cache_layout(self): + if self._kv_cache_layout is None: + self._kv_cache_layout = FLASHINFER_KV_CACHE_LAYOUT + return self._kv_cache_layout + def _get_prefill_wrapper(self): if self._prefill_wrapper is None: self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( - self._get_workspace_buffer(), "NHD") + self._get_workspace_buffer(), self.get_kv_cache_layout()) return self._prefill_wrapper def _get_decode_wrapper(self): @@ -213,7 +229,7 @@ class FlashInferState(AttentionState): num_qo_heads // num_kv_heads > 4) self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( self._get_workspace_buffer(), - "NHD", + self.get_kv_cache_layout(), use_tensor_cores=use_tensor_cores) return self._decode_wrapper @@ -274,7 +290,8 @@ class FlashInferState(AttentionState): self._graph_decode_wrapper = \ CUDAGraphBatchDecodeWithPagedKVCacheWrapper( self._graph_decode_workspace_buffer, _indptr_buffer, - self._graph_indices_buffer, _last_page_len_buffer, "NHD", + self._graph_indices_buffer, _last_page_len_buffer, + self.get_kv_cache_layout(), use_tensor_cores) if self.runner.kv_cache_dtype.startswith("fp8"): kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( @@ -613,7 +630,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): # Global hyperparameters shared by all attention layers self.global_hyperparameters: Optional[PerLayerParameters] = None - self.vllm_config = get_current_vllm_config() + self.vllm_config = self.runner.vllm_config def prepare(self): self.slot_mapping: List[int] = [] @@ -1007,6 +1024,7 @@ class FlashInferImpl(AttentionImpl): prefill_output: Optional[torch.Tensor] = None decode_output: Optional[torch.Tensor] = None + stride_order = FlashInferBackend.get_kv_cache_stride_order() if prefill_meta := attn_metadata.prefill_metadata: # We will use flash attention for prefill # when kv_cache is not provided. @@ -1038,7 +1056,7 @@ class FlashInferImpl(AttentionImpl): prefill_output = prefill_meta.prefill_wrapper.run( query, - kv_cache, + kv_cache.permute(*stride_order), k_scale=layer._k_scale_float, v_scale=layer._v_scale_float, ) @@ -1053,7 +1071,7 @@ class FlashInferImpl(AttentionImpl): decode_output = decode_meta.decode_wrapper.run( decode_query, - kv_cache, + kv_cache.permute(*stride_order), k_scale=layer._k_scale_float, v_scale=layer._v_scale_float, ) diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 15625612e08e45c670c567de3c721fbf70fbfce4..55a63a81677fd5725d3484f64b2085c01d9c4ae6 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -4,14 +4,14 @@ # Copyright (C) 2024 Habana Labs, Ltd. an Intel Company ############################################################################### -import os from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Type import torch +import vllm_hpu_extension.kernels as kernels import vllm_hpu_extension.ops as ops -from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax, - VLLMKVCache) +from vllm_hpu_extension.flags import enabled_flags +from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, @@ -126,7 +126,15 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): self.block2batch_matmul = Matmul() self.k_cache = VLLMKVCache() self.v_cache = VLLMKVCache() - ops.pa_impl = ops.pa + self.fused_scaled_dot_product_attention = kernels.fsdpa() + + self.prefill_impl = 'naive' + if "flex_attention" in enabled_flags(): + self.prefill_impl = 'flex' + if "fsdpa" in enabled_flags(): + assert alibi_slopes is None, \ + 'Prefill with FusedSDPA not supported with alibi slopes!' + self.prefill_impl = 'fsdpa' self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.sliding_window = sliding_window @@ -138,19 +146,9 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - self.prefill_usefusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', - '0').lower() in ['1', 'true'] - self.fused_scaled_dot_product_attention = None - if self.prefill_usefusedsdpa: + if self.prefill_impl == 'fsdpa': assert alibi_slopes is None, \ 'Prefill with FusedSDPA not supported with alibi slopes!' - try: - from habana_frameworks.torch.hpex.kernels import FusedSDPA - self.fused_scaled_dot_product_attention = ModuleFusedSDPA( - FusedSDPA) - except ImportError: - logger.warning("Could not import HPU FusedSDPA kernel. " - "vLLM will use native implementation.") supported_head_sizes = HPUPagedAttention.get_supported_head_sizes() if head_size not in supported_head_sizes: @@ -158,7 +156,8 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {supported_head_sizes}.") - if attn_type != AttentionType.DECODER: + self.attn_type = attn_type + if self.attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " "are not implemented for " @@ -192,15 +191,18 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): batch_size, seq_len, hidden_size = query.shape _, seq_len_kv, _ = key.shape - query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) block_indices = attn_metadata.block_indices block_offsets = attn_metadata.block_offsets - if attn_metadata.is_prompt: + key_cache = None + value_cache = None + if attn_metadata.is_prompt and self.attn_type \ + is not AttentionType.ENCODER_ONLY \ + and attn_metadata.block_list is None: key = key.unflatten(0, (block_indices.size(0), -1)) value = value.unflatten(0, (block_indices.size(0), -1)) - if kv_cache is not None: + if kv_cache is not None and isinstance(kv_cache, tuple): key_cache, value_cache = HPUPagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) @@ -214,36 +216,28 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): if attn_metadata.is_prompt: # Prompt run. - if not self.prefill_usefusedsdpa: - # TODO: move this outside of model - assert attn_metadata.attn_bias is not None, \ - 'attn_bias must be set before calling model.forward!' - attn_bias = attn_metadata.attn_bias - if self.alibi_slopes is not None: - position_bias = _make_alibi_bias(self.alibi_slopes, - self.num_kv_heads, - attn_bias.dtype, - attn_bias.shape[-1]) - attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1)) - attn_bias.add_(position_bias) - else: - attn_bias = None - query_shape = (batch_size, seq_len, self.num_heads, self.head_size) kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, self.head_size) + + attn_bias = attn_metadata.attn_bias + if attn_bias is not None and self.alibi_slopes is not None: + position_bias = _make_alibi_bias(self.alibi_slopes, + self.num_kv_heads, + attn_bias.dtype, + attn_bias.shape[-1]) + attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1)) + attn_bias.add_(position_bias) + out = ops.prompt_attention( - query.view(query_shape), - key.view(kv_shape), - value.view(kv_shape), + impl=self.prefill_impl, + query=query.view(query_shape), + key=key.view(kv_shape), + value=value.view(kv_shape), + is_causal=True, attn_bias=attn_bias, - p=0.0, - scale=self.scale, - matmul_qk_op=self.matmul_qk, - softmax_op=self.softmax, - matmul_av_op=self.matmul_av, - fsdpa_op=self.fused_scaled_dot_product_attention, - ) + valid_seq_lengths=attn_metadata.seq_lens_tensor, + **self.common_attention_args()) output = out.reshape(batch_size, seq_len, hidden_size) else: # Decoding run. @@ -254,18 +248,26 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): block_list=attn_metadata.block_list, block_mapping=attn_metadata.block_mapping, block_bias=attn_metadata.attn_bias, - block_scales=attn_metadata.block_scales, block_groups=attn_metadata.block_groups, - scale=self.scale, - matmul_qk_op=self.matmul_qk, - matmul_av_op=self.matmul_av, - batch2block_matmul_op=self.batch2block_matmul, - block2batch_matmul_op=self.block2batch_matmul, - keys_fetch_func=self.k_cache.fetch_from_cache, - values_fetch_func=self.v_cache.fetch_from_cache) + **self.common_attention_args()) # Reshape the output tensor. return output.view(batch_size, seq_len, hidden_size) + def common_attention_args(self): + fsdpa_op = self.fused_scaled_dot_product_attention.apply \ + if self.fused_scaled_dot_product_attention is not None else None + return { + 'scale': self.scale, + 'matmul_qk_op': self.matmul_qk, + 'matmul_av_op': self.matmul_av, + 'batch2block_matmul_op': self.batch2block_matmul, + 'block2batch_matmul_op': self.block2batch_matmul, + 'fsdpa_op': fsdpa_op, + 'keys_fetch_func': self.k_cache.fetch_from_cache, + 'values_fetch_func': self.v_cache.fetch_from_cache, + 'softmax_op': self.softmax, + } + def _make_alibi_bias( alibi_slopes: torch.Tensor, diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 99917a92af5f97dd4345623b78dfb1c0f7114098..27959caa651a4a54191e2fba936d623007598308 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -220,8 +220,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): value_cache, attn_metadata.slot_mapping.flatten(), self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, + layer._k_scale_float, + layer._v_scale_float, ) if attn_metadata.is_prompt: @@ -306,8 +306,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): max_seq_len, self.alibi_slopes, self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, + layer._k_scale_float, + layer._v_scale_float, ) else: # Run PagedAttention V2. @@ -339,8 +339,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): max_seq_len, self.alibi_slopes, self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, + layer._k_scale_float, + layer._v_scale_float, ) # Reshape the output tensor. diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 02eba98c16ebe8180a8d714fa4424ee84aaa526b..ff36ccadece58d6617eae63f9aa63d300102e573 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -205,6 +205,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) from vllm.attention.ops.merge_attn_states import merge_attn_states +from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, RowParallelLinear, UnquantizedLinearMethod) @@ -214,7 +215,6 @@ from vllm.multimodal import MultiModalPlaceholderMap from vllm.platforms import current_platform from vllm.triton_utils import HAS_TRITON from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down -from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version if HAS_TRITON: from vllm.attention.ops.triton_flash_attention import triton_attention @@ -711,12 +711,24 @@ class MLACommonMetadata(AttentionMetadata): self.seq_lens[i] += 1 self.max_decode_seq_len = max(self.seq_lens) + self._ops_advance_step(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions) + + def _ops_advance_step(self, num_seqs: int, num_queries: int, + block_size: int, input_tokens: torch.Tensor, + sampled_token_ids: torch.Tensor, + input_positions: torch.Tensor) -> None: + # here we use advance_step_flashinfo to update the paged_kv_* tensors ops.advance_step_flashattn(num_seqs=num_seqs, num_queries=num_queries, block_size=block_size, - input_tokens=model_input.input_tokens, + input_tokens=input_tokens, sampled_token_ids=sampled_token_ids, - input_positions=model_input.input_positions, + input_positions=input_positions, seq_lens=self.seq_lens_tensor, slot_mapping=self.slot_mapping, block_tables=self.block_tables) @@ -727,6 +739,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): NOTE: Please read the comment at the top of the file before trying to understand this class """ + BLOCK_TABLE_EXTENDER: list[list[int]] = [] def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.input_builder = input_builder @@ -877,8 +890,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): num_seqs = len(seq_lens) if use_captured_graph: self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) - self.block_tables.extend([] * cuda_graph_pad_size) + self.block_tables.extend(self.__class__.BLOCK_TABLE_EXTENDER * + cuda_graph_pad_size) num_decode_tokens = batch_size - self.num_prefill_tokens + block_tables = self._get_graph_runner_block_tables( num_seqs, self.block_tables) else: @@ -1043,8 +1058,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): self.q_proj = q_proj self.kv_b_proj = kv_b_proj self.o_proj = o_proj - self.triton_fa_func = triton_attention + self.triton_fa_func = triton_attention # Handle the differences between the flash_attn_varlen from flash_attn # and the one from vllm_flash_attn. The former is used on RoCM and the # latter has an additional parameter to control FA2 vs FA3 @@ -1057,6 +1072,82 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim for attention backends that do + # not support different headdims + # We don't need to pad V if we are on a hopper system with FA3 + self._pad_v = self.vllm_flash_attn_version is None or not ( + self.vllm_flash_attn_version == 3 + and current_platform.get_device_capability()[0] == 9 + and torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count == 120 ) + + def _flash_attn_varlen_diff_headdims(self, q, k, v, softmax_scale, + return_softmax_lse, **kwargs): + maybe_padded_v = v + if self._pad_v: + # maybe_padded_v = torch.nn.functional.pad( + # v, [0, q.shape[-1] - v.shape[-1]], value=0) + maybe_padded_v = torch.nn.functional.pad( + v, [0, q.shape[-1] - v.shape[-1]] - 32, value=0) + v_tmp = maybe_padded_v[..., :-32].reshape(v.shape[0], v.shape[1],v.shape[2]) + + if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN \ + and not return_softmax_lse: + attn_out = self.triton_fa_func( + q, + k, + maybe_padded_v, + None, # output + kwargs["cu_seqlens_q"], + kwargs["cu_seqlens_k"], + kwargs["max_seqlen_q"], + kwargs["max_seqlen_k"], + kwargs["causal"], + softmax_scale, + None, # bias + ) + if is_vllm_fa: + attn_out = self.flash_attn_varlen_func( + q=q, + k=k, + v=maybe_padded_v, + return_softmax_lse=return_softmax_lse, + softmax_scale=softmax_scale, + **kwargs, + ) + else: + # Use return_attn_probs instead of return_softmax_lse for RoCM + attn_out = self.flash_attn_varlen_func( + q=q, + k=k, + # v=maybe_padded_v, + v = v_tmp, + return_attn_probs=return_softmax_lse, + softmax_scale=softmax_scale, + **kwargs, + ) + + # Unpack the output if there is multiple results, + # triton always returns (output, softmax_lse), + # vllm_flash_attn returns (output, softmax_lse) when + # `return_softmax_lse = True` + # flash_attn (RoCM) returns (output, softmax_lse, ...) when + # `return_attn_probs = True` + rest = None + if isinstance(attn_out, tuple): + attn_out, *rest = attn_out + + # unpad if necessary + if self._pad_v: + attn_out = attn_out[..., :v.shape[-1]] + + # Remain consistent with old `flash_attn_varlen_func` where there + # is only one output tensor if `return_softmax_lse` is False. + if return_softmax_lse: + assert rest is not None + return attn_out, rest[0] + return attn_out + def _v_up_proj_and_o_proj(self, x): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) @@ -1181,40 +1272,19 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - # For MLA the v head dim is smaller than qk head dim so we pad - # out v with 0s to match the qk head dim - v_padded = torch.nn.functional.pad(v, - [0, q.shape[-1] - v.shape[-1]], - value=0) - - if is_vllm_fa: - attn_output, attn_softmax_lse = self.flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], - max_seqlen_q=prefill_metadata.max_query_len, - max_seqlen_k=prefill_metadata. - context_chunk_max_seq_lens[i], - softmax_scale=self.scale, - causal=False, # Context is unmasked - return_softmax_lse=True, - ) - else: - attn_output, attn_softmax_lse, _ = self.flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], - max_seqlen_q=prefill_metadata.max_query_len, - max_seqlen_k=prefill_metadata. - context_chunk_max_seq_lens[i], - softmax_scale=self.scale, - causal=False, # Context is unmasked - return_attn_probs=True, - ) + attn_output, attn_softmax_lse = \ + self._flash_attn_varlen_diff_headdims( + q=q, + k=k, + v=v, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], + max_seqlen_q=prefill_metadata.max_query_len, + max_seqlen_k=prefill_metadata.context_chunk_max_seq_lens[i], + softmax_scale=self.scale, + causal=False, # Context is unmasked + return_softmax_lse=True, + ) if output is None: output = attn_output @@ -1257,61 +1327,22 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - # For MLA the v head dim is smaller than qk head dim so we pad out - # v with 0s to match the qk head dim - # v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], - # value=0) - v_padded = torch.nn.functional.pad(v, [0, (q.shape[-1] - v.shape[-1] -32)], - value=0) - v_tmp = v_padded[..., :-32].reshape(v.shape[0], v.shape[1],v.shape[2]) - - if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN and not has_context: - output = self.triton_fa_func( - q, - k, - v_padded, - None, - prefill_metadata.query_start_loc, - prefill_metadata.query_start_loc, - prefill_metadata.max_prefill_seq_len, - prefill_metadata.max_prefill_seq_len, - True, # causal - self.scale, - None, # attn_mask is None unless applying ALiBi mask - ) - ## triton flash attention always return 2 objects - if not has_context: - output = output[0] - elif is_vllm_fa: - output = self.flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.query_start_loc, - max_seqlen_q=prefill_metadata.max_prefill_seq_len, - max_seqlen_k=prefill_metadata.max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - return_softmax_lse=has_context, - ) - else: - output = self.flash_attn_varlen_func( - q=q, - k=k, - v=v_tmp if torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count == 120 else v, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.query_start_loc, - max_seqlen_q=prefill_metadata.max_prefill_seq_len, - max_seqlen_k=prefill_metadata.max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - return_attn_probs=has_context, - ) + output = self._flash_attn_varlen_diff_headdims( + q=q, + k=k, + v=v, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.query_start_loc, + max_seqlen_q=prefill_metadata.max_prefill_seq_len, + max_seqlen_k=prefill_metadata.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + return_softmax_lse=has_context, + ) if has_context: # ROCm flash_attn_varlen_func will return 3 objects instead of 2 - suffix_output, suffix_lse, *rest = output + suffix_output, suffix_lse = output context_output, context_lse = self._compute_prefill_context( \ q, kv_c_and_k_pe_cache, attn_metadata) @@ -1324,14 +1355,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): suffix_lse=suffix_lse, ) - # slice by `:v.shape[-1]` in order to remove v headdim padding - # output = output\ - # .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ - # .reshape(-1, self.num_heads * v.shape[-1]) - output = output\ - .reshape(-1, self.num_heads * v.shape[-1]) - - return self.o_proj(output)[0] + return self.o_proj(output.flatten(start_dim=-2))[0] @abstractmethod def _forward_decode( diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py new file mode 100644 index 0000000000000000000000000000000000000000..6e695b78e0e1536ff9953620585214c5f7e958df --- /dev/null +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -0,0 +1,412 @@ +# SPDX-License-Identifier: Apache-2.0 + +from contextlib import contextmanager +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional, Type, Union + +import torch + +import vllm._custom_ops as ops +import vllm.envs as envs +from vllm.attention.backends.mla.common import (MLACommonBackend, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder, + MLACommonState) +from vllm.attention.backends.utils import (compute_slot_mapping, + compute_slot_mapping_start_idx, + is_block_tables_empty) +from vllm.attention.ops.rocm_aiter_mla import (aiter_mla_decode_fwd, + get_aiter_mla_metadata) + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUBuilder + + +def is_aiter_mla_enabled() -> bool: + return envs.VLLM_ROCM_USE_AITER \ + and envs.VLLM_ROCM_USE_AITER_MLA + + +class AiterMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "ROCM_AITER_MLA" + + @staticmethod + def get_impl_cls() -> Type["AiterMLAImpl"]: + return AiterMLAImpl + + @staticmethod + def get_metadata_cls() -> Type["AiterMLAMetadata"]: + return AiterMLAMetadata + + @staticmethod + def get_builder_cls() -> Type["AiterMLAMetadataBuilder"]: + return AiterMLAMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["AiterMLAState"]: + return AiterMLAState + + +@dataclass +class AiterMLAMetadata(MLACommonMetadata): + # The following 4 tensors are for current version of AITER MLA + block_table_bound: Optional[torch.Tensor] = None + # The indptr of the paged kv cache, shape: [batch_size + 1] + paged_kv_indptr: Optional[torch.Tensor] = None + # The page indices of the paged kv cache + paged_kv_indices: Optional[torch.Tensor] = None + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + paged_kv_last_page_lens: Optional[torch.Tensor] = None + + @property + def prefill_metadata(self): + prefill_metadata = super().prefill_metadata + self._cached_prefill_metadata = prefill_metadata + + if prefill_metadata is not None: + prefill_metadata.paged_kv_indptr = self.paged_kv_indptr + prefill_metadata.paged_kv_indices = self.paged_kv_indices + prefill_metadata\ + .paged_kv_last_page_lens = self.paged_kv_last_page_lens + prefill_metadata.block_table_bound = self.block_table_bound + + # update the cache + self._cached_prefill_metadata = self.__class__( + **prefill_metadata.__dict__) + + return self._cached_prefill_metadata + + @property + def decode_metadata(self): + decode_metadata = super().decode_metadata + + self._cached_decode_metadata = decode_metadata + + if decode_metadata is not None: + decode_metadata.paged_kv_indptr = self.paged_kv_indptr + decode_metadata.paged_kv_indices = self.paged_kv_indices + decode_metadata\ + .paged_kv_last_page_lens = self.paged_kv_last_page_lens + decode_metadata.block_table_bound = self.block_table_bound + + # update the cache + self._cached_decode_metadata = self.__class__( + **decode_metadata.__dict__) + + return self._cached_decode_metadata + + def _ops_advance_step(self, num_seqs: int, num_queries: int, + block_size: int, input_tokens: torch.Tensor, + sampled_token_ids: torch.Tensor, + input_positions: torch.Tensor) -> None: + + ops.advance_step_flashinfer( + num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables, + paged_kv_indices=self.paged_kv_indices, + paged_kv_indptr=self.paged_kv_indptr, + paged_kv_last_page_lens=self.paged_kv_last_page_lens, + block_table_bound=self.block_table_bound) + + +class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): + BLOCK_TABLE_EXTENDER: list[list[int]] = [[]] + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + super().__init__(input_builder) + assert self.runner.model_config.max_model_len == 32768,\ + "AITER MLA requires max model len to be set to 32768" + assert self.block_size == 1, "AITER MLA requires only block size 1." + + def prepare(self): + super().prepare() + self.paged_kv_indices: list[int] = [] + self.paged_kv_indptr: list[int] = [0] + self.paged_kv_last_page_lens: list[int] = [] + self.total_blocks = 0 + + def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool, + prefix_cache_hit: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block, input_positions) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks, + inter_data.input_positions): + self.input_positions.extend(input_positions) + self.context_lens.append(context_len) + if is_prompt: + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if prefix_cache_hit: + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + block_table = block_tables[seq_id] + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + if curr_sliding_window_block == 0: + block_table = block_tables[seq_id] + else: + block_table = block_tables[seq_id][ + -curr_sliding_window_block:] + self.block_tables.append(block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + if is_profile_run: + return + + # Update paged_kv_* tensors only for non-profile run + block_table = block_tables[seq_id] + self._update_paged_kv_tensors(block_table, seq_len) + + def _update_paged_kv_tensors(self, block_table: list[int], seq_len: int): + # Get the number of valid blocks based on sequence length. + # If seq_len = 16, block_size = 16, + # block_table_bound is 1 with 1 valid block. + # If seq_len = 15, block_size = 16, + # block_table_bound is 0 + 1 with 1 valid block. + self.total_blocks += len(block_table) + block_table_bound = seq_len // self.block_size + 1 \ + if seq_len % self.block_size != 0 \ + else seq_len // self.block_size + self.paged_kv_indices.extend(block_table[:block_table_bound]) + self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + + block_table_bound) + + last_page_len = seq_len % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + self.paged_kv_last_page_lens.append(last_page_len) + + def build(self, seq_lens: list[int], query_lens: list[int], + cuda_graph_pad_size: int, batch_size: int) -> AiterMLAMetadata: + metadata = super().build(seq_lens, query_lens, cuda_graph_pad_size, + batch_size) + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + if use_captured_graph: + last_paged_kv_indptr = self.paged_kv_indptr[-1] + self.paged_kv_indptr.extend([last_paged_kv_indptr] * + cuda_graph_pad_size) + self.paged_kv_last_page_lens.extend([0] * cuda_graph_pad_size) + + # For current version of AITER MLA + if len(self.paged_kv_indptr) > 0: + # extend to the maximum number of blocks as returned by the + # scheduler + self.paged_kv_indices.extend( + [0] * (self.total_blocks - len(self.paged_kv_indices))) + paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, + device=device, + dtype=torch.int) + paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr, + device=device, + dtype=torch.int) + paged_kv_last_page_lens_tensor = torch.tensor( + self.paged_kv_last_page_lens, device=device, dtype=torch.int) + block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) - + 1, + device=device, + dtype=torch.int) + else: + paged_kv_indices_tensor = None + paged_kv_indptr_tensor = None + paged_kv_last_page_lens_tensor = None + block_table_bound_tensor = None + + metadata.paged_kv_indptr = paged_kv_indptr_tensor + metadata.paged_kv_indices = paged_kv_indices_tensor + metadata.paged_kv_last_page_lens = paged_kv_last_page_lens_tensor + metadata.block_table_bound = block_table_bound_tensor + + return metadata + + +class AiterMLAState(MLACommonState[AiterMLAMetadata]): + + @contextmanager + def graph_capture(self, max_batch_size: int): + kv_indices, kv_indptr, last_page_lens = get_aiter_mla_metadata( + max_batch_size=max_batch_size, + block_size=self.runner.block_size, + max_block_per_batch=self.runner.get_max_block_per_batch(), + device=self.runner.device) + self._paged_kv_indices_tensor = kv_indices + self._paged_kv_indptr_tensor = kv_indptr + self._paged_kv_last_page_lens_tensor = last_page_lens + + with super().graph_capture(max_batch_size): + yield + + del self._paged_kv_indices_tensor + del self._paged_kv_indptr_tensor + del self._paged_kv_last_page_lens_tensor + + def graph_capture_get_metadata_for_batch( + self, + batch_size: int, + is_encoder_decoder_model: bool = False) -> AiterMLAMetadata: + + metadata = super().graph_capture_get_metadata_for_batch( + batch_size, is_encoder_decoder_model) + + paged_kv_indptr = self._paged_kv_indptr_tensor[:batch_size + 1] + paged_kv_indices = self._paged_kv_indices_tensor + paged_kv_last_page_lens = self._paged_kv_last_page_lens_tensor[: + batch_size] + + metadata.paged_kv_indptr = paged_kv_indptr + metadata.paged_kv_indices = paged_kv_indices + metadata.paged_kv_last_page_lens = paged_kv_last_page_lens + + return metadata + + def get_graph_input_buffers(self, + attn_metadata: AiterMLAMetadata, + is_encoder_decoder_model: bool = False): + input_buffers = super().get_graph_input_buffers( + attn_metadata, is_encoder_decoder_model) + input_buffers[ + 'paged_kv_indptr'] = attn_metadata.decode_metadata.paged_kv_indptr + input_buffers[ + "paged_kv_indices"] = attn_metadata.\ + decode_metadata.paged_kv_indices + input_buffers[ + "paged_kv_last_page_lens"] = attn_metadata.\ + decode_metadata.paged_kv_last_page_lens + + return input_buffers + + def prepare_graph_input_buffers(self, + input_buffers, + attn_metadata: AiterMLAMetadata, + is_encoder_decoder_model: bool = False): + super().prepare_graph_input_buffers(input_buffers, attn_metadata, + is_encoder_decoder_model) + + num_total_blocks = attn_metadata.decode_metadata.paged_kv_indices.shape[ + 0] + input_buffers["paged_kv_indptr"].copy_( + attn_metadata.decode_metadata.paged_kv_indptr, non_blocking=True) + input_buffers["paged_kv_indices"][:num_total_blocks].copy_( + attn_metadata.decode_metadata.paged_kv_indices, non_blocking=True) + input_buffers["paged_kv_last_page_lens"].copy_( + attn_metadata.decode_metadata.paged_kv_last_page_lens, + non_blocking=True) + + +class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + **mla_args) + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "Aiter MLA does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + from aiter import flash_attn_varlen_func + self.flash_attn_varlen_func = flash_attn_varlen_func + + def _flash_attn_varlen_diff_headdims( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + softmax_scale: float, return_softmax_lse: bool, + **kwargs) -> Union[tuple[torch.Tensor, ...], torch.Tensor]: + output = self.flash_attn_varlen_func( + q=q, + k=k, + v=v, + softmax_scale=softmax_scale, + return_lse=return_softmax_lse, + **kwargs, + ) + + return output + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: AiterMLAMetadata, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + + decode_meta = attn_metadata.decode_metadata + assert decode_meta is not None + B = q_nope.shape[0] + + q = torch.cat([q_nope, q_pe], dim=-1) + o = torch.zeros(B, + self.num_heads, + self.kv_lora_rank, + dtype=q.dtype, + device=q.device) + + kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) + + aiter_mla_decode_fwd(q, kv_buffer, o, self.scale, + attn_metadata.paged_kv_indptr, + attn_metadata.paged_kv_indices, + attn_metadata.paged_kv_last_page_lens) + + return self._v_up_proj_and_o_proj(o) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 600f4471bd3d5ee02f0b4d95b3761b68180e8ad6..93eb2bd30813f00927223d3254a5b4f08a335afa 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -2,6 +2,7 @@ """Attention layer ROCm GPUs.""" import itertools from dataclasses import dataclass +from functools import cache from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import torch @@ -29,7 +30,34 @@ logger = init_logger(__name__) _PARTITION_SIZE_ROCM = 256 +@cache +def is_rocm_aiter_paged_attn_enabled() -> bool: + return envs.VLLM_ROCM_USE_AITER_PAGED_ATTN \ + and envs.VLLM_ROCM_USE_AITER \ + + +@cache +def _get_paged_attn_module() -> PagedAttention: + """ + Initializes the appropriate PagedAttention module from `attention/ops`, + which is used as helper function + by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`. + + The choice of attention module depends on whether + AITER paged attention is enabled: + - If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`. + - Otherwise, it defaults to using the original `PagedAttention`. + """ + if is_rocm_aiter_paged_attn_enabled(): + # Import AITERPagedAttention only when the flag is enabled + from vllm.attention.ops.rocm_aiter_paged_attn import ( + AITERPagedAttention) + return AITERPagedAttention() + return PagedAttention() + + class ROCmFlashAttentionBackend(AttentionBackend): + accept_output_buffer: bool = True @staticmethod def get_name() -> str: @@ -58,8 +86,9 @@ class ROCmFlashAttentionBackend(AttentionBackend): num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: - return PagedAttention.get_kv_cache_shape(num_blocks, block_size, - num_kv_heads, head_size) + paged_attn = _get_paged_attn_module() + return paged_attn.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) @staticmethod def swap_blocks( @@ -67,14 +96,16 @@ class ROCmFlashAttentionBackend(AttentionBackend): dst_kv_cache: torch.Tensor, src_to_dst: torch.Tensor, ) -> None: - PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + paged_attn = _get_paged_attn_module() + paged_attn.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: - PagedAttention.copy_blocks(kv_caches, src_to_dists) + paged_attn = _get_paged_attn_module() + paged_attn.copy_blocks(kv_caches, src_to_dists) @dataclass @@ -504,7 +535,10 @@ class ROCmFlashAttentionImpl(AttentionImpl): assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - supported_head_sizes = PagedAttention.get_supported_head_sizes() + self.paged_attn_module = _get_paged_attn_module() + supported_head_sizes = self.paged_attn_module.get_supported_head_sizes( + ) + if head_size not in supported_head_sizes: raise ValueError( f"Head size {head_size} is not supported by PagedAttention. " @@ -524,12 +558,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 triton_attention) - self.attn_func = triton_attention - - # from vllm.attention.ops.flash_attn_triton_mqa_gqa import ( - # flash_attn_varlen_func) - # self.attn_func = flash_attn_varlen_func - + self.triton_attn_func = triton_attention logger.debug("Using Triton FA in ROCmBackend") if self.sliding_window != (-1, -1): logger.warning("ROCm Triton FA does not currently support " @@ -545,7 +574,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): else: try: from flash_attn import flash_attn_varlen_func # noqa: F401 - self.attn_func = flash_attn_varlen_func + self.fa_attn_func = flash_attn_varlen_func logger.debug("Using CUTLASS FA in ROCmBackend") except ModuleNotFoundError: self.use_naive_attn = True @@ -556,9 +585,11 @@ class ROCmFlashAttentionImpl(AttentionImpl): "ROCm Naive FlashAttention does not support " "attention logits soft capping.") - self.attn_func = _sdpa_attention + self.sdpa_attn_func = _sdpa_attention logger.debug("Using naive (SDPA) attention in ROCmBackend") + self.aiter_kv_scales_initialized = False + def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" tokens, n_kv_heads, head_dim = x.shape @@ -627,6 +658,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): Returns: shape = [num_tokens, num_heads * head_size] """ + assert output is not None, "Output tensor must be provided." + query = query.view(-1, self.num_heads, self.head_size) if key is not None: assert value is not None @@ -635,12 +668,37 @@ class ROCmFlashAttentionImpl(AttentionImpl): else: assert value is None + paged_attn = self.paged_attn_module + + # Reshaping kv tensors is required for AITER paged attention kernel + # because it works on a different tensor shape, + # when the size of one element is one byte (int8/fp8 dtypes). + # This reshaping is only required on the first forward call + # and the kv cache must not be empty. + if (is_rocm_aiter_paged_attn_enabled() and kv_cache.dtype.itemsize == 1 + and not self.aiter_kv_scales_initialized + and kv_cache.shape != torch.Size([0])): + num_blocks = kv_cache.shape[1] + block_size = kv_cache.shape[2] // (self.num_kv_heads * + self.head_size) + k_scale = torch.empty((self.num_kv_heads, num_blocks * block_size), + dtype=torch.float32, + device=kv_cache.device) + v_scale = torch.empty((self.num_kv_heads, num_blocks * block_size), + dtype=torch.float32, + device=kv_cache.device) + self.aiter_kv_scales_initialized = True + k_scale.fill_(layer._k_scale.item()) + v_scale.fill_(layer._v_scale.item()) + layer._k_scale = k_scale + layer._v_scale = v_scale + # Only update KV cache for decoder self-attention # and encoder-decoder cross-attention if self.attn_type not in [ AttentionType.ENCODER, AttentionType.ENCODER_ONLY ] and kv_cache.numel() > 0: - key_cache, value_cache = PagedAttention.split_kv_cache( + key_cache, value_cache = paged_attn.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) if key is not None and value is not None: @@ -648,7 +706,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): # cache. If kv_cache is not provided, the new key and value # tensors are not cached. This happens during the initial # memory profiling run. - PagedAttention.write_to_paged_cache( + paged_attn.write_to_paged_cache( key, value, key_cache, @@ -670,7 +728,6 @@ class ROCmFlashAttentionImpl(AttentionImpl): assert attn_metadata.num_encoder_tokens is not None num_prefill_tokens = attn_metadata.num_encoder_tokens - output = torch.empty_like(query) # Query for decode. KV is not needed because it is already cached. decode_query = query[num_prefill_tokens:] # QKV for prefill. @@ -718,11 +775,17 @@ class ROCmFlashAttentionImpl(AttentionImpl): query.dtype, seq_lens, make_attn_mask=causal_mask) # type: ignore - out, _ = self.attn_func( + use_fp8_scales = (layer._q_scale and layer._k_scale + and layer._v_scale and layer._prob_scale + and self.kv_cache_dtype == "fp8") + full_scales = ( + layer._q_scale, layer._k_scale, layer._v_scale, + layer._prob_scale) if use_fp8_scales else None + self.triton_attn_func( query, key, value, - None, + output[:num_prefill_tokens], query_seq_start_loc, key_seq_start_loc, query_max_seq_len, @@ -731,6 +794,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): self.scale, attn_masks[0][None] if attn_masks is not None else None, + full_scales, ) elif self.use_naive_attn: if self.num_kv_heads != self.num_heads: @@ -747,10 +811,11 @@ class ROCmFlashAttentionImpl(AttentionImpl): key = key.movedim(0, key.dim() - 2) value = value.movedim(0, value.dim() - 2) # sdpa math backend attention - out = self.attn_func( + self.sdpa_attn_func( query, key, value, + output[:num_prefill_tokens], query_seq_start_loc, num_prefill_tokens, self.num_heads, @@ -759,7 +824,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): attn_masks, ) else: - out = self.attn_func( + # upstream FA does not support an output arg, copy + output[:num_prefill_tokens] = self.fa_attn_func( q=query, k=key, v=value, @@ -774,34 +840,27 @@ class ROCmFlashAttentionImpl(AttentionImpl): softcap=self.logits_soft_cap, ) - # common code for prefill - assert output[:num_prefill_tokens].shape == out.shape - if output.shape[0] > num_prefill_tokens: - output[:num_prefill_tokens] = out - else: - output = out else: # prefix-enabled attention - # not applicable for encoder-only models version_key = triton_key() if self.attn_type != AttentionType.ENCODER_ONLY: - output[: - num_prefill_tokens] = PagedAttention.forward_prefix( - query, - key, - value, - self.kv_cache_dtype, - key_cache, - value_cache, - prefill_meta.block_tables, - prefill_meta.query_start_loc, - prefill_meta.seq_lens_tensor, - prefill_meta.max_query_len, - self.alibi_slopes, - self.sliding_window[0], - layer._k_scale, - layer._v_scale, - ) + output[:num_prefill_tokens] = paged_attn.forward_prefix( + query, + key, + value, + self.kv_cache_dtype, + key_cache, + value_cache, + prefill_meta.block_tables, + prefill_meta.query_start_loc, + prefill_meta.seq_lens_tensor, + prefill_meta.max_query_len, + self.alibi_slopes, + self.sliding_window[0], + layer._k_scale, + layer._v_scale, + ) # Skip decode phase for encoder-only models if (decode_meta := attn_metadata.decode_metadata) and ( self.attn_type != AttentionType.ENCODER_ONLY): @@ -834,14 +893,10 @@ class ROCmFlashAttentionImpl(AttentionImpl): device=output.device, ) max_logits = torch.empty_like(exp_sums) - if num_prefill_tokens > 0: - out = output[num_prefill_tokens:] - else: - out = output query_start_loc = None ops.paged_attention_rocm( - out, + output[num_prefill_tokens:], exp_sums, max_logits, tmp_output, @@ -866,7 +921,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ) else: tree_attention_masks_tensor = decode_meta.tree_attention_masks_tensor - output[num_prefill_tokens:] = PagedAttention.forward_decode( + output[num_prefill_tokens:] = paged_attn.forward_decode( decode_query, key_cache, value_cache, @@ -897,7 +952,8 @@ def _sdpa_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - seq_lens: List[int], + output: torch.Tensor, + seq_lens: torch.Tensor, num_tokens: int, num_heads: int, head_size: int, @@ -905,9 +961,9 @@ def _sdpa_attention( attn_masks: Optional[List[torch.Tensor]] = None, ) -> torch.Tensor: start = 0 - output = torch.empty((num_tokens, num_heads, head_size), - dtype=query.dtype, - device=query.device) + assert output.shape == (num_tokens, num_heads, head_size) + assert output.dtype == query.dtype + assert output.device == query.device for i, seq_len in enumerate(seq_lens): end = start + seq_len diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index a983b20e001725d5b329ca74d63544067f3b76e7..81bbb22582694c5797aecfc6b6e3278e7c5a0bc3 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -2,8 +2,10 @@ """Attention backend utils""" from collections import defaultdict from contextlib import contextmanager +from dataclasses import dataclass from itertools import accumulate -from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union +from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, + TypeVar, Union) import numpy as np import torch @@ -11,6 +13,7 @@ import torch from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, AttentionState) from vllm.attention.backends.abstract import AttentionType +from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.multimodal import MultiModalPlaceholderMap from vllm.utils import async_tensor_h2d, make_tensor_with_pad @@ -584,3 +587,24 @@ def get_num_prefill_decode_query_kv_tokens( return (num_prefill_query_tokens, num_prefill_kv_tokens, num_decode_query_tokens) + + +@dataclass +class MLADims: + q_lora_rank: Optional[int] + kv_lora_rank: int + qk_nope_head_dim: int + qk_rope_head_dim: int + v_head_dim: int + + +def get_mla_dims(model_config: ModelConfig) -> MLADims: + hf_text_config = model_config.hf_text_config + + return MLADims( + q_lora_rank=getattr(hf_text_config, "q_lora_rank", None), + kv_lora_rank=hf_text_config.kv_lora_rank, + qk_nope_head_dim=hf_text_config.qk_nope_head_dim, + qk_rope_head_dim=hf_text_config.qk_rope_head_dim, + v_head_dim=hf_text_config.v_head_dim, + ) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index dbf4723ee1bd724d8fe5ba1cfbf86e9f25888f13..aa218cc37af96d36f4c8235dd679599aa1313a7f 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -10,6 +10,9 @@ import vllm.envs as envs from vllm.attention import AttentionType from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.config import CacheConfig, get_current_vllm_config +from vllm.distributed.kv_transfer import (get_kv_transfer_group, + has_kv_transfer_group, + is_v1_kv_transfer_group) from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.quantization.base_config import ( @@ -87,6 +90,7 @@ class Attention(nn.Module): # FlashAttn doesn't support quantizing the kv-cache only # but requires q to be quantized as well. self._q_scale = torch.tensor(1.0, dtype=torch.float32) + self._prob_scale = torch.tensor(1.0, dtype=torch.float32) # We also keep the float32 versions of k/v_scale for attention # backends that don't support tensors (Flashinfer) @@ -329,17 +333,54 @@ class MultiHeadAttention(nn.Module): return out.reshape(bsz, q_len, -1) +def wait_for_kv_layer_from_connector(layer_name: str): + if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): + return + + connector = get_kv_transfer_group() + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return + + connector.wait_for_layer_load(layer_name) + + +def maybe_save_kv_layer_to_connector( + layer_name: str, + kv_cache_layer: List[torch.Tensor], +): + if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): + return + + connector = get_kv_transfer_group() + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return + + connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata) + + def unified_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, layer_name: str, ) -> torch.Tensor: + wait_for_kv_layer_from_connector(layer_name) + forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] - return self.impl.forward(self, query, key, value, kv_cache, attn_metadata) + output = self.impl.forward(self, query, key, value, kv_cache, + attn_metadata) + + maybe_save_kv_layer_to_connector(layer_name, kv_cache) + return output def unified_attention_fake( @@ -367,6 +408,7 @@ def unified_attention_with_output( output: torch.Tensor, layer_name: str, ) -> None: + wait_for_kv_layer_from_connector(layer_name) forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata self = forward_context.no_compile_layers[layer_name] @@ -379,6 +421,8 @@ def unified_attention_with_output( attn_metadata, output=output) + maybe_save_kv_layer_to_connector(layer_name, kv_cache) + def unified_attention_with_output_fake( query: torch.Tensor, diff --git a/vllm/attention/ops/hpu_paged_attn.py b/vllm/attention/ops/hpu_paged_attn.py index 49ea420d092cc48a6c77ba320569371c86f9c23a..1dedd2ffc5fa26ff1eb1f374b24d4c36c653b525 100644 --- a/vllm/attention/ops/hpu_paged_attn.py +++ b/vllm/attention/ops/hpu_paged_attn.py @@ -22,7 +22,6 @@ class HPUPagedAttentionMetadata: block_usage: Optional[torch.Tensor] block_indices: Optional[torch.Tensor] block_offsets: Optional[torch.Tensor] - block_scales: Optional[torch.Tensor] block_groups: Optional[torch.Tensor] diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 36a546841b8fefbce941efb62ba4f515f4d1b718..92833f9cb38b43495e9f501fd469a8943289845b 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -20,831 +20,778 @@ NUM_WARPS = 8 # To check compatibility IS_TURING = current_platform.get_device_capability() == (7, 5) -if triton.__version__ >= "2.1.0": - - @triton.jit - def _fwd_kernel( - Q, - K, - V, - K_cache, - V_cache, - B_Loc, - sm_scale, - k_scale, - v_scale, - B_Start_Loc, - B_Seqlen, - block_size, - x, - Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_k_cache_bs, - stride_k_cache_h, - stride_k_cache_d, - stride_k_cache_bl, - stride_k_cache_x, - stride_v_cache_bs, - stride_v_cache_h, - stride_v_cache_d, - stride_v_cache_bl, - num_queries_per_kv: int, - IN_PRECISION: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, # head size - BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 - BLOCK_N: tl.constexpr, - SLIDING_WINDOW: tl.constexpr, - SKIP_DECODE: tl.constexpr, - ): - - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // num_queries_per_kv - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) - cur_batch_query_len = (cur_batch_in_all_stop_index - - cur_batch_in_all_start_index) - cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len - - if SKIP_DECODE and cur_batch_query_len == 1: - return - - # start position inside of the query - # generally, N goes over kv, while M goes over query_len - block_start_loc = BLOCK_M * start_m - - # initialize offsets - # [N]; starts at 0 - offs_n = tl.arange(0, BLOCK_N) - # [D]; starts at 0 - offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) - # [M]; starts at current position in query - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - # [M,D] - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) - - dim_mask = tl.where( - tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, - 0).to(tl.int1) # [D] - - q = tl.load(Q + off_q, - mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_query_len), - other=0.0) # [M,D] - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # [M] - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # [M] - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], - dtype=tl.float32) # [M,D] - - # compute query against context (no causal mask here) - for start_n in range(0, cur_batch_ctx_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - ((start_n + offs_n) // block_size) * stride_b_loc_s, - mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) # [N] - # [D,N] - off_k = (bn[None, :] * stride_k_cache_bs + - cur_kv_head * stride_k_cache_h + - (offs_d[:, None] // x) * stride_k_cache_d + - ((start_n + offs_n[None, :]) % block_size) * - stride_k_cache_bl + - (offs_d[:, None] % x) * stride_k_cache_x) - # [N,D] - off_v = ( - bn[:, None] * stride_v_cache_bs + - cur_kv_head * stride_v_cache_h + - offs_d[None, :] * stride_v_cache_d + - (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) - k_load = tl.load(K_cache + off_k, - mask=dim_mask[:, None] & - ((start_n + offs_n[None, :]) < cur_batch_ctx_len), - other=0.0) # [D,N] - - if k_load.dtype.is_fp8(): - k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) - else: - k = k_load - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N] - qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) - qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, - float("-inf")) - qk *= sm_scale - if SLIDING_WINDOW > 0: - # (cur_batch_ctx_len + offs_m[:, None]) are the positions of - # Q entries in sequence - # (start_n + offs_n[None, :]) are the positions of - # KV entries in sequence - # So the condition makes sure each entry in Q only attends - # to KV entries not more than SLIDING_WINDOW away. - # - # We can't use -inf here, because the - # sliding window may lead to the entire row being masked. - # This then makes m_ij contain -inf, which causes NaNs in - # exp(). - qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) - - (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, - -10000) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) # [M] - p = tl.exp(qk - m_ij[:, None]) # [M,N] - l_ij = tl.sum(p, 1) # [M] - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) # [M] - alpha = tl.exp(m_i - m_i_new) # [M] - beta = tl.exp(m_ij - m_i_new) # [M] - l_i_new = alpha * l_i + beta * l_ij # [M] - - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v_load = tl.load(V_cache + off_v, - mask=dim_mask[None, :] & - ((start_n + offs_n[:, None]) < cur_batch_ctx_len), - other=0.0) # [N,D] - if v_load.dtype.is_fp8(): - v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) - else: - v = v_load - p = p.to(v.dtype) - - acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) - # # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) - k_ptrs = K + off_k - v_ptrs = V + off_v - - # block_mask is 0 when we're already past the current query length - block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) - - # compute query against itself (with causal mask) - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(k_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=dim_mask[:, None] & - ((start_n + offs_n[None, :]) < cur_batch_query_len), - other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) - qk *= sm_scale - # apply causal mask - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, - float("-inf")) - if SLIDING_WINDOW > 0: - qk = tl.where( - offs_m[:, None] - (start_n + offs_n[None, :]) - < SLIDING_WINDOW, qk, -10000) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(v_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=dim_mask[None, :] & - ((start_n + offs_n[:, None]) < cur_batch_query_len), - other=0.0) - p = p.to(v.dtype) - - acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) - out_ptrs = Out + off_o - tl.store(out_ptrs, - acc, - mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_query_len)) + +# Here's an example autotuner config for this kernel. This config does provide +# a performance improvement, but dramatically increases first call latency in +# triton 3.2. Because of this tradeoff, it's currently commented out. +# @triton.autotune( +# configs=[ +# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \ +# "num_unroll_cache": 4, \ +# "num_unroll_request": 1 } | \ +# ({"kpack": 2, "waves_per_eu": 2} \ +# if current_platform.is_rocm() else {}), \ +# num_warps=4, \ +# num_stages=1) +# ], +# key=["BLOCK_SIZE", "MAX_Q_LEN", "MAX_CTX_LEN"] +# ) +@triton.jit +def _fwd_kernel(Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + k_scale, + v_scale, + B_Start_Loc, + B_Seqlen, + x: tl.constexpr, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl: tl.constexpr, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: tl.constexpr, + IN_PRECISION: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DMODEL_PADDED: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, + SLIDING_WINDOW: tl.constexpr, + num_unroll_cache: tl.constexpr, + num_unroll_request: tl.constexpr, + SKIP_DECODE: tl.constexpr, + MAX_Q_LEN: tl.constexpr = 0, + MAX_CTX_LEN: tl.constexpr = 0): + + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) + cur_batch_query_len = (cur_batch_in_all_stop_index - + cur_batch_in_all_start_index) + cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len + + if SKIP_DECODE and cur_batch_query_len == 1: return - @triton.jit - def _fwd_kernel_flash_attn_v2( - Q, - K, - V, - K_cache, - V_cache, - B_Loc, - sm_scale, - B_Start_Loc, - B_Seqlen, - B_Ctxlen, - block_size, - x, - Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_k_cache_bs, - stride_k_cache_h, - stride_k_cache_d, - stride_k_cache_bl, - stride_k_cache_x, - stride_v_cache_bs, - stride_v_cache_h, - stride_v_cache_d, - stride_v_cache_bl, - num_queries_per_kv: int, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // num_queries_per_kv - - cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) - - q = tl.load(Q + off_q, - mask=offs_m[:, None] - < cur_batch_seq_len - cur_batch_ctx_len, + # start position inside of the query + # generally, N goes over kv, while M goes over query_len + block_start_loc = BLOCK_M * start_m + + # initialize offsets + # [BLOCK_SIZE]; starts at 0 + offs_bs_n = tl.arange(0, BLOCK_SIZE) + # [N]; starts at 0 + offs_n = tl.arange(0, BLOCK_N) + # [D]; starts at 0 + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) + # [M]; starts at current position in query + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # [M,D] + off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, + 0).to(tl.int1) # [D] + + q = tl.load(Q + off_q, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_query_len), + other=0.0) # [M,D] + + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D] + + # compute query against context (no causal mask here) + for start_n in tl.range(0, cur_batch_ctx_len, BLOCK_SIZE, \ + loop_unroll_factor=num_unroll_cache): + start_n = tl.multiple_of(start_n, BLOCK_SIZE) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + (start_n // BLOCK_SIZE) * stride_b_loc_s) + # [D,BLOCK_SIZE] + off_k = ( + bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + + # [BLOCK_SIZE,D] + off_v = (bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + offs_bs_n[:, None] * stride_v_cache_bl) + + if start_n + BLOCK_SIZE > cur_batch_ctx_len or \ + BLOCK_DMODEL != BLOCK_DMODEL_PADDED: + k_load = tl.load( + K_cache + off_k, + mask=dim_mask[:, None] & + ((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] + else: + k_load = tl.load(K_cache + off_k) + + if k_load.dtype.is_fp8(): + k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) + else: + k = k_load + + qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N] + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) + qk = tl.where((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + if SLIDING_WINDOW > 0: + # (cur_batch_ctx_len + offs_m[:, None]) are the positions of + # Q entries in sequence + # (start_n + offs_bs_n[None, :]) are the positions of + # KV entries in sequence + # So the condition makes sure each entry in Q only attends + # to KV entries not more than SLIDING_WINDOW away. + # + # We can't use -inf here, because the + # sliding window may lead to the entire row being masked. + # This then makes m_ij contain -inf, which causes NaNs in + # exp(). + qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) - + (start_n + offs_bs_n[None, :]) < SLIDING_WINDOW, qk, + -10000) + + # compute running maximum + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + alpha = tl.exp(m_i - m_ij) + acc = acc * alpha[:, None] + + # update acc + if start_n + BLOCK_SIZE > cur_batch_ctx_len or \ + BLOCK_DMODEL != BLOCK_DMODEL_PADDED: + v_load = tl.load( + V_cache + off_v, + mask=dim_mask[None, :] & + ((start_n + offs_bs_n[:, None]) < cur_batch_ctx_len), + other=0.0) # [N,D] + else: + v_load = tl.load(V_cache + off_v) + + if v_load.dtype.is_fp8(): + v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) + else: + v = v_load + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) + # # update m_i and l_i + l_i = l_i * alpha + l_ij + m_i = m_ij + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + # block_mask is 0 when we're already past the current query length + block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) + + # compute query against itself (with causal mask) + for start_n in tl.range(0, \ + block_mask * (start_m + 1) * BLOCK_M, BLOCK_N, \ + loop_unroll_factor=num_unroll_request): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_query_len), other=0.0) - # # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - for start_n in range(0, cur_batch_ctx_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - ((start_n + offs_n) // block_size) * stride_b_loc_s, - mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) - off_k = (bn[None, :] * stride_k_cache_bs + - cur_kv_head * stride_k_cache_h + - (offs_d[:, None] // x) * stride_k_cache_d + - ((start_n + offs_n[None, :]) % block_size) * - stride_k_cache_bl + - (offs_d[:, None] % x) * stride_k_cache_x) - off_v = ( - bn[:, None] * stride_v_cache_bs + - cur_kv_head * stride_v_cache_h + - offs_d[None, :] * stride_v_cache_d + - (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) - k = tl.load(K_cache + off_k, - mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, - other=0.0) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, - float("-inf")) - qk *= sm_scale - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(V_cache + off_v, - mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, - other=0.0) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) - k_ptrs = K + off_k - v_ptrs = V + off_v - - block_mask = tl.where( - block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(k_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) - < cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, - float("-inf")) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(v_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) - < cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - # acc /= l_i[:, None] - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) - out_ptrs = Out + off_o - tl.store(out_ptrs, - acc, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) - return - - @triton.jit - def _fwd_kernel_alibi( - Q, - K, - V, - K_cache, - V_cache, - B_Loc, - sm_scale, - k_scale, - v_scale, - B_Start_Loc, - B_Seqlen, - Alibi_slopes, - block_size, - x, - Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_k_cache_bs, - stride_k_cache_h, - stride_k_cache_d, - stride_k_cache_bl, - stride_k_cache_x, - stride_v_cache_bs, - stride_v_cache_h, - stride_v_cache_d, - stride_v_cache_bl, - num_queries_per_kv: int, - IN_PRECISION: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, # head size - BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 - BLOCK_N: tl.constexpr, - SKIP_DECODE: tl.constexpr, - ): - # attn_bias[] - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // num_queries_per_kv - - # cur_batch_seq_len: the length of prompts - # cur_batch_ctx_len: the length of prefix - # cur_batch_in_all_start_index: the start id of the dim=0 - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) - cur_batch_query_len = (cur_batch_in_all_stop_index - - cur_batch_in_all_start_index) - cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len - - if SKIP_DECODE and cur_batch_query_len == 1: - return - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) - - dim_mask = tl.where( - tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) - - q = tl.load(Q + off_q, + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) + qk *= sm_scale + # apply causal mask + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + if SLIDING_WINDOW > 0: + qk = tl.where( + offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW, + qk, -10000) + + # compute running maximum + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + alpha = tl.exp(m_i - m_ij) + acc = acc * alpha[:, None] + + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), + ((start_n + offs_n[:, None]) < cur_batch_query_len), + other=0.0) + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) + # update m_i and l_i + l_i = l_i * alpha + l_ij + m_i = m_ij + + acc = acc / l_i[:, None] + + # initialize pointers to output + off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len)) + return + + +@triton.jit +def _fwd_kernel_flash_attn_v2( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + q = tl.load(Q + off_q, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = ( + bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = (bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k = tl.load(K_cache + off_k, + mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, + other=0.0) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(V_cache + off_v, + mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) + < cur_batch_seq_len - cur_batch_ctx_len, other=0.0) - # # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) - - alibi_slope = tl.load(Alibi_slopes + cur_head) - alibi_start_q = tl.arange( - 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len - alibi_start_k = 0 - for start_n in range(0, cur_batch_ctx_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - ((start_n + offs_n) // block_size) * stride_b_loc_s, - mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) - off_k = (bn[None, :] * stride_k_cache_bs + - cur_kv_head * stride_k_cache_h + - (offs_d[:, None] // x) * stride_k_cache_d + - ((start_n + offs_n[None, :]) % block_size) * - stride_k_cache_bl + - (offs_d[:, None] % x) * stride_k_cache_x) - off_v = ( - bn[:, None] * stride_v_cache_bs + - cur_kv_head * stride_v_cache_h + - offs_d[None, :] * stride_v_cache_d + - (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) - k_load = tl.load(K_cache + off_k, - mask=dim_mask[:, None] & - ((start_n + offs_n[None, :]) < cur_batch_ctx_len), - other=0.0) # [D,N] - - if k_load.dtype.is_fp8(): - k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) - else: - k = k_load - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) - qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, - float("-inf")) - qk *= sm_scale - - # load alibi - alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - - alibi_start_q[:, None]) * alibi_slope - alibi = tl.where( - (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), - alibi, float("-inf")) - qk += alibi - alibi_start_k += BLOCK_N - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v_load = tl.load(V_cache + off_v, - mask=dim_mask[None, :] & - ((start_n + offs_n[:, None]) < cur_batch_ctx_len), - other=0.0) - if v_load.dtype.is_fp8(): - v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) - else: - v = v_load - p = p.to(v.dtype) - - acc = tl.dot(p, v, acc=acc, input_precision='ieee') - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) - k_ptrs = K + off_k - v_ptrs = V + off_v - - block_mask = tl.where( - block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) - - # init alibi - alibi_slope = tl.load(Alibi_slopes + cur_head) - alibi_start_q = tl.arange( - 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len - alibi_start_k = cur_batch_ctx_len - # # init debugger - # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc - # offset_db_k = tl.arange(0, BLOCK_N) - # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL] - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(k_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=dim_mask[:, None] & - ((start_n + offs_n[None, :]) - < cur_batch_seq_len - cur_batch_ctx_len), - other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk = tl.dot(q, k, acc=qk, input_precision='ieee') - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, - float("-inf")) - - # load alibi - alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - - alibi_start_q[:, None]) * alibi_slope - alibi = tl.where( - (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), - alibi, float("-inf")) - qk += alibi - alibi_start_k += BLOCK_N - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(v_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=dim_mask[None, :] & - ((start_n + offs_n[:, None]) - < cur_batch_seq_len - cur_batch_ctx_len), - other=0.0) - p = p.to(v.dtype) - - acc = tl.dot(p, v, acc=acc, input_precision='ieee') - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - acc = acc / l_i[:, None] - - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) - out_ptrs = Out + off_o - tl.store(out_ptrs, - acc, - mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) + < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + # acc /= l_i[:, None] + # initialize pointers to output + off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + return + + +@triton.jit +def _fwd_kernel_alibi( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + k_scale, + v_scale, + B_Start_Loc, + B_Seqlen, + Alibi_slopes, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + IN_PRECISION: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, # head size + BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 + BLOCK_N: tl.constexpr, + SKIP_DECODE: tl.constexpr, +): + # attn_bias[] + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + # cur_batch_seq_len: the length of prompts + # cur_batch_ctx_len: the length of prefix + # cur_batch_in_all_start_index: the start id of the dim=0 + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) + cur_batch_query_len = (cur_batch_in_all_stop_index - + cur_batch_in_all_start_index) + cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len + + if SKIP_DECODE and cur_batch_query_len == 1: return - @torch.inference_mode() - def context_attention_fwd(q, - k, - v, - o, - kv_cache_dtype: str, - k_cache, - v_cache, - b_loc, - b_start_loc, - b_seq_len, - max_seq_len, - max_input_len, - k_scale: torch.Tensor, - v_scale: torch.Tensor, - alibi_slopes=None, - sliding_window=None, - sm_scale=None, - skip_decode=False): - - q_dtype_is_f32 = q.dtype is torch.float32 + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) + + q = tl.load(Q + off_q, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) + + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = 0 + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = ( + bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = (bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k_load = tl.load(K_cache + off_k, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] + + if k_load.dtype.is_fp8(): + k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) + else: + k = k_load + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi, + float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v_load = tl.load(V_cache + off_v, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_ctx_len), + other=0.0) + if v_load.dtype.is_fp8(): + v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) + else: + v = v_load + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision='ieee') + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + # init alibi + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = cur_batch_ctx_len + # # init debugger + # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc + # offset_db_k = tl.arange(0, BLOCK_N) + # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL] + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load( + k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=dim_mask[:, None] & ((start_n + offs_n[None, :]) + < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, acc=qk, input_precision='ieee') + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi, + float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load( + v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=dim_mask[None, :] & ((start_n + offs_n[:, None]) + < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision='ieee') + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + acc = acc / l_i[:, None] + + # initialize pointers to output + off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)) + return + + +@torch.inference_mode() +def context_attention_fwd(q, + k, + v, + o, + kv_cache_dtype: str, + k_cache, + v_cache, + b_loc, + b_start_loc, + b_seq_len, + max_seq_len, + max_input_len, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + alibi_slopes=None, + sliding_window=None, + sm_scale=None, + skip_decode=False): + + q_dtype_is_f32 = q.dtype is torch.float32 + + # Turing does have tensor core for float32 multiplication + # use ieee as fallback for triton kernels work. There is also + # warning on vllm/config.py to inform users this fallback + # implementation + IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None + + # Conversion of FP8 Tensor from uint8 storage to + # appropriate torch.dtype for interpretation by Triton + if "fp8" in kv_cache_dtype: + assert (k_cache.dtype == torch.uint8) + assert (v_cache.dtype == torch.uint8) + + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + target_dtype = current_platform.fp8_dtype() + elif kv_cache_dtype == "fp8_e5m2": + target_dtype = torch.float8_e5m2 + else: + raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype) + + k_cache = k_cache.view(target_dtype) + v_cache = v_cache.view(target_dtype) + + if (k_cache.dtype == torch.uint8 + or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"): + raise ValueError("kv_cache_dtype='auto' unsupported for\ + FP8 KV Cache prefill kernel") + + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + # round up Lk to a power of 2 - this is required for Triton block size + Lk_padded = triton.next_power_of_2(Lk) + + if sm_scale is None: + sm_scale = 1.0 / (Lq**0.5) + batch, head = b_seq_len.shape[0], q.shape[1] + num_queries_per_kv = q.shape[1] // k.shape[1] + + assert batch + 1 == len(b_start_loc) + + # 0 means "disable" + if sliding_window is None or sliding_window <= 0: + sliding_window = 0 + + if alibi_slopes is not None: # need to reduce num. blocks when using fp32 # due to increased use of GPU shared memory # if q.dtype is torch.float32: BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK - - # Turing does have tensor core for float32 multiplication - # use ieee as fallback for triton kernels work. There is also - # warning on vllm/config.py to inform users this fallback - # implementation - IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None - - # Conversion of FP8 Tensor from uint8 storage to - # appropriate torch.dtype for interpretation by Triton - if "fp8" in kv_cache_dtype: - assert (k_cache.dtype == torch.uint8) - assert (v_cache.dtype == torch.uint8) - - if kv_cache_dtype in ("fp8", "fp8_e4m3"): - target_dtype = current_platform.fp8_dtype() - elif kv_cache_dtype == "fp8_e5m2": - target_dtype = torch.float8_e5m2 - else: - raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype) - - k_cache = k_cache.view(target_dtype) - v_cache = v_cache.view(target_dtype) - - if (k_cache.dtype == torch.uint8 - or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"): - raise ValueError("kv_cache_dtype='auto' unsupported for\ - FP8 KV Cache prefill kernel") - - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - # round up Lk to a power of 2 - this is required for Triton block size - Lk_padded = triton.next_power_of_2(Lk) - - if sm_scale is None: - sm_scale = 1.0 / (Lq**0.5) - batch, head = b_seq_len.shape[0], q.shape[1] - num_queries_per_kv = q.shape[1] // k.shape[1] - - assert batch + 1 == len(b_start_loc) - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, - - # 0 means "disable" - if sliding_window is None or sliding_window <= 0: - sliding_window = 0 - - if alibi_slopes is not None: - _fwd_kernel_alibi[grid]( - q, - k, - v, - k_cache, - v_cache, - b_loc, - sm_scale, - k_scale, - v_scale, - b_start_loc, - b_seq_len, - alibi_slopes, - v_cache.shape[3], - k_cache.shape[4], - o, - b_loc.stride(0), - b_loc.stride(1), - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - k_cache.stride(0), - k_cache.stride(1), - k_cache.stride(2), - k_cache.stride(3), - k_cache.stride( - 4 - ), #[num_blocks, num_kv_heads, head_size/x, block_size, x] - v_cache.stride(0), - v_cache.stride(1), - v_cache.stride(2), - v_cache.stride( - 3), #[num_blocks, num_kv_heads, head_size, block_size] - num_queries_per_kv=num_queries_per_kv, - IN_PRECISION=IN_PRECISION, - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_DMODEL_PADDED=Lk_padded, - BLOCK_N=BLOCK, - SKIP_DECODE=skip_decode, - num_warps=NUM_WARPS, - num_stages=1, - ) - return - - _fwd_kernel[grid]( + # batch, head, + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) + _fwd_kernel_alibi[grid]( q, k, v, @@ -856,6 +803,7 @@ if triton.__version__ >= "2.1.0": v_scale, b_start_loc, b_seq_len, + alibi_slopes, v_cache.shape[3], k_cache.shape[4], o, @@ -890,9 +838,69 @@ if triton.__version__ >= "2.1.0": BLOCK_DMODEL=Lk, BLOCK_DMODEL_PADDED=Lk_padded, BLOCK_N=BLOCK, - SLIDING_WINDOW=sliding_window, SKIP_DECODE=skip_decode, num_warps=NUM_WARPS, num_stages=1, ) return + + max_seq_len = 0 if max_seq_len is None else max_seq_len + extra_kargs = {} + if current_platform.is_rocm(): + extra_kargs = {"kpack": 2, "waves_per_eu": 2} + + grid = lambda META: (batch, head, + triton.cdiv(max_input_len, META["BLOCK_M"])) + _fwd_kernel[grid]( + q, + k, + v, + k_cache, + v_cache, + b_loc, + sm_scale, + k_scale, + v_scale, + b_start_loc, + b_seq_len, + k_cache.shape[4], + o, + b_loc.stride(0), + b_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + k_cache.stride( + 4), #[num_blocks, num_kv_heads, head_size/x, block_size, x] + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride(3), #[num_blocks, num_kv_heads, head_size, block_size] + BLOCK_SIZE=v_cache.shape[3], + num_queries_per_kv=num_queries_per_kv, + IN_PRECISION=IN_PRECISION, + BLOCK_DMODEL=Lk, + BLOCK_DMODEL_PADDED=Lk_padded, + SLIDING_WINDOW=sliding_window, + SKIP_DECODE=skip_decode, + BLOCK_M=128, + BLOCK_N=64, + num_unroll_cache=4, + num_unroll_request=1, + num_warps=4, + num_stages=1, + **extra_kargs) + return diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py new file mode 100644 index 0000000000000000000000000000000000000000..1c90f8c19b09c64b186d0efd8b409f0f212dbff7 --- /dev/null +++ b/vllm/attention/ops/rocm_aiter_mla.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import torch + + +def get_aiter_mla_metadata(max_batch_size: int, block_size: int, + max_block_per_batch: int, + device: torch.device) -> tuple[torch.Tensor, ...]: + paged_kv_indices = torch.zeros(max_batch_size * max_block_per_batch, + dtype=torch.int32, + device=device) + paged_kv_indptr = torch.zeros(max_batch_size + 1, + dtype=torch.int32, + device=device) + paged_kv_last_page_lens = torch.full((max_batch_size, ), + block_size, + dtype=torch.int32) + return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens + + +def aiter_mla_decode_fwd( + q: torch.Tensor, + kv_buffer: torch.Tensor, + o: torch.Tensor, + sm_scale: float, + kv_indptr: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, + kv_last_page_lens: Optional[torch.Tensor] = None, + logit_cap: float = 0.0, +): + from aiter.mla import mla_decode_fwd + + mla_decode_fwd(q, + kv_buffer.view(-1, 1, 1, q.shape[-1]), + o, + kv_indptr, + kv_indices, + kv_last_page_lens, + sm_scale=sm_scale, + logit_cap=logit_cap) diff --git a/vllm/attention/ops/rocm_aiter_paged_attn.py b/vllm/attention/ops/rocm_aiter_paged_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..0f3cf1842c8052a9fa3646c36b34809f355c9964 --- /dev/null +++ b/vllm/attention/ops/rocm_aiter_paged_attn.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import aiter as rocm_aiter +import torch + +from vllm.attention.ops.paged_attn import PagedAttention +from vllm.platforms import current_platform +from vllm.utils import cdiv + +FP8_DTYPE = current_platform.fp8_dtype() + + +class AITERPagedAttention(PagedAttention): + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + ) -> None: + if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]: + PagedAttention.write_to_paged_cache(key, value, key_cache, + value_cache, slot_mapping, + kv_cache_dtype, k_scale, + v_scale) + else: + kv_cache_torch_dtype = (FP8_DTYPE + if "fp8" in kv_cache_dtype else torch.int8) + key_cache = key_cache.view(kv_cache_torch_dtype) + value_cache = value_cache.view(kv_cache_torch_dtype) + + rocm_aiter.reshape_and_cache_with_pertoken_quant( + key, value, key_cache, value_cache, k_scale, v_scale, + slot_mapping.flatten(), True) + + @staticmethod + def forward_decode( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + max_seq_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + k_scale: torch.Tensor, + v_scale: torch.Tensor, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, + ) -> torch.Tensor: + if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]: + return PagedAttention.forward_decode( + query=query, + key_cache=key_cache, + value_cache=value_cache, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=max_seq_len, + kv_cache_dtype=kv_cache_dtype, + num_kv_heads=num_kv_heads, + scale=scale, + alibi_slopes=alibi_slopes, + k_scale=k_scale, + v_scale=v_scale, + tp_rank=tp_rank, + blocksparse_local_blocks=blocksparse_local_blocks, + blocksparse_vert_stride=blocksparse_vert_stride, + blocksparse_block_size=blocksparse_block_size, + blocksparse_head_sliding_step=blocksparse_head_sliding_step) + + if "fp8" in kv_cache_dtype: + key_cache = key_cache.view(torch.float8_e4m3fnuz) + value_cache = value_cache.view(torch.float8_e4m3fnuz) + + if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1: + # use blocksparse paged attention + block_size = value_cache.size(-1) + assert (blocksparse_block_size > 0 and + blocksparse_block_size % block_size == 0), \ + (f"{blocksparse_block_size=} needs to be a multiple of" + f"{block_size=} used in block_tables.") + + output = torch.empty_like(query) + block_size = value_cache.shape[3] + max_num_blocks_per_seq = cdiv(max_seq_len, block_size) + + rocm_aiter.pa_fwd_asm(query, key_cache, value_cache, block_tables, + seq_lens, max_num_blocks_per_seq, k_scale, + v_scale, output) + return output diff --git a/vllm/attention/ops/triton_decode_attention.py b/vllm/attention/ops/triton_decode_attention.py index 690e836438edea9b2ecca6339276d1e2f5fee22b..db6f4b73afeb9ff15fee152752bb7c0b4756289e 100644 --- a/vllm/attention/ops/triton_decode_attention.py +++ b/vllm/attention/ops/triton_decode_attention.py @@ -43,11 +43,12 @@ os.environ["TRITON_HIP_USE_NEW_STREAM_PIPELINE"] = f"0" logger = logging.getLogger(__name__) -# TODO: Remove this when triton>=3.2.0. This issue will not affect performance -# and accuracy. -logger.warning( - "The following error message 'operation scheduled before its operands' " - "can be ignored.") +# Only print the following warnings when triton version < 3.2.0. +# The issue won't affect performance or accuracy. +if triton.__version__ < '3.2.0': + logger.warning( + "The following error message 'operation scheduled before its operands' " + "can be ignored.") @triton.jit diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index 6295aee243b73e99dc4f0767014fe79f9cd1e0a7..e98b5254541b60d7f715d146e1186af5832ada8b 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -1,31 +1,237 @@ -#!/usr/bin/env python # SPDX-License-Identifier: Apache-2.0 """ Fused Attention =============== -This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao -(https://tridao.me/publications/flash2/flash2.pdf) -Credits: OpenAI kernel team, AMD ML Frameworks Triton team +This is a Triton implementation of the Flash Attention v2 algorithm +See https://tridao.me/publications/flash2/flash2.pdf -Features supported: +Credits: +AMD Triton kernels team +OpenAI kernel team -1) Fwd with causal masking -2) Any sequence lengths without padding (currently fwd kernel only) -3) Support for different sequence lengths for q and k -4) Nested tensor API currently does not support dropout or bias. - -Not currently supported: +Currently only the forward kernel is supported, and contains these features: -1) Non power of two head dims +1) Fwd with causal masking +2) Arbitrary Q and KV sequence lengths +3) Arbitrary head sizes +4) Multi and grouped query attention +5) Variable sequence lengths +6) ALiBi and matrix bias +7) FP8 support """ +from typing import Optional + import torch import triton import triton.language as tl -torch_dtype: tl.constexpr = torch.float16 +from vllm import _custom_ops as ops +from vllm.platforms import current_platform + +SUPPORTED_LAYOUTS = ['thd', 'bhsd', 'bshd'] + +default_eight_bit_dtype_triton = tl.float8e4b8 +default_eight_bit_dtype_torch = current_platform.fp8_dtype() +default_float8_info = torch.finfo(default_eight_bit_dtype_torch) + +FP8_MIN = triton.language.constexpr(default_float8_info.min) + +# According to https://github.com/vllm-project/vllm/blob/main +# /csrc/quantization/utils.cuh#L31, +# need to make the max for the uz datatype be 224.0 for accuracy reasons. +FP8_MAX = triton.language.constexpr( + default_float8_info.max if default_eight_bit_dtype_torch != + torch.float8_e4m3fnuz else 224.0) + + +class MetaData: + cu_seqlens_q = None + cu_seqlens_k = None + max_seqlens_q = 0 + max_seqlens_k = 0 + bias = None + alibi_slopes = None + causal = False + num_contexts = 0 + varlen = False + eight_bit = False + layout = None + return_encoded_softmax = False + eight_bit_dtype_triton = default_eight_bit_dtype_triton + eight_bit_dtype_torch = default_eight_bit_dtype_torch + output_dtype = None + + # Note about layouts: + # + # thd - [num_tokens, num_heads, head_size] + # bshd - [batch_size, seq_len, num_heads, head_size] + # bhsd - [batch_size, num_heads, seq_len, head_size] + # + # This is for each tensor, all tensors must have same layout. + # Q can have num_heads and seq_len differ from from K and V, + # however K and V must agree on this. + # + # Notes about varlen and bias: + # Only one or the other is implemented, meaning can't combine + # both varlen and bias right now. + # + # Note about quantization: + # Only 8-bit quantization supported (for now) and specifically fp8. + # Scales must be tensors. + # o_scale: This is 'output scaling', but comes from parameter called + # 'input_scale', this is applied to the output from the kernel. + # o_scale should be None if none of the other quantization parameters + # are used. + # + # NOTE: Object is in a tentatively good state after initialized, however, + # to verify, call check_args(q,k,v,o) where o is the output tensor. + def __init__( + self, + sm_scale=1.0, + layout=None, # layout can be 'bshd', 'bhsd', or 'thd' + output_dtype=None, + max_seqlens_q=0, + max_seqlens_k=0, + # varlen params + cu_seqlens_q=None, # only 'thd' layout supported for varlen + cu_seqlens_k=None, + # quant params + q_descale=None, + k_descale=None, + v_descale=None, + p_scale=None, + o_scale=None, + # bias params + bias=None, # varlen not implemented for bias + seqlen_q=None, + seqlen_k=None, + # alibi params + alibi_slopes=None, + alibi_batch=None, + alibi_nheads=None, + # causal + causal=None, + ): + self.sm_scale = sm_scale + self.output_dtype = output_dtype + self.max_seqlens_q = max_seqlens_q + self.max_seqlens_k = max_seqlens_k + self.layout = layout + if cu_seqlens_q is not None or cu_seqlens_k is not None: + assert cu_seqlens_q is not None and cu_seqlens_k is not None + assert layout is None or layout not in [ + 'bshd', 'bhsd' + ], "Varlen only implemented for thd layout" + self.set_varlen_params(cu_seqlens_q, cu_seqlens_k) + quant_params = [q_descale, k_descale, v_descale, p_scale, o_scale] + if any(x is not None for x in quant_params): + p_descale = 1.0 / p_scale if p_scale is not None else None + self.set_eight_bit_params(q_descale, k_descale, v_descale, p_scale, + p_descale, o_scale) + if bias is not None: + self.need_bias(bias, seqlen_q, seqlen_k) + if alibi_slopes is not None: + self.need_alibi(alibi_slopes, alibi_batch, alibi_nheads) + if causal is not None and causal: + self.need_causal() + + def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k): + self.varlen = True + self.layout = 'thd' + self.cu_seqlens_q = cu_seqlens_q + self.cu_seqlens_k = cu_seqlens_k + # Without "varlen", there should still be one sequence. + assert len(cu_seqlens_q) >= 2 + assert len(cu_seqlens_q) == len(cu_seqlens_k) + self.num_contexts = len(cu_seqlens_q) - 1 + for i in range(0, self.num_contexts): + self.max_seqlens_q = max( + cu_seqlens_q[i + 1].item() - cu_seqlens_q[i].item(), + self.max_seqlens_q) + self.max_seqlens_k = max( + cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item(), + self.max_seqlens_k) + + def set_eight_bit_params(self, q_descale, k_descale, v_descale, p_scale, + p_descale, o_scale): + self.eight_bit = True + self.q_descale = q_descale + self.k_descale = k_descale + self.v_descale = v_descale + self.p_scale = p_scale + self.p_descale = p_descale + self.o_scale = o_scale + self.use_p_scale = (p_scale is not None) and ( + p_descale is not None) and (v_descale is not None) + self.eight_bit_kv = ((q_descale is None) and (k_descale is not None) + and (v_descale is not None)) + self.eight_bit_dtype_torch = default_eight_bit_dtype_torch + + def need_bias(self, bias, seqlen_q, seqlen_k): + assert bias is not None + assert bias.is_cuda + assert bias.dim() == 4 + assert bias.shape[0] == 1 + assert bias.shape[2:] == (seqlen_q, seqlen_k) + self.bias = bias + + def need_alibi(self, alibi_slopes, batch, nheads): + assert alibi_slopes.is_cuda + assert alibi_slopes.dim() == 2 + assert alibi_slopes.shape[0] == batch + assert alibi_slopes.shape[1] == nheads + self.alibi_slopes = alibi_slopes + + def need_causal(self): + self.causal = True + + def check_args(self, q, k, v, o): + assert q.dim() == k.dim() and q.dim() == v.dim() + + batch, nheads_q, nheads_k, head_size = get_shape_from_layout( + q, k, self) + if self.varlen: + assert q.dim() == 3 + assert self.cu_seqlens_q is not None + assert self.cu_seqlens_k is not None + assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k) + # TODO: Remove once bias is supported with varlen + assert self.bias is None + assert not self.return_encoded_softmax + else: + assert q.dim() == 4 + assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0 + assert self.cu_seqlens_q is None and self.cu_seqlens_k is None + assert k.shape == v.shape + assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] + # TODO: Change assert if we support qkl f8 and v f16 + if self.eight_bit: + if self.eight_bit_kv: + assert (v.dtype == k.dtype + and k.dtype == self.eight_bit_dtype_torch) + assert q.dtype != k.dtype + assert (self.v_descale is not None) and (self.k_descale + is not None) + else: + assert (q.dtype == k.dtype and q.dtype == v.dtype + and q.dtype == self.eight_bit_dtype_torch) + assert (self.q_descale + is not None) and (self.k_descale + is not None) and (self.v_descale + is not None) + if self.use_p_scale: + assert (self.p_scale is not None) and (self.p_descale + is not None) + else: + assert (q.dtype == k.dtype) and (q.dtype == v.dtype) + assert head_size <= 256 + assert o.shape == q.shape + assert (nheads_q % nheads_k) == 0 + assert self.layout is not None + assert self.layout == 'thd' or not self.varlen @triton.jit @@ -38,40 +244,85 @@ def max_fn(x, y): return tl.math.max(x, y) +# Convenience function to load with optional boundary checks. +# "First" is the major dim, "second" is the minor dim. @triton.jit -def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): - ms = tl.arange(0, m) - ns = tl.arange(0, n) - return philox_offset + ms[:, None] * stride + ns[None, :] +def masked_load(ptrs, offset_first, offset_second, boundary_first, + boundary_second): + if offset_first is not None and offset_second is not None: + mask = (offset_first[:, None] < boundary_first) & \ + (offset_second[None, :] < boundary_second) + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_first is not None: + mask = offset_first[:, None] < boundary_first + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_second is not None: + mask = offset_second[None, :] < boundary_second + tensor = tl.load(ptrs, mask=mask, other=0.0) + else: + tensor = tl.load(ptrs) + return tensor @triton.jit -def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): - rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, - stride).to(tl.uint32) - # TODO: use tl.randint for better performance - return tl.rand(philox_seed, rng_offsets) +def compute_alibi_block(alibi_slope, + seqlen_q, + seqlen_k, + offs_m, + offs_n, + transpose=False): + # when seqlen_k and seqlen_q are different we want the diagonal to stick to + # the bottom right of the attention matrix + # for casual mask we want something like this where (1 is kept and 0 is + # masked) + # seqlen_q = 2 and seqlen_k = 5 + # 1 1 1 1 0 + # 1 1 1 1 1 + # seqlen_q = 5 and seqlen_k = 2 + # 0 0 + # 0 0 + # 0 0 + # 1 0 + # 1 1 + # for alibi the diagonal is 0 indicating no penalty for attending to that + # spot and increasing penalty for attending further from the diagonal + # e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, + # offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False + # 1. offs_m[:,None] = [[0], + # [1], + # 2. offs_m[:,None] + seqlen_k = [[5], + # [6], + # 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3], + # [4], + # 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = + # [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], [4], [ 4, 3, 2, 1, 0]] + # 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1], + # [ -4, -3, -2, -1, 0]], + relative_pos_block = (offs_m[:, None] + seqlen_k - seqlen_q - + offs_n[None, :]) + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + if transpose: + return alibi_block.T + else: + return alibi_block -@triton.jit -def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): - rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, - stride) - rng_keep = rng_output > dropout_p - return rng_keep +def compute_alibi_tensor(alibi_slopes, seqlen_q, seqlen_k): + q_idx = torch.arange(seqlen_q, dtype=torch.int32, + device="cuda").unsqueeze(-1) # (N_CTX_Q, 1) + k_idx = torch.arange(seqlen_k, dtype=torch.int32, + device="cuda").unsqueeze(0) # (1, N_CTX_K) + relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - + k_idx) # (N_CTX_Q, N_CTX_K) + return -1 * alibi_slopes.unsqueeze(-1).unsqueeze( + -1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K) @triton.jit -def load_fn(block_ptr, first, second, pad): - if first and second: - tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) - elif first: - tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad) - elif second: - tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad) - else: - tensor = tl.load(block_ptr) - return tensor +def quant_fp8(x, scale): + x *= scale + x = tl.clamp(x, FP8_MIN, FP8_MAX) + return x @triton.jit @@ -80,58 +331,68 @@ def _attn_fwd_inner( l_i, m_i, q, - K_block_ptr, - V_block_ptr, + k_ptrs, + v_ptrs, + bias_ptrs, + stride_kn, + stride_vk, + stride_bn, start_m, actual_seqlen_k, - dropout_p, + actual_seqlen_q, philox_seed, batch_philox_offset, - encoded_softmax_block_ptr, + encoded_sm_ptrs, block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, - bias_ptr, + alibi_slope, + q_descale, + k_descale, + v_descale, + p_scale, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, - PRE_LOAD_V: tl.constexpr, - MASK_STEPS: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - RETURN_ENCODED_SOFTMAX: tl.constexpr, - PADDED_HEAD: tl.constexpr, + SHOULD_PRE_LOAD_V: tl.constexpr, + SHOULD_MASK_STEPS: tl.constexpr, + SHOULD_RETURN_ENCODED_SOFTMAX: tl.constexpr, + USE_PADDED_HEAD: tl.constexpr, + IS_ACTUAL_BLOCK_DMODEL: tl.constexpr, + QK_SCALE: tl.constexpr, + IS_EIGHT_BIT_GEMM: tl.constexpr, + USE_P_SCALE: tl.constexpr, + IS_EIGHT_BIT_KV: tl.constexpr, + QUANT_DTYPE: tl.constexpr = default_eight_bit_dtype_triton, ): + # loop over k, v, and update accumulator for start_n in range(block_min, block_max, BLOCK_N): # For padded blocks, we will overrun the tensor size if # we load all BLOCK_N. For others, the blocks are all within range. - k = load_fn( - K_block_ptr, - PADDED_HEAD, - MASK_STEPS and (n_extra_tokens != 0), - "zero", - ) - if PRE_LOAD_V: - v = load_fn( - V_block_ptr, - MASK_STEPS and (n_extra_tokens != 0), - PADDED_HEAD, - "zero", - ) + k_offs_n = start_n + tl.arange(0, + BLOCK_N) if SHOULD_MASK_STEPS else None + k_offs_k = None if not USE_PADDED_HEAD else tl.arange(0, BLOCK_DMODEL) + k = masked_load(k_ptrs, k_offs_k, k_offs_n, IS_ACTUAL_BLOCK_DMODEL, + actual_seqlen_k) + if SHOULD_PRE_LOAD_V: + # We can use the same offsets as k, just with dims transposed. + v = masked_load(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, + IS_ACTUAL_BLOCK_DMODEL) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # We start from end of seqlen_k so only the first iteration would need # to be checked for padding if it is not a multiple of block_n # TODO: This can be optimized to only be true for the padded block. - if MASK_STEPS: # noqa: SIM102 + if SHOULD_MASK_STEPS: # noqa: SIM102 # If this is the last block / iteration, we want to # mask if the sequence length is not a multiple of block size - # a solution is to always do BLOCK_M // BLOCK_N + 1 steps - # if not is_modulo_mn. last step might get wasted but that is okay. + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not + # is_modulo_mn. last step might get wasted but that is okay. # check if this masking works for that case. if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): boundary_m = tl.full([BLOCK_M], @@ -144,167 +405,276 @@ def _attn_fwd_inner( causal_boundary = start_n + offs_n_causal causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] qk = tl.where(causal_mask, qk, float("-inf")) + # -- compute qk ---- - qk += tl.dot(q, k) - if bias_ptr is not None: - bias = load_fn(bias_ptr, False, MASK_STEPS - and (n_extra_tokens != 0), "zero") - # While bias is added after multiplying qk with sm_scale, our - # optimization to use 2^x instead of e^x results in an additional - # scale factor of log2(e) which we must also multiply the bias with. - qk += bias * 1.44269504089 + if IS_EIGHT_BIT_GEMM: + qk += ((((tl.dot(q, k).to(tl.float32) * q_descale)) * k_descale) * + QK_SCALE) + else: + if IS_EIGHT_BIT_KV: + k = (k * k_descale).to(q.type.element_ty) + qk += (tl.dot(q, k) * QK_SCALE) + + if bias_ptrs is not None: + bias_offs_n = start_n + tl.arange( + 0, BLOCK_N) if SHOULD_MASK_STEPS else None + bias = masked_load(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, + actual_seqlen_k) + # While bias is added after multiplying qk with sm_scale, + # our optimization to use 2^x instead of e^x results in an + # additional scale factor of log2(e) which we must also multiply + # the bias with. + qk += (bias * 1.44269504089) + + if alibi_slope is not None: + # Compute the global position of each token within the sequence + global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + global_n_positions = start_n + tl.arange(0, BLOCK_N) + alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, + actual_seqlen_k, + global_m_positions, + global_n_positions) + qk += (alibi_block * 1.44269504089) # scale factor of log2(e) + + # softmax m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk = qk - m_ij[:, None] p = tl.math.exp2(qk) # CAVEAT: Must update l_ij before applying dropout l_ij = tl.sum(p, 1) - if ENABLE_DROPOUT: - philox_offset = (batch_philox_offset + - start_m * BLOCK_M * actual_seqlen_k + start_n - - BLOCK_N) - keep = dropout_mask( - philox_seed, - philox_offset, - dropout_p, - BLOCK_M, - BLOCK_N, - actual_seqlen_k, - ) - if RETURN_ENCODED_SOFTMAX: - tl.store( - encoded_softmax_block_ptr, - tl.where(keep, p, - -p).to(encoded_softmax_block_ptr.type.element_ty), - ) - p = tl.where(keep, p, 0.0) - elif RETURN_ENCODED_SOFTMAX: - tl.store( - encoded_softmax_block_ptr, - p.to(encoded_softmax_block_ptr.type.element_ty), - ) + if SHOULD_RETURN_ENCODED_SOFTMAX: + tl.store(encoded_sm_ptrs, p.to(encoded_sm_ptrs.type.element_ty)) # -- update output accumulator -- alpha = tl.math.exp2(m_i - m_ij) acc = acc * alpha[:, None] - if not PRE_LOAD_V: - v = load_fn( - V_block_ptr, - MASK_STEPS and (n_extra_tokens != 0), - PADDED_HEAD, - "zero", - ) + if not SHOULD_PRE_LOAD_V: + v = masked_load(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, + IS_ACTUAL_BLOCK_DMODEL) # -- update m_i and l_i l_i = l_i * alpha + l_ij # update m_i and l_i m_i = m_ij - acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - if bias_ptr is not None: - bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) - if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, - (0, BLOCK_N)) + + if IS_EIGHT_BIT_GEMM: + if USE_P_SCALE: + p = quant_fp8(p, p_scale).to(QUANT_DTYPE) + acc += tl.dot(p, v) + else: + # v is in eight_bit but p is not, we want the gemm in p's type + acc += tl.dot(p, v.to(p.type.element_ty)) + else: + if IS_EIGHT_BIT_KV: + v = (v * v_descale).to(p.type.element_ty) + acc += tl.dot(p.to(v.type.element_ty), v) + + k_ptrs += BLOCK_N * stride_kn + v_ptrs += BLOCK_N * stride_vk + if bias_ptrs is not None: + bias_ptrs += BLOCK_N * stride_bn + if SHOULD_RETURN_ENCODED_SOFTMAX: + encoded_sm_ptrs += BLOCK_N return acc, l_i, m_i -@triton.autotune( - configs=[ +def get_cdna_autotune_configs(): + return [ + triton.Config( + { + 'BLOCK_M': 128, + 'BLOCK_N': 128, + 'waves_per_eu': 2, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 + }, + num_stages=1, + num_warps=4), + triton.Config( + { + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'waves_per_eu': 2, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 + }, + num_stages=1, + num_warps=4), + triton.Config( + { + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'waves_per_eu': 3, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 + }, + num_stages=1, + num_warps=4), triton.Config( { - "BLOCK_M": 256, - "BLOCK_N": 64, - "waves_per_eu": 0, - "PRE_LOAD_V": False, + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'waves_per_eu': 1, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 }, num_stages=1, - num_warps=8, - ), + num_warps=4), triton.Config( { - "BLOCK_M": 128, - "BLOCK_N": 128, - "waves_per_eu": 0, - "PRE_LOAD_V": False, + 'BLOCK_M': 128, + 'BLOCK_N': 32, + 'waves_per_eu': 2, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 }, num_stages=1, - num_warps=4, - ), + num_warps=4), + ], [ + 'IS_CAUSAL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', + 'IS_ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK' + ] + + +def get_rdna_autotune_configs(): + return [ triton.Config( { - "BLOCK_M": 256, - "BLOCK_N": 128, - "waves_per_eu": 0, - "PRE_LOAD_V": False, + 'BLOCK_M': 32, + 'BLOCK_N': 32, + 'waves_per_eu': 4, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 }, num_stages=1, - num_warps=8, - ), + num_warps=2), triton.Config( { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 0, - "PRE_LOAD_V": False, + 'BLOCK_M': 32, + 'BLOCK_N': 32, + 'waves_per_eu': 2, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 }, num_stages=1, - num_warps=4, - ), + num_warps=2), triton.Config( { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 0, - "PRE_LOAD_V": True, + 'BLOCK_M': 32, + 'BLOCK_N': 16, + 'waves_per_eu': 4, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 }, num_stages=1, - num_warps=4, - ), + num_warps=2), triton.Config( { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 0, - "PRE_LOAD_V": False, + 'BLOCK_M': 32, + 'BLOCK_N': 16, + 'waves_per_eu': 2, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 }, num_stages=1, - num_warps=4, - ), + num_warps=2), triton.Config( { - "BLOCK_M": 64, - "BLOCK_N": 64, - "waves_per_eu": 0, - "PRE_LOAD_V": False, + 'BLOCK_M': 16, + 'BLOCK_N': 16, + 'waves_per_eu': 4, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 }, num_stages=1, - num_warps=8, - ), + num_warps=2), triton.Config( { - "BLOCK_M": 32, - "BLOCK_N": 32, - "waves_per_eu": 0, - "PRE_LOAD_V": False, + 'BLOCK_M': 16, + 'BLOCK_N': 16, + 'waves_per_eu': 2, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 }, num_stages=1, - num_warps=8, - ), - # TODO: This config fails with head_size not pow2 with data mismatches. - # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, - # 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + num_warps=2), + # Fall-back config. triton.Config( { - "BLOCK_M": 16, - "BLOCK_N": 16, - "waves_per_eu": 0, - "PRE_LOAD_V": False, + 'BLOCK_M': 16, + 'BLOCK_N': 16, + 'waves_per_eu': 1, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 }, num_stages=1, - num_warps=4, - ), - ], - key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], + num_warps=2), + ], [ + 'IS_CAUSAL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', + 'IS_ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK' + ] + + +def get_general_autotune_configs(): + return [ + triton.Config( + { + 'BLOCK_M': 128, + 'BLOCK_N': 128, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 + }, + num_stages=1, + num_warps=4), + triton.Config( + { + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 + }, + num_stages=1, + num_warps=4), + triton.Config( + { + 'BLOCK_M': 128, + 'BLOCK_N': 32, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 + }, + num_stages=1, + num_warps=4), + ], [ + 'IS_CAUSAL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', + 'IS_ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK' + ] + + +def has_cdna_target(): + ROCM_CDNA_TARGETS = ["gfx940", "gfx941", "gfx942", "gfx90a", "gfx908"] + return triton.runtime.driver.active.get_current_target( + ).arch in ROCM_CDNA_TARGETS + + +def is_rocm_cdna(): + return current_platform.is_rocm() and has_cdna_target() + + +def get_autotune_configs(): + if is_rocm_cdna(): + return get_cdna_autotune_configs() + elif current_platform.is_rocm(): + return get_rdna_autotune_configs() + else: + return get_general_autotune_configs() + + +autotune_configs, autotune_keys = get_autotune_configs() + + +@triton.autotune( + configs=autotune_configs, + key=autotune_keys, + use_cuda_graph=True, ) @triton.jit def attn_fwd( @@ -312,38 +682,53 @@ def attn_fwd( K, V, bias, - sm_scale, + SM_SCALE: tl.constexpr, L, Out, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vk, - stride_vn, - stride_oz, - stride_oh, - stride_om, - stride_on, - stride_bz, - stride_bh, - stride_bm, - stride_bn, + stride_qz: tl.int64, + stride_qh: tl.int64, + stride_qm: tl.int64, + stride_qk: tl.int64, + stride_kz: tl.int64, + stride_kh: tl.int64, + stride_kn: tl.int64, + stride_kk: tl.int64, + stride_vz: tl.int64, + stride_vh: tl.int64, + stride_vk: tl.int64, + stride_vn: tl.int64, + stride_oz: tl.int64, + stride_oh: tl.int64, + stride_om: tl.int64, + stride_on: tl.int64, + stride_bz: tl.int64, + stride_bh: tl.int64, + stride_bm: tl.int64, + stride_bn: tl.int64, + stride_az: tl.int64, + stride_ah: tl.int64, + q_descale_ptr, + k_descale_ptr, + p_scale_ptr, + p_descale_ptr, + o_descale_ptr, + v_descale_ptr, + q_descale_has_singleton: tl.constexpr, + k_descale_has_singleton: tl.constexpr, + p_descale_has_singleton: tl.constexpr, + v_descale_has_singleton: tl.constexpr, cu_seqlens_q, cu_seqlens_k, - dropout_p, philox_seed, + NUM_CU: tl.constexpr, + GRID_CU_MULTIP: tl.constexpr, + B: tl.constexpr, philox_offset_base, encoded_softmax, + alibi_slopes, HQ: tl.constexpr, HK: tl.constexpr, - ACTUAL_BLOCK_DMODEL: tl.constexpr, + IS_ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, @@ -351,24 +736,39 @@ def attn_fwd( BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, - PRE_LOAD_V: tl.constexpr, - BIAS_TYPE: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - RETURN_ENCODED_SOFTMAX: tl.constexpr, + SHOULD_PRE_LOAD_V: tl.constexpr, + USE_BIAS: tl.constexpr, + SHOULD_RETURN_ENCODED_SOFTMAX: tl.constexpr, + USE_ALIBI: tl.constexpr, + IS_EIGHT_BIT: tl.constexpr, + USE_P_SCALE: tl.constexpr, + IS_EIGHT_BIT_KV: tl.constexpr, + QUANT_DTYPE: tl.constexpr = default_eight_bit_dtype_triton, ): - start_m = tl.program_id(0) - off_h_q = tl.program_id(1) - off_z = tl.program_id(2) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) + + if o_descale_ptr is not None: + o_descale = tl.load(o_descale_ptr) + + start_m: tl.int64 = tl.program_id(0) + off_h_q: tl.int64 = tl.program_id(1) + off_z: tl.int64 = tl.program_id(2) + + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M).to(tl.int64) + offs_n = tl.arange(0, BLOCK_N).to(tl.int64) + offs_d = tl.arange(0, BLOCK_DMODEL).to(tl.int64) + + # as we can't have return statements inside while loop in Triton + continue_condition = True + if VARLEN: cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start - # We have a one-size-fits-all grid in id(0). Some seqlens might be too - # small for all start_m so for those we return early. + # We have a one-size-fits-all grid in id(0). Some seqlens might be + # too small for all start_m so for those we return early. if start_m * BLOCK_M > seqlen_q: - return + continue_condition = False + # return cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start @@ -378,444 +778,598 @@ def attn_fwd( seqlen_q = MAX_SEQLENS_Q seqlen_k = MAX_SEQLENS_K - # Now we compute whether we need to exit early due to causal masking. - # This is because for seqlen_q > seqlen_k, M rows of the attn scores - # are completely masked, resulting in 0s written to the output, and - # inf written to LSE. We don't need to do any GEMMs in this case. - # This block of code determines what N is, and if this WG is operating - # on those M rows. - n_blocks = cdiv_fn(seqlen_k, BLOCK_N) - if IS_CAUSAL: - # If seqlen_q == seqlen_k, the attn scores are a square matrix. - # If seqlen_q != seqlen_k, attn scores are rectangular which means - # the causal mask boundary is bottom right aligned, and ends at either - # the top edge (seqlen_q < seqlen_k) or left edge. - # This captures the decrease in n_blocks if we have a rectangular attn - # matrix - n_blocks_seqlen = cdiv_fn( - (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) - # This is what adjusts the block_max for the current WG, only - # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks - n_blocks = min(n_blocks, n_blocks_seqlen) - # If we have no blocks after adjusting for seqlen deltas, this WG is - # part of the blocks that are all 0. We exit early. - if n_blocks <= 0: - o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + - off_h_q * stride_oh) - O_block_ptr = tl.make_block_ptr( - base=Out + o_offset, - shape=(seqlen_q, BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) - # We still need to write 0s to the result - # tl.store(O_block_ptr, - # acc.to(Out.type.element_ty), boundary_check=(0,1)) - # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q - # + offs_m - # We store inf to LSE, not -inf because in the bwd pass, - # we subtract this - # from qk which makes it -inf, such that exp(qk - inf) = 0 - # for these masked blocks. - # l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) - # tl.store(l_ptrs, l) - # TODO: Should dropout and return encoded softmax be handled here? - return - - # If MQA / GQA, set the K and V head offsets appropriately. - GROUP_SIZE: tl.constexpr = HQ // HK - off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q - - n_extra_tokens = 0 - if seqlen_k < BLOCK_N: - n_extra_tokens = BLOCK_N - seqlen_k - elif seqlen_k % BLOCK_N: - n_extra_tokens = seqlen_k % BLOCK_N - padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL - - # Compute pointers for all the tensors used in this kernel. - q_offset = (off_z * stride_qz + off_h_q * stride_qh + - cu_seqlens_q_start * stride_qm) - Q_block_ptr = tl.make_block_ptr( - base=Q + q_offset, - shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - k_offset = (off_z * stride_kz + off_h_k * stride_kh + - cu_seqlens_k_start * stride_kn) - K_block_ptr = tl.make_block_ptr( - base=K + k_offset, - shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), - strides=(stride_kk, stride_kn), - offsets=(0, 0), - block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1), - ) - v_offset = (off_z * stride_vz + off_h_k * stride_vh + - cu_seqlens_k_start * stride_vk) - V_block_ptr = tl.make_block_ptr( - base=V + v_offset, - shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), - strides=(stride_vk, stride_vn), - offsets=(0, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(1, 0), - ) - if BIAS_TYPE != 0: - bias_ptr = tl.make_block_ptr( - base=bias + off_h_q * stride_bh, - shape=(seqlen_q, seqlen_k), - strides=(stride_bm, stride_bn), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_N), - order=(1, 0), - ) - else: - bias_ptr = None - if ENABLE_DROPOUT: - batch_philox_offset = philox_offset_base \ - + (off_z * HQ + off_h_q) \ - * seqlen_q * seqlen_k - else: - batch_philox_offset = 0 - # We can ask to return the dropout mask without actually doing any dropout. - # In this case, we return an invalid pointer so indicate the mask is not i - # valid. - # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. - if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.make_block_ptr( - base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, - shape=(seqlen_q, seqlen_k), - strides=(seqlen_k, 1), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_N), - order=(1, 0), - ) - else: - encoded_softmax_block_ptr = 0 - # initialize pointer to m and l - m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) - l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # scale sm_scale by log_2(e) and use 2^x in the loop as we do not - # have native e^x support in HW. - qk_scale = sm_scale * 1.44269504089 - # Q is loaded once at the beginning and shared by all N blocks. - q = load_fn(Q_block_ptr, True, padded_head, "zero") - q = (q * qk_scale).to(Q_block_ptr.type.element_ty) - - # Here we compute how many full and masked blocks we have. - padded_block_k = n_extra_tokens != 0 - is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) - if IS_CAUSAL: - # There are always at least BLOCK_M // BLOCK_N masked blocks. - # Additionally there might be one more due to dissimilar seqlens. - masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) - else: - # Padding on Q does not need to be masked in the FA loop. - masked_blocks = padded_block_k - # if IS_CAUSAL, not is_modulo_mn does not always result in an additional - # block. In this case we might exceed n_blocks so pick the min. - masked_blocks = min(masked_blocks, n_blocks) - n_full_blocks = n_blocks - masked_blocks - block_min = 0 - block_max = n_blocks * BLOCK_N - # Compute for full blocks. Here we set causal to false regardless of its - # value because there is no masking. Similarly we do not need padding. - if n_full_blocks > 0: - block_max = (n_blocks - masked_blocks) * BLOCK_N - acc, l_i, m_i = _attn_fwd_inner( - acc, - l_i, - m_i, - q, - K_block_ptr, - V_block_ptr, - start_m, - seqlen_k, - dropout_p, - philox_seed, - batch_philox_offset, - encoded_softmax_block_ptr, - # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ - block_min, - block_max, - 0, - 0, - 0, - bias_ptr, - # IS_CAUSAL, .... - False, - BLOCK_M, - BLOCK_DMODEL, - BLOCK_N, - offs_m, - offs_n, - # _, MASK_STEPS, ... - PRE_LOAD_V, - False, - ENABLE_DROPOUT, - RETURN_ENCODED_SOFTMAX, - padded_head, - ) - block_min = block_max - block_max = n_blocks * BLOCK_N - - tl.debug_barrier() - # Remaining blocks, if any, are full / not masked. - if masked_blocks > 0: - offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0 - K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N)) - V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0)) - if bias_ptr is not None: - bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) - if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, - (0, n_full_blocks)) - acc, l_i, m_i = _attn_fwd_inner( - acc, - l_i, - m_i, - q, - K_block_ptr, - V_block_ptr, - start_m, - seqlen_k, - dropout_p, - philox_seed, - batch_philox_offset, - encoded_softmax_block_ptr, - block_min, - block_max, - offs_n_causal, - masked_blocks, - n_extra_tokens, - bias_ptr, - IS_CAUSAL, - BLOCK_M, - BLOCK_DMODEL, - BLOCK_N, - offs_m, - offs_n, - # _, MASK_STEPS, ... - PRE_LOAD_V, - True, - ENABLE_DROPOUT, - RETURN_ENCODED_SOFTMAX, - padded_head, - ) - # epilogue - acc = acc / l_i[:, None] - if ENABLE_DROPOUT: - acc = acc / (1 - dropout_p) - # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, - # then we have one block with a row of all NaNs which come from computing - # softmax over a row of all -infs (-inf - inf = NaN). We check for that here - # and store 0s where there are NaNs as these rows should've been zeroed out. - end_m_idx = (start_m + 1) * BLOCK_M - start_m_idx = start_m * BLOCK_M - causal_start_idx = seqlen_q - seqlen_k - acc = acc.to(Out.type.element_ty) - if IS_CAUSAL: # noqa: SIM102 - if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: - out_mask_boundary = tl.full((BLOCK_DMODEL, ), - causal_start_idx, - dtype=tl.int32) - mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) - out_ptrs_mask = (mask_m_offsets[:, None] - >= out_mask_boundary[None, :]) - z = 0.0 - acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) - # write back LSE - # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m - # If seqlen_q not multiple of BLOCK_M, we need to mask out the last - # few rows. This is only true for the last M block. For others, - # overflow_size will be -ve - # overflow_size = end_m_idx - seqlen_q - # if overflow_size > 0: - # boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) - # # This is a > check because mask being 0 blocks the store. - # l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) - # tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) - # else: - # tl.store(l_ptrs, m_i + tl.math.log2(l_i)) - - # write back O - o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + - off_h_q * stride_oh) - O_block_ptr = tl.make_block_ptr( - base=Out + o_offset, - shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - # Need boundary check on this to make sure the padding from the - # Q and KV tensors in both dims are not part of what we store back. - # TODO: Do the boundary check optionally. - tl.store(O_block_ptr, acc, boundary_check=(0, 1)) - - -def check_args( - q, - k, - v, - o, - varlen=True, - max_seqlens=None, - cu_seqlens_q=None, - cu_seqlens_k=None, -): - assert q.dim() == k.dim() and q.dim() == v.dim() - if varlen: - assert q.dim() == 3 - total_q, nheads_q, head_size = q.shape - total_k, nheads_k, _ = k.shape - assert cu_seqlens_q is not None - assert cu_seqlens_k is not None - assert len(cu_seqlens_q) == len(cu_seqlens_k) - else: - assert q.dim() == 4 - batch, nheads_q, seqlen_q, head_size = q.shape - _, nheads_k, seqlen_k, _ = k.shape - assert max_seqlens > 0 - assert k.shape == v.shape - assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] - # TODO: Change assert if we support qkl f8 and v f16 - assert q.dtype == k.dtype and q.dtype == v.dtype - assert head_size <= 256 - assert o.shape == q.shape - assert (nheads_q % nheads_k) == 0 + if continue_condition: + # Now we compute whether we need to exit early due to causal + # masking. This is because for seqlen_q > seqlen_k, M rows of the + # attn scores are completely masked, resulting in 0s written to the + # output, and inf written to LSE. We don't need to do any GEMMs in + # this case. This block of code determines what N is, and if this + # WG is operating on those M rows. + n_blocks = cdiv_fn(seqlen_k, BLOCK_N) + if (IS_CAUSAL): + # If seqlen_q == seqlen_k, the attn scores are a square matrix. + # If seqlen_q != seqlen_k, attn scores are rectangular which + # means the causal mask boundary is bottom right aligned, and + # ends at either the top edge (seqlen_q < seqlen_k) or left + # edge. This captures the decrease in n_blocks if we have a + # rectangular attn matrix + n_blocks_seqlen = cdiv_fn( + (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + # This is what adjusts the block_max for the current WG, only + # if IS_CAUSAL. Otherwise we want to always iterate through all + # n_blocks + n_blocks = min(n_blocks, n_blocks_seqlen) + # If we have no blocks after adjusting for seqlen deltas, this + # WG is part of the blocks that are all 0. We exit early. + if n_blocks <= 0: + o_offset = (Out + off_z * stride_oz + off_h_q * stride_oh + + cu_seqlens_q_start * stride_om) + o_ptrs = (o_offset + offs_m[:, None] * stride_om + + offs_d[None, :] * stride_on) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + o_ptrs_mask = (offs_m[:, None] < seqlen_q).broadcast_to( + [BLOCK_M, BLOCK_DMODEL]) + # We still need to write 0s to the result + tl.store(o_ptrs, acc, mask=o_ptrs_mask) + # The tensor allocated for L is based on MAX_SEQLENS_Q as + # that is statically known. + l_ptrs = (L + off_z * HQ * MAX_SEQLENS_Q + + off_h_q * MAX_SEQLENS_Q + offs_m) + # We store inf to LSE, not -inf because in the bwd pass, + # we subtract this from qk which makes it -inf, such that + # exp(qk - inf) = 0 for these masked blocks. + l_value = tl.full([BLOCK_M], + value=float("inf"), + dtype=tl.float32) + l_ptrs_mask = offs_m < MAX_SEQLENS_Q + tl.store(l_ptrs, l_value, mask=l_ptrs_mask) + # TODO: Should dropout and return encoded softmax be + # handled here too? + continue_condition = False + # return + + if continue_condition: + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE: tl.constexpr = HQ // HK + off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q + n_extra_tokens = 0 + if seqlen_k < BLOCK_N: + n_extra_tokens = BLOCK_N - seqlen_k + elif seqlen_k % BLOCK_N: + n_extra_tokens = seqlen_k % BLOCK_N + USE_PADDED_HEAD: tl.constexpr = (IS_ACTUAL_BLOCK_DMODEL + != BLOCK_DMODEL) + + # Compute pointers for all the tensors used in this kernel. + q_offset = (Q + off_z * stride_qz + off_h_q * stride_qh + + cu_seqlens_q_start * stride_qm) + q_ptrs = (q_offset + offs_m[:, None] * stride_qm + + offs_d[None, :] * stride_qk) + k_offset = (K + off_z * stride_kz + off_h_k * stride_kh + + cu_seqlens_k_start * stride_kn) + k_ptrs = (k_offset + offs_d[:, None] * stride_kk + + offs_n[None, :] * stride_kn) + v_offset = (V + off_z * stride_vz + off_h_k * stride_vh + + cu_seqlens_k_start * stride_vk) + v_ptrs = (v_offset + offs_n[:, None] * stride_vk + + offs_d[None, :] * stride_vn) + # Compute pointers for all scale tensors used in this kernel. + + IS_EIGHT_BIT_GEMM: tl.constexpr = IS_EIGHT_BIT & ( + not IS_EIGHT_BIT_KV) + if IS_EIGHT_BIT: + if k_descale_has_singleton: + k_descale_ptrs = k_descale_ptr + else: + k_descale_ptrs = k_descale_ptr + off_h_k + + if v_descale_has_singleton: + v_descale_ptrs = v_descale_ptr + else: + v_descale_ptrs = v_descale_ptr + off_h_k + + if not IS_EIGHT_BIT_KV: + if q_descale_has_singleton: + q_descale_ptrs = q_descale_ptr + else: + q_descale_ptrs = q_descale_ptr + off_h_q + if USE_P_SCALE: + if p_descale_has_singleton: + p_scale_ptrs = p_scale_ptr + p_descale_ptrs = p_descale_ptr + else: + p_scale_ptrs = p_scale_ptr + off_h_q + p_descale_ptrs = p_descale_ptr + off_h_q + + if USE_BIAS: + bias_offset = off_h_q * stride_bh + bias_ptrs = (bias + bias_offset + offs_m[:, None] * stride_bm + + offs_n[None, :] * stride_bn) + else: + bias_ptrs = None + + if USE_ALIBI: + a_offset = off_z * stride_az + off_h_q * stride_ah + alibi_slope = tl.load(alibi_slopes + a_offset) + else: + alibi_slope = None + + batch_philox_offset = 0 + # We can ask to return the dropout mask without doing any + # dropout. In this case, we return an invalid pointer so + # indicate the mask is not valid. + if SHOULD_RETURN_ENCODED_SOFTMAX: + encoded_sm_base = (encoded_softmax + + off_h_q * seqlen_q * seqlen_k) + encoded_sm_ptrs = (encoded_sm_base + + offs_m[:, None] * seqlen_k + + offs_n[None, :]) + else: + encoded_sm_ptrs = None + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use 2^x in the loop as we do + # not have native e^x support in HW. + QK_SCALE: tl.constexpr = SM_SCALE * 1.44269504089 + # Q is loaded once at the beginning and shared by all N blocks. + q_ptrs_mask = offs_m[:, None] < seqlen_q + if USE_PADDED_HEAD: + q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] + < IS_ACTUAL_BLOCK_DMODEL) + q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) + + if IS_EIGHT_BIT: + k_descale = tl.load(k_descale_ptrs) + v_descale = tl.load(v_descale_ptrs) + q_descale = None if IS_EIGHT_BIT_KV else tl.load( + q_descale_ptrs) + if USE_P_SCALE: + p_scale = tl.load(p_scale_ptrs) + p_descale = tl.load(p_descale_ptrs) + else: + p_scale = None + p_descale = None + else: + q_descale = None + k_descale = None + v_descale = None + p_scale = None + p_descale = None + # Here we compute how many full and masked blocks we have. + padded_block_k = n_extra_tokens != 0 + is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) + if IS_CAUSAL: + # There are always at least BLOCK_M // BLOCK_N masked + # blocks. Additionally there might be one more due to + # dissimilar seqlens. + masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) + else: + # Padding on Q does not need to be masked in the FA loop. + masked_blocks = padded_block_k + # if IS_CAUSAL, not is_modulo_mn does not always result in an + # additional block. In this case we might exceed n_blocks so + # pick the min. + masked_blocks = min(masked_blocks, n_blocks) + n_full_blocks = n_blocks - masked_blocks + block_min = 0 + block_max = n_blocks * BLOCK_N + # Compute for full blocks. Here we set causal to false + # regardless of its actual value because there is no masking. + # Similarly we do not need padding. + if n_full_blocks > 0: + block_max = (n_blocks - masked_blocks) * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + k_ptrs, + v_ptrs, + bias_ptrs, + stride_kn, + stride_vk, + stride_bn, + start_m, + seqlen_k, + seqlen_q, + philox_seed, + batch_philox_offset, + encoded_sm_ptrs, + # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ + block_min, + block_max, + 0, + 0, + 0, + alibi_slope, + q_descale, + k_descale, + v_descale, + p_scale, + # IS_CAUSAL, .... + False, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, SHOULD_MASK_STEPS, ... + SHOULD_PRE_LOAD_V, + False, + SHOULD_RETURN_ENCODED_SOFTMAX, + USE_PADDED_HEAD, + IS_ACTUAL_BLOCK_DMODEL, + QK_SCALE, + IS_EIGHT_BIT_GEMM, + USE_P_SCALE, + IS_EIGHT_BIT_KV, + QUANT_DTYPE) + block_min = block_max + block_max = n_blocks * BLOCK_N + + tl.debug_barrier() + # Remaining blocks, if any, are full / not masked. + if (masked_blocks > 0): + if IS_CAUSAL: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) + else: + offs_n_causal = 0 + k_ptrs += n_full_blocks * BLOCK_N * stride_kn + v_ptrs += n_full_blocks * BLOCK_N * stride_vk + if USE_BIAS: + bias_ptrs += n_full_blocks * BLOCK_N * stride_bn + if SHOULD_RETURN_ENCODED_SOFTMAX: + encoded_sm_ptrs += n_full_blocks * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + k_ptrs, + v_ptrs, + bias_ptrs, + stride_kn, + stride_vk, + stride_bn, + start_m, + seqlen_k, + seqlen_q, + philox_seed, + batch_philox_offset, + encoded_sm_ptrs, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + alibi_slope, + q_descale, + k_descale, + v_descale, + p_scale, + IS_CAUSAL, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, SHOULD_MASK_STEPS, ... + SHOULD_PRE_LOAD_V, + True, + SHOULD_RETURN_ENCODED_SOFTMAX, + USE_PADDED_HEAD, + IS_ACTUAL_BLOCK_DMODEL, + QK_SCALE, + IS_EIGHT_BIT_GEMM, + USE_P_SCALE, + IS_EIGHT_BIT_KV, + QUANT_DTYPE) + + if IS_EIGHT_BIT and not IS_EIGHT_BIT_KV: + if USE_P_SCALE: + acc *= p_descale + acc *= v_descale + + # epilogue + # This helps the compiler do Newton Raphson on l_i vs on acc + # which is much larger. + l_recip = 1 / l_i[:, None] + acc = acc * l_recip + + # If seqlen_q > seqlen_k but the delta is not a multiple of + # BLOCK_M, then we have one block with a row of all NaNs which + # come from computing softmax over a row of all + # -infs (-inf - inf = NaN). We check for that here and store 0s + # where there are NaNs as these rows should've been zeroed out. + end_m_idx = (start_m + 1) * BLOCK_M + start_m_idx = start_m * BLOCK_M + causal_start_idx = seqlen_q - seqlen_k + if IS_EIGHT_BIT and not IS_EIGHT_BIT_KV: # noqa: SIM102 + if o_descale_ptr is not None: + acc = quant_fp8(acc, o_descale) + + acc = acc.to(Out.type.element_ty) + if IS_CAUSAL: # noqa: SIM102 + if (causal_start_idx > start_m_idx + and causal_start_idx < end_m_idx): + out_mask_boundary = tl.full((BLOCK_DMODEL, ), + causal_start_idx, + dtype=tl.int32) + mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) + out_ptrs_mask = (mask_m_offsets[:, None] + >= out_mask_boundary[None, :]) + z = tl.zeros((1, ), tl.float32) + acc = tl.where(out_ptrs_mask, acc, + z.to(acc.type.element_ty)) + # write back LSE + l_ptrs = (L + off_z * HQ * MAX_SEQLENS_Q + + off_h_q * MAX_SEQLENS_Q + offs_m) + # If seqlen_q not multiple of BLOCK_M, we need to mask out the + # last few rows. This is only true for the last M block. + # For others, overflow_size will be -ve + overflow_size = end_m_idx - seqlen_q + if overflow_size > 0: + boundary = tl.full((BLOCK_M, ), + BLOCK_M - overflow_size, + dtype=tl.int32) + l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary + tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) + else: + tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + + # write back O + o_offset = (Out + off_z * stride_oz + off_h_q * stride_oh + + cu_seqlens_q_start * stride_om) + o_ptrs = (o_offset + offs_m[:, None] * stride_om + + offs_d[None, :] * stride_on) + o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1) + if overflow_size > 0: + o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q) + if USE_PADDED_HEAD: + o_ptrs_mask = o_ptrs_mask & (offs_d[None, :] + < IS_ACTUAL_BLOCK_DMODEL) + tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) + + +def get_shape_from_layout(q, k, metadata): + assert metadata.layout in SUPPORTED_LAYOUTS, "Got unsupported layout." + + if metadata.layout == 'thd': + nheads_q, nheads_k = q.shape[1], k.shape[1] + head_size = q.shape[-1] + batch = metadata.num_contexts + elif metadata.layout == 'bhsd': + batch, nheads_q, _, head_size = q.shape + nheads_k = k.shape[1] + elif metadata.layout == 'bshd': + batch, _, nheads_q, head_size = q.shape + nheads_k = k.shape[2] + return batch, nheads_q, nheads_k, head_size + + +def get_strides_from_layout(q, k, v, o, metadata): + assert metadata.layout in SUPPORTED_LAYOUTS, "Got unsupported layout." + + STRIDE_PERMUTATIONS = { + 'thd': (None, 1, 0, 2), + 'bhsd': (0, 1, 2, 3), + 'bshd': (0, 2, 1, 3), + } + + perm = STRIDE_PERMUTATIONS[metadata.layout] + stride = lambda x, p: (0 if p is None else x.stride(p)) + strides = lambda x: (stride(x, p) for p in perm) + + return tuple(strides(x) for x in [q, k, v, o]) class _attention(torch.autograd.Function): @staticmethod - def forward( - ctx, - q, - k, - v, - o, - cu_seqlens_q, - cu_seqlens_k, - max_seqlens_q, - max_seqlens_k, - causal=False, - sm_scale=1.0, - bias=None, - ): + def forward(ctx, q, k, v, o, metadata: MetaData): + # NOTE: a large bias tensor leads to overflow during pointer arithmetic + if (metadata.bias is not None): + assert (metadata.bias.numel() < 2**31) + if o is None: - o = torch.empty_like(q, dtype=v.dtype) + if metadata.eight_bit: + o = torch.empty_like( + q, + dtype=metadata.output_dtype if metadata.output_dtype + is not None else metadata.eight_bit_dtype_torch) + else: + o = torch.empty_like(q, dtype=q.dtype) - check_args( - q, - k, - v, - o, - varlen=True, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - ) - if True: # varlen - total_q, nheads_q, head_size = q.shape - total_k, nheads_k, _ = k.shape - batch = len(cu_seqlens_q) - 1 - q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) - k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) - v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) - o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) - else: - batch, seqlen_q, nheads_q, head_size = q.shape - _, seqlen_k, nheads_k, _ = k.shape - q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) - k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) - v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) - o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + metadata.check_args(q, k, v, o) + + batch, nheads_q, nheads_k, head_size = get_shape_from_layout( + q, k, metadata) + q_strides, k_strides, v_strides, o_strides = get_strides_from_layout( + q, k, v, o, metadata) # Get closest power of 2 over or equal to 32. - unpadded_head_dims = {32, 64, 128, 256} - if head_size not in unpadded_head_dims: - padded_d_model = None - for i in unpadded_head_dims: - if i > head_size: - padded_d_model = i - break - assert padded_d_model is not None - else: - padded_d_model = head_size + padded_d_model = 1 << (head_size - 1).bit_length() + # Smallest head_dim supported is 16. If smaller, the tile in the + # kernel is padded - there is no padding in memory for any dims. + padded_d_model = max(padded_d_model, 16) - grid = lambda META: ( - triton.cdiv(max_seqlens_q, META["BLOCK_M"]), - nheads_q, - batch, - ) + # encoded_softmax is used to validate dropout behavior vs the + # PyTorch SDPA math backend reference. We zero this out to give a + # consistent starting point and then populate it with the output of + # softmax with the sign bit set according to the dropout mask. + # The resulting return allows this mask to be fed into the reference + # implementation for testing only. This return holds no useful output + # aside from debugging. + if metadata.return_encoded_softmax: + encoded_softmax = torch.zeros( + (q.shape[0], q.shape[1], q.shape[2], k.shape[2]), + device=q.device, + dtype=torch.float32) + else: + encoded_softmax = None - encoded_softmax = None + M = torch.empty((batch, nheads_q, metadata.max_seqlens_q), + device=q.device, + dtype=torch.float32) # Seed the RNG so we get reproducible results for testing. philox_seed = 0x1BF52 philox_offset = 0x1D4B42 - if bias is not None: - bias_strides = ( - bias.stride(0), - bias.stride(1), - bias.stride(2), - bias.stride(3), - ) + if metadata.bias is not None: + bias_strides = (metadata.bias.stride(0), metadata.bias.stride(1), + metadata.bias.stride(2), metadata.bias.stride(3)) else: bias_strides = (0, 0, 0, 0) + if metadata.alibi_slopes is not None: + alibi_strides = (metadata.alibi_slopes.stride(0), + metadata.alibi_slopes.stride(1)) + else: + alibi_strides = (0, 0) + + if metadata.eight_bit: + q_descale, k_descale, p_scale, p_descale, v_descale, o_scale = ( + metadata.q_descale, metadata.k_descale, metadata.p_scale, + metadata.p_descale, metadata.v_descale, metadata.o_scale) + o_descale = 1.0 / o_scale if o_scale is not None else None + else: + q_descale = k_descale = p_scale = None + p_descale = v_descale = o_descale = None + + # number of compute units available + NUM_CU = torch.cuda.get_device_properties("cuda").multi_processor_count + + grid = lambda META: (triton.cdiv(metadata.max_seqlens_q, META[ + 'BLOCK_M']), nheads_q, batch) + attn_fwd[grid]( q, k, v, - bias, - sm_scale, - None, + metadata.bias, + metadata.sm_scale, + M, o, *q_strides, *k_strides, *v_strides, *o_strides, *bias_strides, - cu_seqlens_q, - cu_seqlens_k, - dropout_p=0.0, + *alibi_strides, + q_descale, + k_descale, + p_scale, + p_descale, + o_descale, + v_descale, + q_descale.numel() == 1 if q_descale is not None else False, + k_descale.numel() == 1 if k_descale is not None else False, + p_descale.numel() == 1 if p_descale is not None else False, + v_descale.numel() == 1 if v_descale is not None else False, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, philox_seed=philox_seed, philox_offset_base=philox_offset, encoded_softmax=encoded_softmax, + alibi_slopes=metadata.alibi_slopes, HQ=nheads_q, HK=nheads_k, - ACTUAL_BLOCK_DMODEL=head_size, - MAX_SEQLENS_Q=max_seqlens_q, - MAX_SEQLENS_K=max_seqlens_k, - IS_CAUSAL=causal, - VARLEN=True, + IS_ACTUAL_BLOCK_DMODEL=head_size, + MAX_SEQLENS_Q=metadata.max_seqlens_q, + MAX_SEQLENS_K=metadata.max_seqlens_k, + IS_CAUSAL=metadata.causal, + VARLEN=metadata.varlen, BLOCK_DMODEL=padded_d_model, - BIAS_TYPE=0 if bias is None else 1, - ENABLE_DROPOUT=False, - RETURN_ENCODED_SOFTMAX=False, - ) + USE_BIAS=metadata.bias is not None, + USE_ALIBI=metadata.alibi_slopes is not None, + SHOULD_RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax, + IS_EIGHT_BIT=metadata.eight_bit, + USE_P_SCALE=metadata.eight_bit and metadata.use_p_scale, + IS_EIGHT_BIT_KV=metadata.eight_bit and metadata.eight_bit_kv, + NUM_CU=NUM_CU, + B=batch, + QUANT_DTYPE=metadata.eight_bit_dtype_triton) ctx.grid = grid - ctx.sm_scale = sm_scale + ctx.sm_scale = metadata.sm_scale ctx.BLOCK_DMODEL = head_size - ctx.causal = causal - ctx.dropout_p = 0.0 + ctx.causal = metadata.causal + ctx.alibi_slopes = metadata.alibi_slopes ctx.philox_seed = philox_seed ctx.philox_offset = philox_offset ctx.encoded_softmax = encoded_softmax - ctx.return_encoded_softmax = False + ctx.return_encoded_softmax = metadata.return_encoded_softmax return o, encoded_softmax -triton_attention = _attention.apply +triton_attention_rocm = _attention.apply + + +def scale_fp8(t, scale=None): + t_scaled, scale_out = ops.scaled_fp8_quant(t.reshape(-1, t.shape[-1]), + scale) + return t_scaled.reshape(t.shape), scale_out + + +def maybe_quantize_fp8(t, scale): + eight_bit_dtype = current_platform.fp8_dtype() + if t.dtype != eight_bit_dtype: + t, _ = scale_fp8(t, scale) + return t + + +def check_and_maybe_quantize_qkv(q, k, v, fp8_scales): + (q_scale, k_scale, v_scale, p_scale) = fp8_scales + + q = maybe_quantize_fp8(q, q_scale) + k = maybe_quantize_fp8(k, k_scale) + v = maybe_quantize_fp8(v, v_scale) + + return q, k, v + + +# query - [num_tokens, num_heads, head_size] +# key - [num_tokens, num_kv_heads, head_size] +# value - [num_tokens, num_kv_heads, head_size +# output - [num_tokens, num_heads, head_size] +def triton_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlens_q: int, + max_seqlens_k: int, + causal: bool = False, + sm_scale: float = 1.0, + bias: Optional[torch.Tensor] = None, + fp8_scales: Optional[tuple[float, ...]] = None, + input_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if fp8_scales is not None: + q_descale, k_descale, v_descale, p_scale = fp8_scales + else: + q_descale = k_descale = v_descale = p_scale = None + + attn_metadata = MetaData(sm_scale=sm_scale, + max_seqlens_q=max_seqlens_q, + max_seqlens_k=max_seqlens_k, + causal=causal, + bias=bias, + output_dtype=q.dtype, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + p_scale=p_scale, + o_scale=input_scale) + + if fp8_scales is not None: + q, k, v = check_and_maybe_quantize_qkv(q, k, v, fp8_scales) + + return triton_attention_rocm(q, k, v, o, attn_metadata) diff --git a/vllm/attention/ops/triton_merge_attn_states.py b/vllm/attention/ops/triton_merge_attn_states.py index 9671b933f47b9bbc9fb2800db46c29396e42e3c2..250426d9faa5bc2aafa7f844dc1dff3d6edee0ef 100644 --- a/vllm/attention/ops/triton_merge_attn_states.py +++ b/vllm/attention/ops/triton_merge_attn_states.py @@ -66,7 +66,10 @@ def merge_attn_states_kernel( max_lse = tl.maximum(p_lse, s_lse) p_lse = p_lse - max_lse s_lse = s_lse - max_lse - out_se = (tl.exp(p_lse) + tl.exp(s_lse)) + # Will reuse precomputed Exp values for scale factor computation. + p_se = tl.exp(p_lse) + s_se = tl.exp(s_lse) + out_se = (p_se + s_se) if OUTPUT_LSE: out_lse = tl.log(out_se) + max_lse @@ -84,8 +87,8 @@ def merge_attn_states_kernel( # NOTE(woosuk): Be careful with the numerical stability. # We should compute the scale first, and then multiply it with the output. # Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly. - p_scale = tl.exp(p_lse) / out_se - s_scale = tl.exp(s_lse) / out_se + p_scale = p_se / out_se + s_scale = s_se / out_se out = p_out * p_scale + s_out * s_scale tl.store(output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange, diff --git a/vllm/vllm_flash_attn/fa_utils.py b/vllm/attention/utils/fa_utils.py similarity index 100% rename from vllm/vllm_flash_attn/fa_utils.py rename to vllm/attention/utils/fa_utils.py diff --git a/vllm/beam_search.py b/vllm/beam_search.py index 5d4ebdb7acbcfe89c3fa7fe2dc2fcb7a706d2b38..967510abaeb9b8507b7f94dc9ea9dc5bd159ab8e 100644 --- a/vllm/beam_search.py +++ b/vllm/beam_search.py @@ -38,9 +38,18 @@ class BeamSearchOutput: class BeamSearchInstance: - def __init__(self, prompt_tokens: list[int]): + def __init__( + self, + prompt_tokens: list[int], + logprobs: Optional[list[dict[int, Logprob]]] = None, + **kwargs, + ): self.beams: list[BeamSearchSequence] = [ - BeamSearchSequence(tokens=prompt_tokens, logprobs=[]) + BeamSearchSequence( + tokens=prompt_tokens, + logprobs=[] if logprobs is None else list(logprobs), + **kwargs, + ) ] self.completed: list[BeamSearchSequence] = [] diff --git a/vllm/benchmarks/__init__.py b/vllm/benchmarks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..299c888c2e7b993a450e6650b2e9a832b4b50366 --- /dev/null +++ b/vllm/benchmarks/datasets.py @@ -0,0 +1,831 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This module defines a framework for sampling benchmark requests from various +datasets. Each dataset subclass of BenchmarkDataset must implement sample +generation. Supported dataset types include: + - ShareGPT + - Random (synthetic) + - Sonnet + - BurstGPT + - HuggingFace + - VisionArena + +TODO: Implement CustomDataset to parse a JSON file and convert its contents into +SampleRequest instances, similar to the approach used in ShareGPT. +""" + +import base64 +import io +import json +import logging +import random +from abc import ABC, abstractmethod +from collections.abc import Mapping +from dataclasses import dataclass +from functools import cache +from io import BytesIO +from typing import Any, Callable, Optional, Union + +import numpy as np +from PIL import Image +from transformers import PreTrainedTokenizerBase + +from vllm.lora.request import LoRARequest +from vllm.lora.utils import get_adapter_absolute_path +from vllm.multimodal import MultiModalDataDict +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer + +logger = logging.getLogger(__name__) + +# ----------------------------------------------------------------------------- +# Data Classes +# ----------------------------------------------------------------------------- + + +@dataclass +class SampleRequest: + """ + Represents a single inference request for benchmarking. + """ + + prompt: Union[str, Any] + prompt_len: int + expected_output_len: int + multi_modal_data: Optional[Union[MultiModalDataDict, dict]] = None + lora_request: Optional[LoRARequest] = None + + +# ----------------------------------------------------------------------------- +# Benchmark Dataset Base Class +# ----------------------------------------------------------------------------- + + +class BenchmarkDataset(ABC): + DEFAULT_SEED = 0 + + def __init__( + self, + dataset_path: Optional[str] = None, + random_seed: int = DEFAULT_SEED, + ) -> None: + """ + Initialize the BenchmarkDataset with an optional dataset path and random + seed. + + Args: + dataset_path (Optional[str]): Path to the dataset. If None, it + indicates that a default or random dataset might be used. + random_seed (int): Seed value for reproducible shuffling or + sampling. Defaults to DEFAULT_SEED. + """ + self.dataset_path = dataset_path + # Set the random seed, ensuring that a None value is replaced with the + # default seed. + self.random_seed = (random_seed + if random_seed is not None else self.DEFAULT_SEED) + self.data = None + + def apply_multimodal_chat_transformation( + self, + prompt: str, + mm_content: Optional[MultiModalDataDict] = None) -> list[dict]: + """ + Transform a prompt and optional multimodal content into a chat format. + This method is used for chat models that expect a specific conversation + format. + """ + content = [{"text": prompt, "type": "text"}] + if mm_content is not None: + content.append(mm_content) + return [{"role": "user", "content": content}] + + def load_data(self) -> None: + """ + Load data from the dataset path into self.data. + + This method must be overridden by subclasses since the method to load + data will vary depending on the dataset format and source. + + Raises: + NotImplementedError: If a subclass does not implement this method. + """ + # TODO (jenniferzhao): add support for downloading data + raise NotImplementedError( + "load_data must be implemented in subclasses.") + + def get_random_lora_request( + self, + tokenizer: PreTrainedTokenizerBase, + max_loras: Optional[int] = None, + lora_path: Optional[str] = None, + ) -> tuple[Optional[LoRARequest], AnyTokenizer]: + """ + Optionally select a random LoRA request and return its associated + tokenizer. + + This method is used when LoRA parameters are provided. It randomly + selects a LoRA based on max_loras and retrieves a cached tokenizer for + that LoRA if available. Otherwise, it returns the base tokenizer. + + Args: + tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no + LoRA is selected. max_loras (Optional[int]): The maximum number of + LoRAs available. If None, LoRA is not used. lora_path + (Optional[str]): Path to the LoRA parameters on disk. If None, LoRA + is not used. + + Returns: + tuple[Optional[LoRARequest], AnyTokenizer]: A tuple where the first + element is a LoRARequest (or None if not applicable) and the second + element is the tokenizer associated with the LoRA request (or the + base tokenizer). + """ + if max_loras is None or lora_path is None: + return None, tokenizer + + # Generate a random LoRA ID in the range [1, max_loras]. + lora_id = random.randint(1, max_loras) + lora_request = LoRARequest( + lora_name=str(lora_id), + lora_int_id=lora_id, + lora_path=lora_path_on_disk(lora_path), + ) + if lora_id not in lora_tokenizer_cache: + lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request) + # Return lora_request and the cached tokenizer if available; otherwise, + # return the base tokenizer + return lora_request, lora_tokenizer_cache[lora_id] or tokenizer + + @abstractmethod + def sample(self, tokenizer: PreTrainedTokenizerBase, + num_requests: int) -> list[SampleRequest]: + """ + Abstract method to generate sample requests from the dataset. + + Subclasses must override this method to implement dataset-specific logic + for generating a list of SampleRequest objects. + + Args: + tokenizer (PreTrainedTokenizerBase): The tokenizer to be used + for processing the dataset's text. + num_requests (int): The number of sample requests to generate. + + Returns: + list[SampleRequest]: A list of sample requests generated from the + dataset. + """ + raise NotImplementedError("sample must be implemented in subclasses.") + + def maybe_oversample_requests(self, requests: list[SampleRequest], + num_requests: int) -> None: + """ + Oversamples the list of requests if its size is less than the desired + number. + + Args: + requests (List[SampleRequest]): The current list of sampled + requests. num_requests (int): The target number of requests. + """ + if len(requests) < num_requests: + random.seed(self.random_seed) + additional = random.choices(requests, + k=num_requests - len(requests)) + requests.extend(additional) + logger.info("Oversampled requests to reach %d total samples.", + num_requests) + + +# ----------------------------------------------------------------------------- +# Utility Functions and Global Caches +# ----------------------------------------------------------------------------- + + +def is_valid_sequence( + prompt_len: int, + output_len: int, + min_len: int = 4, + max_prompt_len: int = 1024, + max_total_len: int = 2048, + skip_min_output_len_check: bool = False, +) -> bool: + """ + Validate a sequence based on prompt and output lengths. + + Default pruning criteria are copied from the original `sample_hf_requests` + and `sample_sharegpt_requests` functions in benchmark_serving.py, as well as + from `sample_requests` in benchmark_throughput.py. + """ + # Check for invalid conditions + prompt_too_short = prompt_len < min_len + output_too_short = (not skip_min_output_len_check) and (output_len + < min_len) + prompt_too_long = prompt_len > max_prompt_len + combined_too_long = (prompt_len + output_len) > max_total_len + + # Return True if none of the invalid conditions are met + return not (prompt_too_short or output_too_short or prompt_too_long + or combined_too_long) + + +@cache +def lora_path_on_disk(lora_path: str) -> str: + return get_adapter_absolute_path(lora_path) + + +# Global cache for LoRA tokenizers. +lora_tokenizer_cache: dict[int, AnyTokenizer] = {} + + +def process_image(image: Any) -> Mapping[str, Any]: + """ + Process a single image input and return a multimedia content dictionary. + + Supports three input types: + + 1. Dictionary with raw image bytes: - Expects a dict with a 'bytes' key + containing raw image data. - Loads the bytes as a PIL.Image.Image. + + 2. PIL.Image.Image input: - Converts the image to RGB. - Saves the image as + a JPEG in memory. - Encodes the JPEG data as a base64 string. - Returns + a dictionary with the image as a base64 data URL. + + 3. String input: - Treats the string as a URL or local file path. - + Prepends "file://" if the string doesn't start with "http://" or + "file://". - Returns a dictionary with the image URL. + + Raises: + ValueError: If the input is not a supported type. + """ + if isinstance(image, dict) and 'bytes' in image: + image = Image.open(BytesIO(image['bytes'])) + if isinstance(image, Image.Image): + image = image.convert("RGB") + with io.BytesIO() as image_data: + image.save(image_data, format="JPEG") + image_base64 = base64.b64encode( + image_data.getvalue()).decode("utf-8") + return { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_base64}" + }, + } + + if isinstance(image, str): + image_url = (image if image.startswith( + ("http://", "file://")) else f"file://{image}") + return {"type": "image_url", "image_url": {"url": image_url}} + + raise ValueError(f"Invalid image input {image}. Must be a PIL.Image.Image" + " or str or dictionary with raw image bytes.") + + +# ----------------------------------------------------------------------------- +# Random Dataset Implementation (Synthetic Data) +# ----------------------------------------------------------------------------- + + +class RandomDataset(BenchmarkDataset): + # Default values copied from benchmark_serving.py for the random dataset. + DEFAULT_PREFIX_LEN = 0 + DEFAULT_RANGE_RATIO = 0.0 + DEFAULT_INPUT_LEN = 1024 + DEFAULT_OUTPUT_LEN = 128 + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + prefix_len: int = DEFAULT_PREFIX_LEN, + range_ratio: float = DEFAULT_RANGE_RATIO, + input_len: int = DEFAULT_INPUT_LEN, + output_len: int = DEFAULT_OUTPUT_LEN, + **kwargs, + ) -> list[SampleRequest]: + # Enforce range_ratio < 1 + assert range_ratio < 1.0, ( + "random_range_ratio must be < 1.0 to ensure a valid sampling range" + ) + + vocab_size = tokenizer.vocab_size + + prefix_token_ids = (np.random.randint( + 0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else []) + + # New sampling logic: [X * (1 - b), X * (1 + b)] + input_low = int(input_len * (1 - range_ratio)) + input_high = int(input_len * (1 + range_ratio)) + output_low = int(output_len * (1 - range_ratio)) + output_high = int(output_len * (1 + range_ratio)) + + # Add logging for debugging + logger.info("Sampling input_len from [%s, %s]", input_low, input_high) + logger.info("Sampling output_len from [%s, %s]", output_low, + output_high) + + input_lens = np.random.randint(input_low, + input_high + 1, + size=num_requests) + output_lens = np.random.randint(output_low, + output_high + 1, + size=num_requests) + offsets = np.random.randint(0, vocab_size, size=num_requests) + + requests = [] + for i in range(num_requests): + inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) % + vocab_size).tolist() + token_sequence = prefix_token_ids + inner_seq + prompt = tokenizer.decode(token_sequence) + total_input_len = prefix_len + int(input_lens[i]) + requests.append( + SampleRequest( + prompt=prompt, + prompt_len=total_input_len, + expected_output_len=int(output_lens[i]), + )) + return requests + + +# ----------------------------------------------------------------------------- +# ShareGPT Dataset Implementation +# ----------------------------------------------------------------------------- + + +class ShareGPTDataset(BenchmarkDataset): + """ + Implements the ShareGPT dataset. Loads data from a JSON file and generates + sample requests based on conversation turns. + """ + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self) -> None: + if self.dataset_path is None: + raise ValueError("dataset_path must be provided for loading data.") + + with open(self.dataset_path, encoding="utf-8") as f: + self.data = json.load(f) + # Filter entries with at least two conversation turns. + self.data = [ + entry for entry in self.data + if "conversations" in entry and len(entry["conversations"]) >= 2 + ] + random.seed(self.random_seed) + random.shuffle(self.data) + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + lora_path: Optional[str] = None, + max_loras: Optional[int] = None, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs, + ) -> list: + samples: list = [] + for entry in self.data: + if len(samples) >= num_requests: + break + prompt, completion = ( + entry["conversations"][0]["value"], + entry["conversations"][1]["value"], + ) + + lora_request, tokenizer = self.get_random_lora_request( + tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) + prompt_ids = tokenizer(prompt).input_ids + completion_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_ids) + new_output_len = (len(completion_ids) + if output_len is None else output_len) + if not is_valid_sequence(prompt_len, + new_output_len, + skip_min_output_len_check=output_len + is not None): + continue + if enable_multimodal_chat: + prompt = self.apply_multimodal_chat_transformation( + prompt, None) + samples.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=new_output_len, + lora_request=lora_request, + )) + self.maybe_oversample_requests(samples, num_requests) + return samples + + +# ----------------------------------------------------------------------------- +# Sonnet Dataset Implementation +# ----------------------------------------------------------------------------- + + +class SonnetDataset(BenchmarkDataset): + """ + Simplified implementation of the Sonnet dataset. Loads poem lines from a + text file and generates sample requests. Default values here copied from + `benchmark_serving.py` for the sonnet dataset. + """ + + DEFAULT_PREFIX_LEN = 200 + DEFAULT_INPUT_LEN = 550 + DEFAULT_OUTPUT_LEN = 150 + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self) -> None: + if not self.dataset_path: + raise ValueError("dataset_path must be provided.") + with open(self.dataset_path, encoding="utf-8") as f: + self.data = f.readlines() + + def sample( + self, + tokenizer, + num_requests: int, + prefix_len: int = DEFAULT_PREFIX_LEN, + input_len: int = DEFAULT_INPUT_LEN, + output_len: int = DEFAULT_OUTPUT_LEN, + return_prompt_formatted: bool = False, + **kwargs, + ) -> list: + # Calculate average token length for a poem line. + tokenized_lines = [tokenizer(line).input_ids for line in self.data] + avg_len = sum(len(tokens) + for tokens in tokenized_lines) / len(tokenized_lines) + + # Build the base prompt. + base_prompt = "Pick as many lines as you can from these poem lines:\n" + base_msg = [{"role": "user", "content": base_prompt}] + base_fmt = tokenizer.apply_chat_template(base_msg, + add_generation_prompt=True, + tokenize=False) + base_offset = len(tokenizer(base_fmt).input_ids) + if input_len <= base_offset: + raise ValueError( + f"'input_len' must be higher than the base prompt length " + f"({base_offset}).") + + # Determine how many poem lines to use. + num_input_lines = round((input_len - base_offset) / avg_len) + num_prefix_lines = max(round((prefix_len - base_offset) / avg_len), 0) + prefix_lines = self.data[:num_prefix_lines] + + samples = [] + while len(samples) < num_requests: + extra_lines = random.choices(self.data, + k=num_input_lines - num_prefix_lines) + prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}" + msg = [{"role": "user", "content": prompt}] + prompt_formatted = tokenizer.apply_chat_template( + msg, add_generation_prompt=True, tokenize=False) + prompt_len = len(tokenizer(prompt_formatted).input_ids) + if prompt_len <= input_len: + samples.append( + SampleRequest( + prompt=prompt_formatted + if return_prompt_formatted else prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + )) + return samples + + +# ----------------------------------------------------------------------------- +# BurstGPT Dataset Implementation +# ----------------------------------------------------------------------------- + + +class BurstGPTDataset(BenchmarkDataset): + """ + Implements the BurstGPT dataset. Loads data from a CSV file and generates + sample requests based on synthetic prompt generation. Only rows with Model + "GPT-4" and positive response tokens are used. + """ + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self, ): + if self.dataset_path is None: + raise ValueError("dataset_path must be provided for loading data.") + + try: + import pandas as pd + except ImportError as e: + raise ImportError( + "Pandas is required for BurstGPTDataset. Please install it " + "using `pip install pandas`.") from e + + df = pd.read_csv(self.dataset_path) + # Filter to keep only GPT-4 rows. + gpt4_df = df[df["Model"] == "GPT-4"] + # Remove failed requests (where Response tokens is 0 or less). + gpt4_df = gpt4_df[gpt4_df["Response tokens"] > 0] + # Sample the desired number of rows. + self.data = gpt4_df + + def _sample_loaded_data(self, num_requests: int) -> list: + if num_requests <= len(self.data): + data = self.data.sample(n=num_requests, + random_state=self.random_seed) + else: + data = self.data.sample( + n=num_requests, + random_state=self.random_seed, + replace=True, + ) + # Convert the dataframe to a list of lists. + return data.values.tolist() + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + max_loras: Optional[int] = None, + lora_path: Optional[str] = None, + **kwargs, + ) -> list[SampleRequest]: + samples = [] + data = self._sample_loaded_data(num_requests=num_requests) + for i in range(num_requests): + input_len = int(data[i][2]) + output_len = int(data[i][3]) + lora_req, tokenizer = self.get_random_lora_request( + tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) + vocab_size = tokenizer.vocab_size + # Generate a synthetic prompt: a list of token IDs computed as (i + + # j) modulo vocab_size. + token_ids = [(i + j) % vocab_size for j in range(input_len)] + prompt = tokenizer.decode(token_ids) + samples.append( + SampleRequest( + prompt=prompt, + prompt_len=input_len, + expected_output_len=output_len, + lora_request=lora_req, + )) + return samples + + +# ----------------------------------------------------------------------------- +# HuggingFace Dataset Base Implementation +# ----------------------------------------------------------------------------- +class HuggingFaceDataset(BenchmarkDataset): + """Base class for datasets hosted on HuggingFace.""" + + SUPPORTED_DATASET_PATHS: Union[set[str], dict[str, Callable]] = set() + + def __init__( + self, + dataset_path: str, + dataset_split: str, + dataset_subset: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(dataset_path=dataset_path, **kwargs) + + self.dataset_split = dataset_split + self.dataset_subset = dataset_subset + self.load_data() + + def load_data(self) -> None: + """Load data from HuggingFace datasets.""" + try: + from datasets import load_dataset + except ImportError as e: + raise ImportError( + "Hugging Face datasets library is required for this dataset. " + "Please install it using `pip install datasets`.") from e + + self.data = load_dataset( + self.dataset_path, + name=self.dataset_subset, + split=self.dataset_split, + streaming=True, + ) + self.data = self.data.shuffle(seed=self.random_seed) + + +# ----------------------------------------------------------------------------- +# Conversation Dataset Implementation +# ----------------------------------------------------------------------------- + + +class ConversationDataset(HuggingFaceDataset): + """Dataset for conversation data with multimodal support.""" + SUPPORTED_DATASET_PATHS = { + 'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered' + } + + def sample(self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs) -> list: + # Filter examples with at least 2 conversations + filtered_data = self.data.filter( + lambda x: len(x["conversations"]) >= 2) + sampled_requests = [] + dynamic_output = output_len is None + + for item in filtered_data: + if len(sampled_requests) >= num_requests: + break + conv = item["conversations"] + prompt, completion = conv[0]["value"], conv[1]["value"] + + prompt_ids = tokenizer(prompt).input_ids + completion_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_ids) + completion_len = len(completion_ids) + output_len = completion_len if dynamic_output else output_len + assert isinstance(output_len, int) and output_len > 0 + if dynamic_output and not is_valid_sequence( + prompt_len, completion_len): + continue + mm_content = process_image( + item["image"]) if "image" in item else None + if enable_multimodal_chat: + # Note: when chat is enabled the request prompt_len is no longer + # accurate and we will be using request output to count the + # actual prompt len and output len + prompt = self.apply_multimodal_chat_transformation( + prompt, mm_content) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=mm_content, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + +# ----------------------------------------------------------------------------- +# Vision Arena Dataset Implementation +# ----------------------------------------------------------------------------- + + +class VisionArenaDataset(HuggingFaceDataset): + """ + Vision Arena Dataset. + """ + + DEFAULT_OUTPUT_LEN = 128 + SUPPORTED_DATASET_PATHS = { + "lmarena-ai/VisionArena-Chat": + lambda x: x["conversation"][0][0]["content"], + "lmarena-ai/vision-arena-bench-v0.1": + lambda x: x["turns"][0][0]["content"] + } + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs, + ) -> list: + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + sampled_requests = [] + for item in self.data: + if len(sampled_requests) >= num_requests: + break + parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path) + if parser_fn is None: + raise ValueError( + f"Unsupported dataset path: {self.dataset_path}") + prompt = parser_fn(item) + mm_content = process_image(item["images"][0]) + prompt_len = len(tokenizer(prompt).input_ids) + if enable_multimodal_chat: + # Note: when chat is enabled the request prompt_len is no longer + # accurate and we will be using request output to count the + # actual prompt len + prompt = self.apply_multimodal_chat_transformation( + prompt, mm_content) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=mm_content, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + +# ----------------------------------------------------------------------------- +# Instruct Coder Dataset Implementation +# ----------------------------------------------------------------------------- + + +class InstructCoderDataset(HuggingFaceDataset): + """ + InstructCoder Dataset. + https://huggingface.co/datasets/likaixin/InstructCoder + + InstructCoder is the dataset designed for general code editing. It consists + of 114,239 instruction-input-output triplets, and covers multiple distinct + code editing scenario. + """ + + DEFAULT_OUTPUT_LEN = 200 # this is the average default output length + SUPPORTED_DATASET_PATHS = { + "likaixin/InstructCoder", + } + + def sample(self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs) -> list: + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + sampled_requests = [] + for item in self.data: + if len(sampled_requests) >= num_requests: + break + prompt = f"{item['instruction']}:\n{item['input']}" + prompt_len = len(tokenizer(prompt).input_ids) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + +# ----------------------------------------------------------------------------- +# AIMO Dataset Implementation +# ----------------------------------------------------------------------------- + + +class AIMODataset(HuggingFaceDataset): + """ + Dataset class for processing a AIMO dataset with reasoning questions. + """ + SUPPORTED_DATASET_PATHS = { + "AI-MO/aimo-validation-aime", "AI-MO/NuminaMath-1.5", + "AI-MO/NuminaMath-CoT" + } + + def sample(self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + **kwargs) -> list: + sampled_requests = [] + dynamic_output = output_len is None + + for item in self.data: + if len(sampled_requests) >= num_requests: + break + prompt, completion = item['problem'], item["solution"] + + prompt_ids = tokenizer(prompt).input_ids + completion_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_ids) + completion_len = len(completion_ids) + output_len = completion_len if dynamic_output else output_len + assert isinstance(output_len, int) and output_len > 0 + if dynamic_output and not is_valid_sequence(prompt_len, + completion_len, + max_prompt_len=2048, + max_total_len=32000): + continue + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=None, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests diff --git a/vllm/benchmarks/latency.py b/vllm/benchmarks/latency.py new file mode 100644 index 0000000000000000000000000000000000000000..06f6848f50cb4c2e171237877ada83fc548a9d23 --- /dev/null +++ b/vllm/benchmarks/latency.py @@ -0,0 +1,181 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Benchmark the latency of processing a single batch of requests.""" + +import argparse +import dataclasses +import json +import os +import time +from pathlib import Path +from typing import Any, Optional + +import numpy as np +import torch +from tqdm import tqdm + +from vllm import LLM, SamplingParams +from vllm.benchmarks.utils import (convert_to_pytorch_benchmark_format, + write_to_json) +from vllm.engine.arg_utils import EngineArgs +from vllm.inputs import PromptType +from vllm.sampling_params import BeamSearchParams + + +def save_to_pytorch_benchmark_format(args: argparse.Namespace, + results: dict[str, Any]) -> None: + pt_records = convert_to_pytorch_benchmark_format( + args=args, + metrics={"latency": results["latencies"]}, + extra_info={k: results[k] + for k in ["avg_latency", "percentiles"]}) + if pt_records: + pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" + write_to_json(pt_file, pt_records) + + +def add_cli_args(parser: argparse.ArgumentParser): + parser.add_argument("--input-len", type=int, default=32) + parser.add_argument("--output-len", type=int, default=128) + parser.add_argument("--batch-size", type=int, default=8) + parser.add_argument( + "--n", + type=int, + default=1, + help="Number of generated sequences per prompt.", + ) + parser.add_argument("--use-beam-search", action="store_true") + parser.add_argument( + "--num-iters-warmup", + type=int, + default=10, + help="Number of iterations to run for warmup.", + ) + parser.add_argument("--num-iters", + type=int, + default=30, + help="Number of iterations to run.") + parser.add_argument( + "--profile", + action="store_true", + help="profile the generation process of a single batch", + ) + parser.add_argument( + "--profile-result-dir", + type=str, + default=None, + help=("path to save the pytorch profiler output. Can be visualized " + "with ui.perfetto.dev or Tensorboard."), + ) + parser.add_argument( + "--output-json", + type=str, + default=None, + help="Path to save the latency results in JSON format.", + ) + parser.add_argument( + "--disable-detokenize", + action="store_true", + help=("Do not detokenize responses (i.e. do not include " + "detokenization time in the latency measurement)"), + ) + + parser = EngineArgs.add_cli_args(parser) + + +def main(args: argparse.Namespace): + print(args) + + engine_args = EngineArgs.from_cli_args(args) + + # NOTE(woosuk): If the request cannot be processed in a single batch, + # the engine will automatically process the request in multiple batches. + llm = LLM(**dataclasses.asdict(engine_args)) + assert llm.llm_engine.model_config.max_model_len >= ( + args.input_len + + args.output_len), ("Please ensure that max_model_len is greater than" + " the sum of input_len and output_len.") + + sampling_params = SamplingParams( + n=args.n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=args.output_len, + detokenize=not args.disable_detokenize, + ) + print(sampling_params) + dummy_prompt_token_ids = np.random.randint(10000, + size=(args.batch_size, + args.input_len)) + dummy_prompts: list[PromptType] = [{ + "prompt_token_ids": batch + } for batch in dummy_prompt_token_ids.tolist()] + + def llm_generate(): + if not args.use_beam_search: + llm.generate(dummy_prompts, + sampling_params=sampling_params, + use_tqdm=False) + else: + llm.beam_search( + dummy_prompts, + BeamSearchParams( + beam_width=args.n, + max_tokens=args.output_len, + ignore_eos=True, + ), + ) + + def run_to_completion(profile_dir: Optional[str] = None): + if profile_dir: + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + on_trace_ready=torch.profiler.tensorboard_trace_handler( + str(profile_dir)), + ) as p: + llm_generate() + print(p.key_averages().table(sort_by="self_cuda_time_total")) + else: + start_time = time.perf_counter() + llm_generate() + end_time = time.perf_counter() + latency = end_time - start_time + return latency + + print("Warming up...") + for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"): + run_to_completion(profile_dir=None) + + if args.profile: + profile_dir = args.profile_result_dir + if not profile_dir: + profile_dir = (Path(".") / "vllm_benchmark_result" / + f"latency_result_{time.time()}") + print(f"Profiling (results will be saved to '{profile_dir}')...") + run_to_completion(profile_dir=profile_dir) + return + + # Benchmark. + latencies = [] + for _ in tqdm(range(args.num_iters), desc="Profiling iterations"): + latencies.append(run_to_completion(profile_dir=None)) + latencies = np.array(latencies) + percentages = [10, 25, 50, 75, 90, 99] + percentiles = np.percentile(latencies, percentages) + print(f"Avg latency: {np.mean(latencies)} seconds") + for percentage, percentile in zip(percentages, percentiles): + print(f"{percentage}% percentile latency: {percentile} seconds") + + # Output JSON results if specified + if args.output_json: + results = { + "avg_latency": np.mean(latencies), + "latencies": latencies.tolist(), + "percentiles": dict(zip(percentages, percentiles.tolist())), + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + save_to_pytorch_benchmark_format(args, results) diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py new file mode 100644 index 0000000000000000000000000000000000000000..b3e24911cc98202c96061c4ae5daa5693517a6b3 --- /dev/null +++ b/vllm/benchmarks/throughput.py @@ -0,0 +1,608 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Benchmark offline inference throughput.""" +import argparse +import dataclasses +import json +import os +import random +import time +import warnings +from typing import Any, Optional, Union + +import torch +import uvloop +from tqdm import tqdm +from transformers import (AutoModelForCausalLM, AutoTokenizer, + PreTrainedTokenizerBase) + +from vllm.benchmarks.datasets import (AIMODataset, BurstGPTDataset, + ConversationDataset, + InstructCoderDataset, RandomDataset, + SampleRequest, ShareGPTDataset, + SonnetDataset, VisionArenaDataset) +from vllm.benchmarks.utils import (convert_to_pytorch_benchmark_format, + write_to_json) +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs +from vllm.entrypoints.openai.api_server import ( + build_async_engine_client_from_engine_args) +from vllm.inputs import TextPrompt, TokensPrompt +from vllm.lora.request import LoRARequest +from vllm.outputs import RequestOutput +from vllm.sampling_params import BeamSearchParams +from vllm.utils import merge_async_iterators + + +def run_vllm( + requests: list[SampleRequest], + n: int, + engine_args: EngineArgs, + disable_detokenize: bool = False, +) -> tuple[float, Optional[list[RequestOutput]]]: + from vllm import LLM, SamplingParams + llm = LLM(**dataclasses.asdict(engine_args)) + assert all( + llm.llm_engine.model_config.max_model_len >= ( + request.prompt_len + request.expected_output_len) + for request in requests), ( + "Please ensure that max_model_len is greater than the sum of" + " prompt_len and expected_output_len for all requests.") + # Add the requests to the engine. + prompts: list[Union[TextPrompt, TokensPrompt]] = [] + sampling_params: list[SamplingParams] = [] + for request in requests: + prompts.append( + TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], + multi_modal_data=request.multi_modal_data) + if "prompt_token_ids" in request.prompt else \ + TextPrompt(prompt=request.prompt, + multi_modal_data=request.multi_modal_data)) + sampling_params.append( + SamplingParams( + n=n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=request.expected_output_len, + detokenize=not disable_detokenize, + )) + lora_requests: Optional[list[LoRARequest]] = None + if engine_args.enable_lora: + lora_requests = [request.lora_request for request in requests] + + use_beam_search = False + + outputs = None + if not use_beam_search: + start = time.perf_counter() + outputs = llm.generate(prompts, + sampling_params, + lora_request=lora_requests, + use_tqdm=True) + end = time.perf_counter() + else: + assert lora_requests is None, "BeamSearch API does not support LoRA" + prompts = [request.prompt for request in requests] + # output_len should be the same for all requests. + output_len = requests[0][2] + for request in requests: + assert request.expected_output_len == output_len + start = time.perf_counter() + llm.beam_search( + prompts, + BeamSearchParams( + beam_width=n, + max_tokens=output_len, + ignore_eos=True, + )) + end = time.perf_counter() + return end - start, outputs + + +def run_vllm_chat( + requests: list[SampleRequest], + n: int, + engine_args: EngineArgs, + disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]: + """ + Run vLLM chat benchmark. This function is recommended ONLY for benchmarking + multimodal models as it properly handles multimodal inputs and chat + formatting. For non-multimodal models, use run_vllm() instead. + """ + from vllm import LLM, SamplingParams + llm = LLM(**dataclasses.asdict(engine_args)) + + assert all( + llm.llm_engine.model_config.max_model_len >= ( + request.prompt_len + request.expected_output_len) + for request in requests), ( + "Please ensure that max_model_len is greater than the sum of " + "prompt_len and expected_output_len for all requests.") + + prompts = [] + sampling_params: list[SamplingParams] = [] + for request in requests: + prompts.append(request.prompt) + sampling_params.append( + SamplingParams( + n=n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=request.expected_output_len, + detokenize=not disable_detokenize, + )) + start = time.perf_counter() + outputs = llm.chat(prompts, sampling_params, use_tqdm=True) + end = time.perf_counter() + return end - start, outputs + + +async def run_vllm_async( + requests: list[SampleRequest], + n: int, + engine_args: AsyncEngineArgs, + disable_frontend_multiprocessing: bool = False, + disable_detokenize: bool = False, +) -> float: + from vllm import SamplingParams + + async with build_async_engine_client_from_engine_args( + engine_args, disable_frontend_multiprocessing) as llm: + assert all( + llm.model_config.max_model_len >= (request.prompt_len + + request.expected_output_len) + for request in requests), ( + "Please ensure that max_model_len is greater than the sum of" + " prompt_len and expected_output_len for all requests.") + + # Add the requests to the engine. + prompts: list[Union[TextPrompt, TokensPrompt]] = [] + sampling_params: list[SamplingParams] = [] + lora_requests: list[Optional[LoRARequest]] = [] + for request in requests: + prompts.append( + TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], + multi_modal_data=request.multi_modal_data) + if "prompt_token_ids" in request.prompt else \ + TextPrompt(prompt=request.prompt, + multi_modal_data=request.multi_modal_data)) + sampling_params.append( + SamplingParams( + n=n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=request.expected_output_len, + detokenize=not disable_detokenize, + )) + lora_requests.append(request.lora_request) + + generators = [] + start = time.perf_counter() + for i, (prompt, sp, + lr) in enumerate(zip(prompts, sampling_params, lora_requests)): + generator = llm.generate(prompt, + sp, + lora_request=lr, + request_id=f"test{i}") + generators.append(generator) + all_gens = merge_async_iterators(*generators) + async for i, res in all_gens: + pass + end = time.perf_counter() + return end - start + + +def run_hf( + requests: list[SampleRequest], + model: str, + tokenizer: PreTrainedTokenizerBase, + n: int, + max_batch_size: int, + trust_remote_code: bool, + disable_detokenize: bool = False, +) -> float: + llm = AutoModelForCausalLM.from_pretrained( + model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) + if llm.config.model_type == "llama": + # To enable padding in the HF backend. + tokenizer.pad_token = tokenizer.eos_token + llm = llm.cuda() + + pbar = tqdm(total=len(requests)) + start = time.perf_counter() + batch: list[str] = [] + max_prompt_len = 0 + max_output_len = 0 + for i in range(len(requests)): + prompt = requests[i].prompt + prompt_len = requests[i].prompt_len + output_len = requests[i].expected_output_len + # Add the prompt to the batch. + batch.append(prompt) + max_prompt_len = max(max_prompt_len, prompt_len) + max_output_len = max(max_output_len, output_len) + if len(batch) < max_batch_size and i != len(requests) - 1: + # Check if we can add more requests to the batch. + next_prompt_len = requests[i + 1].prompt_len + next_output_len = requests[i + 1].expected_output_len + if (max(max_prompt_len, next_prompt_len) + + max(max_output_len, next_output_len)) <= 2048: + # We can add more requests to the batch. + continue + + # Generate the sequences. + input_ids = tokenizer(batch, return_tensors="pt", + padding=True).input_ids + llm_outputs = llm.generate( + input_ids=input_ids.cuda(), + do_sample=True, + num_return_sequences=n, + temperature=1.0, + top_p=1.0, + use_cache=True, + max_new_tokens=max_output_len, + ) + if not disable_detokenize: + # Include the decoding time. + tokenizer.batch_decode(llm_outputs, skip_special_tokens=True) + pbar.update(len(batch)) + + # Clear the batch. + batch = [] + max_prompt_len = 0 + max_output_len = 0 + end = time.perf_counter() + return end - start + + +def save_to_pytorch_benchmark_format(args: argparse.Namespace, + results: dict[str, Any]) -> None: + pt_records = convert_to_pytorch_benchmark_format( + args=args, + metrics={ + "requests_per_second": [results["requests_per_second"]], + "tokens_per_second": [results["tokens_per_second"]], + }, + extra_info={ + k: results[k] + for k in ["elapsed_time", "num_requests", "total_num_tokens"] + }) + if pt_records: + # Don't use json suffix here as we don't want CI to pick it up + pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" + write_to_json(pt_file, pt_records) + + +def get_requests(args, tokenizer): + # Common parameters for all dataset types. + common_kwargs = { + "dataset_path": args.dataset_path, + "random_seed": args.seed, + } + sample_kwargs = { + "tokenizer": tokenizer, + "lora_path": args.lora_path, + "max_loras": args.max_loras, + "num_requests": args.num_prompts, + "input_len": args.input_len, + "output_len": args.output_len, + } + + if args.dataset_path is None or args.dataset_name == "random": + sample_kwargs["range_ratio"] = args.random_range_ratio + sample_kwargs["prefix_len"] = args.prefix_len + dataset_cls = RandomDataset + elif args.dataset_name == "sharegpt": + dataset_cls = ShareGPTDataset + if args.backend == "vllm-chat": + sample_kwargs["enable_multimodal_chat"] = True + elif args.dataset_name == "sonnet": + assert tokenizer.chat_template or tokenizer.default_chat_template, ( + "Tokenizer/model must have chat template for sonnet dataset.") + dataset_cls = SonnetDataset + sample_kwargs["prefix_len"] = args.prefix_len + sample_kwargs["return_prompt_formatted"] = True + elif args.dataset_name == "burstgpt": + dataset_cls = BurstGPTDataset + elif args.dataset_name == "hf": + if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: + dataset_cls = VisionArenaDataset + common_kwargs['dataset_subset'] = None + common_kwargs['dataset_split'] = "train" + sample_kwargs["enable_multimodal_chat"] = True + elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: + dataset_cls = InstructCoderDataset + common_kwargs['dataset_split'] = "train" + elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: + dataset_cls = ConversationDataset + common_kwargs['dataset_subset'] = args.hf_subset + common_kwargs['dataset_split'] = args.hf_split + sample_kwargs["enable_multimodal_chat"] = True + elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: + dataset_cls = AIMODataset + common_kwargs['dataset_subset'] = None + common_kwargs['dataset_split'] = "train" + else: + raise ValueError(f"Unknown dataset name: {args.dataset_name}") + # Remove None values + sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None} + return dataset_cls(**common_kwargs).sample(**sample_kwargs) + + +def validate_args(args): + """ + Validate command-line arguments. + """ + + # === Deprecation and Defaulting === + if args.dataset is not None: + warnings.warn( + "The '--dataset' argument will be deprecated in the next release. " + "Please use '--dataset-name' and '--dataset-path' instead.", + stacklevel=2) + args.dataset_path = args.dataset + + if not getattr(args, "tokenizer", None): + args.tokenizer = args.model + + # === Backend Validation === + valid_backends = {"vllm", "hf", "mii", "vllm-chat"} + if args.backend not in valid_backends: + raise ValueError(f"Unsupported backend: {args.backend}") + + # === Dataset Configuration === + if not args.dataset and not args.dataset_path: + print( + "When dataset path is not set, it will default to random dataset") + args.dataset_name = 'random' + if args.input_len is None: + raise ValueError("input_len must be provided for a random dataset") + + # === Dataset Name Specific Checks === + # --hf-subset and --hf-split: only used + # when dataset_name is 'hf' + if args.dataset_name != "hf" and ( + getattr(args, "hf_subset", None) is not None + or getattr(args, "hf_split", None) is not None): + warnings.warn("--hf-subset and --hf-split will be ignored \ + since --dataset-name is not 'hf'.", + stacklevel=2) + elif args.dataset_name == "hf": + if args.dataset_path in ( + VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys() + | ConversationDataset.SUPPORTED_DATASET_PATHS): + assert args.backend == "vllm-chat", f"{args.dataset_path} needs to use vllm-chat as the backend." #noqa: E501 + elif args.dataset_path in (InstructCoderDataset.SUPPORTED_DATASET_PATHS + | AIMODataset.SUPPORTED_DATASET_PATHS): + assert args.backend == "vllm", f"{args.dataset_path} needs to use vllm as the backend." #noqa: E501 + else: + raise ValueError( + f"{args.dataset_path} is not supported by hf dataset.") + + # --random-range-ratio: only used when dataset_name is 'random' + if args.dataset_name != 'random' and args.random_range_ratio is not None: + warnings.warn("--random-range-ratio will be ignored since \ + --dataset-name is not 'random'.", + stacklevel=2) + + # --prefix-len: only used when dataset_name is 'random', 'sonnet', or not + # set. + if args.dataset_name not in {"random", "sonnet", None + } and args.prefix_len is not None: + warnings.warn("--prefix-len will be ignored since --dataset-name\ + is not 'random', 'sonnet', or not set.", + stacklevel=2) + + # === LoRA Settings === + if getattr(args, "enable_lora", False) and args.backend != "vllm": + raise ValueError( + "LoRA benchmarking is only supported for vLLM backend") + if getattr(args, "enable_lora", False) and args.lora_path is None: + raise ValueError("LoRA path must be provided when enable_lora is True") + + # === Backend-specific Validations === + if args.backend == "hf" and args.hf_max_batch_size is None: + raise ValueError("HF max batch size is required for HF backend") + if args.backend != "hf" and args.hf_max_batch_size is not None: + raise ValueError("HF max batch size is only for HF backend.") + + if args.backend in {"hf", "mii"} and getattr(args, "quantization", + None) is not None: + raise ValueError("Quantization is only for vLLM backend.") + + if args.backend == "mii" and args.dtype != "auto": + raise ValueError("dtype must be auto for MII backend.") + if args.backend == "mii" and args.n != 1: + raise ValueError("n must be 1 for MII backend.") + if args.backend == "mii" and args.tokenizer != args.model: + raise ValueError( + "Tokenizer must be the same as the model for MII backend.") + + +def add_cli_args(parser: argparse.ArgumentParser): + parser.add_argument("--backend", + type=str, + choices=["vllm", "hf", "mii", "vllm-chat"], + default="vllm") + parser.add_argument( + "--dataset-name", + type=str, + choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"], + help="Name of the dataset to benchmark on.", + default="sharegpt") + parser.add_argument( + "--dataset", + type=str, + default=None, + help="Path to the ShareGPT dataset, will be deprecated in\ + the next release. The dataset is expected to " + "be a json in form of list[dict[..., conversations: " + "list[dict[..., value: ]]]]") + parser.add_argument("--dataset-path", + type=str, + default=None, + help="Path to the dataset") + parser.add_argument("--input-len", + type=int, + default=None, + help="Input prompt length for each request") + parser.add_argument("--output-len", + type=int, + default=None, + help="Output length for each request. Overrides the " + "output length from the dataset.") + parser.add_argument("--n", + type=int, + default=1, + help="Number of generated sequences per prompt.") + parser.add_argument("--num-prompts", + type=int, + default=1000, + help="Number of prompts to process.") + parser.add_argument("--hf-max-batch-size", + type=int, + default=None, + help="Maximum batch size for HF backend.") + parser.add_argument( + '--output-json', + type=str, + default=None, + help='Path to save the throughput results in JSON format.') + parser.add_argument("--async-engine", + action='store_true', + default=False, + help="Use vLLM async engine rather than LLM class.") + parser.add_argument("--disable-frontend-multiprocessing", + action='store_true', + default=False, + help="Disable decoupled async engine frontend.") + parser.add_argument( + "--disable-detokenize", + action="store_true", + help=("Do not detokenize the response (i.e. do not include " + "detokenization time in the measurement)")) + # LoRA + parser.add_argument( + "--lora-path", + type=str, + default=None, + help="Path to the lora adapters to use. This can be an absolute path, " + "a relative path, or a Hugging Face model identifier.") + parser.add_argument( + "--prefix-len", + type=int, + default=0, + help="Number of fixed prefix tokens before the random " + "context in a request (default: 0).", + ) + # random dataset + parser.add_argument( + "--random-range-ratio", + type=float, + default=0.0, + help="Range ratio for sampling input/output length, " + "used only for RandomDataset. Must be in the range [0, 1) to define " + "a symmetric sampling range " + "[length * (1 - range_ratio), length * (1 + range_ratio)].", + ) + + # hf dtaset + parser.add_argument("--hf-subset", + type=str, + default=None, + help="Subset of the HF dataset.") + parser.add_argument("--hf-split", + type=str, + default=None, + help="Split of the HF dataset.") + + parser = AsyncEngineArgs.add_cli_args(parser) + + +def main(args: argparse.Namespace): + if args.tokenizer is None: + args.tokenizer = args.model + validate_args(args) + if args.seed is None: + args.seed = 0 + print(args) + random.seed(args.seed) + # Sample the requests. + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer, trust_remote_code=args.trust_remote_code) + requests = get_requests(args, tokenizer) + is_multi_modal = any(request.multi_modal_data is not None + for request in requests) + request_outputs: Optional[list[RequestOutput]] = None + if args.backend == "vllm": + if args.async_engine: + elapsed_time = uvloop.run( + run_vllm_async( + requests, + args.n, + AsyncEngineArgs.from_cli_args(args), + args.disable_frontend_multiprocessing, + args.disable_detokenize, + )) + else: + elapsed_time, request_outputs = run_vllm( + requests, args.n, EngineArgs.from_cli_args(args), + args.disable_detokenize) + elif args.backend == "hf": + assert args.tensor_parallel_size == 1 + elapsed_time = run_hf(requests, args.model, tokenizer, args.n, + args.hf_max_batch_size, args.trust_remote_code, + args.disable_detokenize) + elif args.backend == "vllm-chat": + elapsed_time, request_outputs = run_vllm_chat( + requests, args.n, EngineArgs.from_cli_args(args), + args.disable_detokenize) + else: + raise ValueError(f"Unknown backend: {args.backend}") + + if request_outputs: + # Note: with the vllm and vllm-chat backends, + # we have request_outputs, which we use to count tokens. + total_prompt_tokens = 0 + total_output_tokens = 0 + for ro in request_outputs: + if not isinstance(ro, RequestOutput): + continue + total_prompt_tokens += len( + ro.prompt_token_ids) if ro.prompt_token_ids else 0 + total_output_tokens += sum( + len(o.token_ids) for o in ro.outputs if o) + total_num_tokens = total_prompt_tokens + total_output_tokens + else: + total_num_tokens = sum(r.prompt_len + r.expected_output_len + for r in requests) + total_output_tokens = sum(r.expected_output_len for r in requests) + total_prompt_tokens = total_num_tokens - total_output_tokens + + if is_multi_modal and args.backend != "vllm-chat": + print("\033[91mWARNING\033[0m: Multi-modal request with " + f"{args.backend} backend detected. The " + "following metrics are not accurate because image tokens are not" + " counted. See vllm-project/vllm/issues/9778 for details.") + # TODO(vllm-project/vllm/issues/9778): Count multi-modal token length. + # vllm-chat backend counts the image tokens now + + print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " + f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " + f"{total_output_tokens / elapsed_time:.2f} output tokens/s") + print(f"Total num prompt tokens: {total_prompt_tokens}") + print(f"Total num output tokens: {total_output_tokens}") + + # Output JSON results if specified + if args.output_json: + results = { + "elapsed_time": elapsed_time, + "num_requests": len(requests), + "total_num_tokens": total_num_tokens, + "requests_per_second": len(requests) / elapsed_time, + "tokens_per_second": total_num_tokens / elapsed_time, + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + save_to_pytorch_benchmark_format(args, results) diff --git a/collect_env.py b/vllm/collect_env.py similarity index 96% rename from collect_env.py rename to vllm/collect_env.py index e11271a13640ac4208299f60cdeb137a5508a1d7..9cfceb7c45cc53f49729aab7951b57929acdfb6f 100644 --- a/collect_env.py +++ b/vllm/collect_env.py @@ -282,13 +282,21 @@ def get_vllm_version(): if __version__ == "dev": return "N/A (dev)" - - if len(__version_tuple__) == 4: # dev build - git_sha = __version_tuple__[-1][1:] # type: ignore - return f"{__version__} (git sha: {git_sha}" - + version_str = __version_tuple__[-1] + if isinstance(version_str, str) and version_str.startswith('g'): + # it's a dev build + if '.' in version_str: + # it's a dev build containing local changes + git_sha = version_str.split('.')[0][1:] + date = version_str.split('.')[-1][1:] + return f"{__version__} (git sha: {git_sha}, date: {date})" + else: + # it's a dev build without local changes + git_sha = version_str[1:] # type: ignore + return f"{__version__} (git sha: {git_sha})" return __version__ + def summarize_vllm_build_flags(): # This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc. return 'CUDA Archs: {}; ROCm: {}; Neuron: {}'.format( @@ -502,7 +510,9 @@ def get_pip_packages(run_lambda, patterns=None): print("uv is set") cmd = ["uv", "pip", "list", "--format=freeze"] else: - raise RuntimeError("Could not collect pip list output (pip or uv module not available)") + raise RuntimeError( + "Could not collect pip list output (pip or uv module not available)" + ) out = run_and_read_all(run_lambda, cmd) return "\n".join(line for line in out.splitlines() @@ -535,13 +545,12 @@ def is_xnnpack_available(): else: return "N/A" + def get_env_vars(): env_vars = '' - secret_terms=('secret', 'token', 'api', 'access', 'password') - report_prefix = ("TORCH", "NCCL", "PYTORCH", - "CUDA", "CUBLAS", "CUDNN", - "OMP_", "MKL_", - "NVIDIA") + secret_terms = ('secret', 'token', 'api', 'access', 'password') + report_prefix = ("TORCH", "NCCL", "PYTORCH", "CUDA", "CUBLAS", "CUDNN", + "OMP_", "MKL_", "NVIDIA") for k, v in os.environ.items(): if any(term in k.lower() for term in secret_terms): continue @@ -552,6 +561,7 @@ def get_env_vars(): return env_vars + def get_env_info(): run_lambda = run pip_version, pip_list_output = get_pip_packages(run_lambda) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 45988c2e9b0d462b83c61ba491293b86826b25f5..a1d12b5175504afe8e4e8eea5114305e9d79fa86 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -110,10 +110,14 @@ class CompilerManager: compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape) if compiled_graph is not None: - if graph_index == 0: - # adds some info logging for the first graph - logger.info("Directly load the compiled graph for shape %s " - "from the cache", str(runtime_shape)) # noqa + if graph_index == num_graphs - 1: + # after loading the last graph for this shape, record the time. + # there can be multiple graphs due to piecewise compilation. + now = time.time() + elapsed = now - compilation_start_time + logger.info( + "Directly load the compiled graph(s) for shape %s " + "from the cache, took %.3f s", str(runtime_shape), elapsed) return compiled_graph # no compiler cached the graph, or the cache is disabled, @@ -335,7 +339,7 @@ class VllmBackend: def configure_post_pass(self): config = self.compilation_config - self.post_grad_pass_manager.configure(config.pass_config) + self.post_grad_pass_manager.configure(self.vllm_config) # Post-grad custom passes are run using the post_grad_custom_post_pass # hook. If a pass for that hook exists, add it to the pass manager. diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 6c8875916efc3417274d0b506aaae32656d41f7b..c5454ccdcbf7e1abf7c2d96cbfd2557a2dcd0d71 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -11,9 +11,12 @@ import torch import torch._inductor.compile_fx import torch.fx as fx +import vllm.envs as envs from vllm.config import VllmConfig from vllm.utils import is_torch_equal_or_newer +from .inductor_pass import pass_context + class CompilerInterface: """ @@ -167,8 +170,7 @@ class InductorAdaptor(CompilerInterface): compiler_config: Dict[str, Any], runtime_shape: Optional[int] = None ) -> Tuple[Optional[Callable], Optional[Any]]: - from torch._inductor import config - current_config = config.get_config_copy() + current_config = {} from torch._inductor.compile_fx import compile_fx # disable remote cache @@ -196,7 +198,6 @@ class InductorAdaptor(CompilerInterface): hash_str, file_path = None, None from torch._inductor.codecache import (FxGraphCache, compiled_fx_graph_hash) - if torch.__version__.startswith("2.5"): original_load = FxGraphCache.load original_load_name = "torch._inductor.codecache.FxGraphCache.load" @@ -281,6 +282,16 @@ class InductorAdaptor(CompilerInterface): patch("torch._inductor.codecache.FxGraphCache._get_shape_env", _get_shape_env)) + from torch._functorch._aot_autograd.autograd_cache import ( + AOTAutogradCache) + + # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache + if hasattr(AOTAutogradCache, "_get_shape_env"): + stack.enter_context( + patch( + "torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env", + _get_shape_env)) + # for forcing the graph to be cached stack.enter_context( patch( @@ -290,16 +301,34 @@ class InductorAdaptor(CompilerInterface): # Dynamo metrics context, see method for more details. stack.enter_context(self.metrics_context()) - compiled_graph = compile_fx( - graph, - example_inputs, - inner_compile=hijacked_compile_fx_inner, - config_patches=current_config) - - assert hash_str is not None, ( - "failed to get the hash of the compiled graph") - assert file_path is not None, ( - "failed to get the file path of the compiled graph") + # Disable remote caching. When these are on, on remote cache-hit, + # the monkey-patched functions never actually get called. + # vLLM today assumes and requires the monkey-patched functions to + # get hit. + # TODO(zou3519): we're going to replace this all with + # standalone_compile sometime. + if is_torch_equal_or_newer("2.6"): + stack.enter_context( + torch._inductor.config.patch(fx_graph_remote_cache=False)) + stack.enter_context( + torch._functorch.config.patch( + enable_remote_autograd_cache=False)) + + with pass_context(runtime_shape): + compiled_graph = compile_fx( + graph, + example_inputs, + inner_compile=hijacked_compile_fx_inner, + config_patches=current_config) + + # We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch + # compilation cache. So turn off the checks if we disable the + # compilation cache. + if not envs.VLLM_DISABLE_COMPILE_CACHE: + assert hash_str is not None, ( + "failed to get the hash of the compiled graph") + assert file_path is not None, ( + "failed to get the file path of the compiled graph") return compiled_graph, (hash_str, file_path) def load(self, @@ -313,11 +342,19 @@ class InductorAdaptor(CompilerInterface): assert isinstance(handle[1], str) hash_str = handle[0] + from torch._functorch._aot_autograd.autograd_cache import ( + AOTAutogradCache) from torch._inductor.codecache import FxGraphCache with ExitStack() as exit_stack: exit_stack.enter_context( patch("torch._inductor.codecache.FxGraphCache._get_shape_env", lambda *args, **kwargs: AlwaysHitShapeEnv())) + # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache + if hasattr(AOTAutogradCache, "_get_shape_env"): + exit_stack.enter_context( + patch( + "torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env", + lambda *args, **kwargs: AlwaysHitShapeEnv())) # Dynamo metrics context, see method for more details. exit_stack.enter_context(self.metrics_context()) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index bed2c465fc7fd5cad1f19ffa75d5d005da094e05..6eaf20821c58549ea5313151ce7477343f5c0210 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -9,7 +9,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import PatternMatcherPass from torch._ops import OpOverload -from vllm.config import CompilationConfig +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform @@ -531,7 +531,7 @@ class FusionPass(VllmInductorPass): _instance: 'Optional[FusionPass]' = None @classmethod - def instance(cls, config: CompilationConfig.PassConfig): + def instance(cls, config: VllmConfig): """ Get the singleton instance of the FusionPass. If the instance exists, the config is updated but @@ -540,10 +540,10 @@ class FusionPass(VllmInductorPass): if cls._instance is None: cls._instance = FusionPass(config) else: - cls._instance.config = config + cls._instance.pass_config = config.compilation_config.pass_config return cls._instance - def __init__(self, config: CompilationConfig.PassConfig): + def __init__(self, config: VllmConfig): assert self.__class__._instance is None, \ "FusionPass singleton instance already exists" super().__init__(config) diff --git a/vllm/compilation/fx_utils.py b/vllm/compilation/fx_utils.py index b9a8d3112e7758fe71756699b57542bb31c34954..f9427e48ac315db81d7d6e809900331898e5d609 100644 --- a/vllm/compilation/fx_utils.py +++ b/vllm/compilation/fx_utils.py @@ -12,6 +12,22 @@ def is_func(node: fx.Node, target) -> bool: return node.op == "call_function" and node.target == target +# Returns the first specified node with the given op (if it exists) +def find_specified_fn_maybe(nodes: Iterable[fx.Node], + op: OpOverload) -> Optional[fx.Node]: + for node in nodes: + if node.target == op: + return node + return None + + +# Returns the first specified node with the given op +def find_specified_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node: + node = find_specified_fn_maybe(nodes, op) + assert node is not None, f"Could not find {op} in nodes {nodes}" + return node + + # Returns the first auto_functionalized node with the given op (if it exists) def find_auto_fn_maybe(nodes: Iterable[fx.Node], op: OpOverload) -> Optional[fx.Node]: diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 00a2e89f21aebd089ba9170f337af10e14554400..6cd7720fca2f91456ec2c0ac308aee69dea5a859 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -4,6 +4,7 @@ import hashlib import inspect import json import types +from contextlib import contextmanager from typing import Any, Callable, Dict, Optional, Union import torch @@ -18,6 +19,34 @@ else: from .torch25_custom_graph_pass import ( # noqa: yapf Torch25CustomGraphPass as CustomGraphPass) +_pass_context = None + + +class PassContext: + + def __init__(self, runtime_shape: Optional[int]): + self.runtime_shape = runtime_shape + + +def get_pass_context() -> PassContext: + """Get the current pass context.""" + assert _pass_context is not None + return _pass_context + + +@contextmanager +def pass_context(runtime_shape: Optional[int]): + """A context manager that stores the current pass context, + usually it is a list of sizes to specialize. + """ + global _pass_context + prev_context = _pass_context + _pass_context = PassContext(runtime_shape) + try: + yield + finally: + _pass_context = prev_context + class InductorPass(CustomGraphPass): """ @@ -62,6 +91,9 @@ class InductorPass(CustomGraphPass): encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") return hashlib.sha256(encoded).hexdigest() + def is_applicable_for_shape(self, shape: Optional[int]): + return True + class CallableInductorPass(InductorPass): """ diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 530a88b2b09aeaf1805f010b980477facda92e10..f8e8c4971cbb6c5e3b34cd27cfe30eca930c21fc 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -4,13 +4,15 @@ from typing import List from torch import fx as fx -from vllm.config import CompilationConfig +from vllm.config import VllmConfig from vllm.logger import init_logger from .fix_functionalization import FixFunctionalizationPass from .fusion import FusionPass -from .inductor_pass import CustomGraphPass, InductorPass +from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context from .noop_elimination import NoOpEliminationPass +from .sequence_parallelism import SequenceParallelismPass +from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) @@ -31,24 +33,29 @@ class PostGradPassManager(CustomGraphPass): """ def __init__(self): - self.passes: List[InductorPass] = [] + self.passes: List[VllmInductorPass] = [] def __call__(self, graph: fx.Graph): + shape = get_pass_context().runtime_shape for pass_ in self.passes: - pass_(graph) + if pass_.is_applicable_for_shape(shape): + pass_(graph) # always run fix_functionalization last self.fix_functionalization(graph) - def configure(self, pass_config: CompilationConfig.PassConfig): - self.pass_config = pass_config - if pass_config.enable_noop: - self.passes += [NoOpEliminationPass(pass_config)] + def configure(self, config: VllmConfig): + self.pass_config = config.compilation_config.pass_config + if self.pass_config.enable_noop: + self.passes += [NoOpEliminationPass(config)] - if pass_config.enable_fusion: - self.passes += [FusionPass.instance(pass_config)] + if self.pass_config.enable_fusion: + self.passes += [FusionPass.instance(config)] - self.fix_functionalization = FixFunctionalizationPass(pass_config) + if self.pass_config.enable_sequence_parallelism: + self.passes += [SequenceParallelismPass(config)] + + self.fix_functionalization = FixFunctionalizationPass(config) def add(self, pass_: InductorPass): assert isinstance(pass_, InductorPass) diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py new file mode 100644 index 0000000000000000000000000000000000000000..95db63d34f7eab4befe5335f0ec624c1ba85deb3 --- /dev/null +++ b/vllm/compilation/sequence_parallelism.py @@ -0,0 +1,266 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + +import torch +import torch._inductor.pattern_matcher as pm +import torch.fx as fx +from torch._inductor.pattern_matcher import PatternMatcherPass + +from vllm.config import VllmConfig +from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_world_size) +from vllm.logger import init_logger + +from .vllm_inductor_pass import VllmInductorPass + +logger = init_logger(__name__) + + +class AllReduceRMSNormPattern: + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str): + self.epsilon = epsilon + self.dtype = dtype + self.device = device + + +class EmbeddingAllReduceRMSNormPattern(AllReduceRMSNormPattern): + + def get_inputs(self): + arg2_1 = torch.empty([16, 4], device=self.device, dtype=self.dtype) + mul_6 = torch.tensor([[3, 7, 1, 4, 9, 2, 5, 0]], + device=self.device, + dtype=torch.long) + unsqueeze = torch.rand([1, 8, 1], device=self.device, \ + dtype=self.dtype) > 0.5 + full_default = torch.zeros([1, 8, 4], device=self.device, \ + dtype=self.dtype) + permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) + arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype) + + return [arg2_1, mul_6, unsqueeze, full_default, permute, arg3_1] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + arg2_1: torch.Tensor, + mul_6: torch.Tensor, + unsqueeze: torch.Tensor, + full_default: torch.Tensor, + permute: torch.Tensor, + arg3_1: torch.Tensor, + ): + embedding = torch.ops.aten.embedding.default(arg2_1, mul_6) + where = torch.ops.aten.where.self(unsqueeze, full_default, + embedding) + all_reduce = tensor_model_parallel_all_reduce(where) + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.rms_norm.default, + result=permute, + input=all_reduce, + weight=arg3_1, + epsilon=self.epsilon, + ) + + return rmsnorm[1], all_reduce + + def replacement( + arg2_1: torch.Tensor, + mul_6: torch.Tensor, + unsqueeze: torch.Tensor, + full_default: torch.Tensor, + permute: torch.Tensor, + arg3_1: torch.Tensor, + ): + embedding = torch.ops.aten.embedding.default(arg2_1, mul_6) + where = torch.ops.aten.where.self(unsqueeze, full_default, + embedding) + + tp = get_tp_group() + tp_size = get_tensor_model_parallel_world_size() + reduce_scatter = torch.ops.vllm.reduce_scatter.default( + where, dim=0, world_size=tp_size, group_name=tp.unique_name) + + rmsnorm_result = torch.empty_like(reduce_scatter) + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.rms_norm.default, + result=rmsnorm_result, + input=reduce_scatter, + weight=arg3_1, + epsilon=self.epsilon, + ) + + all_gather = torch.ops.vllm.all_gather.default( + rmsnorm[1], + dim=0, + world_size=tp_size, + group_name=tp.unique_name) + + return all_gather, reduce_scatter + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern): + + def get_inputs(self): + mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) + rms_norm_weights = torch.empty([4, 4], + device=self.device, + dtype=self.dtype) + + return [ + residual, + mm_1, + rms_norm_weights, + ] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + all_reduce = tensor_model_parallel_all_reduce(mm_1) + + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=all_reduce, + residual=residual, + weight=rms_norm_weights, + epsilon=self.epsilon, + ) + + return rmsnorm[1], rmsnorm[2] + + def replacement( + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + tp = get_tp_group() + tp_size = get_tensor_model_parallel_world_size() + reduce_scatter = torch.ops.vllm.reduce_scatter.default( + mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name) + + # TODO is it possible to extract epsilon from somewhere + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=reduce_scatter, + residual=residual, + weight=rms_norm_weights, + epsilon=self.epsilon, + ) + + all_gather = torch.ops.vllm.all_gather.default( + rmsnorm[1], + dim=0, + world_size=tp_size, + group_name=tp.unique_name) + return all_gather, rmsnorm[2] + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern): + + def get_inputs(self): + mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) + rms_norm_weights = torch.empty([4, 4], + device=self.device, + dtype=self.dtype) + + return [ + residual, + mm_1, + rms_norm_weights, + ] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + all_reduce = tensor_model_parallel_all_reduce(mm_1) + + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=all_reduce, + residual=residual, + weight=rms_norm_weights, + epsilon=self.epsilon, + ) + + return rmsnorm[1] + + def replacement( + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + tp = get_tp_group() + tp_size = get_tensor_model_parallel_world_size() + reduce_scatter = torch.ops.vllm.reduce_scatter.default( + mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name) + + # TODO is it possible to extract epsilon from somewhere + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=reduce_scatter, + residual=residual, + weight=rms_norm_weights, + epsilon=self.epsilon, + ) + + normalized = torch.ops.vllm.all_gather.default( + rmsnorm[1], + dim=0, + world_size=tp_size, + group_name=tp.unique_name) + + return normalized + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class SequenceParallelismPass(VllmInductorPass): + + def __init__(self, config: VllmConfig): + super().__init__(config) + + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="sequence_parallelism_pass") + for epsilon in [1e-5, 1e-6]: + EmbeddingAllReduceRMSNormPattern( + epsilon, self.dtype, self.device).register(self.patterns) + + MiddleAllReduceRMSNormPattern(epsilon, self.dtype, + self.device).register(self.patterns) + + LastAllReduceRMSNormPattern(epsilon, self.dtype, + self.device).register(self.patterns) + # WARNING: This is a hack to clear the pattern matcher cache + # and allow multiple values of epsilon. + torch._inductor.pattern_matcher._seen_patterns.clear() + + def is_applicable_for_shape(self, shape: Optional[int]) -> bool: + # only do replace for specific shapes + tp_size = get_tensor_model_parallel_world_size() + return shape is not None and shape % tp_size == 0 + + def __call__(self, graph: fx.Graph): + self.dump_graph(graph, "before_sequence_parallelism_pass") + count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", count) + self.dump_graph(graph, "after_sequence_parallelism_pass") diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py index 98ed6f1472a4570801e1b0b96420a3bd772d3102..e8bffb406f148bb09134a5f1f44e01d5ded1603e 100644 --- a/vllm/compilation/vllm_inductor_pass.py +++ b/vllm/compilation/vllm_inductor_pass.py @@ -4,7 +4,7 @@ import time import torch -from vllm.config import CompilationConfig +from vllm.config import CompilationConfig, VllmConfig # yapf: disable from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank from vllm.distributed import ( @@ -24,16 +24,19 @@ class VllmInductorPass(InductorPass): It provides timing, logging, and dumping utilities. """ - def __init__(self, config: CompilationConfig.PassConfig): - self.config = config + def __init__(self, config: VllmConfig): + self.pass_config = config.compilation_config.pass_config + self.dtype = config.model_config.dtype if config.model_config else None + self.device = config.device_config.device if config.device_config \ + else None self.pass_name = self.__class__.__name__ def dump_graph(self, graph: torch.fx.Graph, stage: str, always=False): - if stage in self.config.dump_graph_stages or always: + if stage in self.pass_config.dump_graph_stages or always: # Make sure filename includes rank in the distributed setting parallel = p_is_init() and get_tp_world_size() > 1 rank = f"-{get_tp_rank()}" if parallel else "" - filepath = self.config.dump_graph_dir / f"{stage}{rank}.py" + filepath = self.pass_config.dump_graph_dir / f"{stage}{rank}.py" logger.info("%s printing graph to %s", self.pass_name, filepath) with open(filepath, "w") as f: diff --git a/vllm/config.py b/vllm/config.py index 39859c965d2b57ad99d3fa5c1eb0502bf4bc8c3b..ac311743ae0da791498eb78ad35f7803d6b9aaaf 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -6,18 +6,18 @@ import enum import hashlib import inspect import json +import re import sys import textwrap import warnings from collections import Counter -from collections.abc import Mapping from contextlib import contextmanager from dataclasses import (MISSING, dataclass, field, fields, is_dataclass, replace) from importlib.util import find_spec from pathlib import Path -from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal, - Optional, Protocol, TypeVar, Union) +from typing import (TYPE_CHECKING, Any, Callable, List, ClassVar, Final, Literal, + Optional, Protocol, TypeVar, Union, get_args) import torch from pydantic import BaseModel, Field, PrivateAttr @@ -28,6 +28,7 @@ import vllm.envs as envs from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.logger import init_logger from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, + QuantizationMethods, get_quantization_config) from vllm.model_executor.models import ModelRegistry from vllm.platforms import CpuArchEnum, current_platform @@ -52,16 +53,16 @@ if TYPE_CHECKING: from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.model_loader.loader import BaseModelLoader - from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( - BaseTokenizerGroup) - Config = TypeVar("Config", bound=DataclassInstance) + ConfigType = type[DataclassInstance] else: QuantizationConfig = None - Config = TypeVar("Config") + ConfigType = type logger = init_logger(__name__) +ConfigT = TypeVar("ConfigT", bound=ConfigType) + # This value is chosen to have a balance between ITL and TTFT. Note it is # not optimized for throughput. _DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048 @@ -121,7 +122,7 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]: def pairwise(iterable): """ Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise - + Can be removed when Python 3.9 support is dropped. """ iterator = iter(iterable) @@ -163,7 +164,7 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]: return out -def config(cls: type[Config]) -> type[Config]: +def config(cls: ConfigT) -> ConfigT: """ A decorator that ensures all fields in a dataclass have default values and that each field has a docstring. @@ -182,6 +183,23 @@ def config(cls: type[Config]) -> type[Config]: return cls +def get_field(cls: ConfigType, name: str) -> Field: + """Get the default factory field of a dataclass by name. Used for getting + default factory fields in `EngineArgs`.""" + if not is_dataclass(cls): + raise TypeError("The given class is not a dataclass.") + cls_fields = {f.name: f for f in fields(cls)} + if name not in cls_fields: + raise ValueError(f"Field '{name}' not found in {cls.__name__}.") + named_field: Field = cls_fields.get(name) + if (default_factory := named_field.default_factory) is not MISSING: + return field(default_factory=default_factory) + if (default := named_field.default) is not MISSING: + return field(default=default) + raise ValueError( + f"{cls.__name__}.{name} must have a default value or default factory.") + + class ModelConfig: """Configuration for the model. @@ -250,7 +268,7 @@ class ModelConfig: config_format: The config format which shall be loaded. Defaults to 'auto' which defaults to 'hf'. hf_token: The token to use as HTTP bearer authorization for remote files - . If `True`, will use the token generated when running + . If `True`, will use the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). hf_overrides: If a dictionary, contains arguments to be forwarded to the HuggingFace config. If a callable, it is called to update the @@ -298,12 +316,20 @@ class ModelConfig: factors.append(self.quantization) factors.append(self.revision) factors.append(self.code_revision) + factors.append(self.max_model_len) + factors.append(self.max_logprobs) + factors.append(self.disable_sliding_window) factors.append(self.trust_remote_code) + factors.append(self.mm_processor_kwargs) + factors.append(self.generation_config) + factors.append(self.model_impl) + factors.append(self.override_generation_config) factors.append(self.rope_scaling) factors.append(self.rope_theta) - # rope cos/sin cache depends on the max_position_embeddings - factors.append( - getattr(self.hf_config, "max_position_embeddings", "None")) + # hf_config can control how the model looks! + factors.append(self.hf_config.to_json_string()) + str_factors = str(factors) + assert_hashable(str_factors) return hashlib.sha256(str(factors).encode()).hexdigest() def __init__( @@ -332,7 +358,7 @@ class ModelConfig: disable_cascade_attn: bool = False, skip_tokenizer_init: bool = False, served_model_name: Optional[Union[str, list[str]]] = None, - limit_mm_per_prompt: Optional[Mapping[str, int]] = None, + limit_mm_per_prompt: Optional[dict[str, int]] = None, use_async_output_proc: bool = True, config_format: ConfigFormat = ConfigFormat.AUTO, hf_token: Optional[Union[bool, str]] = None, @@ -417,8 +443,10 @@ class ModelConfig: from vllm.platforms import current_platform - if self.enable_sleep_mode and not current_platform.is_cuda(): - raise ValueError("Sleep mode is only supported on CUDA devices.") + if (self.enable_sleep_mode + and not current_platform.is_sleep_mode_available()): + raise ValueError( + "Sleep mode is not supported on current platform.") hf_config = get_config(self.hf_config_path or self.model, trust_remote_code, revision, code_revision, @@ -553,7 +581,7 @@ class ModelConfig: self.tokenizer = s3_tokenizer.dir def _init_multimodal_config( - self, limit_mm_per_prompt: Optional[Mapping[str, int]] + self, limit_mm_per_prompt: Optional[dict[str, int]] ) -> Optional["MultiModalConfig"]: if self.registry.is_multimodal_model(self.architectures): return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {}) @@ -725,8 +753,8 @@ class ModelConfig: supported_quantization = QUANTIZATION_METHODS optimized_quantization_methods = [ "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", - "awq_marlin", "fbgemm_fp8", "compressed_tensors", - "compressed-tensors", "experts_int8", "quark", "nvfp4" + "awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8", + "quark", "nvfp4", "bitblas", "gptq_bitblas" ] if self.quantization is not None: self.quantization = self.quantization.lower() @@ -736,13 +764,47 @@ class ModelConfig: if quant_cfg is not None: quant_method = quant_cfg.get("quant_method", "").lower() + quant_method = quant_method.replace("compressed_tensors", + "compressed-tensors") + quant_cfg["quant_method"] = quant_method + + # Quantization methods which are overrides (i.e. they have a + # `override_quantization_method` method) must be checked in order + # of preference (this is particularly important for GPTQ). + overrides = [ + "marlin", + "bitblas", + "gptq_marlin_24", + "gptq_marlin", + "gptq_bitblas", + "awq_marlin", + "ipex", + "moe_wna16", + ] + quantization_methods = [ + q for q in supported_quantization if q not in overrides + ] + # Any custom overrides will be in quantization_methods so we place + # them at the start of the list so custom overrides have preference + # over the built in ones. + quantization_methods = quantization_methods + overrides # Detect which checkpoint is it - for name in QUANTIZATION_METHODS: + for name in quantization_methods: method = get_quantization_config(name) quantization_override = method.override_quantization_method( quant_cfg, self.quantization) - if quantization_override: + if quantization_override is not None: + # Raise error if the override is not custom (custom would + # be in QUANTIZATION_METHODS but not QuantizationMethods) + # and hasn't been added to the overrides list. + if (name in get_args(QuantizationMethods) + and name not in overrides): + raise ValueError( + f"Quantization method {name} is an override but " + "is has not been added to the `overrides` list " + "above. This is necessary to ensure that the " + "overrides are checked in order of preference.") quant_method = quantization_override self.quantization = quantization_override break @@ -1220,23 +1282,78 @@ class ModelConfig: return (hasattr(self.hf_config, "matryoshka_dimensions") or getattr(self.hf_config, "is_matryoshka", False)) + @property + def matryoshka_dimensions(self): + return getattr(self.hf_config, "matryoshka_dimensions", None) -class CacheConfig: - """Configuration for the KV cache. - Args: - block_size: Size of a cache block in number of tokens. - gpu_memory_utilization: Fraction of GPU memory to use for the - vLLM execution. - swap_space: Size of the CPU swap space per GPU (in GiB). - cache_dtype: Data type for kv cache storage. - is_attention_free: Whether the model is attention-free. - num_gpu_blocks_override: Number of GPU blocks to use. This overrides the - profiled num_gpu_blocks if specified. Does nothing if None. - sliding_window: Sliding window size for the KV cache. - enable_prefix_caching: Whether to enable prefix caching. - cpu_offload_gb: Size of the CPU offload buffer in GiB. +BlockSize = Literal[1, 8, 16, 32, 64, 128] +CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"] +PrefixCachingHashAlgo = Literal["builtin", "sha256"] + + +@config +@dataclass +class CacheConfig: + """Configuration for the KV cache.""" + + block_size: BlockSize = None # type: ignore + """Size of a contiguous cache block in number of tokens. This is ignored on + neuron devices and set to `--max-model-len`. On CUDA devices, only block + sizes up to 32 are supported. On HPU devices, block size defaults to 128. + + This config has no static default. If left unspecified by the user, it will + be set in `Platform.check_and_update_configs()` based on the current + platform.""" + gpu_memory_utilization: float = 0.9 + """The fraction of GPU memory to be used for the model executor, which can + range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory + utilization. If unspecified, will use the default value of 0.9. This is a + per-instance limit, and only applies to the current vLLM instance. It does + not matter if you have another vLLM instance running on the same GPU. For + example, if you have two vLLM instances running on the same GPU, you can + set the GPU memory utilization to 0.5 for each instance.""" + swap_space: float = 4 + """Size of the CPU swap space per GPU (in GiB).""" + cache_dtype: CacheDType = "auto" + """Data type for kv cache storage. If "auto", will use model data type. + CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports + fp8 (=fp8_e4m3).""" + is_attention_free: bool = False + """Whether the model is attention-free. This is primarily set in + `ModelConfig` and that value should be manually duplicated here.""" + num_gpu_blocks_override: Optional[int] = None + """Number of GPU blocks to use. This overrides the profiled `num_gpu_blocks` + if specified. Does nothing if `None`. Used for testing preemption.""" + sliding_window: Optional[int] = None + """Sliding window size for the KV cache. This is primarily set in + `ModelConfig` and that value should be manually duplicated here.""" + enable_prefix_caching: Optional[bool] = None + """Whether to enable prefix caching. Disabled by default for V0. Enabled by + default for V1.""" + prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin" + """Set the hash algorithm for prefix caching:\n + - "builtin" is Python's built-in hash.\n + - "sha256" is collision resistant but with certain overheads.""" + cpu_offload_gb: float = 0 + """The space in GiB to offload to CPU, per GPU. Default is 0, which means + no offloading. Intuitively, this argument can be seen as a virtual way to + increase the GPU memory size. For example, if you have one 24 GB GPU and + set this to 10, virtually you can think of it as a 34 GB GPU. Then you can + load a 13B model with BF16 weight, which requires at least 26GB GPU memory. + Note that this requires fast CPU-GPU interconnect, as part of the model is + loaded from CPU memory to GPU memory on the fly in each model forward pass. """ + calculate_kv_scales: bool = False + """This enables dynamic calculation of `k_scale` and `v_scale` when + kv_cache_dtype is fp8. If `False`, the scales will be loaded from the model + checkpoint if available. Otherwise, the scales will default to 1.0.""" + + # Will be set after profiling. + num_gpu_blocks: Optional[int] = field(default=None, init=False) + """The number of blocks to allocate for GPU memory.""" + num_cpu_blocks: Optional[int] = field(default=None, init=False) + """The number of blocks to allocate for CPU memory.""" def compute_hash(self) -> str: """ @@ -1257,43 +1374,13 @@ class CacheConfig: usedforsecurity=False).hexdigest() return hash_str - def __init__( - self, - block_size: int, - gpu_memory_utilization: float, - swap_space: float, - cache_dtype: str, - is_attention_free: bool = False, - num_gpu_blocks_override: Optional[int] = None, - sliding_window: Optional[int] = None, - enable_prefix_caching: bool = False, - prefix_caching_hash_algo: str = "builtin", - cpu_offload_gb: float = 0, - calculate_kv_scales: Optional[bool] = None, - ) -> None: - self.block_size = block_size - self.gpu_memory_utilization = gpu_memory_utilization - self.swap_space_bytes = swap_space * GiB_bytes - self.num_gpu_blocks_override = num_gpu_blocks_override - self.cache_dtype = cache_dtype - self.is_attention_free = is_attention_free - self.sliding_window = sliding_window - self.enable_prefix_caching = enable_prefix_caching - self.prefix_caching_hash_algo = prefix_caching_hash_algo - self.cpu_offload_gb = cpu_offload_gb - self.calculate_kv_scales = calculate_kv_scales + def __post_init__(self) -> None: + self.swap_space_bytes = self.swap_space * GiB_bytes + self._verify_args() self._verify_cache_dtype() self._verify_prefix_caching() - # Will be set after profiling. - self.num_gpu_blocks: Optional[int] = None - self.num_cpu_blocks: Optional[int] = None - - # Set calculate_kv_scales to False if the value is unset. - if self.calculate_kv_scales is None: - self.calculate_kv_scales = False - def metrics_info(self): # convert cache_config to dict(key: str, value: str) for prometheus # metrics info @@ -1312,7 +1399,7 @@ class CacheConfig: def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": pass - elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"): + elif self.cache_dtype in get_args(CacheDType): logger.info( "Using fp8 data type to store kv cache. It reduces the GPU " "memory footprint and boosts the performance. " @@ -1330,12 +1417,12 @@ class CacheConfig: "Prefix caching is not supported with sliding window. " "Run with --disable-sliding-window to use prefix caching.") - if self.enable_prefix_caching and self.prefix_caching_hash_algo not in ( - "builtin", "sha256"): + if (self.enable_prefix_caching and self.prefix_caching_hash_algo + not in get_args(PrefixCachingHashAlgo)): raise ValueError( "Unknown prefix caching hash algorithm: " - f"{self.prefix_caching_hash_algo}. Must be either " - "'builtin' or 'sha256'.") + f"{self.prefix_caching_hash_algo}. Must be one of " + f"{get_args(PrefixCachingHashAlgo)}.") def verify_with_parallel_config( self, @@ -1356,77 +1443,33 @@ class CacheConfig: logger.warning("Possibly too large swap space. %s", msg) +@config @dataclass class TokenizerPoolConfig: - """Configuration for the tokenizer pool. + """This config is deprecated and will be removed in a future release. - Args: - pool_size: Number of tokenizer workers in the pool. - pool_type: Type of the pool. - extra_config: Additional config for the pool. - The way the config will be used depends on the - pool type. + Passing these parameters will have no effect. Please remove them from your + configurations. """ - pool_size: int - pool_type: Union[str, type["BaseTokenizerGroup"]] - extra_config: dict - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. + pool_size: int = 0 + """This parameter is deprecated and will be removed in a future release. + Passing this parameter will have no effect. Please remove it from your + configurations.""" + pool_type: str = "ray" + """This parameter is deprecated and will be removed in a future release. + Passing this parameter will have no effect. Please remove it from your + configurations.""" + extra_config: dict = field(default_factory=dict) + """This parameter is deprecated and will be removed in a future release. + Passing this parameter will have no effect. Please remove it from your + configurations.""" - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - # no factors to consider. - # this config will not affect the computation graph. - factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() - return hash_str - - def __post_init__(self): - if self.pool_type not in ("ray", ) and not isinstance( - self.pool_type, type): - raise ValueError(f"Unknown pool type: {self.pool_type}") - if not isinstance(self.extra_config, dict): - raise ValueError("extra_config must be a dictionary.") - - @classmethod - def create_config( - cls, tokenizer_pool_size: int, - tokenizer_pool_type: Union[str, type["BaseTokenizerGroup"]], - tokenizer_pool_extra_config: Optional[Union[str, dict]] - ) -> Optional["TokenizerPoolConfig"]: - """Create a TokenizerPoolConfig from the given parameters. - - If tokenizer_pool_size is 0, return None. - - Args: - tokenizer_pool_size: Number of tokenizer workers in the pool. - tokenizer_pool_type: Type of the pool. - tokenizer_pool_extra_config: Additional config for the pool. - The way the config will be used depends on the - pool type. This can be a JSON string (will be parsed). - """ - if tokenizer_pool_size: - if isinstance(tokenizer_pool_extra_config, str): - tokenizer_pool_extra_config_parsed = json.loads( - tokenizer_pool_extra_config) - else: - tokenizer_pool_extra_config_parsed = ( - tokenizer_pool_extra_config or {}) - tokenizer_pool_config = cls(tokenizer_pool_size, - tokenizer_pool_type, - tokenizer_pool_extra_config_parsed) - else: - tokenizer_pool_config = None - return tokenizer_pool_config + def __post_init__(self) -> None: + logger.warning_once( + "TokenizerPoolConfig is deprecated and will be removed in a " + "future release. Passing this parameter will have no effect. " + "Please remove it from your configurations.") class LoadFormat(str, enum.Enum): @@ -1441,6 +1484,7 @@ class LoadFormat(str, enum.Enum): BITSANDBYTES = "bitsandbytes" MISTRAL = "mistral" RUNAI_STREAMER = "runai_streamer" + RUNAI_STREAMER_SHARDED = "runai_streamer_sharded" FASTSAFETENSORS = "fastsafetensors" @@ -1475,7 +1519,7 @@ class LoadConfig: download_dir: Optional[str] = None """Directory to download and load the weights, default to the default cache directory of Hugging Face.""" - model_loader_extra_config: Optional[Union[str, dict]] = None + model_loader_extra_config: dict = field(default_factory=dict) """Extra config for model loader. This will be passed to the model loader corresponding to the chosen load_format. This should be a JSON string that will be parsed into a dictionary.""" @@ -1506,10 +1550,6 @@ class LoadConfig: return hash_str def __post_init__(self): - model_loader_extra_config = self.model_loader_extra_config or {} - if isinstance(model_loader_extra_config, str): - self.model_loader_extra_config = json.loads( - model_loader_extra_config) if isinstance(self.load_format, str): load_format = self.load_format.lower() self.load_format = LoadFormat(load_format) @@ -1522,6 +1562,9 @@ class LoadConfig: self.ignore_patterns = ["original/**/*"] +DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"] + + @config @dataclass class ParallelConfig: @@ -1536,8 +1579,21 @@ class ParallelConfig: the product of the tensor parallel size and data parallel size.""" data_parallel_rank: int = 0 """Rank of the data parallel group.""" - data_parallel_rank_local: Optional[int] = None - """Local rank of the data parallel group, defaults to global rank.""" + _data_parallel_rank_local: Optional[int] = field(default=None, init=False) + """Private field to store the local rank of the data parallel group.""" + + @property + def data_parallel_rank_local(self) -> int: + """Local rank of the data parallel group, defaults to global rank.""" + if self._data_parallel_rank_local is None: + return self.data_parallel_rank + return self._data_parallel_rank_local + + @data_parallel_rank_local.setter + def data_parallel_rank_local(self, value: int) -> None: + """Set the local rank of the data parallel group.""" + self._data_parallel_rank_local = value + data_parallel_master_ip: str = "127.0.0.1" """IP of the data parallel master.""" data_parallel_master_port: int = 29500 @@ -1554,8 +1610,8 @@ class ParallelConfig: """Disable the custom all-reduce kernel and fall back to NCCL.""" tokenizer_pool_config: Optional[TokenizerPoolConfig] = None - """Config for the tokenizer pool. If None, will use synchronous - tokenization.""" + """This parameter is deprecated and will be removed in a future release. + Please remove it from your configs""" ray_workers_use_nsight: bool = False """Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.""" @@ -1563,7 +1619,7 @@ class ParallelConfig: placement_group: Optional["PlacementGroup"] = None """ray distributed model workers placement group.""" - distributed_executor_backend: Optional[Union[str, + distributed_executor_backend: Optional[Union[DistributedExecutorBackend, type["ExecutorBase"]]] = None """Backend to use for distributed model workers, either "ray" or "mp" (multiprocessing). If the product @@ -1577,7 +1633,7 @@ class ParallelConfig: """The full name of the worker class to use. If "auto", the worker class will be determined based on the platform.""" sd_worker_cls: str = "auto" - """The full name of the worker class to use for speculative decofing. + """The full name of the worker class to use for speculative decofing. If "auto", the worker class will be determined based on the platform.""" worker_extension_cls: str = "" """The full name of the worker extension class to use. The worker extension @@ -1646,6 +1702,7 @@ class ParallelConfig: factors: list[Any] = [] factors.append(self.pipeline_parallel_size) factors.append(self.tensor_parallel_size) + factors.append(self.enable_expert_parallel) return hashlib.sha256(str(factors).encode()).hexdigest() def __post_init__(self) -> None: @@ -1687,7 +1744,7 @@ class ParallelConfig: # current node and we aren't in a ray placement group. from vllm.executor import ray_utils - backend = "mp" + backend: DistributedExecutorBackend = "mp" ray_found = ray_utils.ray_is_available() if current_platform.is_neuron(): # neuron uses single process to control multiple devices @@ -1754,92 +1811,125 @@ class ParallelConfig: "worker_extension_cls must be a string (qualified class name).") +PreemptionMode = Literal["swap", "recompute"] +SchedulerPolicy = Literal["fcfs", "priority"] + + +@config @dataclass class SchedulerConfig: """Scheduler configuration.""" - runner_type: str = "generate" # The runner type to launch for the model. + runner_type: RunnerType = "generate" + """The runner type to launch for the model.""" + + max_num_batched_tokens: int = None # type: ignore + """Maximum number of tokens to be processed in a single iteration. + + This config has no static default. If left unspecified by the user, it will + be set in `EngineArgs.create_engine_config` based on the usage context.""" - # Maximum number of tokens to be processed in a single iteration. - max_num_batched_tokens: int = field(default=None) # type: ignore + max_num_seqs: int = None # type: ignore + """Maximum number of sequences to be processed in a single iteration. - # Maximum number of sequences to be processed in a single iteration. - max_num_seqs: int = 128 + This config has no static default. If left unspecified by the user, it will + be set in `EngineArgs.create_engine_config` based on the usage context.""" - # Maximum length of a sequence (including prompt and generated text). - max_model_len: int = 8192 + max_model_len: int = None # type: ignore + """Maximum length of a sequence (including prompt and generated text). This + is primarily set in `ModelConfig` and that value should be manually + duplicated here.""" - # Maximum number of sequences that can be partially prefilled concurrently max_num_partial_prefills: int = 1 + """For chunked prefill, the maximum number of sequences that can be + partially prefilled concurrently.""" - # Maximum number of "very long prompt" sequences that can be prefilled - # concurrently (long is defined by long_prefill_threshold) max_long_partial_prefills: int = 1 + """For chunked prefill, the maximum number of prompts longer than + long_prefill_token_threshold that will be prefilled concurrently. Setting + this less than max_num_partial_prefills will allow shorter prompts to jump + the queue in front of longer prompts in some cases, improving latency.""" - # calculate context length that determines which sequences are - # considered "long" long_prefill_token_threshold: int = 0 + """For chunked prefill, a request is considered long if the prompt is + longer than this number of tokens.""" - # The number of slots to allocate per sequence per - # step, beyond the known token ids. This is used in speculative - # decoding to store KV activations of tokens which may or may not be - # accepted. num_lookahead_slots: int = 0 + """The number of slots to allocate per sequence per + step, beyond the known token ids. This is used in speculative + decoding to store KV activations of tokens which may or may not be + accepted. + + NOTE: This will be replaced by speculative config in the future; it is + present to enable correctness tests until then.""" - # Apply a delay (of delay factor multiplied by previous - # prompt latency) before scheduling next prompt. delay_factor: float = 0.0 + """Apply a delay (of delay factor multiplied by previous + prompt latency) before scheduling next prompt.""" - # If True, prefill requests can be chunked based - # on the remaining max_num_batched_tokens. - enable_chunked_prefill: bool = False + enable_chunked_prefill: bool = None # type: ignore + """If True, prefill requests can be chunked based + on the remaining max_num_batched_tokens.""" is_multimodal_model: bool = False + """True if the model is multimodal.""" + + # TODO (ywang96): Make this configurable. + max_num_encoder_input_tokens: int = field(init=False) + """Multimodal encoder compute budget, only used in V1. + + NOTE: This is not currently configurable. It will be overridden by + max_num_batched_tokens in case max multimodal embedding size is larger.""" - # NOTE: The following multimodal encoder budget will be initialized to - # max_num_batched_tokens and overridden in case max multimodal embedding - # size is larger. - # TODO (ywang96): Make these configurable. - # Multimodal encoder compute budget, only used in V1 - max_num_encoder_input_tokens: int = field(default=None) # type: ignore + # TODO (ywang96): Make this configurable. + encoder_cache_size: int = field(init=False) + """Multimodal encoder cache size, only used in V1. - # Multimodal encoder cache size, only used in V1 - encoder_cache_size: int = field(default=None) # type: ignore + NOTE: This is not currently configurable. It will be overridden by + max_num_batched_tokens in case max multimodal embedding size is larger.""" - # Whether to perform preemption by swapping or - # recomputation. If not specified, we determine the mode as follows: - # We use recomputation by default since it incurs lower overhead than - # swapping. However, when the sequence group has multiple sequences - # (e.g., beam search), recomputation is not currently supported. In - # such a case, we use swapping instead. - preemption_mode: Optional[str] = None + preemption_mode: Optional[PreemptionMode] = None + """Whether to perform preemption by swapping or + recomputation. If not specified, we determine the mode as follows: + We use recomputation by default since it incurs lower overhead than + swapping. However, when the sequence group has multiple sequences + (e.g., beam search), recomputation is not currently supported. In + such a case, we use swapping instead.""" num_scheduler_steps: int = 1 + """Maximum number of forward steps per scheduler call.""" - multi_step_stream_outputs: bool = False + multi_step_stream_outputs: bool = True + """If False, then multi-step will stream outputs at the end of all steps""" - # Private API. If used, scheduler sends delta data to - # workers instead of an entire data. It should be enabled only - # when SPMD worker architecture is enabled. I.e., - # VLLM_USE_RAY_SPMD_WORKER=1 send_delta_data: bool = False - - # The scheduling policy to use. "fcfs" (default) or "priority". - policy: str = "fcfs" + """Private API. If used, scheduler sends delta data to + workers instead of an entire data. It should be enabled only + when SPMD worker architecture is enabled. I.e., + VLLM_USE_RAY_SPMD_WORKER=1""" + + policy: SchedulerPolicy = "fcfs" + """The scheduling policy to use:\n + - "fcfs" means first come first served, i.e. requests are handled in order + of arrival.\n + - "priority" means requests are handled based on given priority (lower + value means earlier handling) and time of arrival deciding any ties).""" chunked_prefill_enabled: bool = field(init=False) + """True if chunked prefill is enabled.""" - # If set to true and chunked prefill is enabled, we do not want to - # partially schedule a multimodal item. Only used in V1 - # This ensures that if a request has a mixed prompt - # (like text tokens TTTT followed by image tokens IIIIIIIIII) where only - # some image tokens can be scheduled (like TTTTIIIII, leaving IIIII), - # it will be scheduled as TTTT in one step and IIIIIIIIII in the next. disable_chunked_mm_input: bool = False + """If set to true and chunked prefill is enabled, we do not want to + partially schedule a multimodal item. Only used in V1 + This ensures that if a request has a mixed prompt + (like text tokens TTTT followed by image tokens IIIIIIIIII) where only + some image tokens can be scheduled (like TTTTIIIII, leaving IIIII), + it will be scheduled as TTTT in one step and IIIIIIIIII in the next.""" - # scheduler class or path. "vllm.core.scheduler.Scheduler" (default) - # or "mod.custom_class". scheduler_cls: Union[str, type[object]] = "vllm.core.scheduler.Scheduler" + """The scheduler class to use. "vllm.core.scheduler.Scheduler" is the + default scheduler. Can be a class directly or the path to a class of form + "mod.custom_class".""" def compute_hash(self) -> str: """ @@ -1861,6 +1951,18 @@ class SchedulerConfig: return hash_str def __post_init__(self) -> None: + if self.max_model_len is None: + self.max_model_len = 8192 + logger.warning( + "max_model_len was is not set. Defaulting to arbitrary value " + "of %d.", self.max_model_len) + + if self.max_num_seqs is None: + self.max_num_seqs = 128 + logger.warning( + "max_num_seqs was is not set. Defaulting to arbitrary value " + "of %d.", self.max_num_seqs) + if self.max_num_batched_tokens is None: if self.enable_chunked_prefill: if self.num_scheduler_steps > 1: @@ -1973,9 +2075,19 @@ class SchedulerConfig: return self.num_scheduler_steps > 1 +Device = Literal["auto", "cuda", "neuron", "cpu", "tpu", "xpu", "hpu"] + + +@config +@dataclass class DeviceConfig: - device: Optional[torch.device] - device_type: str + """Configuration for the device to use for vLLM execution.""" + + device: Union[Device, torch.device] = "auto" + """Device type for vLLM execution.""" + device_type: str = field(init=False) + """Device type from the current platform. This is set in + `__post_init__`.""" def compute_hash(self) -> str: """ @@ -1997,8 +2109,8 @@ class DeviceConfig: usedforsecurity=False).hexdigest() return hash_str - def __init__(self, device: str = "auto") -> None: - if device == "auto": + def __post_init__(self): + if self.device == "auto": # Automated device type detection from vllm.platforms import current_platform self.device_type = current_platform.device_type @@ -2009,7 +2121,7 @@ class DeviceConfig: "to turn on verbose logging to help debug the issue.") else: # Device type is assigned explicitly - self.device_type = device + self.device_type = self.device # Some device types require processing inputs on CPU if self.device_type in ["neuron"]: @@ -2021,139 +2133,113 @@ class DeviceConfig: self.device = torch.device(self.device_type) +SpeculativeMethod = Literal["ngram", "eagle", "medusa", "mlp_speculator", + "draft_model"] +SpeculativeAcceptanceMethod = Literal["rejection_sampler", + "typical_acceptance_sampler"] + + +@config @dataclass class SpeculativeConfig: - """ - Configuration for speculative decoding. - Configurable parameters include: - - General Speculative Decoding Control: - - num_speculative_tokens (int): The number of speculative - tokens, if provided. It will default to the number in the draft - model config if present, otherwise, it is required. - - model (Optional[str]): The name of the draft model, eagle head, - or additional weights, if provided. - - method (Optional[str]): The name of the speculative method to use. - If users provide and set the `model` param, the speculative method - type will be detected automatically if possible, if `model` param - is not provided, the method name must be provided. - - Possible values: - - ngram - Related additional configuration: - - prompt_lookup_max (Optional[int]): - Maximum size of ngram token window when using Ngram - proposer, required when method is set to ngram. - - prompt_lookup_min (Optional[int]): - Minimum size of ngram token window when using Ngram - proposer, if provided. Defaults to 1. - - eagle - - medusa - - mlp_speculator - - draft_model - - acceptance_method (str): The method to use for accepting draft - tokens. This can take two possible values: 'rejection_sampler' and - 'typical_acceptance_sampler' for RejectionSampler and - TypicalAcceptanceSampler respectively. If not specified, it - defaults to 'rejection_sampler'. - - Possible values: - - rejection_sampler - - typical_acceptance_sampler - Related additional configuration: - - posterior_threshold (Optional[float]): - A threshold value that sets a lower bound on the - posterior probability of a token in the target model - for it to be accepted. This threshold is used only - when we use the TypicalAcceptanceSampler for token - acceptance. - - posterior_alpha (Optional[float]): - Scaling factor for entropy-based threshold, applied - when using TypicalAcceptanceSampler. - - draft_tensor_parallel_size (Optional[int]): The degree of the tensor - parallelism for the draft model. Can only be 1 or the same as the - target model's tensor parallel size. - - disable_logprobs (bool): If set to True, token log probabilities are - not returned during speculative decoding. If set to False, token - log probabilities are returned according to the log probability - settings in SamplingParams. If not specified, it defaults to True. - - - Draft Model Configuration: - - quantization (Optional[str]): Quantization method that was used to - quantize the draft model weights. If None, we assume the - model weights are not quantized. Note that it only takes effect - when using the draft model-based speculative method. - - max_model_len (Optional[int]): The maximum model length of the - draft model. Used when testing the ability to skip - speculation for some sequences. - - revision: The specific model version to use for the draft model. It - can be a branch name, a tag name, or a commit id. If unspecified, - will use the default version. - - code_revision: The specific revision to use for the draft model code - on Hugging Face Hub. It can be a branch name, a tag name, or a - commit id. If unspecified, will use the default version. + """Configuration for speculative decoding.""" - - Advanced Control: - - disable_mqa_scorer (bool): Disable the MQA scorer and fall back to - batch expansion for scoring proposals. If not specified, it - defaults to False. - - disable_by_batch_size (Optional[int]): Disable speculative decoding - for new incoming requests when the number of enqueued requests is - larger than this value, if provided. - - Although the parameters above are structured hierarchically, there is no - need to nest them during configuration. - - Non-configurable internal parameters include: - - Model Configuration: - - target_model_config (ModelConfig): The configuration of the target - model. - - draft_model_config (ModelConfig): The configuration of the draft - model initialized internal. - - Parallelism Configuration: - - target_parallel_config (ParallelConfig): The parallel configuration - for the target model. - - draft_parallel_config (ParallelConfig): The parallel configuration - for the draft model initialized internal. - - Execution Control: - - enable_chunked_prefill (bool): Whether vLLM is configured to use - chunked prefill or not. Used for raising an error since it's not - yet compatible with speculative decode. - - disable_log_stats (bool): Whether to disable the periodic printing of - stage times in speculative decoding. - """ - # speculative configs from cli args + # General speculative decoding control num_speculative_tokens: int = field(default=None, init=True) # type: ignore - method: Optional[str] = None - acceptance_method: str = "rejection_sampler" + """The number of speculative tokens, if provided. It will default to the + number in the draft model config if present, otherwise, it is required.""" + model: Optional[str] = None + """The name of the draft model, eagle head, or additional weights, if + provided.""" + method: Optional[SpeculativeMethod] = None + """The name of the speculative method to use. If users provide and set the + `model` param, the speculative method type will be detected automatically + if possible, if `model` param is not provided, the method name must be + provided. + + If using `ngram` method, the related configuration `prompt_lookup_max` and + `prompt_lookup_min` should be considered.""" + acceptance_method: SpeculativeAcceptanceMethod = "rejection_sampler" + """The method to use for accepting draft tokens:\n + - "rejection_sampler" maps to `RejectionSampler`.\n + - "typical_acceptance_sampler" maps to `TypicalAcceptanceSampler`. + + If using `typical_acceptance_sampler`, the related configuration + `posterior_threshold` and `posterior_alpha` should be considered.""" draft_tensor_parallel_size: Optional[int] = None + """The degree of the tensor parallelism for the draft model. Can only be 1 + or the same as the target model's tensor parallel size.""" disable_logprobs: bool = True + """If set to True, token log probabilities are not returned during + speculative decoding. If set to False, token log probabilities are returned + according to the log probability settings in SamplingParams.""" - model: Optional[str] = None + # Draft model configuration quantization: Optional[str] = None + """Quantization method that was used to quantize the draft model weights. + If `None`, we assume the model weights are not quantized. Note that it only + takes effect when using the draft model-based speculative method.""" max_model_len: Optional[int] = None + """The maximum model length of the draft model. Used when testing the + ability to skip speculation for some sequences.""" revision: Optional[str] = None + """The specific model version to use for the draft model. It can be a + branch name, a tag name, or a commit id. If unspecified, will use the + default version.""" code_revision: Optional[str] = None + """The specific revision to use for the draft model code on Hugging Face + Hub. It can be a branch name, a tag name, or a commit id. If unspecified, + will use the default version.""" + # Advanced control disable_mqa_scorer: bool = False + """Disable the MQA scorer and fall back to batch expansion for scoring + proposals.""" disable_by_batch_size: Optional[int] = None + """Disable speculative decoding for new incoming requests when the number + of enqueued requests is larger than this value, if provided.""" + + # Ngram proposer configuration prompt_lookup_max: Optional[int] = None + """Maximum size of ngram token window when using Ngram proposer, required + when method is set to ngram.""" prompt_lookup_min: Optional[int] = None + """Minimum size of ngram token window when using Ngram proposer, if + provided. Defaults to 1.""" + + # Typical acceptance sampler configuration posterior_threshold: Optional[float] = None + """A threshold value that sets a lower bound on the posterior probability + of a token in the target model for it to be accepted. This threshold is + used only when we use the `TypicalAcceptanceSampler` for token acceptance. + """ posterior_alpha: Optional[float] = None + """Scaling factor for entropy-based threshold, applied when using + `TypicalAcceptanceSampler`.""" # required configuration params passed from engine target_model_config: ModelConfig = field(default=None, init=True) # type: ignore + """The configuration of the target model.""" target_parallel_config: ParallelConfig = field(default=None, init=True) # type: ignore + """The parallel configuration for the target model.""" enable_chunked_prefill: bool = field(default=None, init=True) # type: ignore + """Whether vLLM is configured to use chunked prefill or not. Used for + raising an error since it's not yet compatible with speculative decode.""" disable_log_stats: bool = field(default=None, init=True) # type: ignore + """Whether to disable the periodic printing of stage times in speculative + decoding.""" # params generated in the post-init stage draft_model_config: ModelConfig = field(default=None, init=True) # type: ignore + """The configuration of the draft model initialized internal.""" draft_parallel_config: ParallelConfig = field(default=None, init=True) # type: ignore + """The parallel configuration for the draft model initialized internal.""" def compute_hash(self) -> str: """ @@ -2167,9 +2253,10 @@ class SpeculativeConfig: excluding anything before input ids/embeddings and after the final hidden states. """ - # no factors to consider. - # spec decode does not use `torch.compile` yet. factors: list[Any] = [] + # Eagle3 affects the computation graph because it returns intermediate + # hidden states in addition to the final hidden state. + factors.append(self.method == "eagle3") hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str @@ -2204,7 +2291,8 @@ class SpeculativeConfig: if self.model is None and self.num_speculative_tokens is not None: # TODO(Shangming): Refactor mtp configuration logic when supporting # mtp acceleration for more models besides deepseek_v3 - if self.target_model_config.hf_text_config.model_type \ + if self.target_model_config and \ + self.target_model_config.hf_text_config.model_type \ == "deepseek_v3": # use the draft model from the same model: self.model = self.target_model_config.model @@ -2285,7 +2373,10 @@ class SpeculativeConfig: ) # Automatically detect the method - if "eagle-" in self.draft_model_config.model.lower(): + if self.method in ('eagle', 'eagle3'): + pass + elif "eagle-" in self.draft_model_config.model.lower() or \ + "eagle3-" in self.draft_model_config.model.lower(): self.method = "eagle" elif self.draft_model_config.hf_config.model_type == "medusa": self.method = "medusa" @@ -2296,7 +2387,7 @@ class SpeculativeConfig: self.method = "draft_model" # Replace hf_config for EAGLE draft_model - if self.method == "eagle": + if self.method in ("eagle", "eagle3"): if self.enable_chunked_prefill and not envs.VLLM_USE_V1: raise ValueError( "Chunked prefill and EAGLE are not compatible " @@ -2445,7 +2536,6 @@ class SpeculativeConfig: max_parallel_loading_workers, disable_custom_all_reduce=target_parallel_config. disable_custom_all_reduce, - tokenizer_pool_config=target_parallel_config.tokenizer_pool_config, ray_workers_use_nsight=target_parallel_config. ray_workers_use_nsight, placement_group=target_parallel_config.placement_group, @@ -2498,6 +2588,12 @@ class SpeculativeConfig: "speculative decoding is > 1, but got " f"{self.disable_by_batch_size=}") + if self.method == "eagle3" and self.target_model_config and \ + "llama" not in self.target_model_config.hf_text_config.model_type: + raise ValueError( + "Eagle3 is only supported for Llama models. " + f"Got {self.target_model_config.hf_text_config.model_type=}") + @property def num_lookahead_slots(self) -> int: """The number of additional slots the scheduler should allocate per @@ -2508,6 +2604,9 @@ class SpeculativeConfig: """ return self.num_speculative_tokens + def use_eagle(self) -> bool: + return self.method in ("eagle", "eagle3") + def __repr__(self) -> str: method = self.method model = None if method == "ngram" else self.draft_model_config.model @@ -2515,18 +2614,45 @@ class SpeculativeConfig: return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})" +LoRADType = Literal["auto", "float16", "bfloat16"] + + +@config @dataclass class LoRAConfig: - max_lora_rank: int - max_loras: int + """Configuration for LoRA.""" + + max_lora_rank: int = 16 + """Max LoRA rank.""" + max_loras: int = 1 + """Max number of LoRAs in a single batch.""" fully_sharded_loras: bool = False + """By default, only half of the LoRA computation is sharded with tensor + parallelism. Enabling this will use the fully sharded layers. At high + sequence length, max rank or tensor parallel size, this is likely faster. + """ max_cpu_loras: Optional[int] = None - lora_dtype: Optional[Union[torch.dtype, str]] = None + """Maximum number of LoRAs to store in CPU memory. Must be >= than + `max_loras`.""" + lora_target_modules: Optional[List[str]] = None + """List of lora module name, If not specified, + modules will be chosen according to the model architecture. + """ + lora_dtype: Union[torch.dtype, LoRADType] = "auto" + """Data type for LoRA. If auto, will default to base model dtype.""" lora_extra_vocab_size: int = 256 + """Maximum size of extra vocabulary that can be present in a LoRA adapter + (added to the base model vocabulary).""" # This is a constant. lora_vocab_padding_size: ClassVar[int] = 256 - long_lora_scaling_factors: Optional[tuple[float]] = None + long_lora_scaling_factors: Optional[tuple[float, ...]] = None + """Specify multiple scaling factors (which can be different from base model + scaling factor - see eg. Long LoRA) to allow for multiple LoRA adapters + trained with those scaling factors to be used at the same time. If not + specified, only adapters trained with the base model scaling factor are + allowed.""" bias_enabled: bool = False + """Enable bias for LoRA adapters.""" def compute_hash(self) -> str: """ @@ -2589,25 +2715,27 @@ class LoRAConfig: elif isinstance(self.lora_dtype, str): self.lora_dtype = getattr(torch, self.lora_dtype) - def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): - # Reminder: Please update docs/source/features/compatibility_matrix.md - # If the feature combo become valid - if scheduler_config.chunked_prefill_enabled: - logger.warning("LoRA with chunked prefill is still experimental " - "and may be unstable.") - def verify_lora_support(self): if self.long_lora_scaling_factors is not None and envs.VLLM_USE_V1: raise ValueError( "V1 LoRA does not support long LoRA, please use V0.") +@config @dataclass class PromptAdapterConfig: - max_prompt_adapters: int - max_prompt_adapter_token: int + """Configuration for PromptAdapters.""" + + max_prompt_adapters: int = 1 + """Max number of PromptAdapters in a batch.""" + max_prompt_adapter_token: int = 0 + """Max number of PromptAdapters tokens.""" max_cpu_prompt_adapters: Optional[int] = None - prompt_adapter_dtype: Optional[torch.dtype] = None + """Maximum number of PromptAdapters to store in CPU memory. Must be >= than + `max_prompt_adapters`.""" + prompt_adapter_dtype: Union[torch.dtype, str] = "auto" + """Data type for PromptAdapter. If auto, will default to base model dtype. + """ def compute_hash(self) -> str: """ @@ -2639,20 +2767,26 @@ class PromptAdapterConfig: self.max_cpu_prompt_adapters = self.max_prompt_adapters def verify_with_model_config(self, model_config: ModelConfig): - if self.prompt_adapter_dtype in (None, "auto"): + if self.prompt_adapter_dtype == "auto": self.prompt_adapter_dtype = model_config.dtype elif isinstance(self.prompt_adapter_dtype, str): self.prompt_adapter_dtype = getattr(torch, self.prompt_adapter_dtype) +@config @dataclass class MultiModalConfig: """Controls the behavior of multimodal models.""" - limit_per_prompt: Mapping[str, int] = field(default_factory=dict) + limit_per_prompt: dict[str, int] = field(default_factory=dict) """ The maximum number of input items allowed per prompt for each modality. + This should be a JSON string that will be parsed into a dictionary. + Defaults to 1 (V0) or 999 (V1) for each modality. + + For example, to allow up to 16 images and 2 videos per prompt: + ``{"images": 16, "videos": 2}`` """ def compute_hash(self) -> str: @@ -2674,24 +2808,20 @@ class MultiModalConfig: usedforsecurity=False).hexdigest() return hash_str - def get_default_limit_per_prompt(self) -> int: - """ - Return the default number of input items allowed per prompt - for any modality if not specified by the user. - """ - return 999 if envs.VLLM_USE_V1 else 1 - def get_limit_per_prompt(self, modality: str) -> int: """ Get the maximum number of input items allowed per prompt for the given modality. """ - default = self.get_default_limit_per_prompt() - return self.limit_per_prompt.get(modality, default) + return self.limit_per_prompt.get( + modality, + 999 if envs.VLLM_USE_V1 else 1, + ) # TODO: Add configs to init vision tower or not. +@config @dataclass class PoolerConfig: """Controls the behavior of output pooling in pooling models.""" @@ -2769,12 +2899,10 @@ def _get_and_verify_dtype( ) -> torch.dtype: # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct # because config.torch_dtype can be None. - config_dtype = getattr(config, "torch_dtype", None) + config_dtype = getattr(config.get_text_config(), "torch_dtype", None) - # Fallbacks for multi-modal models if the root config + # Fallback for multi-modal models if the root config # does not define torch_dtype - if config_dtype is None and hasattr(config, "text_config"): - config_dtype = getattr(config.text_config, "torch_dtype", None) if config_dtype is None and hasattr(config, "vision_config"): config_dtype = getattr(config.vision_config, "torch_dtype", None) @@ -2790,6 +2918,13 @@ def _get_and_verify_dtype( else: torch_dtype = config_dtype + if config.model_type == "plamo2": + logger.info( + "For PLaMo2, we cast models to bfloat16 instead of using " + "float16 by default. This is because float16 does not work." + ) + torch_dtype = torch.bfloat16 + from vllm.platforms import current_platform if (current_platform.is_cpu() and current_platform.get_cpu_architecture() @@ -2819,6 +2954,11 @@ def _get_and_verify_dtype( "using float16 by default. Please specify `dtype` if you " "want to use float16.") torch_dtype = torch.bfloat16 + elif dtype == "float16" and config.model_type == "plamo2": + logger.warning( + "For PLaMo2, using float16 is unstable and might cause " + "unexpected behavior. Please use bfloat16 or float32 instead.") + torch_dtype = torch.float16 else: if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: raise ValueError(f"Unknown dtype: {dtype}") @@ -3004,15 +3144,28 @@ def get_served_model_name(model: str, return served_model_name +GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer", + "xgrammar", "guidance"] +GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"] + + +@config @dataclass class DecodingConfig: - """Dataclass which contains the decoding strategy of the engine""" + """Dataclass which contains the decoding strategy of the engine.""" - # Which guided decoding algo to use. - # 'outlines' / 'lm-format-enforcer' / 'xgrammar' - guided_decoding_backend: str = "auto" if envs.VLLM_USE_V1 else "xgrammar" + guided_decoding_backend: Union[ + GuidedDecodingBackendV0, + GuidedDecodingBackendV1] = "auto" if envs.VLLM_USE_V1 else "xgrammar" + """Which engine will be used for guided decoding (JSON schema / regex etc) + by default. With "auto", we will make opinionated choices based on request + contents and what the backend libraries currently support, so the behavior + is subject to change in each release.""" reasoning_backend: Optional[str] = None + """Select the reasoning parser depending on the model that you're using. + This is used to parse the reasoning content into OpenAI API format. + Required for `--enable-reasoning`.""" def compute_hash(self) -> str: """ @@ -3034,17 +3187,12 @@ class DecodingConfig: return hash_str def __post_init__(self): - v0_valid_guided_backends = [ - 'outlines', 'lm-format-enforcer', 'xgrammar', 'auto' - ] - v1_valid_guided_backends = ['xgrammar', 'guidance', 'auto'] - backend = GuidedDecodingParams( backend=self.guided_decoding_backend).backend_name if envs.VLLM_USE_V1: - valid_guided_backends = v1_valid_guided_backends + valid_guided_backends = get_args(GuidedDecodingBackendV1) else: - valid_guided_backends = v0_valid_guided_backends + valid_guided_backends = get_args(GuidedDecodingBackendV0) if backend not in valid_guided_backends: raise ValueError(f"Invalid guided_decoding_backend '{backend}'," f" must be one of {valid_guided_backends}") @@ -3304,11 +3452,13 @@ class CompilationConfig(BaseModel): - enable_fusion: whether to enable the custom fusion pass. - enable_noop: whether to enable the custom no-op elimination pass. TODO(luka) better pass enabling system. + - enable_sequence_parallelism: whether to enable sequence parallelism. """ dump_graph_stages: list[str] = Field(default_factory=list) dump_graph_dir: Path = Field(default=Path(".")) enable_fusion: bool = True enable_noop: bool = True + enable_sequence_parallelism: bool = False def uuid(self): """ @@ -3317,7 +3467,8 @@ class CompilationConfig(BaseModel): Do not include dump_graph_* in the hash - they don't affect compilation. """ - dict_ = self.model_dump(include={"enable_fusion", "enable_noop"}) + dict_ = self.model_dump(include={"enable_fusion", "enable_noop", \ + "enable_sequence_parallelism"}) return InductorPass.hash_dict(dict_) def model_post_init(self, __context: Any) -> None: @@ -3344,7 +3495,8 @@ class CompilationConfig(BaseModel): compilation_time: float = PrivateAttr # Per-model forward context - # Map from layer name to the attention cls + # Map from layer name to layer objects that need to be accessed outside + # model code, e.g., Attention, FusedMOE when dp_size>1. static_forward_context: dict[str, Any] = PrivateAttr def compute_hash(self) -> str: @@ -3675,6 +3827,17 @@ class VllmConfig: return quant_config return None + @staticmethod + def get_quantization_config( + model_config: ModelConfig, + load_config: LoadConfig) -> Optional[QuantizationConfig]: + import copy + + # For some reason, the _ version of this modifies the model_config + # object, so using deepcopy to avoid this problem. + return VllmConfig._get_quantization_config(copy.deepcopy(model_config), + load_config) + def with_hf_config( self, hf_config: PretrainedConfig, @@ -3704,8 +3867,6 @@ class VllmConfig: if self.lora_config: self.lora_config.verify_with_cache_config(self.cache_config) self.lora_config.verify_with_model_config(self.model_config) - self.lora_config.verify_with_scheduler_config( - self.scheduler_config) self.lora_config.verify_lora_support() if self.prompt_adapter_config: self.prompt_adapter_config.verify_with_model_config( @@ -3729,6 +3890,8 @@ class VllmConfig: if self.compilation_config is None: self.compilation_config = CompilationConfig() + if self.compilation_config.pass_config.enable_sequence_parallelism: + self.compilation_config.custom_ops.append("+rms_norm") if envs.VLLM_USE_V1 and self.model_config is not None and \ not self.model_config.enforce_eager: # NOTE(woosuk): Currently, we use inductor because the piecewise @@ -3736,7 +3899,8 @@ class VllmConfig: # FIXME(woosuk): Disable inductor to reduce the compilation time # and avoid any potential issues with the inductor. # FIXME(rob): Add function to set all of these. - self.compilation_config.custom_ops = ["none"] + if not self.compilation_config.custom_ops: + self.compilation_config.custom_ops = ["none"] self.compilation_config.use_cudagraph = True self.compilation_config.use_inductor = True self.compilation_config.cudagraph_num_of_warmups = 1 @@ -3747,6 +3911,18 @@ class VllmConfig: self.compilation_config.level = CompilationLevel.NO_COMPILATION self.compilation_config.set_splitting_ops_for_v1() + if self.parallel_config is not None and \ + self.parallel_config.tensor_parallel_size > 1 and \ + self.parallel_config.pipeline_parallel_size > 1 and \ + self.compilation_config is not None and \ + self.compilation_config.pass_config is not None and \ + self.compilation_config.pass_config.enable_sequence_parallelism: + logger.warning_once( + "Sequence parallelism is not supported with pipeline " + "parallelism. Disabling sequence parallelism.") + self.compilation_config.pass_config.\ + enable_sequence_parallelism = False + self._set_cudagraph_sizes() if self.cache_config is not None and \ @@ -3786,6 +3962,26 @@ class VllmConfig: if not self.instance_id: self.instance_id = random_uuid()[:5] + def update_sizes_for_sequence_parallelism(self, + possible_sizes: list) -> list: + # remove the sizes that not multiple of tp_size when + # enable sequence parallelism + removed_sizes = [ + size for size in possible_sizes + if size % self.parallel_config.tensor_parallel_size != 0 + ] + if removed_sizes: + logger.warning( + "Batch sizes %s are removed because they are not " + "multiple of tp_size %d when " + "sequence parallelism is enabled", removed_sizes, + self.parallel_config.tensor_parallel_size) + + return [ + size for size in possible_sizes + if size % self.parallel_config.tensor_parallel_size == 0 + ] + def _set_cudagraph_sizes(self): """ cudagraph batchsize padding logic: @@ -3823,6 +4019,11 @@ class VllmConfig: not self.model_config.enforce_eager: possible_sizes = [1, 2, 4] + [8 * i for i in range(1, 1025)] + if self.parallel_config.tensor_parallel_size > 1 and \ + self.compilation_config.pass_config.enable_sequence_parallelism: + possible_sizes = self.update_sizes_for_sequence_parallelism( + possible_sizes) + # find the minimum size that is larger than max_num_seqs, # which then becomes the max_batchsize_to_capture larger_sizes = [ @@ -3846,6 +4047,11 @@ class VllmConfig: not self.model_config.enforce_eager: batch_size_capture_list = [1, 2, 4 ] + [i for i in range(8, 513, 8)] + if self.parallel_config.tensor_parallel_size > 1 and \ + self.compilation_config.pass_config.enable_sequence_parallelism: + batch_size_capture_list = \ + self.update_sizes_for_sequence_parallelism(batch_size_capture_list) + max_num_tokens = self.scheduler_config.max_num_batched_tokens batch_size_capture_list = [ size for size in batch_size_capture_list @@ -3944,3 +4150,43 @@ def get_current_vllm_config() -> VllmConfig: from vllm.config import VllmConfig return VllmConfig() return _current_vllm_config + + +def contains_object_print(text): + """ + Check if the text looks like a printed Python object, e.g. + contains any substring matching the pattern: "at 0xFFFFFFF>" + We match against 0x followed by 2-16 hex chars (there's + a max of 16 on a 64 bit system). + + Args: + text (str): The text to check + + Returns: + bool: True if a match is found, False otherwise + """ + pattern = r'at 0x[a-fA-F0-9]{2,16}>' + match = re.search(pattern, text) + return match is not None + + +def assert_hashable(text): + if not contains_object_print(text): + return True + raise AssertionError( + f"vLLM tried to hash some configs that may have Python objects ids " + f"in them. This is a bug, please file an issue. " + f"Text being hashed: {text}") + + +T = TypeVar("T") + + +def get_layers_from_vllm_config(vllm_config: VllmConfig, + layer_type: type[T]) -> dict[str, T]: + return { + layer_name: layer + for layer_name, layer in + vllm_config.compilation_config.static_forward_context.items() + if isinstance(layer, layer_type) + } diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index cf85a2135c817d384787643242f70e5de4827f8d..97d03d5e3b40a636ef3f03191d431478a2e10334 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1596,7 +1596,6 @@ class Scheduler: multi_modal_placeholders=( seq_group.multi_modal_placeholders if scheduler_outputs.num_prefill_groups > 0 else None), - mm_processor_kwargs=seq_group.mm_processor_kwargs, prompt_adapter_request=seq_group.prompt_adapter_request, ) else: diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 0228264f91f9a8688cc885e9c8d9c090eb954321..894a0fafb64034fc73e47e60c340eff04ed92ee4 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -19,6 +19,12 @@ def tensor_model_parallel_all_gather(input_: torch.Tensor, return get_tp_group().all_gather(input_, dim) +def tensor_model_parallel_reduce_scatter(input_: torch.Tensor, + dim: int = -1) -> torch.Tensor: + """Reduce-Scatter the input tensor across model parallel group.""" + return get_tp_group().reduce_scatter(input_, dim) + + def tensor_model_parallel_gather(input_: torch.Tensor, dst: int = 0, dim: int = -1) -> Optional[torch.Tensor]: diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index eb12f8834b4191c537e268dbf1bfc46080b59cb1..240313b98c88b6ac7aee1ec4bb59089624fa0e5f 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -61,6 +61,40 @@ class DeviceCommunicatorBase: input_size[dim + 1:]) return output_tensor + def reduce_scatter(self, + input_: torch.Tensor, + dim: int = -1) -> torch.Tensor: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Note: This will produce an incorrect answer if we don't make + # the input_tensor contiguous. Possible bug in reduce_scatter_tensor? + input_tensor = input_.movedim(0, dim).contiguous() + + assert input_tensor.shape[0] % world_size == 0 + chunk_size = input_tensor.shape[0] // world_size + output_shape = (chunk_size, ) + input_tensor.shape[1:] + + output_tensor = torch.empty(output_shape, + dtype=input_tensor.dtype, + device=input_tensor.device) + + # Perform reduce-scatter operation + torch.distributed.reduce_scatter_tensor(output_tensor, + input_tensor, + group=self.device_group) + + # Reshape before returning + return output_tensor.movedim(0, dim).contiguous() + def gather(self, input_: torch.Tensor, dst: int = 0, diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 07c9ff5060924da7821ac71b0fc29e6b06b35089..8bca278f3888b03c944dde5ffe83b487585eb788 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -70,6 +70,31 @@ class CudaCommunicator(DeviceCommunicatorBase): torch.distributed.all_reduce(out, group=self.device_group) return out + def reduce_scatter(self, input_: torch.Tensor, dim: int = -1): + world_size = self.world_size + pynccl_comm = self.pynccl_comm + assert pynccl_comm is not None + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Note: This will produce an incorrect answer if we don't make + # the input_tensor contiguous. Possible bug in reduce_scatter_tensor? + input_tensor = input_.movedim(0, dim).contiguous() + + assert input_tensor.shape[0] % world_size == 0 + chunk_size = input_tensor.shape[0] // world_size + output_shape = (chunk_size, ) + input_tensor.shape[1:] + + output = torch.empty(output_shape, + dtype=input_tensor.dtype, + device=input_tensor.device) + + pynccl_comm.reduce_scatter(output, input_) + + # Reshape before returning + return output.movedim(0, dim).contiguous() + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: """Sends a tensor to the destination rank in a non-blocking way""" """NOTE: `dst` is the local rank of the destination rank.""" diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 11ed7c0843779f4b9258b6eb4792696bc56bfd67..723719c79e9c2c3533cb79277d74f8b6c277a12e 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -7,11 +7,13 @@ import time from contextlib import contextmanager from dataclasses import dataclass, field from multiprocessing import shared_memory -from typing import List, Optional, Tuple, Union +from threading import Event +from typing import Any, List, Optional, Tuple, Union from unittest.mock import patch import torch import torch.distributed as dist +import zmq from torch.distributed import ProcessGroup from zmq import IPV6 # type: ignore from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore @@ -239,7 +241,7 @@ class MessageQueue: self.remote_socket.setsockopt(IPV6, 1) remote_addr_ipv6 = True connect_ip = f"[{connect_ip}]" - socket_addr = f"tcp://*:{remote_subscribe_port}" + socket_addr = f"tcp://{connect_ip}:{remote_subscribe_port}" self.remote_socket.bind(socket_addr) remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}" else: @@ -400,7 +402,9 @@ class MessageQueue: break @contextmanager - def acquire_read(self, timeout: Optional[float] = None): + def acquire_read(self, + timeout: Optional[float] = None, + cancel: Optional[Event] = None): assert self._is_local_reader, "Only readers can acquire read" start_time = time.monotonic() n_warning = 1 @@ -430,6 +434,9 @@ class MessageQueue: ) n_warning += 1 + if cancel is not None and cancel.is_set(): + raise RuntimeError("cancelled") + # if we time out, raise an exception if (timeout is not None and time.monotonic() - start_time > timeout): @@ -464,10 +471,12 @@ class MessageQueue: if self.n_remote_reader > 0: self.remote_socket.send(serialized_obj) - def dequeue(self, timeout: Optional[float] = None): + def dequeue(self, + timeout: Optional[float] = None, + cancel: Optional[Event] = None): """ Read from message queue with optional timeout (in seconds) """ if self._is_local_reader: - with self.acquire_read(timeout) as buf: + with self.acquire_read(timeout, cancel) as buf: overflow = buf[0] == 1 if not overflow: # no need to know the size of serialized object @@ -475,15 +484,21 @@ class MessageQueue: # see https://docs.python.org/3/library/pickle.html obj = pickle.loads(buf[1:]) if overflow: - recv = self.local_socket.recv() - obj = pickle.loads(recv) + obj = MessageQueue.recv(self.local_socket, timeout) elif self._is_remote_reader: - recv = self.remote_socket.recv() - obj = pickle.loads(recv) + obj = MessageQueue.recv(self.remote_socket, timeout) else: raise RuntimeError("Only readers can dequeue") return obj + @staticmethod + def recv(socket: zmq.Socket, timeout: Optional[float]) -> Any: + timeout_ms = None if timeout is None else int(timeout * 1000) + if not socket.poll(timeout=timeout_ms): + raise TimeoutError + recv = socket.recv(copy=False) + return pickle.loads(recv.buffer) + def broadcast_object(self, obj=None): if self._is_writer: self.enqueue(obj) diff --git a/vllm/distributed/kv_transfer/__init__.py b/vllm/distributed/kv_transfer/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..ec07c6fe0d12d5fb9f3eb0429ad7ca1922fe61e1 100644 --- a/vllm/distributed/kv_transfer/__init__.py +++ b/vllm/distributed/kv_transfer/__init__.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.distributed.kv_transfer.kv_transfer_state import ( + ensure_kv_transfer_initialized, get_kv_transfer_group, + has_kv_transfer_group, is_v1_kv_transfer_group) + +__all__ = [ + "get_kv_transfer_group", "has_kv_transfer_group", + "is_v1_kv_transfer_group", "ensure_kv_transfer_initialized", + "KVConnectorBaseType" +] diff --git a/vllm/distributed/kv_transfer/kv_connector/base.py b/vllm/distributed/kv_transfer/kv_connector/base.py index 57c764b481c29f30ed0daf18db5968ade5b8c3c3..0d1a3d40af413911e2b4e66199664b0737bd431e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/base.py @@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, List, Tuple, Union import torch +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.sequence import IntermediateTensors if TYPE_CHECKING: @@ -121,3 +122,6 @@ class KVConnectorBase(ABC): """ raise NotImplementedError + + +KVConnectorBaseType = Union[KVConnectorBase, KVConnectorBase_V1] diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index e37ce6dc75b031fea0a147b9fbe4ac89815d8d88..6532c101a4f6a4008f5c343eb7e78bb536f83307 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -3,14 +3,22 @@ import importlib from typing import TYPE_CHECKING, Callable, Dict, Type +import vllm.envs as envs +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType +from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, + KVConnectorRole) +from vllm.logger import init_logger + from .base import KVConnectorBase if TYPE_CHECKING: from vllm.config import VllmConfig +logger = init_logger(__name__) + class KVConnectorFactory: - _registry: Dict[str, Callable[[], Type[KVConnectorBase]]] = {} + _registry: Dict[str, Callable[[], Type[KVConnectorBaseType]]] = {} @classmethod def register_connector(cls, name: str, module_path: str, @@ -19,22 +27,51 @@ class KVConnectorFactory: if name in cls._registry: raise ValueError(f"Connector '{name}' is already registered.") - def loader() -> Type[KVConnectorBase]: + def loader() -> Type[KVConnectorBaseType]: module = importlib.import_module(module_path) return getattr(module, class_name) cls._registry[name] = loader @classmethod - def create_connector(cls, rank: int, local_rank: int, - config: "VllmConfig") -> KVConnectorBase: + def create_connector_v0(cls, rank: int, local_rank: int, + config: "VllmConfig") -> KVConnectorBase: + if envs.VLLM_USE_V1: + raise ValueError("Attempting to initialize a V0 Connector, " + f"but found {envs.VLLM_USE_V1=}") + connector_name = config.kv_transfer_config.kv_connector if connector_name not in cls._registry: raise ValueError(f"Unsupported connector type: {connector_name}") connector_cls = cls._registry[connector_name]() + assert issubclass(connector_cls, KVConnectorBase) return connector_cls(rank, local_rank, config) + @classmethod + def create_connector_v1( + cls, + config: "VllmConfig", + role: KVConnectorRole, + ) -> KVConnectorBase_V1: + if not envs.VLLM_USE_V1: + raise ValueError("Attempting to initialize a V1 Connector, " + f"but found {envs.VLLM_USE_V1=}") + + connector_name = config.kv_transfer_config.kv_connector + connector_cls = cls._registry[connector_name]() + assert issubclass(connector_cls, KVConnectorBase_V1) + logger.info("Creating v1 connector with name: %s", connector_name) + # NOTE(Kuntai): v1 connector is explicitly separated into two roles. + # Scheduler connector: + # - Co-locate with scheduler process + # - Should only be used inside the Scheduler class + # Worker connector: + # - Co-locate with worker process + # - Should only be used inside the forward context & attention layer + # We build separately to enforce strict separation + return connector_cls(config, role) + # Register various connectors here. # The registration should not be done in each individual file, as we want to @@ -57,4 +94,14 @@ KVConnectorFactory.register_connector( KVConnectorFactory.register_connector( "MooncakeStoreConnector", "vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector", - "MooncakeStoreConnector") \ No newline at end of file + "MooncakeStoreConnector") + +KVConnectorFactory.register_connector( + "SharedStorageConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector", + "SharedStorageConnector") + +KVConnectorFactory.register_connector( + "LMCacheConnectorV1", + "vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector", + "LMCacheConnectorV1") diff --git a/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py b/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py index c5135dab23ebab99295e1e08fae4461a743e5f58..7b26aec23239cb91db3a24400d2aab1ca3d7d8e2 100644 --- a/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """ MooncakeStore Connector for Distributed Machine Learning Inference - The MooncakeStoreConnector transfers KV caches between prefill vLLM workers (KV cache producer) and decode vLLM workers (KV cache consumer) using a database-style KVStore. @@ -11,9 +10,10 @@ from typing import TYPE_CHECKING, List, Tuple, Union import torch -from vllm import _custom_ops as ops from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.kv_transfer.kv_connector.utils import ( + model_aware_kv_ops_helper as kv_helper) from vllm.logger import init_logger from vllm.sequence import IntermediateTensors @@ -32,8 +32,7 @@ class MooncakeStoreConnector(KVConnectorBase): config: VllmConfig, ): self.config = config.kv_transfer_config - self.tp_size = config.parallel_config.tensor_parallel_size - + self.kv_helper = kv_helper(config) self.local_tp_rank = local_rank # Init kv_store @@ -80,12 +79,7 @@ class MooncakeStoreConnector(KVConnectorBase): slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() start_layer = model_executable.model.start_layer end_layer = model_executable.model.end_layer - - model_config = model_executable.model.config - num_heads = int(model_config.num_key_value_heads / self.tp_size) - hidden_size = model_config.hidden_size - num_attention_heads = model_config.num_attention_heads - head_size = int(hidden_size / num_attention_heads) + num_heads, head_size = self.kv_helper.get_model_args(model_executable) for idx, slen in enumerate(seq_lens): start_pos = sum(seq_lens[:idx]) @@ -97,10 +91,8 @@ class MooncakeStoreConnector(KVConnectorBase): for layer_id in range(start_layer, end_layer): kv_cache = kv_caches[layer_id - start_layer] - - key_cache = kv_cache[0].reshape(-1, num_heads, head_size) - value_cache = kv_cache[1].reshape(-1, num_heads, head_size) - + key_cache, value_cache = self.kv_helper.get_kv_from_cache( + kv_cache, num_heads, head_size) current_slot_mapping = slot_mapping_flat[start_pos:end_pos] keys.append(key_cache[current_slot_mapping].unsqueeze(0)) @@ -173,22 +165,15 @@ class MooncakeStoreConnector(KVConnectorBase): layer = model_executable.model.layers[layer_id] # get kvcache object kv_cache = kv_caches[layer_id - start_layer] - key_cache, value_cache = kv_cache[0], kv_cache[1] - # get remote kvcache + # get remote kvcache remote_k, remote_v = remote_kv[0][layer_id], remote_kv[1][ layer_id] - # use ops.reshape_and_cache_flash to put kv into kvcache - ops.reshape_and_cache_flash( - remote_k.to(key_cache.device), - remote_v.to(value_cache.device), - key_cache, - value_cache, - slot_mapping[start_pos:end_pos], - layer.self_attn.attn.kv_cache_dtype, - layer.self_attn.attn._k_scale, - layer.self_attn.attn._v_scale, - ) + + self.kv_helper.put_kv_to_cache(model_executable, remote_k, + remote_v, layer, kv_cache, + slot_mapping, start_pos, + end_pos) hidden_or_intermediate_states_for_one_req.append(hidden) diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py index 49b97d7b588978928d16a815d9b5f182fb0d59ce..0464a7585138f9ebb69ac25501850f553761a0b3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py @@ -12,10 +12,10 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch -import vllm.envs as envs -from vllm import _custom_ops as ops from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.kv_transfer.kv_connector.utils import ( + model_aware_kv_ops_helper as kv_helper) from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( SimpleBuffer) from vllm.logger import init_logger @@ -37,9 +37,7 @@ class SimpleConnector(KVConnectorBase): ): self.config = config.kv_transfer_config - self.tp_size = config.parallel_config.tensor_parallel_size - self.is_deepseek_mla = config.model_config.is_deepseek_mla - self.use_mla_opt = not envs.VLLM_MLA_DISABLE + self.kv_helper = kv_helper(config) if self.config.kv_connector == "PyNcclConnector": from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import ( @@ -165,31 +163,7 @@ class SimpleConnector(KVConnectorBase): num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens start_layer = model_executable.model.start_layer end_layer = model_executable.model.end_layer - - model_config = model_executable.model.config - num_heads = int(model_config.num_key_value_heads / self.tp_size) - hidden_size = model_config.hidden_size - num_attention_heads = model_config.num_attention_heads - - # Deepseek's MLA (Multi-head Latent Attention) uses two different - # kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0. - # When VLLM_MLA_DISABLE=0 (default), forward absorb is applied, - # resulting in a kv_cache shape of [num_blks, blk_size, 1, - # kv_lora_rank + qk_rope_head_dim]. - # When VLLM_MLA_DISABLE=1, standard FA is used instead, leading - # to a kv_cache shape of [2, num_blks, blk_size, - # num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim]. - # For more details, see vllm/attention/backends/mla/common.py. - if self.is_deepseek_mla and self.use_mla_opt: - head_size = model_config.kv_lora_rank + \ - model_config.qk_rope_head_dim - num_heads = 1 - elif self.is_deepseek_mla and not self.use_mla_opt: - head_size = model_config.qk_nope_head_dim + \ - model_config.qk_rope_head_dim - else: - head_size = getattr(model_config, "head_dim", - int(hidden_size // num_attention_heads)) + num_heads, head_size = self.kv_helper.get_model_args(model_executable) # query_lens contains new KV caches that are added to vLLM. # so we will send them to decode instance @@ -212,13 +186,8 @@ class SimpleConnector(KVConnectorBase): for layer_id in range(start_layer, end_layer): kv_cache = kv_caches[layer_id - start_layer] - - if self.is_deepseek_mla and self.use_mla_opt: - key_cache = kv_cache.reshape(-1, num_heads, head_size) - value_cache = kv_cache.reshape(-1, num_heads, head_size) - else: - key_cache = kv_cache[0].reshape(-1, num_heads, head_size) - value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + key_cache, value_cache = self.kv_helper.get_kv_from_cache( + kv_cache, num_heads, head_size) current_slot_mapping = slot_mapping_flat[start_pos:end_pos] @@ -248,12 +217,12 @@ class SimpleConnector(KVConnectorBase): # and hidden states. bypass_model_exec = True - model_config = model_executable.model.config - input_tokens_tensor = model_input.input_tokens seq_lens = model_input.attn_metadata.seq_lens num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens slot_mapping = model_input.attn_metadata.slot_mapping.flatten() + start_layer = model_executable.model.start_layer + end_layer = model_executable.model.end_layer hidden_or_intermediate_states_for_one_req = [] @@ -312,41 +281,19 @@ class SimpleConnector(KVConnectorBase): end_pos = start_pos + num_computed_tokens # put received KV caches into paged memory - for i in range(model_executable.model.start_layer, - model_executable.model.end_layer): - - kv_cache = kv_caches[i - model_executable.model.start_layer] - layer = model_executable.model.layers[i] - - if self.is_deepseek_mla and self.use_mla_opt: - layer.self_attn.attn = layer.self_attn.mla_attn - k_c_normed_k_pe = keys[ - i - model_executable.model.start_layer].to( - kv_cache.device).squeeze(1) - k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank] - k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:] - ops.concat_and_cache_mla( - k_c_normed, - k_pe, - kv_cache, - slot_mapping[start_pos:end_pos], - layer.self_attn.attn.kv_cache_dtype, - layer.self_attn.attn._k_scale, - ) - else: - key_cache, value_cache = kv_cache[0], kv_cache[1] - ops.reshape_and_cache_flash( - keys[i - model_executable.model.start_layer].to( - key_cache.device), - values[i - model_executable.model.start_layer].to( - value_cache.device), - key_cache, - value_cache, - slot_mapping[start_pos:end_pos], - layer.self_attn.attn.kv_cache_dtype, - layer.self_attn.attn._k_scale, - layer.self_attn.attn._v_scale, - ) + for cur_layer in range(start_layer, end_layer): + + layer_id = cur_layer - start_layer + kv_cache = kv_caches[layer_id] + layer = model_executable.model.layers[cur_layer] + + # get remote kvcache + remote_k, remote_v = keys[layer_id], values[layer_id] + + self.kv_helper.put_kv_to_cache(model_executable, remote_k, + remote_v, layer, kv_cache, + slot_mapping, start_pos, + end_pos) hidden_or_intermediate_states_for_one_req.append(hidden) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0b0ce9828a74d594136596ca82c8a053da7a21f5 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +KV cache helper for store. +""" +import torch + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.config import VllmConfig +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class model_aware_kv_ops_helper: + + def __init__(self, config: VllmConfig): + self.is_deepseek_mla = config.model_config.is_deepseek_mla + self.use_mla_opt = not envs.VLLM_MLA_DISABLE + self.tp_size = config.parallel_config.tensor_parallel_size + + def get_model_args(self, model_executable: torch.nn.Module): + + model_config = model_executable.model.config + self.model_executable = model_executable + num_heads = int(model_config.num_key_value_heads / self.tp_size) + hidden_size = model_config.hidden_size + num_attention_heads = model_config.num_attention_heads + + # Deepseek's MLA (Multi-head Latent Attention) uses two different + # kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0. + # When VLLM_MLA_DISABLE=0 (default), forward absorb is applied, + # resulting in a kv_cache shape of [num_blks, blk_size, 1, + # kv_lora_rank + qk_rope_head_dim]. + # When VLLM_MLA_DISABLE=1, standard FA is used instead, leading + # to a kv_cache shape of [2, num_blks, blk_size, + # num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim]. + # For more details, see vllm/attention/backends/mla/common.py. + if self.is_deepseek_mla and self.use_mla_opt: + head_size = model_config.kv_lora_rank + \ + model_config.qk_rope_head_dim + num_heads = 1 + elif self.is_deepseek_mla and not self.use_mla_opt: + head_size = model_config.qk_nope_head_dim + \ + model_config.qk_rope_head_dim + else: + head_size = getattr(model_config, "head_dim", + int(hidden_size // num_attention_heads)) + + return num_heads, head_size + + def get_kv_from_cache(self, kv_cache, num_heads, head_size): + if self.is_deepseek_mla and self.use_mla_opt: + key_cache = kv_cache.reshape(-1, num_heads, head_size) + value_cache = kv_cache.reshape(-1, num_heads, head_size) + else: + key_cache = kv_cache[0].reshape(-1, num_heads, head_size) + value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + return key_cache, value_cache + + def put_kv_to_cache(self, model_executable: torch.nn.Module, keys, values, + layer, kv_cache, slot_mapping, start_pos, end_pos): + + model_config = model_executable.model.config + + if self.is_deepseek_mla and self.use_mla_opt: + layer.self_attn.attn = layer.self_attn.mla_attn + k_c_normed_k_pe = keys.squeeze(1) + k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank] + k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:] + ops.concat_and_cache_mla( + k_c_normed.to(kv_cache.device), + k_pe.to(kv_cache.device), + kv_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + ) + else: + key_cache, value_cache = kv_cache[0], kv_cache[1] + ops.reshape_and_cache_flash( + keys.to(key_cache.device), + values.to(value_cache.device), + key_cache, + value_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + layer.self_attn.attn._v_scale, + ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py b/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a017b140e090284d18471f81503df40fbf986f08 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: Apache-2.0 +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorRole) + +__all__ = [ + "KVConnectorRole", + "KVConnectorBase_V1", +] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py new file mode 100644 index 0000000000000000000000000000000000000000..95967d2ca91933afa7702723be7d0612de146e31 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -0,0 +1,209 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +KVConnectorBase_V1 Class for Distributed KV Cache & Hidden State +communication in vLLM v1 + +The class provides the following primitives: + Scheduler-side: runs in the scheduler, binds metadata, which + is used by the worker-side to load/save KV cache. + get_num_new_matched_tokens() - get number of new tokens + that exist in the remote KV cache + update_state_after_alloc() - update KVConnector state after + temporary buffer alloc by the CacheManager. + + Worker-side: runs in each worker, loads/saves KV cache to/from + the Connector based on the metadata. + start_load_kv() - starts loading all KVs (maybe async) + wait_for_layer_load() - blocks until layer i load is done + + save_kv_layer() - starts saving KV for layer i (maybe async) + wait_for_save() - blocks until all saves are done +""" + +import enum +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import torch + +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.config import VllmConfig + from vllm.forward_context import ForwardContext + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +class KVConnectorRole(enum.Enum): + # Connector running in the scheduler process + SCHEDULER = 0 + + # Connector running in the worker process + WORKER = 1 + + +@dataclass +class KVConnectorMetadata: + pass + + +class KVConnectorBase_V1(ABC): + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + logger.warning( + "Initializing KVConnectorBase_V1. This API is experimental and " + "subject to change in the future as we iterate the design.") + self._connector_metadata = KVConnectorMetadata() + self._vllm_config = vllm_config + self._role = role + + @property + def role(self) -> KVConnectorRole: + return self._role + + def bind_connector_metadata( + self, connector_metadata: KVConnectorMetadata) -> None: + """Set the connector metadata from the scheduler. + + This function should be called by the model runner every time + before the model execution. The metadata will be used for runtime + KV cache loading and saving. + + Args: + connector_metadata (dict): the connector metadata. + """ + self._connector_metadata = connector_metadata + + def clear_connector_metadata(self) -> None: + """Clear the connector metadata. + + This function should be called by the model runner every time + after the model execution. + """ + self._connector_metadata = KVConnectorMetadata() + + def _get_connector_metadata(self) -> KVConnectorMetadata: + """Get the connector metadata. + + This function should only be called inside the connector. + + Returns: + ConnectorMetadata: the connector metadata. + """ + return self._connector_metadata + + # ============================== + # Worker-side methods + # ============================== + + @abstractmethod + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + """ + Start loading the KV cache from the connector to vLLM's paged + KV buffer. This is called from the forward context before the + forward pass to enable async loading during model execution. + + Args: + forward_context (ForwardContext): the forward context. + **kwargs: additional arguments for the load operation + + Note: + The number of elements in kv_caches and layer_names should be + the same. + + """ + pass + + @abstractmethod + def wait_for_layer_load(self, layer_name: str) -> None: + """ + Block until the KV for a specific layer is loaded into vLLM's + paged buffer. This is called from within attention layer to ensure + async copying from start_load_kv is complete. + + This interface will be useful for layer-by-layer pipelining. + + Args: + layer_name: the name of that layer + """ + pass + + @abstractmethod + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """ + Start saving a layer of KV cache from vLLM's paged buffer + to the connector. This is called from within attention layer to + enable async copying during execution. + + Args: + layer_name (str): the name of the layer. + kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. + """ + pass + + @abstractmethod + def wait_for_save(self): + """ + Block until all the save operations is done. This is called + as the forward context exits to ensure that the async saving + from save_kv_layer is complete before finishing the forward. + + This prevents overwrites of paged KV buffer before saving done. + """ + pass + + # ============================== + # Scheduler-side methods + # ============================== + @abstractmethod + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> int: + """ + Get number of new tokens that can be loaded from the + external KV cache beyond the num_computed_tokens. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + pass + + @abstractmethod + def update_state_after_alloc(self, request: "Request", + num_external_tokens: int): + """ + Update KVConnector state after block allocation. + """ + pass + + @abstractmethod + def build_connector_meta( + self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + """ + Build the connector metadata for this step. + + This function should NOT modify fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + pass diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py new file mode 100644 index 0000000000000000000000000000000000000000..e07f185f0dd8129ad4b7ef19ec2e72f1a242956c --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import TYPE_CHECKING + +import torch +from lmcache.integration.vllm.vllm_v1_adapter import LMCacheConnectorV1Impl + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +class LMCacheConnectorV1(KVConnectorBase_V1): + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + self._lmcache_engine = LMCacheConnectorV1Impl(vllm_config, role, self) + + # ============================== + # Worker-side methods + # ============================== + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + """ + Start loading the KV cache from the connector to vLLM's paged + KV buffer. This is called from the forward context before the + forward pass to enable async loading during model execution. + + Args: + forward_context (ForwardContext): the forward context. + **kwargs: additional arguments for the load operation + + Note: + The number of elements in kv_caches and layer_names should be + the same. + + """ + self._lmcache_engine.start_load_kv(forward_context, **kwargs) + + def wait_for_layer_load(self, layer_name: str) -> None: + """ + Block until the KV for a specific layer is loaded into vLLM's + paged buffer. This is called from within attention layer to ensure + async copying from start_load_kv is complete. + + This interface will be useful for layer-by-layer pipelining. + + Args: + layer_name: the name of that layer + """ + self._lmcache_engine.wait_for_layer_load(layer_name) + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """ + Start saving the a layer of KV cache from vLLM's paged buffer + to the connector. This is called from within attention layer to + enable async copying during execution. + + Args: + layer_name (str): the name of the layer. + kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. + """ + self._lmcache_engine.save_kv_layer(layer_name, kv_layer, attn_metadata, + **kwargs) + + def wait_for_save(self): + """ + Block until all the save operations is done. This is called + as the forward context exits to ensure that the async saving + from save_kv_layer is complete before finishing the forward. + + This prevents overwrites of paged KV buffer before saving done. + """ + self._lmcache_engine.wait_for_save() + + # ============================== + # Scheduler-side methods + # ============================== + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> int: + """ + Get number of new tokens that can be loaded from the + external KV cache beyond the num_computed_tokens. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + return self._lmcache_engine.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + num_external_tokens: int): + """ + Update KVConnector state after block allocation. + """ + self._lmcache_engine.update_state_after_alloc(request, + num_external_tokens) + + def build_connector_meta( + self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + """ + Build the connector metadata for this step. + + This function should NOT modify fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + return self._lmcache_engine.build_connector_meta(scheduler_output) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py new file mode 100644 index 0000000000000000000000000000000000000000..f91ffbc720e753d0bfb8a9e76a1282153f99192a --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -0,0 +1,383 @@ +# SPDX-License-Identifier: Apache-2.0 +import hashlib +import os +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import safetensors +import torch + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.logger import init_logger +from vllm.v1.attention.backends.mla.common import MLACommonMetadata +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +@dataclass +class ReqMeta: + # Request tokens + token_ids: torch.Tensor + # Slot mappings, should have the same length as token_ids + slot_mapping: torch.Tensor + # Is store or load + is_store: bool + + @staticmethod + def make_meta(token_ids: list[int], block_ids: list[int], block_size: int, + is_store: bool) -> "ReqMeta": + valid_num_tokens = align_to_block_size(len(token_ids), block_size) + token_ids_tensor = torch.tensor(token_ids)[:valid_num_tokens] + block_ids_tensor = torch.tensor(block_ids) + num_blocks = block_ids_tensor.shape[0] + block_offsets = torch.arange(0, block_size) + slot_mapping = block_offsets.reshape((1, block_size)) + \ + block_ids_tensor.reshape((num_blocks, 1)) * block_size + slot_mapping = slot_mapping.flatten()[:valid_num_tokens] + return ReqMeta( + token_ids=token_ids_tensor, + slot_mapping=slot_mapping, + is_store=is_store, + ) + + +@dataclass +class SharedStorageConnectorMetadata(KVConnectorMetadata): + requests: list[ReqMeta] + + def __init__(self): + self.requests = [] + + def add_request( + self, + token_ids: list[int], + block_ids: list[int], + block_size: int, + is_store: bool, + ) -> None: + self.requests.append( + ReqMeta.make_meta(token_ids, block_ids, block_size, is_store)) + + +class SharedStorageConnector(KVConnectorBase_V1): + # NOTE: This is Simple debug implementation of the KV connector. + # It save / load the KV cache to / from the disk. + # It does extra work which will overwrite the existing prefix-cache in GPU + # - to remove the overhead, need to add some "mask" in the ReqMeta class + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + self._block_size = vllm_config.cache_config.block_size + self._requests_need_load: dict[str, Request] = {} + transfer_config = vllm_config.kv_transfer_config + self._storage_path = transfer_config.get_from_extra_config( + "shared_storage_path", "/tmp") + logger.info(vllm_config.kv_transfer_config) + logger.info("Shared storage path is %s", self._storage_path) + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + """Start loading the KV cache from the connector buffer to vLLM's + paged KV buffer. + + Args: + forward_context (ForwardContext): the forward context. + **kwargs: additional arguments for the load operation + + Note: + The number of elements in kv_caches and layer_names should be + the same. + """ + attn_metadata = forward_context.attn_metadata + + def inject_kv_into_layer( + dst_kv_cache_layer: torch.Tensor, + src_kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> None: + """Inject the KV cache into the layer. + + Args: + dst_kv_cache_layer (torch.Tensor): the destination KV cache + layer. In shape [2, num_pages, page_size, xxx] if not + using MLA, [num_pages, page_size, xxx] otherwise. + src_kv_cache (torch.Tensor): the source KV cache. In shape + [2, num_tokens, xxx] if not using MLA, [num_tokens, xxx] + otherwise. + slot_mapping (torch.Tensor): the slot mapping. In shape + [num_tokens]. + """ + dst_kv_cache_layer_shape = dst_kv_cache_layer.shape + if isinstance(attn_metadata, MLACommonMetadata): + num_pages = dst_kv_cache_layer_shape[0] + page_size = dst_kv_cache_layer_shape[1] + dst_kv_cache_layer = dst_kv_cache_layer.reshape( + num_pages * page_size, -1) + dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache + dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) + else: + num_pages = dst_kv_cache_layer_shape[1] + page_size = dst_kv_cache_layer_shape[2] + dst_kv_cache_layer = dst_kv_cache_layer.reshape( + 2, num_pages * page_size, -1) + dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache + dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) + + # Get the metadata + metadata: KVConnectorMetadata = \ + self._get_connector_metadata() + assert isinstance(metadata, SharedStorageConnectorMetadata) + + if metadata is None: + logger.warning( + "In connector.start_load_kv, but the connector metadata is None" + ) + return + + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + logger.warning( + "In connector.start_load_kv, but the attn_metadata is None") + return + + # Load the KV for each request each layer + for request in metadata.requests: + if request.is_store: + continue + logger.info("Inject KV cache of %d tokens to the paged memory", + len(request.slot_mapping)) + for layer_name in forward_context.no_compile_layers: + attn_layer = forward_context.no_compile_layers[layer_name] + kv_cache_layer = attn_layer.kv_cache[\ + forward_context.virtual_engine] + + filename = self._generate_filename_debug( + layer_name, request.token_ids) + kv_cache = safetensors.torch.load_file( + filename)["kv_cache"].cuda() + inject_kv_into_layer(kv_cache_layer, kv_cache, + request.slot_mapping) + + def wait_for_layer_load(self, layer_name: str) -> None: + """Blocking until the KV for a specific layer is loaded into vLLM's + paged buffer. + + This interface will be useful for layer-by-layer pipelining. + + Args: + layer_name: the name of that layer + """ + return + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """Start saving the KV cache of the layer from vLLM's paged buffer + to the connector. + + Args: + layer_name (str): the name of the layer. + kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. + """ + + def extract_kv_from_layer( + layer: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> torch.Tensor: + """Extract the KV cache from the layer. + + Assume the shape of the layer is (2, num_pages, page_size, xxx) + if MLA is not used, and (num_pages, page_size, xxx) otherwise. + """ + if isinstance(attn_metadata, MLACommonMetadata): + num_pages, page_size = layer.shape[0], layer.shape[1] + return layer.reshape(num_pages * page_size, -1)[slot_mapping, + ...] + num_pages, page_size = layer.shape[1], layer.shape[2] + return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, + ...] + + connector_metadata = self._get_connector_metadata() + assert isinstance(connector_metadata, SharedStorageConnectorMetadata) + for request in connector_metadata.requests: + if request.is_store: + filename = self._generate_filename_debug( + layer_name, request.token_ids) + kv_cache = extract_kv_from_layer(kv_layer, + request.slot_mapping) + tensors = {"kv_cache": kv_cache.detach().cpu()} + safetensors.torch.save_file(tensors, filename) + + def wait_for_save(self): + return + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> int: + """ + Get number of new tokens that can be loaded from the + external KV cache beyond the num_computed_tokens. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + + # NOTE: in this debug implementation, we assume that the prompt is + # cached_prompt + newly_generated_single_token + # Therefore, we use prompt_token_ids[:-1] to determine the folder name + + # NOTE: in current v1 scheduler, the num_computed_tokens is aligned + # with the block granularity. And it expects the returned blocks and + # num_computed_tokens to also be aligned with the block granularity. + if not self._found_match_for_request(request): + return 0 + + logger.info("External Cache Hit!") + + # Now, first num_tokens_to_check tokens are hit, we need to prepare + # the metadata for the worker connector to correctly load the KV + num_tokens_to_check = align_to_block_size( + len(request.prompt_token_ids) - 1, self._block_size) + + return num_tokens_to_check - num_computed_tokens + + def update_state_after_alloc(self, request: "Request", + num_external_tokens: int): + """ + Update KVConnector state after block allocation. + + If blocks were allocated, add to _requests_need_load, + such that we load the KVs in the next forward pass. + """ + if num_external_tokens > 0: + self._requests_need_load[request.request_id] = request + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + """Build the connector metadata for this step. + + This function should NOT modify any fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + meta = SharedStorageConnectorMetadata() + + total_need_load = 0 + for new_req in scheduler_output.scheduled_new_reqs: + if new_req.req_id in self._requests_need_load: + meta.add_request(token_ids=new_req.prompt_token_ids, + block_ids=new_req.block_ids, + block_size=self._block_size, + is_store=False) + total_need_load += 1 + else: + # NOTE: here, we set the store and load being exclusive, + # but a single request can have both store and load. + # NOTE(rob): for this debug implementation, we only cache + # the original prompt tokens. + if not self._found_match_for_request(new_req): + meta.add_request(token_ids=new_req.prompt_token_ids, + block_ids=new_req.block_ids, + block_size=self._block_size, + is_store=True) + + for cached_req in scheduler_output.scheduled_cached_reqs: + # NOTE(rob): here we rely on the resumed requests being + # the first N requests in the list scheduled_cache_reqs. + if not cached_req.resumed_from_preemption: + break + if cached_req.req_id in self._requests_need_load: + # NOTE(rob): cached_req_data does not have the full + # list of token ids (only new tokens). So we look it + # up in the actual request object. + request = self._requests_need_load[cached_req.req_id] + total_tokens = (len(cached_req.new_token_ids) + + cached_req.num_computed_tokens) + token_ids = request.all_token_ids[:total_tokens] + + # NOTE(rob): For resumed req, new_block_ids is all + # of the block_ids for the request. + block_ids = cached_req.new_block_ids + + meta.add_request(token_ids=token_ids, + block_ids=block_ids, + block_size=self._block_size, + is_store=False) + total_need_load += 1 + + assert total_need_load == len(self._requests_need_load) + self._requests_need_load.clear() + return meta + + # ============================== + # Helper functions + # ============================== + + def _found_match_for_request( + self, + request: "Request", + ) -> bool: + """Check if the cache is hit for the request. + """ + num_tokens_to_check = align_to_block_size( + len(request.prompt_token_ids) - 1, self._block_size) + foldername = self._generate_foldername_debug(torch.tensor( + request.prompt_token_ids)[:num_tokens_to_check], + create_folder=False) + return os.path.exists(foldername) + + def _generate_foldername_debug( + self, + input_ids: torch.Tensor, + create_folder=False, + ) -> str: + """Generate a folder name based on the hash of the bytes of the input + ids. + """ + input_ids_bytes = input_ids.numpy().tobytes() + input_ids_hash = hashlib.md5(input_ids_bytes, + usedforsecurity=False).hexdigest() + foldername = os.path.join(self._storage_path, input_ids_hash) + if create_folder: + os.makedirs(foldername, exist_ok=True) + return foldername + + def _generate_filename_debug( + self, + layer_name: str, + input_ids: torch.Tensor, + ) -> str: + """Generate a file name based on the layer name and the hash + of the bytes of the input ids. + """ + foldername = self._generate_foldername_debug(input_ids, + create_folder=True) + return os.path.join(foldername, f"{layer_name}.safetensors") + + +def align_to_block_size(num_tokens: int, block_size) -> int: + """Align the number of tokens to the block size. + """ + return (num_tokens - 1) // block_size * block_size diff --git a/vllm/distributed/kv_transfer/kv_transfer_agent.py b/vllm/distributed/kv_transfer/kv_connector_agent.py similarity index 97% rename from vllm/distributed/kv_transfer/kv_transfer_agent.py rename to vllm/distributed/kv_transfer/kv_connector_agent.py index 1e80e0bd7de865b2c87828002d1be55bf373fb54..9d7145098105e5266b5bebefe429af3efe4e3b98 100644 --- a/vllm/distributed/kv_transfer/kv_transfer_agent.py +++ b/vllm/distributed/kv_transfer/kv_connector_agent.py @@ -46,7 +46,7 @@ class KVTransferAgent: assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\ "TransferAgent should only be used when kv_connector is set." - self.connector = KVConnectorFactory.create_connector( + self.connector = KVConnectorFactory.create_connector_v0( rank, local_rank, config) def send_kv_caches_and_hidden_states( diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py index 7fd5967293f2622941d06bb439b0fc341a092cf0..5bb7110216768dc0e55eac6e92bedb718f5f1707 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py @@ -70,7 +70,7 @@ class MooncakeStore(KVStoreBufferBase): ): try: - from mooncake_vllm_adaptor import MooncakeDistributedStore + from mooncake.store import MooncakeDistributedStore except ImportError as e: raise ImportError( "Please install mooncake by following the instructions at " diff --git a/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py index ec46d4045447259918258c2f11e8437fbabfca42..aa4b1ba71492cf49351fe22594d97a7160ae6db7 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py @@ -2,6 +2,7 @@ import json import os +import struct from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from typing import Optional, Union @@ -57,14 +58,14 @@ class MooncakeTransferEngine: def __init__(self, kv_rank: int, local_rank: int): try: - import mooncake_vllm_adaptor as mva + from mooncake.engine import TransferEngine except ImportError as e: raise ImportError( "Please install mooncake by following the instructions at " "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 "to run vLLM with MooncakeConnector.") from e - self.engine = mva.mooncake_vllm_adaptor() + self.engine = TransferEngine() self.local_rank = local_rank try: @@ -115,14 +116,14 @@ class MooncakeTransferEngine: p_rank_offset = int(p_port) + 8 + self.local_rank * 2 d_rank_offset = int(d_port) + 8 + self.local_rank * 2 if kv_rank == 0: - self.sender_socket.bind(f"tcp://*:{p_rank_offset + 1}") + self.sender_socket.bind(f"tcp://{p_host}:{p_rank_offset + 1}") self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}") self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}") - self.receiver_ack.bind(f"tcp://*:{p_rank_offset + 2}") + self.receiver_ack.bind(f"tcp://{p_host}:{p_rank_offset + 2}") else: self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}") - self.sender_socket.bind(f"tcp://*:{d_rank_offset + 1}") - self.receiver_ack.bind(f"tcp://*:{d_rank_offset + 2}") + self.sender_socket.bind(f"tcp://{d_host}:{d_rank_offset + 1}") + self.receiver_ack.bind(f"tcp://{d_host}:{d_rank_offset + 2}") self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}") def initialize(self, local_hostname: str, metadata_server: str, @@ -140,12 +141,12 @@ class MooncakeTransferEngine: "Mooncake Configuration error. `metadata_backend`" f" should be one of {supported_backend}.") - self.engine.initializeExt(local_hostname, metadata_server, - protocol, device_name, metadata_backend) + self.engine.initialize_ext(local_hostname, metadata_server, + protocol, device_name, metadata_backend) def allocate_managed_buffer(self, length: int) -> int: """Allocate a managed buffer of the specified length.""" - ret = self.engine.allocateManagedBuffer(length) + ret = self.engine.allocate_managed_buffer(length) if ret <= 0: logger.error("Allocation Return Error") raise Exception("Allocation Return Error") @@ -153,13 +154,13 @@ class MooncakeTransferEngine: def free_managed_buffer(self, buffer: int, length: int) -> int: """Free a previously allocated managed buffer.""" - return self.engine.freeManagedBuffer(buffer, length) + return self.engine.free_managed_buffer(buffer, length) def transfer_sync(self, buffer: int, peer_buffer_address: int, length: int) -> int: """Synchronously transfer data to the specified address.""" - ret = self.engine.transferSync(self.remote_url, buffer, - peer_buffer_address, length) + ret = self.engine.transfer_sync_read(self.remote_url, buffer, + peer_buffer_address, length) if ret < 0: logger.error("Transfer Return Error") raise Exception("Transfer Return Error") @@ -168,15 +169,15 @@ class MooncakeTransferEngine: def write_bytes_to_buffer(self, buffer: int, user_data: bytes, length: int) -> int: """Write bytes to the allocated buffer.""" - return self.engine.writeBytesToBuffer(buffer, user_data, length) + return self.engine.write_bytes_to_buffer(buffer, user_data, length) def read_bytes_from_buffer(self, buffer: int, length: int) -> bytes: """Read bytes from the allocated buffer.""" - return self.engine.readBytesFromBuffer(buffer, length) + return self.engine.read_bytes_from_buffer(buffer, length) def wait_for_ack(self, src_ptr: int, length: int) -> None: """Asynchronously wait for ACK from the receiver.""" - ack = self.sender_ack.recv_pyobj() + ack = self.sender_ack.recv() if ack != b'ACK': logger.error("Failed to receive ACK from the receiver") @@ -187,18 +188,22 @@ class MooncakeTransferEngine: length = len(user_data) src_ptr = self.allocate_managed_buffer(length) self.write_bytes_to_buffer(src_ptr, user_data, length) - self.sender_socket.send_pyobj((src_ptr, length)) + self.sender_socket.send_multipart( + [struct.pack("!Q", src_ptr), + struct.pack("!Q", length)]) self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length) def recv_bytes(self) -> bytes: """Receive bytes from the remote process.""" - src_ptr, length = self.receiver_socket.recv_pyobj() + data = self.receiver_socket.recv_multipart() + src_ptr = struct.unpack("!Q", data[0])[0] + length = struct.unpack("!Q", data[1])[0] dst_ptr = self.allocate_managed_buffer(length) self.transfer_sync(dst_ptr, src_ptr, length) ret = self.read_bytes_from_buffer(dst_ptr, length) # Buffer cleanup - self.receiver_ack.send_pyobj(b'ACK') + self.receiver_ack.send(b'ACK') self.free_managed_buffer(dst_ptr, length) return ret diff --git a/vllm/distributed/kv_transfer/kv_transfer_state.py b/vllm/distributed/kv_transfer/kv_transfer_state.py new file mode 100644 index 0000000000000000000000000000000000000000..25d2f2cf5c6e6413a87eee48766a8b5dc4801852 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_transfer_state.py @@ -0,0 +1,70 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import TYPE_CHECKING, Optional + +from vllm import envs +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) +from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, + KVConnectorRole) +from vllm.distributed.parallel_state import get_world_group + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +_KV_CONNECTOR_AGENT: Optional[KVConnectorBaseType] = None + + +def get_kv_transfer_group() -> KVConnectorBaseType: + assert _KV_CONNECTOR_AGENT is not None, ( + "disaggregated KV cache transfer parallel group is not initialized") + return _KV_CONNECTOR_AGENT + + +def has_kv_transfer_group() -> bool: + return _KV_CONNECTOR_AGENT is not None + + +def is_v1_kv_transfer_group( + connector: Optional[KVConnectorBaseType] = None) -> bool: + """Check if the KV connector is the v1 connector. + If the argument is None, it will check the global KV connector + + Args: + connector: The KV connector to check. If None, it will check the + global KV connector. + + Note: + This function will no-longer be needed after the v1 KV connector + becomes the default. + """ + if connector is None: + connector = _KV_CONNECTOR_AGENT + + if connector is None: + return False + + return isinstance(connector, KVConnectorBase_V1) + + +def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: + """ + Initialize KV cache transfer parallel group. + """ + + global _KV_CONNECTOR_AGENT + + if vllm_config.kv_transfer_config is None: + return + + if (vllm_config.kv_transfer_config.is_kv_transfer_instance + and _KV_CONNECTOR_AGENT is None): + if envs.VLLM_USE_V1: + _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v1( + config=vllm_config, role=KVConnectorRole.WORKER) + else: + _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v0( + rank=get_world_group().rank, + local_rank=get_world_group().local_rank, + config=vllm_config, + ) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index e0eeeffb88a70eabca952ddc71a723629e910b3b..cb9658ce10043077fc46b3bbaa50db6c46c58c64 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -29,15 +29,13 @@ from collections import namedtuple from contextlib import contextmanager, nullcontext from dataclasses import dataclass from multiprocessing import shared_memory -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, - Union) +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from unittest.mock import patch import torch import torch.distributed from torch.distributed import Backend, ProcessGroup -import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer import vllm.envs as envs from vllm.distributed.device_communicators.base_device_communicator import ( DeviceCommunicatorBase) @@ -46,9 +44,6 @@ from vllm.logger import init_logger from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname, supports_custom_op) -if TYPE_CHECKING: - from vllm.config import VllmConfig - @dataclass class GraphCaptureContext: @@ -118,6 +113,38 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor: return torch.empty_like(tensor) +def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int, + group_name: str) -> torch.Tensor: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + return group.reduce_scatter(tensor, dim) + + +def reduce_scatter_fake(tensor: torch.Tensor, dim: int, world_size: int, + group_name: str) -> torch.Tensor: + new_shape = list(tensor.shape) + new_shape[dim] = tensor.shape[dim] // world_size + return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device) + + +def all_gather(tensor: torch.Tensor, dim: int, world_size: int, + group_name: str) -> torch.Tensor: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + return group.all_gather(tensor, dim) + + +def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int, + group_name: str) -> torch.Tensor: + new_shape = list(tensor.shape) + new_shape[dim] = tensor.shape[dim] * world_size + return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device) + + if supports_custom_op(): from vllm.platforms import current_platform direct_register_custom_op( @@ -128,6 +155,20 @@ if supports_custom_op(): dispatch_key=current_platform.dispatch_key, ) + direct_register_custom_op( + op_name="reduce_scatter", + op_func=reduce_scatter, + mutates_args=[], + fake_impl=reduce_scatter_fake, + ) + + direct_register_custom_op( + op_name="all_gather", + op_func=all_gather, + mutates_args=[], + fake_impl=all_gather_fake, + ) + class GroupCoordinator: """ @@ -327,6 +368,18 @@ class GroupCoordinator: return self.device_communicator.all_gather(input_, dim) + def reduce_scatter(self, + input_: torch.Tensor, + dim: int = -1) -> torch.Tensor: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + + return self.device_communicator.reduce_scatter(input_, dim) + def gather(self, input_: torch.Tensor, dst: int = 0, @@ -772,14 +825,6 @@ def get_pp_group() -> GroupCoordinator: # kept for backward compatibility get_pipeline_model_parallel_group = get_pp_group -_KV_TRANSFER: Optional[kv_transfer.KVTransferAgent] = None - - -def get_kv_transfer_group() -> kv_transfer.KVTransferAgent: - assert _KV_TRANSFER is not None, ( - "disaggregated KV cache transfer parallel group is not initialized") - return _KV_TRANSFER - @contextmanager def graph_capture(device: torch.device): @@ -962,26 +1007,6 @@ def initialize_model_parallel( _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group) -def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: - """ - Initialize KV cache transfer parallel group. - """ - - global _KV_TRANSFER - - if vllm_config.kv_transfer_config is None: - return - - if all([ - vllm_config.kv_transfer_config.is_kv_transfer_instance, - _KV_TRANSFER is None - ]): - _KV_TRANSFER = kv_transfer.KVTransferAgent( - rank=get_world_group().rank, - local_rank=get_world_group().local_rank, - config=vllm_config) - - def ensure_model_parallel_initialized( tensor_model_parallel_size: int, pipeline_model_parallel_size: int, diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 2cb57afd45664bcadfa238edddc4d1abce7c8e8d..e4d4008cd0a688cafa532ef897db1f96d8347c95 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -7,6 +7,7 @@ import dataclasses import datetime import pickle +import socket import time from collections import deque from typing import Any, Deque, Dict, Optional, Sequence, Tuple @@ -123,6 +124,10 @@ class StatelessProcessGroup: rank: int world_size: int store: torch._C._distributed_c10d.Store + + # stores a reference to the socket so that the file descriptor stays alive + socket: Optional[socket.socket] + data_expiration_seconds: int = 3600 # 1 hour # dst rank -> counter @@ -234,18 +239,33 @@ class StatelessProcessGroup: can call `StatelessProcessGroup.create` to form a group, and then process A, B, C, and D can call `StatelessProcessGroup.create` to form another group. """ # noqa + launch_server = rank == 0 + if launch_server: + # listen on the specified interface (instead of 0.0.0.0) + listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + listen_socket.bind((host, port)) + listen_socket.listen() + listen_fd = listen_socket.fileno() + else: + listen_socket = None + listen_fd = None + store = TCPStore( host_name=host, port=port, world_size=world_size, - is_master=(rank == 0), + is_master=launch_server, timeout=datetime.timedelta(seconds=store_timeout), + use_libuv=False, # for now: github.com/pytorch/pytorch/pull/150215 + master_listen_fd=listen_fd, ) return StatelessProcessGroup( rank=rank, world_size=world_size, store=store, + socket=listen_socket, data_expiration_seconds=data_expiration_seconds) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 1297d5d17a1f89b657e4f52ced8f9f51d3ca8683..7d616c37d3d784f337b6aa61985ecef48dfa2d4d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1,25 +1,31 @@ # SPDX-License-Identifier: Apache-2.0 +# yapf: disable import argparse import dataclasses import json import re import threading from dataclasses import MISSING, dataclass, fields -from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, - Tuple, Type, Union, cast, get_args, get_origin) +from typing import (Any, Callable, Dict, List, Literal, Optional, Type, + TypeVar, Union, cast, get_args, get_origin) import torch +from typing_extensions import TypeIs, deprecated import vllm.envs as envs from vllm import version -from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat, - DecodingConfig, DeviceConfig, HfOverrides, +from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, + ConfigFormat, ConfigType, DecodingConfig, Device, + DeviceConfig, DistributedExecutorBackend, + GuidedDecodingBackendV1, HfOverrides, KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig, - ModelConfig, ModelImpl, ObservabilityConfig, - ParallelConfig, PoolerConfig, PromptAdapterConfig, - SchedulerConfig, SpeculativeConfig, TaskOption, - TokenizerPoolConfig, VllmConfig, get_attr_docs) + ModelConfig, ModelImpl, MultiModalConfig, + ObservabilityConfig, ParallelConfig, PoolerConfig, + PrefixCachingHashAlgo, PromptAdapterConfig, + SchedulerConfig, SchedulerPolicy, SpeculativeConfig, + TaskOption, TokenizerPoolConfig, VllmConfig, + get_attr_docs, get_field) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS @@ -28,33 +34,42 @@ from vllm.reasoning import ReasoningParserManager from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 from vllm.transformers_utils.utils import check_gguf_file from vllm.usage.usage_lib import UsageContext -from vllm.utils import FlexibleArgumentParser, StoreBoolean, is_in_ray_actor +from vllm.utils import FlexibleArgumentParser, GiB_bytes, is_in_ray_actor -if TYPE_CHECKING: - from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup +# yapf: enable logger = init_logger(__name__) ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"] -DEVICE_OPTIONS = [ - "auto", - "cuda", - "neuron", - "cpu", - "tpu", - "xpu", - "hpu", -] +# object is used to allow for special typing forms +T = TypeVar("T") +TypeHint = Union[type[Any], object] +TypeHintT = Union[type[T], object] -def nullable_str(val: str): - if not val or val == "None": - return None - return val +def optional_type( + return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]: + + def _optional_type(val: str) -> Optional[T]: + if val == "" or val == "None": + return None + try: + if return_type is json.loads and not re.match("^{.*}$", val): + return cast(T, nullable_kvs(val)) + return return_type(val) + except ValueError as e: + raise argparse.ArgumentTypeError( + f"Value {val} cannot be converted to {return_type}.") from e + + return _optional_type -def nullable_kvs(val: str) -> Optional[Mapping[str, int]]: +@deprecated( + "Passing a JSON argument as a string containing comma separated key=value " + "pairs is deprecated. This will be removed in v0.10.0. Please use a JSON " + "string instead.") +def nullable_kvs(val: str) -> dict[str, int]: """Parses a string containing comma separate key [str] to value [int] pairs into a dictionary. @@ -64,10 +79,7 @@ def nullable_kvs(val: str) -> Optional[Mapping[str, int]]: Returns: Dictionary with parsed values. """ - if len(val) == 0: - return None - - out_dict: Dict[str, int] = {} + out_dict: dict[str, int] = {} for item in val.split(","): kv_parts = [part.lower().strip() for part in item.split("=")] if len(kv_parts) != 2: @@ -89,6 +101,105 @@ def nullable_kvs(val: str) -> Optional[Mapping[str, int]]: return out_dict +def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]: + """Check if the type hint is a specific type.""" + return type_hint is type or get_origin(type_hint) is type + + +def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool: + """Check if the type hints contain a specific type.""" + return any(is_type(type_hint, type) for type_hint in type_hints) + + +def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT: + """Get the specific type from the type hints.""" + return next((th for th in type_hints if is_type(th, type)), None) + + +def is_not_builtin(type_hint: TypeHint) -> bool: + """Check if the class is not a built-in type.""" + return type_hint.__module__ != "builtins" + + +def get_kwargs(cls: ConfigType) -> dict[str, Any]: + cls_docs = get_attr_docs(cls) + kwargs = {} + for field in fields(cls): + # Get the default value of the field + default = field.default + if field.default_factory is not MISSING: + default = field.default_factory() + + # Get the help text for the field + name = field.name + help = cls_docs[name] + # Escape % for argparse + help = help.replace("%", "%%") + + # Initialise the kwargs dictionary for the field + kwargs[name] = {"default": default, "help": help} + + # Get the set of possible types for the field + type_hints: set[TypeHint] = set() + if get_origin(field.type) is Union: + type_hints.update(get_args(field.type)) + else: + type_hints.add(field.type) + + # Set other kwargs based on the type hints + if contains_type(type_hints, bool): + # Creates --no- and -- flags + kwargs[name]["action"] = argparse.BooleanOptionalAction + elif contains_type(type_hints, Literal): + # Creates choices from Literal arguments + type_hint = get_type(type_hints, Literal) + choices = sorted(get_args(type_hint)) + kwargs[name]["choices"] = choices + choice_type = type(choices[0]) + assert all(type(c) is choice_type for c in choices), ( + "All choices must be of the same type. " + f"Got {choices} with types {[type(c) for c in choices]}") + kwargs[name]["type"] = choice_type + elif contains_type(type_hints, tuple): + type_hint = get_type(type_hints, tuple) + types = get_args(type_hint) + tuple_type = types[0] + assert all(t is tuple_type for t in types if t is not Ellipsis), ( + "All non-Ellipsis tuple elements must be of the same " + f"type. Got {types}.") + kwargs[name]["type"] = tuple_type + kwargs[name]["nargs"] = "+" if Ellipsis in types else len(types) + elif contains_type(type_hints, list): + type_hint = get_type(type_hints, list) + types = get_args(type_hint) + assert len(types) == 1, ( + "List type must have exactly one type. Got " + f"{type_hint} with types {types}") + kwargs[name]["type"] = types[0] + kwargs[name]["nargs"] = "+" + elif contains_type(type_hints, int): + kwargs[name]["type"] = int + elif contains_type(type_hints, float): + kwargs[name]["type"] = float + elif contains_type(type_hints, dict): + # Dict arguments will always be optional + kwargs[name]["type"] = optional_type(json.loads) + elif (contains_type(type_hints, str) + or any(is_not_builtin(th) for th in type_hints)): + kwargs[name]["type"] = str + else: + raise ValueError( + f"Unsupported type {type_hints} for argument {name}.") + + # If None is in type_hints, make the argument optional. + # But not if it's a bool, argparse will handle this better. + if type(None) in type_hints and not contains_type(type_hints, bool): + kwargs[name]["type"] = optional_type(kwargs[name]["type"]) + if kwargs[name].get("choices"): + kwargs[name]["choices"].append("None") + return kwargs + + @dataclass class EngineArgs: """Arguments for vLLM engine.""" @@ -105,14 +216,15 @@ class EngineArgs: load_format: str = LoadConfig.load_format config_format: ConfigFormat = ConfigFormat.AUTO dtype: str = 'auto' - kv_cache_dtype: str = 'auto' + kv_cache_dtype: CacheDType = CacheConfig.cache_dtype seed: Optional[int] = None max_model_len: Optional[int] = None # Note: Specifying a custom executor backend by passing a class # is intended for expert use only. The API may change without # notice. distributed_executor_backend: Optional[Union[ - str, Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend + DistributedExecutorBackend, + Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend # number of P/D disaggregation (or other disaggregation) workers pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size tensor_parallel_size: int = ParallelConfig.tensor_parallel_size @@ -120,20 +232,23 @@ class EngineArgs: enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel max_parallel_loading_workers: Optional[ int] = ParallelConfig.max_parallel_loading_workers - block_size: Optional[int] = None - enable_prefix_caching: Optional[bool] = None - prefix_caching_hash_algo: str = "builtin" + block_size: Optional[BlockSize] = CacheConfig.block_size + enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching + prefix_caching_hash_algo: PrefixCachingHashAlgo = \ + CacheConfig.prefix_caching_hash_algo disable_sliding_window: bool = False disable_cascade_attn: bool = False use_v2_block_manager: bool = True - swap_space: float = 4 # GiB - cpu_offload_gb: float = 0 # GiB - gpu_memory_utilization: float = 0.90 - max_num_batched_tokens: Optional[int] = None - max_num_partial_prefills: Optional[int] = 1 - max_long_partial_prefills: Optional[int] = 1 - long_prefill_token_threshold: Optional[int] = 0 - max_num_seqs: Optional[int] = None + swap_space: float = CacheConfig.swap_space + cpu_offload_gb: float = CacheConfig.cpu_offload_gb + gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization + max_num_batched_tokens: Optional[ + int] = SchedulerConfig.max_num_batched_tokens + max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills + max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills + long_prefill_token_threshold: int = \ + SchedulerConfig.long_prefill_token_threshold + max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs max_logprobs: int = 20 # Default value for OpenAI Chat Completions API disable_log_stats: bool = False revision: Optional[str] = None @@ -147,44 +262,52 @@ class EngineArgs: enforce_eager: Optional[bool] = None max_seq_len_to_capture: int = 8192 disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce - tokenizer_pool_size: int = 0 - # Note: Specifying a tokenizer pool by passing a class - # is intended for expert use only. The API may change without - # notice. - tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray" - tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None - limit_mm_per_prompt: Optional[Mapping[str, int]] = None + # The following three fields are deprecated and will be removed in a future + # release. Setting them will have no effect. Please remove them from your + # configurations. + tokenizer_pool_size: int = TokenizerPoolConfig.pool_size + tokenizer_pool_type: str = TokenizerPoolConfig.pool_type + tokenizer_pool_extra_config: dict = \ + get_field(TokenizerPoolConfig, "extra_config") + limit_mm_per_prompt: dict[str, int] = \ + get_field(MultiModalConfig, "limit_per_prompt") mm_processor_kwargs: Optional[Dict[str, Any]] = None disable_mm_preprocessor_cache: bool = False + # LoRA fields enable_lora: bool = False - enable_lora_bias: bool = False - max_loras: int = 1 - max_lora_rank: int = 16 - enable_prompt_adapter: bool = False - max_prompt_adapters: int = 1 - max_prompt_adapter_token: int = 0 - fully_sharded_loras: bool = False - lora_extra_vocab_size: int = 256 - long_lora_scaling_factors: Optional[Tuple[float]] = None - lora_dtype: Optional[Union[str, torch.dtype]] = 'auto' - max_cpu_loras: Optional[int] = None + enable_lora_bias: bool = LoRAConfig.bias_enabled + max_loras: int = LoRAConfig.max_loras + max_lora_rank: int = LoRAConfig.max_lora_rank + fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras + max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras merge_lora: bool = False - lora_target_modules: Optional[List[str]] = None - device: str = 'auto' - num_scheduler_steps: int = 1 - multi_step_stream_outputs: bool = True + lora_target_modules: Optional[List[str]] = LoRAConfig.lora_target_modules + lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype + lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size + long_lora_scaling_factors: Optional[tuple[float, ...]] = \ + LoRAConfig.long_lora_scaling_factors + # PromptAdapter fields + enable_prompt_adapter: bool = False + max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters + max_prompt_adapter_token: int = \ + PromptAdapterConfig.max_prompt_adapter_token + device: Device = DeviceConfig.device + num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps + multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight - num_gpu_blocks_override: Optional[int] = None - num_lookahead_slots: int = 0 - model_loader_extra_config: Optional[ - dict] = LoadConfig.model_loader_extra_config + num_gpu_blocks_override: Optional[ + int] = CacheConfig.num_gpu_blocks_override + num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots + model_loader_extra_config: dict = \ + get_field(LoadConfig, "model_loader_extra_config") ignore_patterns: Optional[Union[str, List[str]]] = LoadConfig.ignore_patterns - preemption_mode: Optional[str] = None + preemption_mode: Optional[str] = SchedulerConfig.preemption_mode - scheduler_delay_factor: float = 0.0 - enable_chunked_prefill: Optional[bool] = None - disable_chunked_mm_input: bool = False + scheduler_delay_factor: float = SchedulerConfig.delay_factor + enable_chunked_prefill: Optional[ + bool] = SchedulerConfig.enable_chunked_prefill + disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input guided_decoding_backend: str = DecodingConfig.guided_decoding_backend logits_processor_pattern: Optional[str] = None @@ -197,8 +320,8 @@ class EngineArgs: otlp_traces_endpoint: Optional[str] = None collect_detailed_traces: Optional[str] = None disable_async_output_proc: bool = False - scheduling_policy: Literal["fcfs", "priority"] = "fcfs" - scheduler_cls: Union[str, Type[object]] = "vllm.core.scheduler.Scheduler" + scheduling_policy: SchedulerPolicy = SchedulerConfig.policy + scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls override_neuron_config: Optional[Dict[str, Any]] = None override_pooler_config: Optional[PoolerConfig] = None @@ -213,11 +336,11 @@ class EngineArgs: enable_sleep_mode: bool = False model_impl: str = "auto" - calculate_kv_scales: Optional[bool] = None + calculate_kv_scales: bool = CacheConfig.calculate_kv_scales additional_config: Optional[Dict[str, Any]] = None enable_reasoning: Optional[bool] = None - reasoning_parser: Optional[str] = None + reasoning_parser: Optional[str] = DecodingConfig.reasoning_backend use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load @@ -240,38 +363,6 @@ class EngineArgs: def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: """Shared CLI arguments for vLLM engine.""" - def is_type_in_union(cls: type[Any], type: type[Any]) -> bool: - """Check if the class is a type in a union type.""" - return get_origin(cls) is Union and type in get_args(cls) - - def is_optional(cls: type[Any]) -> bool: - """Check if the class is an optional type.""" - return is_type_in_union(cls, type(None)) - - def get_kwargs(cls: type[Any]) -> Dict[str, Any]: - cls_docs = get_attr_docs(cls) - kwargs = {} - for field in fields(cls): - name = field.name - # One of these will always be present - default = (field.default_factory - if field.default is MISSING else field.default) - kwargs[name] = {"default": default, "help": cls_docs[name]} - # When using action="store_true" - # add_argument doesn't accept type - if field.type is bool: - continue - # Handle optional fields - if is_optional(field.type): - kwargs[name]["type"] = nullable_str - continue - # Handle str in union fields - if is_type_in_union(field.type, str): - kwargs[name]["type"] = str - continue - kwargs[name]["type"] = field.type - return kwargs - # Model arguments parser.add_argument( '--model', @@ -289,13 +380,13 @@ class EngineArgs: 'which task to use.') parser.add_argument( '--tokenizer', - type=nullable_str, + type=optional_type(str), default=EngineArgs.tokenizer, help='Name or path of the huggingface tokenizer to use. ' 'If unspecified, model name or path will be used.') parser.add_argument( "--hf-config-path", - type=nullable_str, + type=optional_type(str), default=EngineArgs.hf_config_path, help='Name or path of the huggingface config to use. ' 'If unspecified, model name or path will be used.') @@ -307,21 +398,21 @@ class EngineArgs: 'the input. The generated output will contain token ids.') parser.add_argument( '--revision', - type=nullable_str, + type=optional_type(str), default=None, help='The specific model version to use. It can be a branch ' 'name, a tag name, or a commit id. If unspecified, will use ' 'the default version.') parser.add_argument( '--code-revision', - type=nullable_str, + type=optional_type(str), default=None, help='The specific revision to use for the model code on ' 'Hugging Face Hub. It can be a branch name, a tag name, or a ' 'commit id. If unspecified, will use the default version.') parser.add_argument( '--tokenizer-revision', - type=nullable_str, + type=optional_type(str), default=None, help='Revision of the huggingface tokenizer to use. ' 'It can be a branch name, a tag name, or a commit id. ' @@ -361,7 +452,6 @@ class EngineArgs: load_group.add_argument('--model-loader-extra-config', **load_kwargs["model_loader_extra_config"]) load_group.add_argument('--use-tqdm-on-load', - action=argparse.BooleanOptionalAction, **load_kwargs["use_tqdm_on_load"]) parser.add_argument( @@ -386,14 +476,6 @@ class EngineArgs: '* "bfloat16" for a balance between precision and range.\n' '* "float" is shorthand for FP32 precision.\n' '* "float32" for FP32 precision.') - parser.add_argument( - '--kv-cache-dtype', - type=str, - choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], - default=EngineArgs.kv_cache_dtype, - help='Data type for kv cache storage. If "auto", will use model ' - 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' - 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') parser.add_argument('--max-model-len', type=human_readable_int, default=EngineArgs.max_model_len, @@ -403,21 +485,25 @@ class EngineArgs: 'Examples:\n' '- 1k → 1000\n' '- 1K → 1024\n') - parser.add_argument( + + # Guided decoding arguments + guided_decoding_kwargs = get_kwargs(DecodingConfig) + guided_decoding_group = parser.add_argument_group( + title="DecodingConfig", + description=DecodingConfig.__doc__, + ) + guided_decoding_group.add_argument( '--guided-decoding-backend', - type=str, - default=DecodingConfig.guided_decoding_backend, - help='Which engine will be used for guided decoding' - ' (JSON schema / regex etc) by default. Currently support ' - 'https://github.com/mlc-ai/xgrammar and ' - 'https://github.com/guidance-ai/llguidance.' - 'Valid backend values are "xgrammar", "guidance", and "auto". ' - 'With "auto", we will make opinionated choices based on request ' - 'contents and what the backend libraries currently support, so ' - 'the behavior is subject to change in each release.') + **guided_decoding_kwargs["guided_decoding_backend"]) + guided_decoding_group.add_argument( + "--reasoning-parser", + # This choices is a special case because it's not static + choices=list(ReasoningParserManager.reasoning_parsers), + **guided_decoding_kwargs["reasoning_backend"]) + parser.add_argument( '--logits-processor-pattern', - type=nullable_str, + type=optional_type(str), default=None, help='Optional regex pattern specifying valid logits processor ' 'qualified names that can be passed with the `logits_processors` ' @@ -443,7 +529,6 @@ class EngineArgs: ) parallel_group.add_argument( '--distributed-executor-backend', - choices=['ray', 'mp', 'uni', 'external_launcher'], **parallel_kwargs["distributed_executor_backend"]) parallel_group.add_argument( '--pipeline-parallel-size', '-pp', @@ -454,46 +539,40 @@ class EngineArgs: **parallel_kwargs["data_parallel_size"]) parallel_group.add_argument( '--enable-expert-parallel', - action='store_true', **parallel_kwargs["enable_expert_parallel"]) parallel_group.add_argument( '--max-parallel-loading-workers', **parallel_kwargs["max_parallel_loading_workers"]) parallel_group.add_argument( '--ray-workers-use-nsight', - action='store_true', **parallel_kwargs["ray_workers_use_nsight"]) parallel_group.add_argument( '--disable-custom-all-reduce', - action='store_true', **parallel_kwargs["disable_custom_all_reduce"]) - # KV cache arguments - parser.add_argument('--block-size', - type=int, - default=EngineArgs.block_size, - choices=[8, 16, 32, 64, 128], - help='Token block size for contiguous chunks of ' - 'tokens. This is ignored on neuron devices and ' - 'set to ``--max-model-len``. On CUDA devices, ' - 'only block sizes up to 32 are supported. ' - 'On HPU devices, block size defaults to 128.') - parser.add_argument( - "--enable-prefix-caching", - action=argparse.BooleanOptionalAction, - default=EngineArgs.enable_prefix_caching, - help="Enables automatic prefix caching. " - "Use ``--no-enable-prefix-caching`` to disable explicitly.", - ) - parser.add_argument( - "--prefix-caching-hash-algo", - type=str, - choices=["builtin", "sha256"], - default=EngineArgs.prefix_caching_hash_algo, - help="Set the hash algorithm for prefix caching. " - "Options are 'builtin' (Python's built-in hash) or 'sha256' " - "(collision resistant but with certain overheads).", + # KV cache arguments + cache_kwargs = get_kwargs(CacheConfig) + cache_group = parser.add_argument_group( + title="CacheConfig", + description=CacheConfig.__doc__, ) + cache_group.add_argument('--block-size', **cache_kwargs["block_size"]) + cache_group.add_argument('--gpu-memory-utilization', + **cache_kwargs["gpu_memory_utilization"]) + cache_group.add_argument('--swap-space', **cache_kwargs["swap_space"]) + cache_group.add_argument('--kv-cache-dtype', + **cache_kwargs["cache_dtype"]) + cache_group.add_argument('--num-gpu-blocks-override', + **cache_kwargs["num_gpu_blocks_override"]) + cache_group.add_argument("--enable-prefix-caching", + **cache_kwargs["enable_prefix_caching"]) + cache_group.add_argument("--prefix-caching-hash-algo", + **cache_kwargs["prefix_caching_hash_algo"]) + cache_group.add_argument('--cpu-offload-gb', + **cache_kwargs["cpu_offload_gb"]) + cache_group.add_argument('--calculate-kv-scales', + **cache_kwargs["calculate_kv_scales"]) + parser.add_argument('--disable-sliding-window', action='store_true', help='Disables sliding window, ' @@ -506,86 +585,11 @@ class EngineArgs: 'block manager v2) is now the default. ' 'Setting this flag to True or False' ' has no effect on vLLM behavior.') - parser.add_argument( - '--num-lookahead-slots', - type=int, - default=EngineArgs.num_lookahead_slots, - help='Experimental scheduling config necessary for ' - 'speculative decoding. This will be replaced by ' - 'speculative config in the future; it is present ' - 'to enable correctness tests until then.') parser.add_argument('--seed', type=int, default=EngineArgs.seed, help='Random seed for operations.') - parser.add_argument('--swap-space', - type=float, - default=EngineArgs.swap_space, - help='CPU swap space size (GiB) per GPU.') - parser.add_argument( - '--cpu-offload-gb', - type=float, - default=0, - help='The space in GiB to offload to CPU, per GPU. ' - 'Default is 0, which means no offloading. Intuitively, ' - 'this argument can be seen as a virtual way to increase ' - 'the GPU memory size. For example, if you have one 24 GB ' - 'GPU and set this to 10, virtually you can think of it as ' - 'a 34 GB GPU. Then you can load a 13B model with BF16 weight, ' - 'which requires at least 26GB GPU memory. Note that this ' - 'requires fast CPU-GPU interconnect, as part of the model is ' - 'loaded from CPU memory to GPU memory on the fly in each ' - 'model forward pass.') - parser.add_argument( - '--gpu-memory-utilization', - type=float, - default=EngineArgs.gpu_memory_utilization, - help='The fraction of GPU memory to be used for the model ' - 'executor, which can range from 0 to 1. For example, a value of ' - '0.5 would imply 50%% GPU memory utilization. If unspecified, ' - 'will use the default value of 0.9. This is a per-instance ' - 'limit, and only applies to the current vLLM instance.' - 'It does not matter if you have another vLLM instance running ' - 'on the same GPU. For example, if you have two vLLM instances ' - 'running on the same GPU, you can set the GPU memory utilization ' - 'to 0.5 for each instance.') - parser.add_argument( - '--num-gpu-blocks-override', - type=int, - default=None, - help='If specified, ignore GPU profiling result and use this number' - ' of GPU blocks. Used for testing preemption.') - parser.add_argument('--max-num-batched-tokens', - type=int, - default=EngineArgs.max_num_batched_tokens, - help='Maximum number of batched tokens per ' - 'iteration.') - parser.add_argument( - "--max-num-partial-prefills", - type=int, - default=EngineArgs.max_num_partial_prefills, - help="For chunked prefill, the max number of concurrent \ - partial prefills.") - parser.add_argument( - "--max-long-partial-prefills", - type=int, - default=EngineArgs.max_long_partial_prefills, - help="For chunked prefill, the maximum number of prompts longer " - "than --long-prefill-token-threshold that will be prefilled " - "concurrently. Setting this less than --max-num-partial-prefills " - "will allow shorter prompts to jump the queue in front of longer " - "prompts in some cases, improving latency.") - parser.add_argument( - "--long-prefill-token-threshold", - type=float, - default=EngineArgs.long_prefill_token_threshold, - help="For chunked prefill, a request is considered long if the " - "prompt is longer than this number of tokens.") - parser.add_argument('--max-num-seqs', - type=int, - default=EngineArgs.max_num_seqs, - help='Maximum number of sequences per iteration.') parser.add_argument( '--max-logprobs', type=int, @@ -598,7 +602,7 @@ class EngineArgs: # Quantization settings. parser.add_argument('--quantization', '-q', - type=nullable_str, + type=optional_type(str), choices=[*QUANTIZATION_METHODS, None], default=EngineArgs.quantization, help='Method used to quantize the weights. If ' @@ -649,162 +653,113 @@ class EngineArgs: 'Additionally for encoder-decoder models, if the ' 'sequence length of the encoder input is larger ' 'than this, we fall back to the eager mode.') - parser.add_argument('--tokenizer-pool-size', - type=int, - default=EngineArgs.tokenizer_pool_size, - help='Size of tokenizer pool to use for ' - 'asynchronous tokenization. If 0, will ' - 'use synchronous tokenization.') - parser.add_argument('--tokenizer-pool-type', - type=str, - default=EngineArgs.tokenizer_pool_type, - help='Type of tokenizer pool to use for ' - 'asynchronous tokenization. Ignored ' - 'if tokenizer_pool_size is 0.') - parser.add_argument('--tokenizer-pool-extra-config', - type=nullable_str, - default=EngineArgs.tokenizer_pool_extra_config, - help='Extra config for tokenizer pool. ' - 'This should be a JSON string that will be ' - 'parsed into a dictionary. Ignored if ' - 'tokenizer_pool_size is 0.') + + # Tokenizer arguments + tokenizer_kwargs = get_kwargs(TokenizerPoolConfig) + tokenizer_group = parser.add_argument_group( + title="TokenizerPoolConfig", + description=TokenizerPoolConfig.__doc__, + ) + tokenizer_group.add_argument('--tokenizer-pool-size', + **tokenizer_kwargs["pool_size"]) + tokenizer_group.add_argument('--tokenizer-pool-type', + **tokenizer_kwargs["pool_type"]) + tokenizer_group.add_argument('--tokenizer-pool-extra-config', + **tokenizer_kwargs["extra_config"]) # Multimodal related configs - parser.add_argument( - '--limit-mm-per-prompt', - type=nullable_kvs, - default=EngineArgs.limit_mm_per_prompt, - # The default value is given in - # MultiModalConfig.get_default_limit_per_prompt - help=('For each multimodal plugin, limit how many ' - 'input instances to allow for each prompt. ' - 'Expects a comma-separated list of items, ' - 'e.g.: `image=16,video=2` allows a maximum of 16 ' - 'images and 2 videos per prompt. Defaults to ' - '1 (V0) or 999 (V1) for each modality.')) + multimodal_kwargs = get_kwargs(MultiModalConfig) + multimodal_group = parser.add_argument_group( + title="MultiModalConfig", + description=MultiModalConfig.__doc__, + ) + multimodal_group.add_argument('--limit-mm-per-prompt', + **multimodal_kwargs["limit_per_prompt"]) + parser.add_argument( '--mm-processor-kwargs', default=None, type=json.loads, - help=('Overrides for the multimodal input mapping/processing, ' - 'e.g., image processor. For example: ``{"num_crops": 4}``.')) + help=('Overrides for the multi-modal processor obtained from ' + '``AutoProcessor.from_pretrained``. The available overrides ' + 'depend on the model that is being run.' + 'For example, for Phi-3-Vision: ``{"num_crops": 4}``.')) parser.add_argument( '--disable-mm-preprocessor-cache', action='store_true', - help='If true, then disables caching of the multi-modal ' - 'preprocessor/mapper. (not recommended)') + help='If True, disable caching of the processed multi-modal ' + 'inputs.') # LoRA related configs - parser.add_argument('--enable-lora', - action='store_true', - help='If True, enable handling of LoRA adapters.') - parser.add_argument('--enable-lora-bias', - action='store_true', - help='If True, enable bias for LoRA adapters.') - parser.add_argument('--max-loras', - type=int, - default=EngineArgs.max_loras, - help='Max number of LoRAs in a single batch.') - parser.add_argument('--max-lora-rank', - type=int, - default=EngineArgs.max_lora_rank, - help='Max LoRA rank.') - parser.add_argument('--merge-lora', - type=bool, - default=False, + lora_kwargs = get_kwargs(LoRAConfig) + lora_group = parser.add_argument_group( + title="LoRAConfig", + description=LoRAConfig.__doc__, + ) + lora_group.add_argument( + '--enable-lora', + action=argparse.BooleanOptionalAction, + help='If True, enable handling of LoRA adapters.') + lora_group.add_argument('--enable-lora-bias', + **lora_kwargs["bias_enabled"]) + lora_group.add_argument('--max-loras', **lora_kwargs["max_loras"]) + lora_group.add_argument('--max-lora-rank', + **lora_kwargs["max_lora_rank"]) + lora_group.add_argument('--merge-lora', + action=argparse.BooleanOptionalAction, help='If set to True, the weights of the base layer will be merged with the weights of Lora.') - parser.add_argument('--lora-target-modules', - nargs='*', - default=None, - help='List of lora module name, If not specified, modules will be chosen according to the model architecture.') - parser.add_argument( - '--lora-extra-vocab-size', - type=int, - default=EngineArgs.lora_extra_vocab_size, - help=('Maximum size of extra vocabulary that can be ' - 'present in a LoRA adapter (added to the base ' - 'model vocabulary).')) - parser.add_argument( + lora_group.add_argument('--lora-target-modules', + **lora_kwargs["lora_target_modules"]) + lora_group.add_argument('--lora-extra-vocab-size', + **lora_kwargs["lora_extra_vocab_size"]) + lora_group.add_argument( '--lora-dtype', - type=str, - default=EngineArgs.lora_dtype, - choices=['auto', 'float16', 'bfloat16'], - help=('Data type for LoRA. If auto, will default to ' - 'base model dtype.')) - parser.add_argument( - '--long-lora-scaling-factors', - type=nullable_str, - default=EngineArgs.long_lora_scaling_factors, - help=('Specify multiple scaling factors (which can ' - 'be different from base model scaling factor ' - '- see eg. Long LoRA) to allow for multiple ' - 'LoRA adapters trained with those scaling ' - 'factors to be used at the same time. If not ' - 'specified, only adapters trained with the ' - 'base model scaling factor are allowed.')) - parser.add_argument( - '--max-cpu-loras', - type=int, - default=EngineArgs.max_cpu_loras, - help=('Maximum number of LoRAs to store in CPU memory. ' - 'Must be >= than max_loras.')) - parser.add_argument( - '--fully-sharded-loras', - action='store_true', - help=('By default, only half of the LoRA computation is ' - 'sharded with tensor parallelism. ' - 'Enabling this will use the fully sharded layers. ' - 'At high sequence length, max rank or ' - 'tensor parallel size, this is likely faster.')) - parser.add_argument('--enable-prompt-adapter', - action='store_true', - help='If True, enable handling of PromptAdapters.') - parser.add_argument('--max-prompt-adapters', - type=int, - default=EngineArgs.max_prompt_adapters, - help='Max number of PromptAdapters in a batch.') - parser.add_argument('--max-prompt-adapter-token', - type=int, - default=EngineArgs.max_prompt_adapter_token, - help='Max number of PromptAdapters tokens') - parser.add_argument("--device", - type=str, - default=EngineArgs.device, - choices=DEVICE_OPTIONS, - help='Device type for vLLM execution.') - parser.add_argument('--num-scheduler-steps', - type=int, - default=1, - help=('Maximum number of forward steps per ' - 'scheduler call.')) + **lora_kwargs["lora_dtype"], + ) + lora_group.add_argument('--long-lora-scaling-factors', + **lora_kwargs["long_lora_scaling_factors"]) + lora_group.add_argument('--max-cpu-loras', + **lora_kwargs["max_cpu_loras"]) + lora_group.add_argument('--fully-sharded-loras', + **lora_kwargs["fully_sharded_loras"]) + + # PromptAdapter related configs + prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig) + prompt_adapter_group = parser.add_argument_group( + title="PromptAdapterConfig", + description=PromptAdapterConfig.__doc__, + ) + prompt_adapter_group.add_argument( + '--enable-prompt-adapter', + action=argparse.BooleanOptionalAction, + help='If True, enable handling of PromptAdapters.') + prompt_adapter_group.add_argument( + '--max-prompt-adapters', + **prompt_adapter_kwargs["max_prompt_adapters"]) + prompt_adapter_group.add_argument( + '--max-prompt-adapter-token', + **prompt_adapter_kwargs["max_prompt_adapter_token"]) + + # Device arguments + device_kwargs = get_kwargs(DeviceConfig) + device_group = parser.add_argument_group( + title="DeviceConfig", + description=DeviceConfig.__doc__, + ) + device_group.add_argument("--device", **device_kwargs["device"]) + + # Speculative arguments + speculative_group = parser.add_argument_group( + title="SpeculativeConfig", + description=SpeculativeConfig.__doc__, + ) + speculative_group.add_argument( + '--speculative-config', + type=json.loads, + default=None, + help='The configurations for speculative decoding.' + ' Should be a JSON string.') - parser.add_argument( - '--multi-step-stream-outputs', - action=StoreBoolean, - default=EngineArgs.multi_step_stream_outputs, - nargs="?", - const="True", - help='If False, then multi-step will stream outputs at the end ' - 'of all steps') - parser.add_argument( - '--scheduler-delay-factor', - type=float, - default=EngineArgs.scheduler_delay_factor, - help='Apply a delay (of delay factor multiplied by previous ' - 'prompt latency) before scheduling next prompt.') - parser.add_argument( - '--enable-chunked-prefill', - action=StoreBoolean, - default=EngineArgs.enable_chunked_prefill, - nargs="?", - const="True", - help='If set, the prefill requests can be chunked based on the ' - 'max_num_batched_tokens.') - parser.add_argument('--speculative-config', - type=json.loads, - default=None, - help='The configurations for speculative decoding.' - ' Should be a JSON string.') parser.add_argument( '--num-speculative-heads', type=int, @@ -819,13 +774,6 @@ class EngineArgs: help="The pattern(s) to ignore when loading the model." "Default to `original/**/*` to avoid repeated loading of llama's " "checkpoints.") - parser.add_argument( - '--preemption-mode', - type=str, - default=None, - help='If \'recompute\', the engine performs preemption by ' - 'recomputing; If \'swap\', the engine performs preemption by ' - 'block swapping.') parser.add_argument( "--served-model-name", @@ -881,22 +829,47 @@ class EngineArgs: help="Disable async output processing. This may result in " "lower performance.") - parser.add_argument( - '--scheduling-policy', - choices=['fcfs', 'priority'], - default="fcfs", - help='The scheduling policy to use. "fcfs" (first come first served' - ', i.e. requests are handled in order of arrival; default) ' - 'or "priority" (requests are handled based on given ' - 'priority (lower value means earlier handling) and time of ' - 'arrival deciding any ties).') - - parser.add_argument( - '--scheduler-cls', - default=EngineArgs.scheduler_cls, - help='The scheduler class to use. "vllm.core.scheduler.Scheduler" ' - 'is the default scheduler. Can be a class directly or the path to ' - 'a class of form "mod.custom_class".') + # Scheduler arguments + scheduler_kwargs = get_kwargs(SchedulerConfig) + scheduler_group = parser.add_argument_group( + title="SchedulerConfig", + description=SchedulerConfig.__doc__, + ) + scheduler_group.add_argument( + '--max-num-batched-tokens', + **scheduler_kwargs["max_num_batched_tokens"]) + scheduler_group.add_argument('--max-num-seqs', + **scheduler_kwargs["max_num_seqs"]) + scheduler_group.add_argument( + "--max-num-partial-prefills", + **scheduler_kwargs["max_num_partial_prefills"]) + scheduler_group.add_argument( + "--max-long-partial-prefills", + **scheduler_kwargs["max_long_partial_prefills"]) + scheduler_group.add_argument( + "--long-prefill-token-threshold", + **scheduler_kwargs["long_prefill_token_threshold"]) + scheduler_group.add_argument('--num-lookahead-slots', + **scheduler_kwargs["num_lookahead_slots"]) + scheduler_group.add_argument('--scheduler-delay-factor', + **scheduler_kwargs["delay_factor"]) + scheduler_group.add_argument('--preemption-mode', + **scheduler_kwargs["preemption_mode"]) + scheduler_group.add_argument('--num-scheduler-steps', + **scheduler_kwargs["num_scheduler_steps"]) + scheduler_group.add_argument( + '--multi-step-stream-outputs', + **scheduler_kwargs["multi_step_stream_outputs"]) + scheduler_group.add_argument('--scheduling-policy', + **scheduler_kwargs["policy"]) + scheduler_group.add_argument( + '--enable-chunked-prefill', + **scheduler_kwargs["enable_chunked_prefill"]) + scheduler_group.add_argument( + "--disable-chunked-mm-input", + **scheduler_kwargs["disable_chunked_mm_input"]) + parser.add_argument('--scheduler-cls', + **scheduler_kwargs["scheduler_cls"]) parser.add_argument( '--override-neuron-config', @@ -923,10 +896,11 @@ class EngineArgs: 'testing only. level 3 is the recommended level ' 'for production.\n' 'To specify the full compilation config, ' - 'use a JSON string.\n' + 'use a JSON string, e.g. ``{"level": 3, ' + '"cudagraph_capture_sizes": [1, 2, 4, 8]}``\n' 'Following the convention of traditional ' - 'compilers, using -O without space is also ' - 'supported. -O3 is equivalent to -O 3.') + 'compilers, using ``-O`` without space is also ' + 'supported. ``-O3`` is equivalent to ``-O 3``.') parser.add_argument('--kv-transfer-config', type=KVTransferConfig.from_cli, @@ -948,7 +922,7 @@ class EngineArgs: 'class without changing the existing functions.') parser.add_argument( "--generation-config", - type=nullable_str, + type=optional_type(str), default="auto", help="The folder path to the generation config. " "Defaults to 'auto', the generation config will be loaded from " @@ -975,15 +949,6 @@ class EngineArgs: help="Enable sleep mode for the engine. " "(only cuda platform is supported)") - parser.add_argument( - '--calculate-kv-scales', - action='store_true', - help='This enables dynamic calculation of ' - 'k_scale and v_scale when kv-cache-dtype is fp8. ' - 'If calculate-kv-scales is false, the scales will ' - 'be loaded from the model checkpoint if available. ' - 'Otherwise, the scales will default to 1.0.') - parser.add_argument( "--additional-config", type=json.loads, @@ -1001,16 +966,6 @@ class EngineArgs: "If enabled, the model will be able to generate reasoning content." ) - parser.add_argument( - "--reasoning-parser", - type=str, - choices=list(ReasoningParserManager.reasoning_parsers), - default=None, - help= - "Select the reasoning parser depending on the model that you're " - "using. This is used to parse the reasoning content into OpenAI " - "API format. Required for ``--enable-reasoning``.") - parser.add_argument( "--disable-cascade-attn", action="store_true", @@ -1021,20 +976,6 @@ class EngineArgs: "Note that even if this is set to False, cascade attention will be " "only used when the heuristic tells that it's beneficial.") - parser.add_argument( - "--disable-chunked-mm-input", - action=StoreBoolean, - default=EngineArgs.disable_chunked_mm_input, - nargs="?", - const="True", - help="Disable multimodal input chunking attention for V1. " - "If set to true and chunked prefill is enabled, we do not want to" - " partially schedule a multimodal item. This ensures that if a " - "request has a mixed prompt (like text tokens TTTT followed by " - "image tokens IIIIIIIIII) where only some image tokens can be " - "scheduled (like TTTTIIIII, leaving IIIII), it will be scheduled " - "as TTTT in one step and IIIIIIIIII in the next.") - return parser @classmethod @@ -1228,11 +1169,6 @@ class EngineArgs: enable_expert_parallel=self.enable_expert_parallel, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, - tokenizer_pool_config=TokenizerPoolConfig.create_config( - self.tokenizer_pool_size, - self.tokenizer_pool_type, - self.tokenizer_pool_extra_config, - ), ray_workers_use_nsight=self.ray_workers_use_nsight, placement_group=placement_group, distributed_executor_backend=self.distributed_executor_backend, @@ -1308,8 +1244,6 @@ class EngineArgs: if self.qlora_adapter_name_or_path is not None and \ self.qlora_adapter_name_or_path != "": - if self.model_loader_extra_config is None: - self.model_loader_extra_config = {} self.model_loader_extra_config[ "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path @@ -1390,7 +1324,7 @@ class EngineArgs: recommend_to_remove=False) return False - if self.preemption_mode != EngineArgs.preemption_mode: + if self.preemption_mode != SchedulerConfig.preemption_mode: _raise_or_fallback(feature_name="--preemption-mode", recommend_to_remove=True) return False @@ -1401,34 +1335,28 @@ class EngineArgs: recommend_to_remove=True) return False - if self.scheduling_policy != EngineArgs.scheduling_policy: + if self.scheduling_policy != SchedulerConfig.policy: _raise_or_fallback(feature_name="--scheduling-policy", recommend_to_remove=False) return False - if self.num_scheduler_steps != EngineArgs.num_scheduler_steps: + if self.num_scheduler_steps != SchedulerConfig.num_scheduler_steps: _raise_or_fallback(feature_name="--num-scheduler-steps", recommend_to_remove=True) return False - if self.scheduler_delay_factor != EngineArgs.scheduler_delay_factor: + if self.scheduler_delay_factor != SchedulerConfig.delay_factor: _raise_or_fallback(feature_name="--scheduler-delay-factor", recommend_to_remove=True) return False - if self.additional_config != EngineArgs.additional_config: - _raise_or_fallback(feature_name="--additional-config", - recommend_to_remove=False) - return False - - # Xgrammar and Guidance are supported. - SUPPORTED_GUIDED_DECODING = [ - "xgrammar", "xgrammar:disable-any-whitespace", "guidance", - "guidance:disable-any-whitespace", "auto" - ] - if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING: - _raise_or_fallback(feature_name="--guided-decoding-backend", - recommend_to_remove=False) + # remove backend options when doing this check + if self.guided_decoding_backend.split(':')[0] \ + not in get_args(GuidedDecodingBackendV1): + _raise_or_fallback( + feature_name= + f"--guided-decoding-backend={self.guided_decoding_backend}", + recommend_to_remove=False) return False # Need at least Ampere for now (FA support required). @@ -1452,7 +1380,7 @@ class EngineArgs: ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" supported = False if fp8_attention and will_use_fa: - from vllm.vllm_flash_attn.fa_utils import ( + from vllm.attention.utils.fa_utils import ( flash_attn_supports_fp8) supported = flash_attn_supports_fp8() if not supported: @@ -1495,9 +1423,9 @@ class EngineArgs: # No Concurrent Partial Prefills so far. if (self.max_num_partial_prefills - != EngineArgs.max_num_partial_prefills + != SchedulerConfig.max_num_partial_prefills or self.max_long_partial_prefills - != EngineArgs.max_long_partial_prefills): + != SchedulerConfig.max_long_partial_prefills): _raise_or_fallback(feature_name="Concurrent Partial Prefill", recommend_to_remove=False) return False @@ -1517,7 +1445,7 @@ class EngineArgs: if speculative_method: if speculative_method in ("ngram", "[ngram]"): is_ngram_enabled = True - elif speculative_method == "eagle": + elif speculative_method in ("eagle", "eagle3"): is_eagle_enabled = True else: speculative_model = self.speculative_config.get("model") @@ -1529,16 +1457,17 @@ class EngineArgs: recommend_to_remove=False) return False - # No Disaggregated Prefill so far. - if self.kv_transfer_config != EngineArgs.kv_transfer_config: - _raise_or_fallback(feature_name="--kv-transfer-config", - recommend_to_remove=False) - return False - - # No FlashInfer or XFormers so far. + # No XFormers so far. V1_BACKENDS = [ - "FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1", - "TRITON_ATTN_VLLM_V1", "TRITON_MLA", "FLASHMLA" + "FLASH_ATTN_VLLM_V1", + "FLASH_ATTN", + "PALLAS", + "PALLAS_VLLM_V1", + "TRITON_ATTN_VLLM_V1", + "TRITON_MLA", + "FLASHMLA", + "FLASHINFER", + "FLASHINFER_VLLM_V1", ] if (envs.is_set("VLLM_ATTENTION_BACKEND") and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): @@ -1640,9 +1569,7 @@ class EngineArgs: self.enable_prefix_caching = False # VLLM_V0 only supports builtin hash algo for prefix caching. - if self.prefix_caching_hash_algo is None: - self.prefix_caching_hash_algo = "builtin" - elif self.prefix_caching_hash_algo == "sha256": + if self.prefix_caching_hash_algo == "sha256": raise ValueError( "sha256 is not supported for prefix caching in V0 engine. " "Please use 'builtin'.") @@ -1661,10 +1588,6 @@ class EngineArgs: if self.enable_prefix_caching is None: self.enable_prefix_caching = True - # if using prefix caching, we must set a hash algo - if self.enable_prefix_caching and self.prefix_caching_hash_algo is None: - self.prefix_caching_hash_algo = "builtin" - # V1 should use the new scheduler by default. # Swap it only if this arg is set to the original V0 default if self.scheduler_cls == EngineArgs.scheduler_cls: @@ -1681,13 +1604,13 @@ class EngineArgs: # values for non-H100/H200 GPUs. try: from vllm.platforms import current_platform - device_name = current_platform.get_device_name().lower() + device_memory = current_platform.get_device_total_memory() except Exception: # This is only used to set default_max_num_batched_tokens - device_name = "no-device" + device_memory = 0 - if "h100" in device_name or "h200" in device_name: - # For H100 and H200, we use larger default values. + if device_memory >= 70 * GiB_bytes: + # For GPUs like H100 and MI300x, use larger default values. default_max_num_batched_tokens = { UsageContext.LLM_CLASS: 16384, UsageContext.OPENAI_API_SERVER: 8192, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 7f9f85e1f93f28a6618290f7815b975869d48705..6cc9b881464e9864d7a7f30b2af7815a7fe2066b 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -493,12 +493,11 @@ class _AsyncLLMEngine(LLMEngine): tokenizer = await self.get_tokenizer_async(lora_request) self._validate_token_prompt(prompt, tokenizer=tokenizer) - preprocessed_inputs = await self.input_preprocessor.preprocess_async( + processed_inputs = await self.input_preprocessor.preprocess_async( prompt, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, ) - processed_inputs = self.input_processor(preprocessed_inputs) if isinstance(params, SamplingParams) and \ params.guided_decoding is not None: @@ -526,10 +525,15 @@ class _AsyncLLMEngine(LLMEngine): ) async def check_health_async(self) -> None: - if self.tokenizer: - self.tokenizer.check_health() self.model_executor.check_health() + async def collective_rpc_async(self, + method: str, + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None): + raise NotImplementedError + async def build_guided_decoding_logits_processor_async( sampling_params: SamplingParams, tokenizer: AnyTokenizer, @@ -1167,6 +1171,10 @@ class AsyncLLMEngine(EngineClient): exception=asyncio.CancelledError, verbose=self.log_requests) + async def get_vllm_config(self) -> VllmConfig: + """Get the vllm configuration of the vLLM engine.""" + return self.engine.get_vllm_config() + async def get_model_config(self) -> ModelConfig: """Get the model configuration of the vLLM engine.""" return self.engine.get_model_config() @@ -1234,6 +1242,17 @@ class AsyncLLMEngine(EngineClient): async def add_lora(self, lora_request: LoRARequest) -> None: self.engine.add_lora(lora_request) + async def collective_rpc(self, + method: str, + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None): + """ + Perform a collective RPC call to the given path. + """ + return await self.engine.collective_rpc_async(method, timeout, args, + kwargs) + # TODO(v1): Remove this class proxy when V1 goes default. if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 3f17f464baae7fb986a0d408eef397db333a3b5c..f0fcadd6c77778e23918bc587a13e82ce7dc048a 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -30,8 +30,7 @@ from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.entrypoints.openai.logits_processors import ( get_logits_processors as get_openai_logits_processors) from vllm.executor.executor_base import ExecutorBase -from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, - PromptType, SingletonInputs) +from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger @@ -56,7 +55,7 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import ( - BaseTokenizerGroup, init_tokenizer_from_configs) + TokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) from vllm.utils import (Counter, Device, deprecate_kwargs, @@ -67,7 +66,6 @@ from vllm.worker.model_runner_base import InputProcessingError logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 -_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) _O = TypeVar("_O", RequestOutput, PoolingRequestOutput) _R = TypeVar("_R", default=Any) @@ -206,7 +204,7 @@ class LLMEngine: return outputs_ - tokenizer: Optional[BaseTokenizerGroup] + tokenizer: Optional[TokenizerGroup] def __init__( self, @@ -215,7 +213,6 @@ class LLMEngine: log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, - input_registry: InputRegistry = INPUT_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, use_cached_outputs: bool = False, ) -> None: @@ -276,11 +273,7 @@ class LLMEngine: self.tokenizer, mm_registry) - self.input_registry = input_registry - self.input_processor = input_registry.create_input_processor( - self.model_config) - - self.model_executor = executor_class(vllm_config=vllm_config, ) + self.model_executor = executor_class(vllm_config=vllm_config) if self.model_config.runner_type != "pooling": self._initialize_kv_caches() @@ -322,11 +315,6 @@ class LLMEngine: self.parallel_config.disable_custom_all_reduce, }) - if self.tokenizer: - # Ping the tokenizer to ensure liveness if it runs in a - # different process. - self.tokenizer.ping() - self.cached_scheduler_outputs = [ SchedulerOutputState() for _ in range(self.parallel_config.pipeline_parallel_size) @@ -540,21 +528,12 @@ class LLMEngine: if model_executor := getattr(self, "model_executor", None): model_executor.shutdown() - def get_tokenizer_group( - self, - group_type: Type[_G] = BaseTokenizerGroup, - ) -> _G: - tokenizer_group = self.tokenizer - - if tokenizer_group is None: + def get_tokenizer_group(self) -> TokenizerGroup: + if self.tokenizer is None: raise ValueError("Unable to get tokenizer because " "skip_tokenizer_init is True") - if not isinstance(tokenizer_group, group_type): - raise TypeError("Invalid type of tokenizer group. " - f"Expected type: {group_type}, but " - f"found type: {type(tokenizer_group)}") - return tokenizer_group + return self.tokenizer def get_tokenizer( self, @@ -562,11 +541,10 @@ class LLMEngine: ) -> AnyTokenizer: return self.get_tokenizer_group().get_lora_tokenizer(lora_request) - def _init_tokenizer(self) -> BaseTokenizerGroup: + def _init_tokenizer(self) -> TokenizerGroup: return init_tokenizer_from_configs( model_config=self.model_config, scheduler_config=self.scheduler_config, - parallel_config=self.parallel_config, lora_config=self.lora_config) def _verify_args(self) -> None: @@ -781,12 +759,11 @@ class LLMEngine: prompt, tokenizer=self.get_tokenizer(lora_request=lora_request)) - preprocessed_inputs = self.input_preprocessor.preprocess( + processed_inputs = self.input_preprocessor.preprocess( prompt, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, ) - processed_inputs = self.input_processor(preprocessed_inputs) self._add_processed_request( request_id=request_id, @@ -917,6 +894,10 @@ class LLMEngine: scheduler.abort_seq_group( request_id, seq_id_to_seq_group=self.seq_id_to_seq_group) + def get_vllm_config(self) -> VllmConfig: + """Gets the vllm configuration.""" + return self.vllm_config + def get_model_config(self) -> ModelConfig: """Gets the model configuration.""" return self.model_config @@ -1965,8 +1946,6 @@ class LLMEngine: return self.model_executor.is_sleeping def check_health(self) -> None: - if self.tokenizer: - self.tokenizer.check_health() self.model_executor.check_health() def is_tracing_enabled(self) -> bool: @@ -2075,7 +2054,7 @@ class LLMEngine: raise ValueError(f"The {prompt_type} prompt cannot be empty") max_prompt_len = self.model_config.max_model_len - if len(prompt_ids) >= max_prompt_len: + if len(prompt_ids) > max_prompt_len: if prompt_type == "encoder" and model_config.is_multimodal_model: mm_registry = self.input_preprocessor.mm_registry mm_processor = mm_registry.create_processor( diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 7c4265fac20b083b26d8ecc79dad4cdf78e6e8dd..033551d07c39fed432b3444031963126bef2cdd2 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -140,16 +140,13 @@ class Metrics: name="vllm:generation_tokens_total", documentation="Number of generation tokens processed.", labelnames=labelnames) - buckets = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096] - if not vllm_config.model_config.enforce_eager: - buckets = vllm_config.compilation_config.\ - cudagraph_capture_sizes.copy() - buckets.sort() self.histogram_iteration_tokens = self._histogram_cls( name="vllm:iteration_tokens_total", documentation="Histogram of number of tokens per engine_step.", labelnames=labelnames, - buckets=buckets) + buckets=[ + 1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384 + ]) self.histogram_time_to_first_token = self._histogram_cls( name="vllm:time_to_first_token_seconds", documentation="Histogram of time to first token in seconds.", diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index f058b13297bb04a11ba687529ef20af5c381377a..eb3ae89394ecc14561531ea24850cdbcd2eb21ce 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -93,6 +93,7 @@ class MQLLMEngineClient(EngineClient): self._errored_with: Optional[BaseException] = None # Get the configs. + self.vllm_config = engine_config self.model_config = engine_config.model_config self.decoding_config = engine_config.decoding_config @@ -100,7 +101,6 @@ class MQLLMEngineClient(EngineClient): self.tokenizer = init_tokenizer_from_configs( model_config=self.model_config, scheduler_config=engine_config.scheduler_config, - parallel_config=engine_config.parallel_config, lora_config=engine_config.lora_config) self.input_preprocessor = InputPreprocessor(self.model_config, self.tokenizer) @@ -377,6 +377,9 @@ class MQLLMEngineClient(EngineClient): async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None): return await self.tokenizer.get_lora_tokenizer_async(lora_request) + async def get_vllm_config(self) -> VllmConfig: + return self.vllm_config + async def get_decoding_config(self) -> DecodingConfig: return self.decoding_config diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 5f126c7571dc88792ab53dc779982c7ac159c03b..126e7da7021600a36207c586a1d4e834593d417d 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -178,7 +178,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): # generates a fixed number of tokens without evaluating stopping # conditions within the block. This can cause an eos token to be # unintentionally ignored. - if not sampling_params.ignore_eos: + if not sampling_params.ignore_eos and self.detokenizer: eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id # Avoiding .index calls as exception throwing in the happy path # is expensive. diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index e2974b02c5ba354c81a3310f65000afbf43f5a01..7e5ac3a2845226c870130509d96b2dfd740cda5c 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from typing import AsyncGenerator, List, Mapping, Optional from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function -from vllm.config import DecodingConfig, ModelConfig +from vllm.config import DecodingConfig, ModelConfig, VllmConfig from vllm.core.scheduler import SchedulerOutputs from vllm.inputs.data import PromptType, TokensPrompt from vllm.inputs.parse import is_explicit_encoder_decoder_prompt @@ -220,6 +220,11 @@ class EngineClient(ABC): """ ... + @abstractmethod + async def get_vllm_config(self) -> VllmConfig: + """Get the vllm configuration of the vLLM engine.""" + ... + @abstractmethod async def get_model_config(self) -> ModelConfig: """Get the model configuration of the vLLM engine.""" diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 6fb7dc2c9763a64a88b145f099be1a2b71b2bc85..fcaa24eec8c84f96601a790732e9c906449bba07 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -27,10 +27,11 @@ from openai.types.chat import (ChatCompletionMessageToolCallParam, ChatCompletionToolMessageParam) from openai.types.chat.chat_completion_content_part_input_audio_param import ( InputAudio) +from pydantic import TypeAdapter # yapf: enable -# pydantic needs the TypedDict from typing_extensions from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin) +# pydantic needs the TypedDict from typing_extensions from typing_extensions import Required, TypeAlias, TypedDict from vllm.config import ModelConfig @@ -482,11 +483,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): if modality in ("image", "image_embeds"): if model_type == "chatglm": return "<|begin_of_image|><|endoftext|><|end_of_image|>" - if model_type == "phi3_v": - # Workaround since this token is not defined in the tokenizer + if model_type in ("phi3_v", "phi4mm"): return f"<|image_{current_count}|>" - if model_type == "phi4mm": - return "<|endoftext10|>" # 200010 (see vocab.json in hf model) if model_type in ("minicpmo", "minicpmv"): return "(./)" if model_type in ("blip-2", "florence2", "fuyu", "paligemma", @@ -506,20 +504,24 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): return "<|image|>" if model_type in ("qwen2_vl", "qwen2_5_vl"): return "<|vision_start|><|image_pad|><|vision_end|>" + if model_type == "qwen2_5_omni": + return "<|vision_start|><|IMAGE|><|vision_end|>" if model_type == "molmo": return "" if model_type == "aria": return "<|fim_prefix|><|img|><|fim_suffix|>" if model_type == "gemma3": return "" + if model_type == "kimi_vl": + return "<|media_start|>image<|media_content|><|media_pad|><|media_end|>" # noqa: E501 raise TypeError(f"Unknown {modality} model type: {model_type}") elif modality == "audio": - if model_type == "ultravox": + if model_type in ("ultravox", "granite_speech"): return "<|audio|>" if model_type == "phi4mm": - return "<|endoftext11|>" # 200011 (see vocab.json in hf model) - if model_type == "qwen2_audio": + return f"<|audio_{current_count}|>" + if model_type in ("qwen2_audio", "qwen2_5_omni"): return (f"Audio {current_count}: " f"<|audio_bos|><|AUDIO|><|audio_eos|>") if model_type == "minicpmo": @@ -528,6 +530,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): elif modality == "video": if model_type in ("qwen2_vl", "qwen2_5_vl"): return "<|vision_start|><|video_pad|><|vision_end|>" + if model_type == "qwen2_5_omni": + return "<|vision_start|><|VIDEO|><|vision_end|>" if model_type in ("minicpmo", "minicpmv"): return "()" if model_type.startswith("llava"): @@ -876,12 +880,13 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int], # No need to validate using Pydantic again _TextParser = partial(cast, ChatCompletionContentPartTextParam) -_ImageParser = partial(cast, ChatCompletionContentPartImageParam) _ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam) -_AudioParser = partial(cast, ChatCompletionContentPartAudioParam) _InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam) _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) -_VideoParser = partial(cast, ChatCompletionContentPartVideoParam) +# Need to validate url objects +_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python +_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python +_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python _ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio] @@ -1092,7 +1097,11 @@ def _parse_chat_message_content( if role == 'assistant': parsed_msg = _AssistantParser(message) - if "tool_calls" in parsed_msg: + # The 'tool_calls' is not None check ensures compatibility. + # It's needed only if downstream code doesn't strictly + # follow the OpenAI spec. + if ("tool_calls" in parsed_msg + and parsed_msg["tool_calls"] is not None): result_msg["tool_calls"] = list(parsed_msg["tool_calls"]) elif role == "tool": parsed_msg = _ToolParser(message) @@ -1189,14 +1198,25 @@ def apply_hf_chat_template( "allowed, so you must provide a chat template if the tokenizer " "does not define one.") - return tokenizer.apply_chat_template( - conversation=conversation, # type: ignore[arg-type] - tools=tools, # type: ignore[arg-type] - chat_template=hf_chat_template, - tokenize=tokenize, - **kwargs, - ) + try: + + return tokenizer.apply_chat_template( + conversation=conversation, # type: ignore[arg-type] + tools=tools, # type: ignore[arg-type] + chat_template=hf_chat_template, + tokenize=tokenize, + **kwargs, + ) + # External library exceptions can sometimes occur despite the framework's + # internal exception management capabilities. + except Exception as e: + + # Log and report any library-related exceptions for further + # investigation. + logger.exception( + "An error occurred in `transformers` while applying chat template") + raise ValueError from e def apply_mistral_chat_template( tokenizer: MistralTokenizer, @@ -1205,6 +1225,8 @@ def apply_mistral_chat_template( tools: Optional[list[dict[str, Any]]], **kwargs: Any, ) -> list[int]: + from mistral_common.exceptions import MistralCommonException + # The return value of resolve_mistral_chat_template is always None, # and we won't use it. resolve_mistral_chat_template( @@ -1222,5 +1244,16 @@ def apply_mistral_chat_template( # if input does not comply with the expected format. # We convert those assertion errors to ValueErrors so they can be # are properly caught in the preprocessing_input step - except AssertionError as e: + except (AssertionError, MistralCommonException) as e: + raise ValueError from e + + # External library exceptions can sometimes occur despite the framework's + # internal exception management capabilities. + except Exception as e: + + # Log and report any library-related exceptions for further + # investigation. + logger.exception( + "An error occurred in `mistral_common` while applying chat " + "template") raise ValueError from e diff --git a/vllm/entrypoints/cli/benchmark/latency.py b/vllm/entrypoints/cli/benchmark/latency.py new file mode 100644 index 0000000000000000000000000000000000000000..5aca16e0b640c6e61615c53470ae22cd5fec88d4 --- /dev/null +++ b/vllm/entrypoints/cli/benchmark/latency.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 +import argparse + +from vllm.benchmarks.latency import add_cli_args, main +from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase +from vllm.entrypoints.cli.types import CLISubcommand + + +class BenchmarkLatencySubcommand(BenchmarkSubcommandBase): + """ The `latency` subcommand for vllm bench. """ + + def __init__(self): + self.name = "latency" + super().__init__() + + @property + def help(self) -> str: + return "Benchmark the latency of a single batch of requests." + + def add_cli_args(self, parser: argparse.ArgumentParser) -> None: + add_cli_args(parser) + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + main(args) + + +def cmd_init() -> list[CLISubcommand]: + return [BenchmarkLatencySubcommand()] diff --git a/vllm/entrypoints/cli/benchmark/main.py b/vllm/entrypoints/cli/benchmark/main.py index 1bcb25be2fcaeb1fa618c5e9e86550a034f76782..9e857af7d6dbd12fa515647c3e867d8931dfb411 100644 --- a/vllm/entrypoints/cli/benchmark/main.py +++ b/vllm/entrypoints/cli/benchmark/main.py @@ -1,14 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 import argparse +import vllm.entrypoints.cli.benchmark.latency import vllm.entrypoints.cli.benchmark.serve +import vllm.entrypoints.cli.benchmark.throughput from vllm.entrypoints.cli.types import CLISubcommand from vllm.utils import FlexibleArgumentParser -# TODO: Add the rest of the benchmark subcommands here, -# e.g., throughput, latency, etc. BENCHMARK_CMD_MODULES = [ + vllm.entrypoints.cli.benchmark.latency, vllm.entrypoints.cli.benchmark.serve, + vllm.entrypoints.cli.benchmark.throughput, ] diff --git a/vllm/entrypoints/cli/benchmark/throughput.py b/vllm/entrypoints/cli/benchmark/throughput.py new file mode 100644 index 0000000000000000000000000000000000000000..88ee6aa0385783bc785bdefa75bc2787bf6d6a97 --- /dev/null +++ b/vllm/entrypoints/cli/benchmark/throughput.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 +import argparse + +from vllm.benchmarks.throughput import add_cli_args, main +from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase +from vllm.entrypoints.cli.types import CLISubcommand + + +class BenchmarkThroughputSubcommand(BenchmarkSubcommandBase): + """ The `throughput` subcommand for vllm bench. """ + + def __init__(self): + self.name = "throughput" + super().__init__() + + @property + def help(self) -> str: + return "Benchmark offline inference throughput." + + def add_cli_args(self, parser: argparse.ArgumentParser) -> None: + add_cli_args(parser) + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + main(args) + + +def cmd_init() -> list[CLISubcommand]: + return [BenchmarkThroughputSubcommand()] diff --git a/vllm/entrypoints/cli/collect_env.py b/vllm/entrypoints/cli/collect_env.py new file mode 100644 index 0000000000000000000000000000000000000000..d5f9f7e729f08cb03389fddcfb6fa49243b43bd2 --- /dev/null +++ b/vllm/entrypoints/cli/collect_env.py @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse + +from vllm.collect_env import main as collect_env_main +from vllm.entrypoints.cli.types import CLISubcommand +from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.utils import FlexibleArgumentParser + + +class CollectEnvSubcommand(CLISubcommand): + """The `serve` subcommand for the vLLM CLI. """ + + def __init__(self): + self.name = "collect-env" + super().__init__() + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + """Collect information about the environment.""" + collect_env_main() + + def subparser_init( + self, + subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + serve_parser = subparsers.add_parser( + "collect-env", + help="Start collecting environment information.", + description="Start collecting environment information.", + usage="vllm collect-env") + return make_arg_parser(serve_parser) + + +def cmd_init() -> list[CLISubcommand]: + return [CollectEnvSubcommand()] diff --git a/vllm/entrypoints/cli/main.py b/vllm/entrypoints/cli/main.py index aa54bd66bed678f911f58e2ab67a57ef6e54aca5..b7c1afce711811158e8ac12bb61c433ab314df3d 100644 --- a/vllm/entrypoints/cli/main.py +++ b/vllm/entrypoints/cli/main.py @@ -5,6 +5,7 @@ import signal import sys import vllm.entrypoints.cli.benchmark.main +import vllm.entrypoints.cli.collect_env import vllm.entrypoints.cli.openai import vllm.entrypoints.cli.serve import vllm.version @@ -15,6 +16,7 @@ CMD_MODULES = [ vllm.entrypoints.cli.openai, vllm.entrypoints.cli.serve, vllm.entrypoints.cli.benchmark.main, + vllm.entrypoints.cli.collect_env, ] diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index b09ee526f14aebb928c0ad1c47a8f1e9405772a3..a4f70a51ebaf34beaad928ffa140d0dda250abcf 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -12,9 +12,11 @@ from fastapi import FastAPI, Request, Response from vllm import envs from vllm.engine.async_llm_engine import AsyncEngineDeadError from vllm.engine.multiprocessing import MQEngineDeadError +from vllm.engine.protocol import EngineClient from vllm.entrypoints.ssl import SSLCertRefresher from vllm.logger import init_logger from vllm.utils import find_process_using_port +from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError logger = init_logger(__name__) @@ -40,6 +42,8 @@ async def serve_http(app: FastAPI, loop = asyncio.get_running_loop() + watchdog_task = loop.create_task( + watchdog_loop(server, app.state.engine_client)) server_task = loop.create_task( server.serve(sockets=[sock] if sock else None)) @@ -52,6 +56,7 @@ async def serve_http(app: FastAPI, def signal_handler() -> None: # prevents the uvicorn signal handler to exit early server_task.cancel() + watchdog_task.cancel() if ssl_cert_refresher: ssl_cert_refresher.stop() @@ -73,48 +78,69 @@ async def serve_http(app: FastAPI, port, process, " ".join(process.cmdline())) logger.info("Shutting down FastAPI HTTP server.") return server.shutdown() + finally: + watchdog_task.cancel() + + +async def watchdog_loop(server: uvicorn.Server, engine: EngineClient): + """ + # Watchdog task that runs in the background, checking + # for error state in the engine. Needed to trigger shutdown + # if an exception arises is StreamingResponse() generator. + """ + VLLM_WATCHDOG_TIME_S = 5.0 + while True: + await asyncio.sleep(VLLM_WATCHDOG_TIME_S) + terminate_if_errored(server, engine) + + +def terminate_if_errored(server: uvicorn.Server, engine: EngineClient): + """ + See discussions here on shutting down a uvicorn server + https://github.com/encode/uvicorn/discussions/1103 + In this case we cannot await the server shutdown here + because handler must first return to close the connection + for this request. + """ + engine_errored = engine.errored and not engine.is_running + if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine_errored: + server.should_exit = True def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None: - """Adds handlers for fatal errors that should crash the server""" + """ + VLLM V1 AsyncLLM catches exceptions and returns + only two types: EngineGenerateError and EngineDeadError. + + EngineGenerateError is raised by the per request generate() + method. This error could be request specific (and therefore + recoverable - e.g. if there is an error in input processing). + + EngineDeadError is raised by the background output_handler + method. This error is global and therefore not recoverable. + + We register these @app.exception_handlers to return nice + responses to the end user if they occur and shut down if needed. + See https://fastapi.tiangolo.com/tutorial/handling-errors/ + for more details on how exception handlers work. + + If an exception is encountered in a StreamingResponse + generator, the exception is not raised, since we already sent + a 200 status. Rather, we send an error message as the next chunk. + Since the exception is not raised, this means that the server + will not automatically shut down. Instead, we use the watchdog + background task for check for errored state. + """ @app.exception_handler(RuntimeError) - async def runtime_error_handler(request: Request, __): - """On generic runtime error, check to see if the engine has died. - It probably has, in which case the server will no longer be able to - handle requests. Trigger a graceful shutdown with a SIGTERM.""" - engine = request.app.state.engine_client - if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored - and not engine.is_running): - logger.fatal("AsyncLLMEngine has failed, terminating server " - "process") - # See discussions here on shutting down a uvicorn server - # https://github.com/encode/uvicorn/discussions/1103 - # In this case we cannot await the server shutdown here because - # this handler must first return to close the connection for - # this request. - server.should_exit = True - - return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) - @app.exception_handler(AsyncEngineDeadError) - async def async_engine_dead_handler(_, __): - """Kill the server if the async engine is already dead. It will - not handle any further requests.""" - if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: - logger.fatal("AsyncLLMEngine is already dead, terminating server " - "process") - server.should_exit = True - - return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) - @app.exception_handler(MQEngineDeadError) - async def mq_engine_dead_handler(_, __): - """Kill the server if the mq engine is already dead. It will - not handle any further requests.""" - if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: - logger.fatal("MQLLMEngine is already dead, terminating server " - "process") - server.should_exit = True + @app.exception_handler(EngineDeadError) + @app.exception_handler(EngineGenerateError) + async def runtime_exception_handler(request: Request, __): + terminate_if_errored( + server=server, + engine=request.app.state.engine_client, + ) return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index a707087a2e286b3d9021252ee8051d98df743aa7..653e61a11ebd9a6753e3d25522948be4dbe0615a 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -40,7 +40,6 @@ from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, RequestOutputKind, SamplingParams) from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, get_cached_tokenizer) -from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.usage.usage_lib import UsageContext from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs, is_list_of) @@ -118,7 +117,7 @@ class LLM: disable_async_output_proc: Disable async output processing. This may result in lower performance. hf_token: The token to use as HTTP bearer authorization for remote files - . If `True`, will use the token generated when running + . If `True`, will use the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). hf_overrides: If a dictionary, contains arguments to be forwarded to the HuggingFace config. If a callable, it is called to update the @@ -252,11 +251,15 @@ class LLM: self.request_counter = Counter() self.default_sampling_params: Union[dict[str, Any], None] = None - def get_tokenizer(self) -> AnyTokenizer: - return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer + def get_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + return self.llm_engine.get_tokenizer_group().get_lora_tokenizer( + lora_request) def set_tokenizer(self, tokenizer: AnyTokenizer) -> None: - tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup) + tokenizer_group = self.llm_engine.get_tokenizer_group() # While CachedTokenizer is dynamic, have no choice but # compare class name. Misjudgment will arise from @@ -520,11 +523,9 @@ class LLM: prompts: A list of prompts. Each prompt can be a string or a list of token IDs. params: The beam search parameters. - - TODO: how does beam search work together with length penalty, frequency - penalty, and stopping criteria, etc.? """ - + # TODO: how does beam search work together with length penalty, + # frequency, penalty, and stopping criteria, etc.? beam_width = params.beam_width max_tokens = params.max_tokens temperature = params.temperature @@ -536,15 +537,18 @@ class LLM: tokenizer.eos_token_id, length_penalty) - # TODO - fix handling of multimodal data for beam search; we pass it - # through in the async version on the abstract EngineClient, but not - # here. - if any("multi_modal_data" in prompt - and prompt["multi_modal_data"] is not None - for prompt in prompts): - logger.warning( - "Multimodal data appears to have been provided, but is not" - " currently being passed through in LLM.beam_search()!") + def create_tokens_prompt_from_beam( + beam: BeamSearchSequence) -> TokensPrompt: + token_prompt_kwargs: TokensPrompt = { + "prompt_token_ids": beam.tokens + } + if beam.multi_modal_data is not None: + token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data + + if beam.mm_processor_kwargs is not None: + token_prompt_kwargs[ + "mm_processor_kwargs"] = beam.mm_processor_kwargs + return TokensPrompt(**token_prompt_kwargs) tokenizer = self.get_tokenizer() # generate 2 * beam_width candidates at each step @@ -556,11 +560,20 @@ class LLM: instances: list[BeamSearchInstance] = [] for prompt in prompts: + # Add multimodal processor kwargs & data + mm_kwargs = {} + if "multi_modal_data" in prompt: + mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"] + if "mm_processor_kwargs" in prompt: + mm_kwargs["mm_processor_kwargs"] = prompt[ + "mm_processor_kwargs"] + if is_token_prompt(prompt): prompt_tokens = prompt["prompt_token_ids"] else: prompt_tokens = tokenizer.encode(prompt["prompt"]) - instances.append(BeamSearchInstance(prompt_tokens)) + instances.append( + BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs)) for _ in range(max_tokens): all_beams: list[BeamSearchSequence] = list( @@ -575,8 +588,7 @@ class LLM: break prompts_batch = [ - TokensPrompt(prompt_token_ids=beam.tokens) - for beam in all_beams + create_tokens_prompt_from_beam(beam) for beam in all_beams ] # only runs for one step @@ -602,7 +614,10 @@ class LLM: tokens=current_beam.tokens + [token_id], logprobs=current_beam.logprobs + [logprobs], cum_logprob=current_beam.cum_logprob + - logprob_obj.logprob) + logprob_obj.logprob, + multi_modal_data=current_beam.multi_modal_data, + mm_processor_kwargs=current_beam. + mm_processor_kwargs) if token_id == tokenizer.eos_token_id and \ not ignore_eos: @@ -701,7 +716,7 @@ class LLM: cast(list[ChatCompletionMessageParam], messages) ] - tokenizer = self.get_tokenizer() + tokenizer = self.get_tokenizer(lora_request) model_config = self.llm_engine.get_model_config() resolved_content_format = resolve_chat_template_content_format( chat_template, @@ -724,9 +739,8 @@ class LLM: content_format=resolved_content_format, ) - prompt_data: Union[str, list[int]] if isinstance(tokenizer, MistralTokenizer): - prompt_data = apply_mistral_chat_template( + prompt_token_ids = apply_mistral_chat_template( tokenizer, messages=msgs, chat_template=chat_template, @@ -735,7 +749,7 @@ class LLM: continue_final_message=continue_final_message, ) else: - prompt_data = apply_hf_chat_template( + prompt_str = apply_hf_chat_template( tokenizer, trust_remote_code=model_config.trust_remote_code, conversation=conversation, @@ -744,12 +758,12 @@ class LLM: add_generation_prompt=add_generation_prompt, continue_final_message=continue_final_message, ) + # Special tokens are already included in chat templates so + # should not be added by the tokenizer in this case. + prompt_token_ids = tokenizer.encode(prompt_str, + add_special_tokens=False) - prompt: Union[TokensPrompt, TextPrompt] - if is_list_of(prompt_data, int): - prompt = TokensPrompt(prompt_token_ids=prompt_data) - else: - prompt = TextPrompt(prompt=prompt_data) + prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) if mm_data is not None: prompt["multi_modal_data"] = mm_data @@ -1048,8 +1062,6 @@ class LLM: if len(encoded_output_1) == 1: encoded_output_1 = encoded_output_1 * len(encoded_output_2) - scores: list[PoolingRequestOutput] = [] - scores = _cosine_similarity(tokenizer=tokenizer, embed_1=encoded_output_1, embed_2=encoded_output_2) @@ -1384,7 +1396,9 @@ class LLM: grammar=guided_options.guided_grammar, json_object=guided_options.guided_json_object, backend=guided_options.guided_decoding_backend, - whitespace_pattern=guided_options.guided_whitespace_pattern) + whitespace_pattern=guided_options.guided_whitespace_pattern, + structural_tag=guided_options.structural_tag, + ) return params def _run_engine( diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 6a8bdd0602285ebc8225f0a2da5f01da2353018c..13681958089705294f87667a838c36d30b14efb3 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -30,7 +30,7 @@ from starlette.routing import Mount from typing_extensions import assert_never import vllm.envs as envs -from vllm.config import ModelConfig +from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore from vllm.engine.multiprocessing.client import MQLLMEngineClient @@ -310,32 +310,33 @@ def mount_metrics(app: FastAPI): # We need to set PROMETHEUS_MULTIPROC_DIR environment variable # before prometheus_client is imported. # See https://prometheus.github.io/client_python/multiprocess/ - from prometheus_client import (CollectorRegistry, make_asgi_app, + from prometheus_client import (REGISTRY, CollectorRegistry, make_asgi_app, multiprocess) from prometheus_fastapi_instrumentator import Instrumentator + registry = REGISTRY + prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None) if prometheus_multiproc_dir_path is not None: logger.debug("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR", prometheus_multiproc_dir_path) registry = CollectorRegistry() multiprocess.MultiProcessCollector(registry) - Instrumentator( - excluded_handlers=[ - "/metrics", - "/health", - "/load", - "/ping", - "/version", - ], - registry=registry, - ).add().instrument(app).expose(app) - - # Add prometheus asgi middleware to route /metrics requests - metrics_route = Mount("/metrics", make_asgi_app(registry=registry)) - else: - # Add prometheus asgi middleware to route /metrics requests - metrics_route = Mount("/metrics", make_asgi_app()) + + Instrumentator( + excluded_handlers=[ + "/metrics", + "/health", + "/load", + "/ping", + "/version", + "/server_info", + ], + registry=registry, + ).add().instrument(app).expose(app) + + # Add prometheus asgi middleware to route /metrics requests + metrics_route = Mount("/metrics", make_asgi_app(registry=registry)) # Workaround for 307 Redirect for /metrics metrics_route.path_regex = re.compile("^/metrics(?P.*)$") @@ -687,6 +688,11 @@ TASK_HANDLERS: dict[str, dict[str, tuple]] = { if envs.VLLM_SERVER_DEV_MODE: + @router.get("/server_info") + async def show_server_info(raw_request: Request): + server_info = {"vllm_config": str(raw_request.app.state.vllm_config)} + return JSONResponse(content=server_info) + @router.post("/reset_prefix_cache") async def reset_prefix_cache(raw_request: Request): """ @@ -875,7 +881,8 @@ def build_app(args: Namespace) -> FastAPI: section async for section in response.body_iterator ] response.body_iterator = iterate_in_threadpool(iter(response_body)) - logger.info("response_body={%s}", response_body[0].decode()) + logger.info("response_body={%s}", + response_body[0].decode() if response_body else None) return response for middleware in args.middleware: @@ -894,7 +901,7 @@ def build_app(args: Namespace) -> FastAPI: async def init_app_state( engine_client: EngineClient, - model_config: ModelConfig, + vllm_config: VllmConfig, state: State, args: Namespace, ) -> None: @@ -915,6 +922,8 @@ async def init_app_state( state.engine_client = engine_client state.log_stats = not args.disable_log_stats + state.vllm_config = vllm_config + model_config = vllm_config.model_config resolved_chat_template = load_chat_template(args.chat_template) if resolved_chat_template is not None: @@ -1069,8 +1078,8 @@ async def run_server(args, **uvicorn_kwargs) -> None: async with build_async_engine_client(args) as engine_client: app = build_app(args) - model_config = await engine_client.get_model_config() - await init_app_state(engine_client, model_config, app.state, args) + vllm_config = await engine_client.get_vllm_config() + await init_app_state(engine_client, vllm_config, app.state, args) def _listen_addr(a: str) -> str: if is_valid_ipv6_address(a): diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 218a8fbe10b76139a191682a81b3f4895f19bba2..b3824013f055ad3b15895dfb3c4db2acce0fe41b 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -11,7 +11,7 @@ import ssl from collections.abc import Sequence from typing import Optional, Union, get_args -from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str +from vllm.engine.arg_utils import AsyncEngineArgs, optional_type from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, validate_chat_template) from vllm.entrypoints.openai.serving_models import (LoRAModulePath, @@ -79,7 +79,7 @@ class PromptAdapterParserAction(argparse.Action): def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument("--host", - type=nullable_str, + type=optional_type(str), default=None, help="Host name.") parser.add_argument("--port", type=int, default=8000, help="Port number.") @@ -108,13 +108,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=["*"], help="Allowed headers.") parser.add_argument("--api-key", - type=nullable_str, + type=optional_type(str), default=None, help="If provided, the server will require this key " "to be presented in the header.") parser.add_argument( "--lora-modules", - type=nullable_str, + type=optional_type(str), default=None, nargs='+', action=LoRAParserAction, @@ -126,14 +126,14 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "\"base_model_name\": \"id\"}``") parser.add_argument( "--prompt-adapters", - type=nullable_str, + type=optional_type(str), default=None, nargs='+', action=PromptAdapterParserAction, help="Prompt adapter configurations in the format name=path. " "Multiple adapters can be specified.") parser.add_argument("--chat-template", - type=nullable_str, + type=optional_type(str), default=None, help="The file path to the chat template, " "or the template in single-line form " @@ -151,20 +151,20 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'similar to OpenAI schema. ' 'Example: ``[{"type": "text", "text": "Hello world!"}]``') parser.add_argument("--response-role", - type=nullable_str, + type=optional_type(str), default="assistant", help="The role name to return if " "``request.add_generation_prompt=true``.") parser.add_argument("--ssl-keyfile", - type=nullable_str, + type=optional_type(str), default=None, help="The file path to the SSL key file.") parser.add_argument("--ssl-certfile", - type=nullable_str, + type=optional_type(str), default=None, help="The file path to the SSL cert file.") parser.add_argument("--ssl-ca-certs", - type=nullable_str, + type=optional_type(str), default=None, help="The CA certificates file.") parser.add_argument( @@ -180,13 +180,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: ) parser.add_argument( "--root-path", - type=nullable_str, + type=optional_type(str), default=None, help="FastAPI root_path when app is behind a path based routing proxy." ) parser.add_argument( "--middleware", - type=nullable_str, + type=optional_type(str), action="append", default=[], help="Additional ASGI middleware to apply to the app. " diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 4639b4cea06b716bab6f85be54a939ac2f07da32..015943762ab1e80514b250ea6bae293c0d59e271 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -2,6 +2,7 @@ # Adapted from # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py +import json import re import time from argparse import Namespace @@ -139,12 +140,30 @@ class JsonSchemaResponseFormat(OpenAIBaseModel): strict: Optional[bool] = None +class StructuralTag(OpenAIBaseModel): + begin: str + # schema is the field, but that causes conflicts with pydantic so + # instead use structural_tag_schema with an alias + structural_tag_schema: Optional[dict[str, Any]] = Field(default=None, + alias="schema") + end: str + + +class StructuralTagResponseFormat(OpenAIBaseModel): + type: Literal["structural_tag"] + structures: list[StructuralTag] + triggers: list[str] + + class ResponseFormat(OpenAIBaseModel): - # type must be "json_schema", "json_object" or "text" + # type must be "json_schema", "json_object", or "text" type: Literal["text", "json_object", "json_schema"] json_schema: Optional[JsonSchemaResponseFormat] = None +AnyResponseFormat = Union[ResponseFormat, StructuralTagResponseFormat] + + class StreamOptions(OpenAIBaseModel): include_usage: Optional[bool] = True continuous_usage_stats: Optional[bool] = False @@ -227,7 +246,7 @@ class ChatCompletionRequest(OpenAIBaseModel): max_completion_tokens: Optional[int] = None n: Optional[int] = 1 presence_penalty: Optional[float] = 0.0 - response_format: Optional[ResponseFormat] = None + response_format: Optional[AnyResponseFormat] = None seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) stop: Optional[Union[str, list[str]]] = Field(default_factory=list) stream: Optional[bool] = False @@ -340,6 +359,11 @@ class ChatCompletionRequest(OpenAIBaseModel): description=( "If specified, the output will follow the context free grammar."), ) + structural_tag: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the structural tag schema."), + ) guided_decoding_backend: Optional[str] = Field( default=None, description=( @@ -476,6 +500,12 @@ class ChatCompletionRequest(OpenAIBaseModel): json_schema = self.response_format.json_schema assert json_schema is not None self.guided_json = json_schema.json_schema + elif self.response_format.type == "structural_tag": + structural_tag = self.response_format + assert structural_tag is not None and isinstance( + structural_tag, StructuralTagResponseFormat) + s_tag_obj = structural_tag.model_dump(by_alias=True) + self.structural_tag = json.dumps(s_tag_obj) guided_decoding = GuidedDecodingParams.from_optional( json=self._get_guided_json_from_tool() or self.guided_json, @@ -485,6 +515,7 @@ class ChatCompletionRequest(OpenAIBaseModel): json_object=guided_json_object, backend=self.guided_decoding_backend, whitespace_pattern=self.guided_whitespace_pattern, + structural_tag=self.structural_tag, ) return SamplingParams.from_optional( @@ -742,12 +773,13 @@ class CompletionRequest(OpenAIBaseModel): "If true (the default), special tokens (e.g. BOS) will be added to " "the prompt."), ) - response_format: Optional[ResponseFormat] = Field( + response_format: Optional[AnyResponseFormat] = Field( default=None, - description= - ("Similar to chat completion, this parameter specifies the format of " - "output. Only {'type': 'json_object'}, {'type': 'json_schema'} or " - "{'type': 'text' } is supported."), + description=( + "Similar to chat completion, this parameter specifies the format " + "of output. Only {'type': 'json_object'}, {'type': 'json_schema'}" + ", {'type': 'structural_tag'}, or {'type': 'text' } is supported." + ), ) guided_json: Optional[Union[str, dict, BaseModel]] = Field( default=None, @@ -1577,14 +1609,6 @@ class TranscriptionRequest(OpenAIBaseModel): """ ## TODO (varun) : Support if set to 0, certain thresholds are met !! - temperature: float = Field(default=0.0) - """The sampling temperature, between 0 and 1. - - Higher values like 0.8 will make the output more random, while lower values - like 0.2 will make it more focused / deterministic. If set to 0, the model - will use [log probability](https://en.wikipedia.org/wiki/Log_probability) - to automatically increase the temperature until certain thresholds are hit. - """ timestamp_granularities: list[Literal["word", "segment"]] = Field( alias="timestamp_granularities[]", default=[]) @@ -1596,6 +1620,7 @@ class TranscriptionRequest(OpenAIBaseModel): timestamps incurs additional latency. """ + # doc: begin-transcription-extra-params stream: Optional[bool] = False """Custom field not present in the original OpenAI definition. When set, it will enable output to be streamed in a similar fashion as the Chat @@ -1604,10 +1629,51 @@ class TranscriptionRequest(OpenAIBaseModel): # Flattened stream option to simplify form data. stream_include_usage: Optional[bool] = False stream_continuous_usage_stats: Optional[bool] = False + # doc: end-transcription-extra-params + + # doc: begin-transcription-sampling-params + temperature: float = Field(default=0.0) + """The sampling temperature, between 0 and 1. + + Higher values like 0.8 will make the output more random, while lower values + like 0.2 will make it more focused / deterministic. If set to 0, the model + will use [log probability](https://en.wikipedia.org/wiki/Log_probability) + to automatically increase the temperature until certain thresholds are hit. + """ + + top_p: Optional[float] = None + """Enables nucleus (top-p) sampling, where tokens are selected from the + smallest possible set whose cumulative probability exceeds `p`. + """ + + top_k: Optional[int] = None + """Limits sampling to the `k` most probable tokens at each step.""" + + min_p: Optional[float] = None + """Filters out tokens with a probability lower than `min_p`, ensuring a + minimum likelihood threshold during sampling. + """ + + seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) + """The seed to use for sampling.""" + + frequency_penalty: Optional[float] = 0.0 + """The frequency penalty to use for sampling.""" + + repetition_penalty: Optional[float] = None + """The repetition penalty to use for sampling.""" + + presence_penalty: Optional[float] = 0.0 + """The presence penalty to use for sampling.""" + # doc: end-transcription-sampling-params # Default sampling parameters for transcription requests. _DEFAULT_SAMPLING_PARAMS: dict = { - "temperature": 0, + "repetition_penalty": 1.0, + "temperature": 1.0, + "top_p": 1.0, + "top_k": -1, + "min_p": 0.0, } def to_sampling_params( @@ -1619,13 +1685,35 @@ class TranscriptionRequest(OpenAIBaseModel): if default_sampling_params is None: default_sampling_params = {} + # Default parameters if (temperature := self.temperature) is None: temperature = default_sampling_params.get( "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + if (top_p := self.top_p) is None: + top_p = default_sampling_params.get( + "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) + if (top_k := self.top_k) is None: + top_k = default_sampling_params.get( + "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]) + if (min_p := self.min_p) is None: + min_p = default_sampling_params.get( + "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]) + + if (repetition_penalty := self.repetition_penalty) is None: + repetition_penalty = default_sampling_params.get( + "repetition_penalty", + self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"]) return SamplingParams.from_optional(temperature=temperature, max_tokens=max_tokens, + seed=self.seed, + top_p=top_p, + top_k=top_k, + min_p=min_p, + frequency_penalty=self.frequency_penalty, + repetition_penalty=repetition_penalty, + presence_penalty=self.presence_penalty, output_kind=RequestOutputKind.DELTA if self.stream \ else RequestOutputKind.FINAL_ONLY) diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 0d06ba3df23f9018946eb472e81a72c8350bc4ca..fccf459f17dc6d18c3d8decab58e6128b3afa7db 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -12,7 +12,7 @@ import torch from prometheus_client import start_http_server from tqdm import tqdm -from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str +from vllm.engine.arg_utils import AsyncEngineArgs, optional_type from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.logger import RequestLogger, logger # yapf: disable @@ -61,7 +61,7 @@ def parse_args(): "to the output URL.", ) parser.add_argument("--response-role", - type=nullable_str, + type=optional_type(str), default="assistant", help="The role name to return if " "`request.add_generation_prompt=True`.") diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index bbc8eddd8b1b00d95511178556accfe89e8eb27a..49b346a23baf9b7e86628e8968f17fca3ac51242 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -10,6 +10,7 @@ from fastapi import Request from pydantic import Field from starlette.datastructures import Headers +import vllm.envs as envs from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient # yapf conflicts with isort for this block @@ -125,18 +126,29 @@ class OpenAIServing: self, request: AnyRequest, ) -> Optional[ErrorResponse]: + + error_response = None + if self._is_model_supported(request.model): return None if request.model in [ lora.lora_name for lora in self.models.lora_requests ]: return None + if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and ( + load_result := await self.models.resolve_lora(request.model)): + if isinstance(load_result, LoRARequest): + return None + if isinstance(load_result, ErrorResponse) and \ + load_result.code == HTTPStatus.BAD_REQUEST.value: + error_response = load_result if request.model in [ prompt_adapter.prompt_adapter_name for prompt_adapter in self.models.prompt_adapter_requests ]: return None - return self.create_error_response( + + return error_response or self.create_error_response( message=f"The model `{request.model}` does not exist.", err_type="NotFoundError", status_code=HTTPStatus.NOT_FOUND) diff --git a/vllm/entrypoints/openai/serving_models.py b/vllm/entrypoints/openai/serving_models.py index 7a68452efc653017db60d706a0342fe1f7c947b0..74433a1a3c3f5703188e00b09dbf7d9a1d82cfc6 100644 --- a/vllm/entrypoints/openai/serving_models.py +++ b/vllm/entrypoints/openai/serving_models.py @@ -2,6 +2,8 @@ import json import pathlib +from asyncio import Lock +from collections import defaultdict from dataclasses import dataclass from http import HTTPStatus from typing import Optional, Union @@ -15,6 +17,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, UnloadLoRAAdapterRequest) from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.utils import AtomicCounter @@ -63,11 +66,19 @@ class OpenAIServingModels: self.base_model_paths = base_model_paths self.max_model_len = model_config.max_model_len self.engine_client = engine_client + self.model_config = model_config self.static_lora_modules = lora_modules self.lora_requests: list[LoRARequest] = [] self.lora_id_counter = AtomicCounter(0) + self.lora_resolvers: list[LoRAResolver] = [] + for lora_resolver_name in LoRAResolverRegistry.get_supported_resolvers( + ): + self.lora_resolvers.append( + LoRAResolverRegistry.get_resolver(lora_resolver_name)) + self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock) + self.prompt_adapter_requests = [] if prompt_adapters is not None: for i, prompt_adapter in enumerate(prompt_adapters, start=1): @@ -234,6 +245,65 @@ class OpenAIServingModels: return None + async def resolve_lora( + self, lora_name: str) -> Union[LoRARequest, ErrorResponse]: + """Attempt to resolve a LoRA adapter using available resolvers. + + Args: + lora_name: Name/identifier of the LoRA adapter + + Returns: + LoRARequest if found and loaded successfully. + ErrorResponse (404) if no resolver finds the adapter. + ErrorResponse (400) if adapter(s) are found but none load. + """ + async with self.lora_resolver_lock[lora_name]: + # First check if this LoRA is already loaded + for existing in self.lora_requests: + if existing.lora_name == lora_name: + return existing + + base_model_name = self.model_config.model + unique_id = self.lora_id_counter.inc(1) + found_adapter = False + + # Try to resolve using available resolvers + for resolver in self.lora_resolvers: + lora_request = await resolver.resolve_lora( + base_model_name, lora_name) + + if lora_request is not None: + found_adapter = True + lora_request.lora_int_id = unique_id + + try: + await self.engine_client.add_lora(lora_request) + self.lora_requests.append(lora_request) + logger.info( + "Resolved and loaded LoRA adapter '%s' using %s", + lora_name, resolver.__class__.__name__) + return lora_request + except BaseException as e: + logger.warning( + "Failed to load LoRA '%s' resolved by %s: %s. " + "Trying next resolver.", lora_name, + resolver.__class__.__name__, e) + continue + + if found_adapter: + # An adapter was found, but all attempts to load it failed. + return create_error_response( + message=(f"LoRA adapter '{lora_name}' was found " + "but could not be loaded."), + err_type="BadRequestError", + status_code=HTTPStatus.BAD_REQUEST) + else: + # No adapter was found + return create_error_response( + message=f"LoRA adapter {lora_name} does not exist", + err_type="NotFoundError", + status_code=HTTPStatus.NOT_FOUND) + def create_error_response( message: str, diff --git a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py index 20c3238fb3dfe1650ee79462d87ecbb94d619932..5c181616aa01dbe52659b937541bbdde44caa3fe 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py @@ -27,6 +27,7 @@ logger = init_logger(__name__) @ToolParserManager.register_module("llama3_json") +@ToolParserManager.register_module("llama4_json") class Llama3JsonToolParser(ToolParser): """ Tool call parser for Llama 3.1 models intended for use with the diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 0661445639d74bd739ebe4848c2aad19290aab16..9dbfe85ecc686be69044f868105328fe13d22b7e 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -38,6 +38,10 @@ class MistralToolCall(ToolCall): # https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299 return "".join(choices(ALPHANUMERIC, k=9)) + @staticmethod + def is_valid_id(id: str) -> bool: + return id.isalnum() and len(id) == 9 + @ToolParserManager.register_module("mistral") class MistralToolParser(ToolParser): @@ -70,6 +74,19 @@ class MistralToolParser(ToolParser): "Mistral Tool Parser could not locate the tool call token in " "the tokenizer!") + def adjust_request( + self, request: ChatCompletionRequest) -> ChatCompletionRequest: + if not isinstance( + self.model_tokenizer, MistralTokenizer + ) and request.tools and request.tool_choice != 'none': + # Do not skip special tokens when using chat template + # with Mistral parser as TOOL_CALL token is needed + # for tool detection. + # Note: we don't want skip_special_tokens=False + # with MistralTokenizer as it is incompatible + request.skip_special_tokens = False + return request + def extract_tool_calls( self, model_output: str, diff --git a/vllm/env_override.py b/vllm/env_override.py index a351fc78bb87dcee9b68a232f4ce20aa775153f7..9c084aa0afeb22f2132805e4dc7c92db89e343b2 100644 --- a/vllm/env_override.py +++ b/vllm/env_override.py @@ -8,8 +8,21 @@ import torch # that interact with vllm workers. # they are executed whenever `import vllm` is called. -# see https://github.com/NVIDIA/nccl/issues/1234 -os.environ['NCCL_CUMEM_ENABLE'] = '0' +if not os.path.exists('/dev/nvidia-caps-imex-channels'): + # normally, we disable NCCL_CUMEM_ENABLE because it + # will cost 1~2 GiB GPU memory with cudagraph+allreduce, + # see https://github.com/NVIDIA/nccl/issues/1234 + # for more details. + # However, NCCL requires NCCL_CUMEM_ENABLE to work with + # multi-node NVLink, typically on GB200-NVL72 systems. + # The ultimate way to detect multi-node NVLink is to use + # NVML APIs, which are too expensive to call here. + # As an approximation, we check the existence of + # /dev/nvidia-caps-imex-channels, used by + # multi-node NVLink to communicate across nodes. + # This will still cost some GPU memory, but it is worthwhile + # because we can get very fast cross-node bandwidth with NVLink. + os.environ['NCCL_CUMEM_ENABLE'] = '0' # see https://github.com/vllm-project/vllm/pull/15951 # it avoids unintentional cuda initialization from torch.cuda.is_available() diff --git a/vllm/envs.py b/vllm/envs.py index 5a49fe728629afa65908beca34ed146376eb8135..a062f620232893948516469f3a061d7667bcb18d 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -86,10 +86,12 @@ if TYPE_CHECKING: VLLM_DISABLED_KERNELS: list[str] = [] VLLM_USE_V1: bool = True VLLM_ROCM_USE_AITER: bool = False + VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False VLLM_ROCM_USE_AITER_LINEAR: bool = True VLLM_ROCM_USE_AITER_MOE: bool = True - VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE: bool = False VLLM_ROCM_USE_AITER_RMSNORM: bool = True + VLLM_ROCM_USE_AITER_MLA: bool = True + VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True @@ -107,6 +109,7 @@ if TYPE_CHECKING: VLLM_RAY_BUNDLE_INDICES: str = "" VLLM_CUDART_SO_PATH: Optional[str] = None VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True + VLLM_HPU_USE_DELAYED_SAMPLING: bool = False VLLM_DP_RANK: int = 0 VLLM_DP_RANK_LOCAL: int = -1 VLLM_DP_SIZE: int = 1 @@ -114,10 +117,10 @@ if TYPE_CHECKING: VLLM_DP_MASTER_PORT: int = 0 VLLM_MARLIN_USE_ATOMIC_ADD: bool = False VLLM_V0_USE_OUTLINES_CACHE: bool = False - VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False VLLM_TPU_BUCKET_PADDING_GAP: int = 0 VLLM_USE_DEEP_GEMM: bool = False VLLM_XGRAMMAR_CACHE_MB: int = 0 + VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 def get_default_cache_root(): @@ -586,6 +589,12 @@ environment_variables: dict[str, Callable[[], Any]] = { lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in ("true", "1")), + # Whether to use aiter paged attention. + # By default is disabled. + "VLLM_ROCM_USE_AITER_PAGED_ATTN": + lambda: (os.getenv("VLLM_ROCM_USE_AITER_PAGED_ATTN", "False").lower() in + ("true", "1")), + # use aiter linear op if aiter ops are enabled # The following list of related ops # - scaled_mm (per-tensor / rowwise) @@ -599,18 +608,21 @@ environment_variables: dict[str, Callable[[], Any]] = { lambda: (os.getenv("VLLM_ROCM_USE_AITER_MOE", "True").lower() in ("true", "1")), - # Whether to use aiter block scaled moe kernel. - # By default this is disabled. - "VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE": - lambda: - (os.getenv("VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE", "false").lower() in - ("true", "1")), - # use aiter rms norm op if aiter ops are enabled. "VLLM_ROCM_USE_AITER_RMSNORM": lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in ("true", "1")), + # Whether to use aiter mla ops. + # By default is enabled. + "VLLM_ROCM_USE_AITER_MLA": + lambda: (os.getenv("VLLM_ROCM_USE_AITER_MLA", "True").lower() in + ("true", "1")), + # use rocm skinny gemms + "VLLM_ROCM_USE_SKINNY_GEMM": + lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in + ("true", "1")), + # Pad the fp8 weights to 256 bytes for ROCm "VLLM_ROCM_FP8_PADDING": lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))), @@ -700,6 +712,12 @@ environment_variables: dict[str, Callable[[], Any]] = { lambda: os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() in ("1", "true"), + # Use delayed sampling for HPU to reduce host cpu overhead + # between each step. + "VLLM_HPU_USE_DELAYED_SAMPLING": + lambda: os.environ.get("VLLM_DELAYED_SAMPLING", "false").lower() in + ("1", "true"), + # Rank of the process in the data parallel setting "VLLM_DP_RANK": lambda: int(os.getenv("VLLM_DP_RANK", "0")), @@ -745,11 +763,6 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_V0_USE_OUTLINES_CACHE": lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1", - # If set, disables TPU-specific optimization for top-k & top-p sampling - "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION": - lambda: bool(int(os.environ["VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"])) - if "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION" in os.environ else None, - # Gap between padding buckets for the forward pass. So we have # 8, we will run forward pass with [16, 24, 32, ...]. "VLLM_TPU_BUCKET_PADDING_GAP": @@ -765,6 +778,16 @@ environment_variables: dict[str, Callable[[], Any]] = { # It can be changed with this variable if needed for some reason. "VLLM_XGRAMMAR_CACHE_MB": lambda: int(os.getenv("VLLM_XGRAMMAR_CACHE_MB", "512")), + + # Control the threshold for msgspec to use 'zero copy' for + # serialization/deserialization of tensors. Tensors below + # this limit will be encoded into the msgpack buffer, and + # tensors above will instead be sent via a separate message. + # While the sending side still actually copies the tensor + # in all cases, on the receiving side, tensors above this + # limit will actually be zero-copy decoded. + "VLLM_MSGPACK_ZERO_COPY_THRESHOLD": + lambda: int(os.getenv("VLLM_MSGPACK_ZERO_COPY_THRESHOLD", "256")), } # end-env-vars-definition @@ -803,7 +826,7 @@ def compute_hash() -> str: variables, ensure that it is included in the factors list if it affects the computation graph. For example, different values of VLLM_PP_LAYER_PARTITION will generate different computation - graphs, so it is included in the factors list. The env vars that + graphs, so it is included in the factors list. The env vars that affect the choice of different kernels or attention backends should also be included in the factors list. """ @@ -832,6 +855,7 @@ def compute_hash() -> str: if key in environment_variables: factorize(key) - hash_str = hashlib.md5(str(factors).encode()).hexdigest() + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() return hash_str \ No newline at end of file diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index 8c004c790fcbc9624e9d870b6d3ddfb3e8306a4b..2e4b47c1e24a0c70099069d95fdb0aa93f2f7683 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -34,13 +34,13 @@ class UniProcExecutor(ExecutorBase): if len(device_info) > 1: local_rank = int(device_info[1]) rank = 0 + is_driver_worker = True kwargs = dict( vllm_config=self.vllm_config, local_rank=local_rank, rank=rank, distributed_init_method=distributed_init_method, - is_driver_worker=(not self.parallel_config) - or (rank % self.parallel_config.tensor_parallel_size == 0), + is_driver_worker=is_driver_worker, ) self.collective_rpc("init_worker", args=([kwargs], )) self.collective_rpc("init_device") diff --git a/vllm/forward_context.py b/vllm/forward_context.py index e195a03c5cac8a38be7ca14c553c36cfc1d3ae41..06790d8ee2f8c55970a2e07e8696af88623adcff 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -11,6 +11,10 @@ import torch.distributed as dist import vllm.envs as envs from vllm.config import VllmConfig +from vllm.distributed.kv_transfer import (get_kv_transfer_group, + has_kv_transfer_group, + is_v1_kv_transfer_group) +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.logger import init_logger if TYPE_CHECKING: @@ -98,6 +102,17 @@ def set_forward_context(attn_metadata: Any, virtual_engine=virtual_engine, attn_metadata=attn_metadata, dp_metadata=dp_metadata) + + # KVConnector: trigger (possibly async) load before forward. + # Each attn layer will block until the reading is complete. + trigger_kv_transfer = (attn_metadata is not None + and has_kv_transfer_group() + and is_v1_kv_transfer_group()) + if trigger_kv_transfer: + kv_connector = get_kv_transfer_group() + assert isinstance(kv_connector, KVConnectorBase_V1) + kv_connector.start_load_kv(_forward_context) + try: yield finally: @@ -133,4 +148,12 @@ def set_forward_context(attn_metadata: Any, logger.info(("Batchsize forward time stats " "(batchsize, count, median_time(ms)): %s"), forward_stats) + + # KVConnector: each attn layer triggers (possibly async) save. + # Ensure all those operations complete before forward() is done. + if trigger_kv_transfer: + kv_connector = get_kv_transfer_group() + assert isinstance(kv_connector, KVConnectorBase_V1) + kv_connector.wait_for_save() + _forward_context = prev_context diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 6f8f2cd758f7bf9768b1a552dbb4e2ad980c7278..ca706e202836d8f5cc010d14655e5f8b064fe22f 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -2,10 +2,9 @@ from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType, - SingletonInputs, SingletonInputsAdapter, SingletonPrompt, - TextPrompt, TokenInputs, TokensPrompt, - build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, - token_inputs, zip_enc_dec_prompts) + SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs, + TokensPrompt, build_explicit_enc_dec_prompt, + to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts) from .registry import (DummyData, InputContext, InputProcessingContext, InputRegistry) @@ -27,7 +26,6 @@ __all__ = [ "EncoderDecoderInputs", "ProcessorInputs", "SingletonInputs", - "SingletonInputsAdapter", "build_explicit_enc_dec_prompt", "to_enc_dec_tuple_list", "zip_enc_dec_prompts", diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 02b9ae9f1fcbb625b2f95870ec214abfbf05492b..a75d73e2a7678c15cbcbec254c0731c419dfbaea 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -1,17 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 - from collections.abc import Iterable -from dataclasses import dataclass -from functools import cached_property from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast -import torch -from typing_extensions import NotRequired, TypedDict, TypeVar, assert_never +from typing_extensions import NotRequired, TypedDict, TypeVar if TYPE_CHECKING: - from vllm.multimodal import (MultiModalDataDict, MultiModalKwargs, - MultiModalPlaceholderDict) - from vllm.multimodal.inputs import MultiModalInputs + from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs class TextPrompt(TypedDict): @@ -147,46 +141,11 @@ class TokenInputs(TypedDict): The original prompt text corresponding to the token IDs, if available. """ - multi_modal_data: NotRequired["MultiModalDataDict"] - """ - Optional multi-modal data to pass to the model, - if the model supports it. - """ - - multi_modal_inputs: NotRequired["MultiModalKwargs"] - """ - Optional multi-modal inputs to pass to the model, - if the model supports it. - """ - - multi_modal_placeholders: NotRequired["MultiModalPlaceholderDict"] - """ - Placeholder ranges for the multi-modal data. - """ - - multi_modal_hashes: NotRequired[list[str]] - """ - The hashes of the multi-modal data. - """ - - mm_processor_kwargs: NotRequired[dict[str, Any]] - """ - Optional multi-modal processor kwargs to be forwarded to the - multimodal input mapper & processor. Note that if multiple modalities - have registered mappers etc for the model being considered, we attempt - to pass the mm_processor_kwargs to each of them. - """ - def token_inputs( prompt_token_ids: list[int], token_type_ids: Optional[list[int]] = None, prompt: Optional[str] = None, - multi_modal_data: Optional["MultiModalDataDict"] = None, - multi_modal_inputs: Optional["MultiModalKwargs"] = None, - multi_modal_hashes: Optional[list[str]] = None, - multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None, - mm_processor_kwargs: Optional[dict[str, Any]] = None, ) -> TokenInputs: """Construct :class:`TokenInputs` from optional values.""" inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids) @@ -195,16 +154,6 @@ def token_inputs( inputs["prompt"] = prompt if token_type_ids is not None: inputs["token_type_ids"] = token_type_ids - if multi_modal_data is not None: - inputs["multi_modal_data"] = multi_modal_data - if multi_modal_inputs is not None: - inputs["multi_modal_inputs"] = multi_modal_inputs - if multi_modal_hashes is not None: - inputs["multi_modal_hashes"] = multi_modal_hashes - if multi_modal_placeholders is not None: - inputs["multi_modal_placeholders"] = multi_modal_placeholders - if mm_processor_kwargs is not None: - inputs["mm_processor_kwargs"] = mm_processor_kwargs return inputs @@ -237,112 +186,6 @@ A processed :class:`SingletonPrompt` which can be passed to :class:`vllm.sequence.Sequence`. """ - -@dataclass -class SingletonInputsAdapter: - """ - Unified interface to access the components of :class:`SingletonInputs`. - """ - inputs: SingletonInputs - - @cached_property - def prompt(self) -> Optional[str]: - inputs = self.inputs - - if inputs["type"] == "token" or inputs["type"] == "multimodal": - return inputs.get("prompt") - - assert_never(inputs) # type: ignore[arg-type] - - @cached_property - def prompt_token_ids(self) -> list[int]: - inputs = self.inputs - - if inputs["type"] == "token" or inputs["type"] == "multimodal": - return inputs.get("prompt_token_ids", []) - - assert_never(inputs) # type: ignore[arg-type] - - @cached_property - def token_type_ids(self) -> list[int]: - inputs = self.inputs - - if inputs["type"] == "token" or inputs["type"] == "multimodal": - return inputs.get("token_type_ids", []) - - assert_never(inputs) # type: ignore[arg-type] - - @cached_property - def prompt_embeds(self) -> Optional[torch.Tensor]: - inputs = self.inputs - - if inputs["type"] == "token" or inputs["type"] == "multimodal": - return None - - assert_never(inputs) # type: ignore[arg-type] - - @cached_property - def multi_modal_data(self) -> "MultiModalDataDict": - inputs = self.inputs - - if inputs["type"] == "token": - return inputs.get("multi_modal_data", {}) - - if inputs["type"] == "multimodal": - return inputs.get("mm_kwargs", {}) - - assert_never(inputs) # type: ignore[arg-type] - - @cached_property - def multi_modal_inputs(self) -> Union[dict, "MultiModalKwargs"]: - inputs = self.inputs - - if inputs["type"] == "token": - return inputs.get("multi_modal_inputs", {}) - - if inputs["type"] == "multimodal": - return inputs.get("mm_kwargs", {}) - - assert_never(inputs) # type: ignore[arg-type] - - @cached_property - def multi_modal_hashes(self) -> list[str]: - inputs = self.inputs - - if inputs["type"] == "token": - return inputs.get("multi_modal_hashes", []) - - if inputs["type"] == "multimodal": - # only the case when we use MultiModalInputs - return inputs.get("mm_hashes", []) # type: ignore[return-value] - - assert_never(inputs) # type: ignore[arg-type] - - @cached_property - def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict": - inputs = self.inputs - - if inputs["type"] == "token": - return inputs.get("multi_modal_placeholders", {}) - - if inputs["type"] == "multimodal": - return inputs.get("mm_placeholders", {}) - - assert_never(inputs) # type: ignore[arg-type] - - @cached_property - def mm_processor_kwargs(self) -> dict[str, Any]: - inputs = self.inputs - - if inputs["type"] == "token": - return inputs.get("mm_processor_kwargs", {}) - - if inputs["type"] == "multimodal": - return {} - - assert_never(inputs) # type: ignore[arg-type] - - ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs] """ The inputs to :data:`vllm.inputs.InputProcessor`. diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 669fb96e6653a95eec2f9105400f19b104c7ae2a..0edb6da0620935ea6e0641aab5cc56f94c206b49 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -13,7 +13,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, MultiModalInputs) from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup +from vllm.transformers_utils.tokenizer_group import TokenizerGroup from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs, PromptType, SingletonInputs, SingletonPrompt, token_inputs) @@ -27,7 +27,7 @@ class InputPreprocessor: def __init__( self, model_config: ModelConfig, - tokenizer: Optional[BaseTokenizerGroup], + tokenizer: Optional[TokenizerGroup], mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ) -> None: super().__init__() @@ -36,7 +36,7 @@ class InputPreprocessor: self.tokenizer = tokenizer self.mm_registry = mm_registry - def get_tokenizer_group(self) -> BaseTokenizerGroup: + def get_tokenizer_group(self) -> TokenizerGroup: if self.tokenizer is None: raise ValueError("You cannot pass text prompts when " "`skip_tokenizer_init` is True") @@ -223,28 +223,6 @@ class InputPreprocessor: lora_request=lora_request, add_special_tokens=add_special_tokens) - def _can_process_multimodal(self) -> bool: - model_config = self.model_config - - if not model_config.is_multimodal_model: - raise ValueError("Your model does not support multi-modal inputs") - - # Interim measure so we can handle models that have yet to be - # updated to use the new multi-modal processor - can_process_multimodal = self.mm_registry.has_processor(model_config) - if not can_process_multimodal: - from vllm.model_executor.models.registry import _VLLM_MODELS - if not any(arch in _VLLM_MODELS - for arch in model_config.architectures): - logger.warning_once( - "Your model uses the legacy input pipeline, which will be " - "removed in an upcoming release. " - "Please upgrade to the new multi-modal processing pipeline " - "(https://docs.vllm.ai/en/latest/design/mm_processing.html)" - ) - - return can_process_multimodal - def _process_multimodal( self, prompt: Union[str, list[int]], @@ -258,8 +236,7 @@ class InputPreprocessor: returning the corresponding token IDs and metadata. """ # At the moment on model (PrithviGeoSpatialMAE) requires to be - # initialized without a tokenizer while using also multi-modal - # input. + # initialized without a tokenizer while using also multi-modal input if not self.tokenizer: tokenizer = object() # Dummy else: @@ -285,8 +262,7 @@ class InputPreprocessor: ) -> MultiModalInputs: """Async version of :meth:`_process_multimodal`.""" # At the moment on model (PrithviGeoSpatialMAE) requires to be - # initialized without a tokenizer while using also multi-modal - # input. + # initialized without a tokenizer while using also multi-modal input if not self.tokenizer: tokenizer = object() # Dummy else: @@ -343,7 +319,7 @@ class InputPreprocessor: multi_modal_data = tokens_content.get("multi_modal_data") mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") - if multi_modal_data is not None and self._can_process_multimodal(): + if multi_modal_data is not None: return self._process_multimodal( prompt_token_ids, multi_modal_data, @@ -355,8 +331,6 @@ class InputPreprocessor: return token_inputs( prompt_token_ids=prompt_token_ids, token_type_ids=token_type_ids, - multi_modal_data=multi_modal_data, - mm_processor_kwargs=mm_processor_kwargs, ) if parsed["type"] == "text": @@ -366,7 +340,7 @@ class InputPreprocessor: multi_modal_data = text_content.get("multi_modal_data") mm_processor_kwargs = text_content.get("mm_processor_kwargs") - if multi_modal_data is not None and self._can_process_multimodal(): + if multi_modal_data is not None: return self._process_multimodal( prompt_text, multi_modal_data, @@ -383,8 +357,6 @@ class InputPreprocessor: return token_inputs( prompt=prompt_text, prompt_token_ids=prompt_token_ids, - multi_modal_data=multi_modal_data, - mm_processor_kwargs=mm_processor_kwargs, ) assert_never(parsed) @@ -417,7 +389,7 @@ class InputPreprocessor: multi_modal_data = tokens_content.get("multi_modal_data") mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") - if multi_modal_data is not None and self._can_process_multimodal(): + if multi_modal_data is not None: return await self._process_multimodal_async( prompt_token_ids, multi_modal_data, @@ -426,11 +398,7 @@ class InputPreprocessor: return_mm_hashes=return_mm_hashes, ) - return token_inputs( - prompt_token_ids=prompt_token_ids, - multi_modal_data=multi_modal_data, - mm_processor_kwargs=mm_processor_kwargs, - ) + return token_inputs(prompt_token_ids=prompt_token_ids) if parsed["type"] == "text": text_content = parsed["content"] @@ -439,7 +407,7 @@ class InputPreprocessor: multi_modal_data = text_content.get("multi_modal_data") mm_processor_kwargs = text_content.get("mm_processor_kwargs") - if multi_modal_data is not None and self._can_process_multimodal(): + if multi_modal_data is not None: return await self._process_multimodal_async( prompt_text, multi_modal_data, @@ -456,8 +424,6 @@ class InputPreprocessor: return token_inputs( prompt=prompt_text, prompt_token_ids=prompt_token_ids, - multi_modal_data=multi_modal_data, - mm_processor_kwargs=mm_processor_kwargs, ) assert_never(parsed) @@ -594,15 +560,13 @@ class InputPreprocessor: decoder_inputs = self._prompt_to_llm_inputs(decoder_input) # For multimodal model, override decoder prompt from processor # with explicit decoder prompt. - if self.model_config.is_multimodal_model and ( - self._can_process_multimodal()): + if self.model_config.is_multimodal_model: encoder_inputs, decoder_inputs = ( self._separate_enc_dec_inputs_from_mm_processor_outputs( encoder_inputs, decoder_inputs)) else: inputs = self._prompt_to_llm_inputs(prompt) - if self.model_config.is_multimodal_model and ( - self._can_process_multimodal()): + if self.model_config.is_multimodal_model: # Encoder-Decoder Multimodal model encoder_inputs, decoder_inputs = ( self._separate_enc_dec_inputs_from_mm_processor_outputs( @@ -637,15 +601,13 @@ class InputPreprocessor: # For multimodal model, override decoder prompt from processor # with explicit decoder prompt. - if self.model_config.is_multimodal_model and ( - self._can_process_multimodal()): + if self.model_config.is_multimodal_model: encoder_inputs, decoder_inputs = ( self._separate_enc_dec_inputs_from_mm_processor_outputs( encoder_inputs, decoder_inputs)) else: inputs = await self._prompt_to_llm_inputs_async(prompt) - if self.model_config.is_multimodal_model and ( - self._can_process_multimodal()): + if self.model_config.is_multimodal_model: # Encoder-Decoder Multimodal model encoder_inputs, decoder_inputs = ( self._separate_enc_dec_inputs_from_mm_processor_outputs( diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 0579893e5d76743705c2f9c6e0feeb1e433062f5..4c334ab62d3e92507b155499229c1da48a225498 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -1,24 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 - -import functools -from collections import UserDict from collections.abc import Mapping from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Callable, NamedTuple, Optional, - Protocol, Union) +from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union -from torch import nn from transformers import BatchFeature, PretrainedConfig, ProcessorMixin -from typing_extensions import TypeVar, assert_never +from typing_extensions import TypeVar -from vllm.logger import init_logger from vllm.transformers_utils.processor import cached_processor_from_config from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides, - resolve_mm_processor_kwargs) - -from .data import ProcessorInputs, SingletonInputs -from .parse import split_enc_dec_inputs +from vllm.utils import resolve_mm_processor_kwargs if TYPE_CHECKING: from vllm.config import ModelConfig @@ -26,8 +16,6 @@ if TYPE_CHECKING: MultiModalRegistry) from vllm.sequence import SequenceData -logger = init_logger(__name__) - _T = TypeVar("_T") _C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig) _P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin) @@ -172,142 +160,23 @@ class InputProcessingContext(InputContext): raise RuntimeError(msg) from exc -N = TypeVar("N", bound=type[nn.Module]) - - class DummyData(NamedTuple): - """Dummy data used for profiling.""" + """ + Dummy data used for profiling. + + Note: This is only used in V0. + """ seq_data: "SequenceData" multi_modal_data: Optional["MultiModalDataDict"] = None multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None -class DummyDataFactory(Protocol): - - def __call__( - self, - ctx: InputContext, - seq_len: int, - mm_counts: Mapping[str, int], - **mm_processor_kwargs: Any, - ) -> DummyData: - """ - Create dummy data to be inputted into the model. - - Note: - :data:`InputProcessor` is not applied to the dummy data. - - The :code:`mm_processor_kwargs` are overrides provided at - initialization time to values in the config whose values - may affect the number of tokens per instance. - """ - ... - - -class _MultiModalCounts(UserDict[str, int]): - """ - Wraps `mm_counts` for a more informative error message - when attempting to access a plugin that does not exist. - """ - - def __getitem__(self, key: str) -> int: - try: - return super().__getitem__(key) - except KeyError as exc: - msg = (f"There is no multi-modal plugin with the key: {key}. " - f"Available keys: {set(self.keys())}") - raise KeyError(msg) from exc - - -InputProcessor = Callable[[InputContext, ProcessorInputs], ProcessorInputs] -"""Preprocess the inputs to the model.""" - - class InputRegistry: """ - A registry to dispatch data processing - according to the target model. + Note: This is only used in V0. """ - def __init__(self) -> None: - self._dummy_factories_by_model_type = \ - ClassRegistry[nn.Module, DummyDataFactory]() - self._dummy_encoder_factories_by_model_type = \ - ClassRegistry[nn.Module, DummyDataFactory]() - self._input_processors_by_model_type = \ - ClassRegistry[nn.Module, InputProcessor]() - - def _default_dummy_data_factory( - self, - ctx: InputContext, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> DummyData: - """ - The default dummy data factory represents the longest possible text - that can be inputted to the model. - - Note: - :data:`InputProcessor` is not applied to the dummy data. - """ - # Avoid circular import - from vllm.sequence import SequenceData - - return DummyData(SequenceData.from_prompt_token_counts((0, seq_len))) - - def register_dummy_data(self, factory: DummyDataFactory): - """ - Register a dummy data factory to a model class. - - During memory profiling, the provided function is invoked to create - dummy data to be inputted into the model. The resulting memory usage - should be an upper bound of what the model would use at inference time. - """ - - def wrapper(model_cls: N) -> N: - if self._dummy_factories_by_model_type.contains(model_cls, - strict=True): - logger.warning( - "Model class %s already has dummy data " - "registered to %s. It is overwritten by the new one.", - model_cls, self) - - self._dummy_factories_by_model_type[model_cls] = factory - - return model_cls - - return wrapper - - def _get_dummy_data_factory(self, model_cls: type[nn.Module]): - return self._dummy_factories_by_model_type \ - .get(model_cls, self._default_dummy_data_factory) - - def register_dummy_encoder_data(self, factory: DummyDataFactory): - """ - Register a dummy encoder data factory to a model class - - This is similar to :meth:`~register_dummy_data`, but for encoder input. - """ - - def wrapper(model_cls: N) -> N: - if self._dummy_encoder_factories_by_model_type.contains( - model_cls, strict=True): - logger.warning( - "Model class %s already has dummy encoder data " - "registered to %s. It is overwritten by the new one.", - model_cls, self) - - self._dummy_encoder_factories_by_model_type[model_cls] = factory - - return model_cls - - return wrapper - - def _get_dummy_encoder_data_factory(self, model_cls: type[nn.Module]): - return self._dummy_encoder_factories_by_model_type \ - .get(model_cls, self._default_dummy_data_factory) - def dummy_data_for_profiling( self, model_config: "ModelConfig", @@ -319,169 +188,25 @@ class InputRegistry: Create dummy data for profiling the memory usage of a model. The model is identified by ``model_config``. - - Note: - This should be called after - :meth:`~MultiModalRegistry.init_mm_limits_per_prompt`. """ # Avoid circular import - from vllm.model_executor.model_loader import get_model_architecture - from vllm.multimodal import MultiModalKwargs - from vllm.multimodal.profiling import MultiModalProfiler from vllm.sequence import SequenceData - if mm_registry.has_processor(model_config): - processor = mm_registry.create_processor(model_config, - disable_cache=True) - profiler = MultiModalProfiler(processor) - - dummy_data_v1 = (profiler.get_encoder_dummy_data(seq_len) - if is_encoder_data else - profiler.get_decoder_dummy_data(seq_len)) - _seq_data = SequenceData.from_seqs( - dummy_data_v1.prompt_token_ids) # type: ignore[attr-defined] - - dummy_data = DummyData( - seq_data=_seq_data, - multi_modal_data=getattr(dummy_data_v1, "multi_modal_data", - None), - multi_modal_placeholders=getattr(dummy_data_v1, - "multi_modal_placeholders", - None), - ) - else: - model_cls, _ = get_model_architecture(model_config) - if is_encoder_data: - dummy_factory = self._get_dummy_encoder_data_factory(model_cls) - else: - dummy_factory = self._get_dummy_data_factory(model_cls) - mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) - mm_processor_kwargs = get_allowed_kwarg_only_overrides( - dummy_factory, - overrides=model_config.mm_processor_kwargs, - requires_kw_only=False, - allow_var_kwargs=True, - ) - - dummy_data = dummy_factory(InputContext(model_config), seq_len, - _MultiModalCounts(mm_counts), - **mm_processor_kwargs) - - # Having more tokens is over-conservative but otherwise fine - num_tokens = dummy_data.seq_data.prompt_token_ids - if len(num_tokens) < seq_len: - if is_encoder_data: - logger.warning_once( - f"Expected at least {seq_len} dummy encoder tokens for " - f"profiling, but found {len(num_tokens)} tokens instead.") - else: - raise AssertionError( - f"Expected at least {seq_len} dummy tokens for profiling, " - f"but found {len(num_tokens)} tokens instead.") - - if (dummy_data.multi_modal_data is not None and - not isinstance(dummy_data.multi_modal_data, MultiModalKwargs)): - for k, v in dummy_data.multi_modal_data.items(): - num_items = len(v) if isinstance(v, list) else 1 - num_expected = mm_counts[k] - assert num_items >= num_expected, ( - f"Expected at least {num_expected} dummy '{k}' instances " - f"for profiling, but found {num_items} instances instead.") - - return dummy_data - - def _default_input_processor( - self, - ctx: InputContext, - inputs: ProcessorInputs, - **kwargs: object, - ) -> ProcessorInputs: - """The default input processor is a no-op.""" - return inputs - - def register_input_processor(self, processor: InputProcessor): - """ - Register an input processor to a model class. - - The provided function is invoked on each input to the model. This - happens before - :meth:`~vllm.multimodal.registry.MultiModalRegistry.map_input`. - """ - - def wrapper(model_cls: N) -> N: - if self._input_processors_by_model_type.contains(model_cls, - strict=True): - logger.warning( - "Model class %s already has input processor " - "registered to %s. It is overwritten by the new one.", - model_cls, self) - - self._input_processors_by_model_type[model_cls] = processor - - return model_cls + if not model_config.is_multimodal_model: + seq_data = SequenceData.from_prompt_token_counts((0, seq_len)) + return DummyData(seq_data=seq_data) - return wrapper + # Encoder dummy data does not contain multi-modal data + if is_encoder_data: + enc_data = mm_registry.get_encoder_dummy_data( + model_config, seq_len) + seq_data = SequenceData.from_seqs(enc_data.prompt_token_ids) + return DummyData(seq_data=seq_data) - def _get_model_input_processor(self, model_cls: type[nn.Module]): - return self._input_processors_by_model_type \ - .get(model_cls, self._default_input_processor) - - def _ensure_mm_kwargs( - self, - inputs: SingletonInputs, - mm_processor_kwargs: dict[str, Any], - ): - if inputs["type"] == "token": - # In case the input processor for that model fails to set it - if "mm_processor_kwargs" not in inputs: - inputs["mm_processor_kwargs"] = mm_processor_kwargs - elif inputs["type"] == "multimodal": - # Be more strict in V2 - assert "mm_kwargs" in inputs - else: - assert_never(inputs["type"]) # type: ignore[arg-type] - - def process_input(self, model_config: "ModelConfig", - inputs: ProcessorInputs) -> ProcessorInputs: - """ - Apply an input processor to an instance of model inputs. - - The model is identified by ``model_config``. - """ - # Avoid circular import - from vllm.model_executor.model_loader import get_model_architecture - - model_cls, _ = get_model_architecture(model_config) - processor = self._get_model_input_processor(model_cls) - - # Handle multimodal processor kwargs with priority: - # Inference kwargs -> Init kwargs -> {} - # If it's empty, it'll fall back to the default kwarg values - mm_processor_kwargs = resolve_mm_processor_kwargs( - model_config.mm_processor_kwargs, - inputs.get("mm_processor_kwargs", {}), # type: ignore - processor, - requires_kw_only=False, - allow_var_kwargs=True, - ) + dec_data = mm_registry.get_decoder_dummy_data(model_config, seq_len) - processed_inputs = processor( - InputContext(model_config), - inputs, - **mm_processor_kwargs, + return DummyData( + seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids), + multi_modal_data=dec_data.multi_modal_data, + multi_modal_placeholders=dec_data.multi_modal_placeholders, ) - - encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) - if encoder_inputs is not None: - self._ensure_mm_kwargs(encoder_inputs, mm_processor_kwargs) - if decoder_inputs is not None: - self._ensure_mm_kwargs(decoder_inputs, mm_processor_kwargs) - - return processed_inputs - - def create_input_processor(self, model_config: "ModelConfig"): - """ - Create an input processor (see :meth:`_process_input`) for a - specific model. - """ - return functools.partial(self.process_input, model_config) diff --git a/vllm/lora/resolver.py b/vllm/lora/resolver.py new file mode 100644 index 0000000000000000000000000000000000000000..6726ca9a903ff75f70cbfae0220e84ddf850d5bf --- /dev/null +++ b/vllm/lora/resolver.py @@ -0,0 +1,83 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import AbstractSet, Dict, Optional + +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest + +logger = init_logger(__name__) + + +class LoRAResolver(ABC): + """Base class for LoRA adapter resolvers. + + This class defines the interface for resolving and fetching LoRA adapters. + Implementations of this class should handle the logic for locating and + downloading LoRA adapters from various sources (e.g. S3, cloud storage, + etc.). + """ + + @abstractmethod + async def resolve_lora(self, base_model_name: str, + lora_name: str) -> Optional[LoRARequest]: + """Abstract method to resolve and fetch a LoRA model adapter. + + Implements logic to locate and download LoRA adapter based on the name. + Implementations might fetch from a blob storage or other sources. + + Args: + base_model_name: The name/identifier of the base model to resolve. + lora_name: The name/identifier of the LoRA model to resolve. + + Returns: + Optional[LoRARequest]: The resolved LoRA model information, or None + if the LoRA model cannot be found. + """ + pass + + +@dataclass +class _LoRAResolverRegistry: + resolvers: Dict[str, LoRAResolver] = field(default_factory=dict) + + def get_supported_resolvers(self) -> AbstractSet[str]: + """Get all registered resolver names.""" + return self.resolvers.keys() + + def register_resolver( + self, + resolver_name: str, + resolver: LoRAResolver, + ) -> None: + """Register a LoRA resolver. + Args: + resolver_name: Name to register the resolver under. + resolver: The LoRA resolver instance to register. + """ + if resolver_name in self.resolvers: + logger.warning( + "LoRA resolver %s is already registered, and will be " + "overwritten by the new resolver instance %s.", resolver_name, + resolver) + + self.resolvers[resolver_name] = resolver + + def get_resolver(self, resolver_name: str) -> LoRAResolver: + """Get a registered resolver instance by name. + Args: + resolver_name: Name of the resolver to get. + Returns: + The resolver instance. + Raises: + KeyError: If the resolver is not found in the registry. + """ + if resolver_name not in self.resolvers: + raise KeyError( + f"LoRA resolver '{resolver_name}' not found. " + f"Available resolvers: {list(self.resolvers.keys())}") + return self.resolvers[resolver_name] + + +LoRAResolverRegistry = _LoRAResolverRegistry() diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 610cbf87f66a30ea1e9c38abc4d98065a5de787d..883ca938ea1ac4cceffa474dead3d390ac3072f0 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -114,7 +114,7 @@ def parse_fine_tuned_lora_name( is_bias whether the tensor is lora bias. """ - # LoRA weight qualified name always starts with `base_model.model.`, + # LoRA weight qualified name usually starts with `base_model.model.`, # so we remove the prefix `base_model.model.` to make the following # mapping correctly. if "base_model.model." in name: @@ -123,18 +123,23 @@ def parse_fine_tuned_lora_name( # recover the prefix `base_model.model.` name = "base_model.model." + name + # In some situations, we may not start with `base_model.model.`. + # If we don't (e.g., ibm-granite/granite-speech-3.3-8b), + # we should keep the prefix intact. + start_index = 2 if "base_model.model." in name else 0 + parts = name.split(".") if parts[-1] == "weight" and (parts[-2] == "lora_A" or parts[-2] == "lora_B"): - new_name = ".".join(parts[2:-2]) + new_name = ".".join(parts[start_index:-2]) return new_name, parts[-2] == "lora_A", False if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B": - new_name = ".".join(parts[2:-1]) + new_name = ".".join(parts[start_index:-1]) return new_name, parts[-1] == "lora_embedding_A", False if parts[-1] == "bias": - new_name = ".".join(parts[2:-2]) + new_name = ".".join(parts[start_index:-2]) return new_name, False, True raise ValueError(f"{name} is unsupported LoRA weight") diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index d4ee1be9a445d6c8796bb8f4b335cce1f613c0c4..8fdcdcafa9806b2748807bdb054ec6e0b99303b7 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -65,7 +65,7 @@ def maybe_backend_fallback( fallback_or_error( guided_params, "xgrammar does not support advanced JSON schema features like " - "enums, patterns or numeric ranges.", "outlines") + "string length, item limits, or property bounds.", "outlines") # xgrammar only supports GBNF grammars, so we must convert Lark. # We must check if the grammar is likely Lark and if that diff --git a/vllm/model_executor/guided_decoding/guidance_decoding.py b/vllm/model_executor/guided_decoding/guidance_decoding.py index f19ebcbe420e3f0fdce04fe35c7f811142897d30..95b7c71107aab7b2f858a0b0f0997f24c6e6f4df 100644 --- a/vllm/model_executor/guided_decoding/guidance_decoding.py +++ b/vllm/model_executor/guided_decoding/guidance_decoding.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +import json from re import escape as regex_escape import llguidance @@ -7,6 +8,8 @@ from transformers import PreTrainedTokenizerBase from vllm.model_executor.guided_decoding.guidance_logits_processors import ( GuidanceLogitsProcessor) from vllm.sampling_params import GuidedDecodingParams +from vllm.v1.structured_output.backend_guidance import ( + process_for_additional_properties) def get_local_guidance_guided_decoding_logits_processor( @@ -20,9 +23,17 @@ def get_local_guidance_guided_decoding_logits_processor( grm = "" any_whitespace = 'disable-any-whitespace' not in \ guided_params.backend_options() - if guided_params.json: + if (guide_json := guided_params.json) is not None: + # Optionally set additionalProperties to False at the top-level + # By default, other backends do not allow additional top-level + # properties, so this makes guidance more similar to other backends + if 'no-additional-properties' in guided_params.backend_options(): + if not isinstance(guide_json, str): + guide_json = json.dumps(guide_json) + guide_json = process_for_additional_properties(guide_json) + grm = llguidance.LLMatcher.grammar_from_json_schema( - guided_params.json, + guide_json, overrides={"whitespace_pattern": guided_params.whitespace_pattern}, defaults={ "whitespace_flexible": any_whitespace, diff --git a/vllm/model_executor/guided_decoding/guided_fields.py b/vllm/model_executor/guided_decoding/guided_fields.py index db4ce26806c1ffb48b9ad31ac70af27938d69f1d..1593868a164aa0218ca98db04231cdf8e9e77543 100644 --- a/vllm/model_executor/guided_decoding/guided_fields.py +++ b/vllm/model_executor/guided_decoding/guided_fields.py @@ -27,14 +27,15 @@ class GuidedDecodingRequest: guided_decoding_backend: Optional[str] = None guided_whitespace_pattern: Optional[str] = None guided_json_object: Optional[bool] = None + structural_tag: Optional[str] = None def __post_init__(self): """Validate that some fields are mutually exclusive.""" - guide_count = sum([ - self.guided_json is not None, self.guided_regex is not None, - self.guided_choice is not None, self.guided_grammar is not None, - self.guided_json_object is not None - ]) + guide_count = sum(x is not None + for x in (self.guided_json, self.guided_regex, + self.guided_choice, self.guided_grammar, + self.guided_json_object, + self.structural_tag)) if guide_count > 1: raise ValueError( "You can only use one kind of guided decoding but multiple are " diff --git a/vllm/model_executor/guided_decoding/utils.py b/vllm/model_executor/guided_decoding/utils.py index ba7c1025269972e563bed0720c961ae86982460c..1ad1ef8fbf1662ad2c26ed96c58fde2759706742 100644 --- a/vllm/model_executor/guided_decoding/utils.py +++ b/vllm/model_executor/guided_decoding/utils.py @@ -10,16 +10,8 @@ def has_xgrammar_unsupported_json_features(schema: dict) -> bool: if not isinstance(obj, dict): return False - # Check for pattern restrictions - if "pattern" in obj: - return True - # Check for numeric ranges - if obj.get("type") in ("integer", "number") and any( - key in obj for key in [ - "minimum", "maximum", "exclusiveMinimum", - "exclusiveMaximum", "multipleOf" - ]): + if obj.get("type") in ("integer", "number") and ("multipleOf" in obj): return True # Check for array unsupported keywords diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 59f3c87278b0cb05abaed28151ac09f48b4ef61d..0eaf744b974e6001675221c7c5c15b2ba269b1d5 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -364,6 +364,7 @@ def get_act_fn(act_fn_name: str) -> nn.Module: _ACTIVATION_AND_MUL_REGISTRY = LazyDict({ "gelu": lambda: GeluAndMul(), "silu": lambda: SiluAndMul(), + "gelu_and_mul": lambda: GeluAndMul(), }) diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json new file mode 100644 index 0000000000000000000000000000000000000000..555d173644522713f502ee8ee21efa5f04916184 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000000000000000000000000000000000..5de5605d401c2e84b42134e4a3ed7e5811a8ffe3 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H20.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H20.json new file mode 100644 index 0000000000000000000000000000000000000000..2221e99cd1adccee247d3bc3f221e47210c9cbaf --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H20.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H200.json new file mode 100644 index 0000000000000000000000000000000000000000..74374c573f3fcb5d407b92fcb64de2a9d640f079 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000000000000000000000000000000..b34b6e4e8a8e7985384acc0b88975a9cb30384b1 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20.json new file mode 100644 index 0000000000000000000000000000000000000000..ab169a0183ddc11ace79bc480aefd7db154bea67 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000000000000000000000000000000..324ad7b22fedf6b353029e6fe38675fb73968419 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H200.json new file mode 100644 index 0000000000000000000000000000000000000000..ab6e15552909b795ad63eff23c3161fd29c7b824 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000000000000000000000000000000000..249359fb93d77432712a11f83e4cde87d8a8005f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000000000000000000000000000000..b4efc9b7e44ceca6da12658441d1303c71ae925b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H20.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H20.json new file mode 100644 index 0000000000000000000000000000000000000000..03dfc73b6c0a1157baeba25098b00e7a87cd3559 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H20.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000000000000000000000000000000..9c07695ba9101c1697ca839787fa01cc12abf4d7 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H200.json new file mode 100644 index 0000000000000000000000000000000000000000..beaac7f641e442734102dfadb36dce4083dec392 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H20.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H20.json new file mode 100644 index 0000000000000000000000000000000000000000..ebff99e26dc7fac0a3e4007593bd3821dbd65a6b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H20.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000000000000000000000000000000..857d11e488917b22dabd44f58de013bf61f754c6 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/README b/vllm/model_executor/layers/fused_moe/configs/README index 787bd061166468b3f213eced25aa963774fbd090..85970e2d1cea5b5dedb55c03562d093f3d439503 100644 --- a/vllm/model_executor/layers/fused_moe/configs/README +++ b/vllm/model_executor/layers/fused_moe/configs/README @@ -9,5 +9,4 @@ The example configurations provided are for the Mixtral model for TP2 on H100 and TP4 on A100. Mixtral has intermediate size N = 14336, i.e. for TP2 we have N = 7168 and for TP4 we have N = 3584. -Please feel free to tune the configurations using scripts in `benchmarks/kernels/benchmark_moe.py` -Some of the configurations files are copied from the SGLang repository. Thank you! +See `benchmark/kernels/benchmark_moe.py` on how to generate these config files. diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index d6a27aa0ddc47a4eb55fc530d0d091cf49d716de..960c7f834857162b97db224636f9024eef5ff6fb 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -15,7 +15,7 @@ def cutlass_moe_fp8( w1_scale: torch.Tensor, w2_scale: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, + topk_ids_: torch.Tensor, ab_strides1: torch.Tensor, c_strides1: torch.Tensor, ab_strides2: torch.Tensor, @@ -23,6 +23,7 @@ def cutlass_moe_fp8( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.half, + expert_map: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, ) -> torch.Tensor: """ @@ -57,12 +58,19 @@ def cutlass_moe_fp8( quantize the intermediate result between the gemms. Shape: scalar or [M] - out_dtype (torch.Tensor): The output tensor type. + - expert_map (Optional[torch.Tensor]): In the case of Expert parallel, + every Rank is responsible for a subset of experts. expert_map is a + mapping from global expert-id to local expert-id. When expert_map[i] + is -1, it means that this Rank is not responsible for global + expert-id i. + - apply_router_weight_on_input (bool): When true, the topk weights are + applied directly on the inputs. This is only applicable when topk is 1. Returns: - torch.Tensor: The fp16 output tensor after applying the MoE layer. """ - assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch" assert w1_q.dtype == torch.float8_e4m3fn assert w2_q.dtype == torch.float8_e4m3fn assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1" @@ -96,7 +104,13 @@ def cutlass_moe_fp8( k = w1_q.size(1) n = w2_q.size(1) - topk = topk_ids.size(1) + local_topk_ids = topk_ids_ + if expert_map is not None: + "Translate info from expert_map to topk_ids" + local_topk_ids = torch.where(expert_map[topk_ids_] != -1, + expert_map[topk_ids_], -1) + + topk = local_topk_ids.size(1) per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( a2_scale.numel() != 1 if a2_scale is not None else False) @@ -120,10 +134,23 @@ def cutlass_moe_fp8( dtype=torch.int32, device=device) - a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) - c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) - - ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, + a_map_initializer = torch.empty + c2_initializer = torch.empty + if expert_map is not None: + # With expert_map each Rank processes only a subset of experts. As + # a result not all of a_map and c2 tensors are filled. We fill it + # zeros for correctness. + a_map_initializer = torch.zeros + c2_initializer = torch.zeros + + a_map = a_map_initializer((local_topk_ids.numel()), + dtype=torch.int32, + device=device) + c_map = torch.empty((local_topk_ids.numel()), + dtype=torch.int32, + device=device) + + ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets, problem_sizes1, problem_sizes2, a_map, c_map, num_experts, n, k) @@ -131,7 +158,7 @@ def cutlass_moe_fp8( rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype) - c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype) + c2 = c2_initializer((m * topk, k), device=device, dtype=out_dtype) ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale, expert_offsets[:-1], problem_sizes1, ab_strides1, diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index ee158d7ee474eb3455296384572090d6e09b5edf..62614a59cbe9a6e233004ee14a173da5fd32851d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -5,17 +5,16 @@ from typing import Optional import torch +import vllm._custom_ops as ops from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, moe_align_block_size, try_get_optimal_moe_config) -from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.utils import direct_register_custom_op def get_scalar_type(num_bits: int, has_zp: bool): if has_zp: - assert num_bits == 4 - return scalar_types.uint4 + return scalar_types.uint4 if num_bits == 4 else scalar_types.uint8 else: return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128 @@ -27,9 +26,12 @@ def single_marlin_moe( gating_output: torch.Tensor, topk: int, renormalize: bool, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, g_idx: Optional[torch.Tensor] = None, sort_indices: Optional[torch.Tensor] = None, w_zeros: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: @@ -62,7 +64,7 @@ def single_marlin_moe( assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w.is_contiguous(), "Expert weights must be contiguous" - assert hidden_states.dtype == torch.float16 + assert hidden_states.dtype in [torch.float16, torch.bfloat16] assert num_bits in [4, 8] M, K = hidden_states.shape @@ -83,39 +85,54 @@ def single_marlin_moe( block_size_m = config['BLOCK_SIZE_M'] - sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) - - max_workspace_size = (N // 64) * 16 - workspace = torch.zeros(max_workspace_size, - dtype=torch.int, - device=hidden_states.device, - requires_grad=False) - - has_zero_point = w_zeros is not None - if w_zeros is None: - w_zeros = torch.empty((0, 0), - dtype=hidden_states.dtype, - device=hidden_states.device, - requires_grad=False) - - if g_idx is None: - g_idx = torch.empty((0, 0), - dtype=torch.int32, - device=hidden_states.device, - requires_grad=False) - - if sort_indices is None: - sort_indices = torch.empty((0), - dtype=torch.int32, - device=hidden_states.device, - requires_grad=False) - - scalar_type = get_scalar_type(num_bits, has_zero_point) + if global_num_experts == -1: + global_num_experts = E + sorted_token_ids, expert_ids, num_tokens_post_padded = \ + moe_align_block_size(topk_ids, block_size_m, E, expert_map) + + if workspace is None: + max_workspace_size = (max(2 * N, K) // 64) * \ + (sorted_token_ids.size(0) // block_size_m) + device = hidden_states.device + sms = torch.cuda.get_device_properties(device).multi_processor_count + max_workspace_size = min(max_workspace_size, sms) + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + device=device, + requires_grad=False) + + scalar_type = get_scalar_type(num_bits, w_zeros is not None) + intermediate_cache = torch.empty( + (M * topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) - intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( - hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, - w_zeros, g_idx, sort_indices, workspace, scalar_type.id, M, N, K, - is_k_full, E, topk, block_size_m, True, False) + ops.moe_wna16_marlin_gemm(hidden_states, + intermediate_cache, + w, + scales, + w_zeros, + g_idx, + sort_indices, + workspace, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + topk_weights, + moe_block_size=block_size_m, + top_k=topk, + mul_topk_weights=False, + is_ep=expert_map is not None, + b_q_type=scalar_type, + size_m=M, + size_n=N, + size_k=K, + is_k_full=is_k_full, + use_atomic_add=False, + use_fp32_reduce=True, + is_zp_float=False) + intermediate_cache = intermediate_cache.view(-1, topk, N) return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) @@ -127,9 +144,12 @@ def single_marlin_moe_fake( gating_output: torch.Tensor, topk: int, renormalize: bool, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, g_idx: Optional[torch.Tensor] = None, sort_indices: Optional[torch.Tensor] = None, w_zeros: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: @@ -144,24 +164,26 @@ direct_register_custom_op( ) -def fused_marlin_moe( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - gating_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - g_idx1: Optional[torch.Tensor] = None, - g_idx2: Optional[torch.Tensor] = None, - sort_indices1: Optional[torch.Tensor] = None, - sort_indices2: Optional[torch.Tensor] = None, - w1_zeros: Optional[torch.Tensor] = None, - w2_zeros: Optional[torch.Tensor] = None, - num_bits: int = 8, - is_k_full: bool = True, -) -> torch.Tensor: +def fused_marlin_moe(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + g_idx1: Optional[torch.Tensor] = None, + g_idx2: Optional[torch.Tensor] = None, + sort_indices1: Optional[torch.Tensor] = None, + sort_indices2: Optional[torch.Tensor] = None, + w1_zeros: Optional[torch.Tensor] = None, + w2_zeros: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, + inplace: bool = False) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. @@ -196,27 +218,12 @@ def fused_marlin_moe( 1] == w1.shape[1] * 16, "Hidden size mismatch w1" assert hidden_states.shape[1] == w2.shape[2] // ( num_bits // 2), "Hidden size mismatch w2" - assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype == torch.float16 + assert hidden_states.dtype in [torch.float16, torch.bfloat16] assert num_bits in [4, 8] - has_no_act_order = (g_idx1 is None and g_idx2 is None - and sort_indices1 is None and sort_indices2 is None) - has_all_act_order = (g_idx1 is not None and g_idx2 is not None - and sort_indices1 is not None - and sort_indices2 is not None) - assert has_no_act_order or has_all_act_order, ( - "g_idx and sorted_indices " - "must be all not None or must be all None") - - has_no_zp = w1_zeros is None and w2_zeros is None - has_all_zp = w1_zeros is not None and w2_zeros is not None - assert has_no_zp or has_all_zp, ("zero points must be both not None or " - "must be both None") - M, K = hidden_states.shape E = w1.shape[0] N = w2.shape[1] * 16 @@ -234,122 +241,128 @@ def fused_marlin_moe( block_size_m = config["BLOCK_SIZE_M"] - sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) - - max_workspace_size = (max(2 * N, K) // 64) * 16 - workspace = torch.zeros(max_workspace_size, - dtype=torch.int, - device=current_platform.device_type, - requires_grad=False) - - if has_no_zp: - w1_zeros = torch.empty((0, 0), - dtype=hidden_states.dtype, - device=hidden_states.device, - requires_grad=False) - w2_zeros = torch.empty((0, 0), - dtype=hidden_states.dtype, - device=hidden_states.device, - requires_grad=False) - - if has_no_act_order: - g_idx1 = torch.empty((0, 0), - dtype=torch.int32, - device=hidden_states.device, - requires_grad=False) - g_idx2 = torch.empty((0, 0), - dtype=torch.int32, - device=hidden_states.device, - requires_grad=False) - sort_indices1 = torch.empty((0), - dtype=torch.int32, - device=hidden_states.device, - requires_grad=False) - sort_indices2 = torch.empty((0, 0), - dtype=torch.int32, - device=hidden_states.device, - requires_grad=False) - - scalar_type1 = get_scalar_type(num_bits, has_all_zp) - scalar_type2 = get_scalar_type(num_bits, has_all_zp) + if global_num_experts == -1: + global_num_experts = E + sorted_token_ids, expert_ids, num_tokens_post_padded = \ + moe_align_block_size(topk_ids, block_size_m, global_num_experts, + expert_map) + + if workspace is None: + max_workspace_size = (max(2 * N, K) // 64) * \ + (sorted_token_ids.size(0) // block_size_m) + device = hidden_states.device + sms = torch.cuda.get_device_properties(device).multi_processor_count + max_workspace_size = min(max_workspace_size, sms * 4) + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + device=device, + requires_grad=False) + + scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None) + scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None) intermediate_cache2 = torch.empty( (M * topk_ids.shape[1], N), device=hidden_states.device, dtype=hidden_states.dtype, ) + intermediate_cache13 = torch.empty( + (M * topk_ids.shape[1] * max(2 * N, K), ), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache1 = intermediate_cache13[:M * topk_ids.shape[1] * 2 * N] + intermediate_cache1 = intermediate_cache1.view(-1, 2 * N) + intermediate_cache3 = intermediate_cache13[:M * topk_ids.shape[1] * K] + intermediate_cache3 = intermediate_cache3.view(-1, K) + + use_atomic_add = hidden_states.dtype == torch.half or \ + torch.cuda.get_device_capability(hidden_states.device)[0] >= 9 - intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe( + intermediate_cache1 = ops.moe_wna16_marlin_gemm( hidden_states, + intermediate_cache1, w1, - sorted_token_ids, - topk_weights, - topk_ids, w1_scale, w1_zeros, g_idx1, sort_indices1, workspace, - scalar_type1.id, - M, - 2 * N, - K, - is_k_full, - E, - topk, - block_size_m, - True, - False, - ) + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + topk_weights, + moe_block_size=block_size_m, + top_k=topk, + mul_topk_weights=False, + is_ep=expert_map is not None, + b_q_type=scalar_type1, + size_m=M, + size_n=2 * N, + size_k=K, + is_k_full=is_k_full, + use_atomic_add=use_atomic_add, + use_fp32_reduce=True, + is_zp_float=False) torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) - intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe( + if expert_map is not None: + intermediate_cache3.zero_() + + intermediate_cache3 = ops.moe_wna16_marlin_gemm( intermediate_cache2, + intermediate_cache3, w2, - sorted_token_ids, - topk_weights, - topk_ids, w2_scale, w2_zeros, g_idx2, sort_indices2, workspace, - scalar_type2.id, - M, - K, - N, - is_k_full, - E, - topk, - block_size_m, - False, - True, - ) - + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + topk_weights, + moe_block_size=block_size_m, + top_k=1, + mul_topk_weights=True, + is_ep=expert_map is not None, + b_q_type=scalar_type2, + size_m=M * topk, + size_n=K, + size_k=N, + is_k_full=is_k_full, + use_atomic_add=use_atomic_add, + use_fp32_reduce=True, + is_zp_float=False).view(-1, topk, K) + + output = hidden_states if inplace else torch.empty_like(hidden_states) return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), - dim=1) - - -def fused_marlin_moe_fake( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - gating_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - g_idx1: Optional[torch.Tensor] = None, - g_idx2: Optional[torch.Tensor] = None, - sort_indices1: Optional[torch.Tensor] = None, - sort_indices2: Optional[torch.Tensor] = None, - w1_zeros: Optional[torch.Tensor] = None, - w2_zeros: Optional[torch.Tensor] = None, - num_bits: int = 8, - is_k_full: bool = True, -) -> torch.Tensor: + dim=1, + out=output) + + +def fused_marlin_moe_fake(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + g_idx1: Optional[torch.Tensor] = None, + g_idx2: Optional[torch.Tensor] = None, + sort_indices1: Optional[torch.Tensor] = None, + sort_indices2: Optional[torch.Tensor] = None, + w1_zeros: Optional[torch.Tensor] = None, + w2_zeros: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, + inplace: bool = False) -> torch.Tensor: return torch.empty_like(hidden_states) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index dbc00c815f1d87c3d289dba2e1a9312accdae11f..64f044d0abf2009ff3795f2db060790f0a1c233a 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -23,9 +23,7 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import ( from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op -from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled, - rocm_aiter_fused_experts, - rocm_aiter_topk_softmax) +from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled logger = init_logger(__name__) device_name = current_platform.get_device_name().replace(" ", "_") @@ -1048,6 +1046,7 @@ def get_default_config( "num_warps": 4, "num_stages": 3, } + # elif dtype in ["int4_w4a16", "int8_w8a16"] and block_shape is not None: # # moe wna16 kernels # # only set BLOCK_SIZE_M @@ -1063,6 +1062,18 @@ def get_default_config( # config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1} # else: # config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1} + elif is_marlin: + for block_size_m in [8, 16, 32, 48, 64]: + if M * topk / E / block_size_m < 0.9: + break + return {"BLOCK_SIZE_M": block_size_m} + elif M <= E: + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } else: config = { "BLOCK_SIZE_M": 64, @@ -1070,14 +1081,7 @@ def get_default_config( "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8, } - # A heuristic: fused marlin works faster with this config for small M - if M <= E or (is_marlin and M <= 32): - config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - } + if use_nn_moe: config["num_ldmatrixes"] = 1 return config @@ -1138,6 +1142,7 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor, def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]: if is_rocm_aiter_moe_enabled(): + from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax return rocm_aiter_topk_softmax return vllm_topk_softmax @@ -1401,6 +1406,7 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor: def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]: if is_rocm_aiter_moe_enabled(): + from .rocm_aiter_fused_moe import rocm_aiter_fused_experts return rocm_aiter_fused_experts if inplace: return torch_vllm_inplace_fused_experts diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 5399447c26f865ff5cd56783d6f9da855008ff8d..b8a3dc8ca57b1e82b23c61769755f6d231f43b2e 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -131,12 +131,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: super().process_weights_after_loading(layer) - layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight( - layer.w13_weight.data), - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight( - layer.w2_weight.data), - requires_grad=False) + # Padding the weight for better performance on ROCm + layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data) + layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data) # Lazy import to avoid importing triton. from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( is_rocm_aiter_moe_enabled, shuffle_weights) @@ -145,10 +142,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): shuffled_w13, shuffled_w2 = shuffle_weights( layer.w13_weight.data, layer.w2_weight.data) - layer.w13_weight = torch.nn.Parameter(shuffled_w13, - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(shuffled_w2, - requires_grad=False) + layer.w13_weight.data = shuffled_w13 + layer.w2_weight.data = shuffled_w2 if current_platform.is_cpu(): if current_platform.get_cpu_architecture() == CpuArchEnum.X86: @@ -446,6 +441,7 @@ class FusedMoE(torch.nn.Module): if params_dtype is None: params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype # Note: here we guard against accessing the TP and DP groups when # uninitialized (this happens when testing) @@ -496,6 +492,7 @@ class FusedMoE(torch.nn.Module): self.global_num_experts = num_experts assert intermediate_size % self.tp_size == 0 + self.hidden_size = hidden_size self.intermediate_size_per_partition = intermediate_size // self.tp_size self.reduce_results = reduce_results self.renormalize = renormalize diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index ac158a7eee53450f6267d5330eeff5bf0758e630..acaa93f5a23edceccec5dc839139674ad2ba52fd 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -1,126 +1,385 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional +from functools import cache +from typing import List, Optional, Tuple import torch -import vllm.envs as envs +from vllm import envs from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op +@cache def is_rocm_aiter_moe_enabled() -> bool: return current_platform.is_rocm() \ and envs.VLLM_ROCM_USE_AITER_MOE \ - and envs.VLLM_ROCM_USE_AITER \ + and envs.VLLM_ROCM_USE_AITER -def is_rocm_aiter_block_scaled_moe_enabled() -> bool: - return is_rocm_aiter_moe_enabled() and \ - envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE - - -def rocm_aiter_fused_experts( - *, +def rocm_aiter_asm_moe_tkw1_impl( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - topk_weights: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + fc1_scale: Optional[torch.Tensor] = None, + fc2_scale: Optional[torch.Tensor] = None, + fc1_smooth_scale: Optional[torch.Tensor] = None, + fc2_smooth_scale: Optional[torch.Tensor] = None, + a16: bool = False, + per_tensor_quant_scale: Optional[torch.Tensor] = None, + expert_mask: Optional[torch.Tensor] = None, + activation_str: str = "silu") -> torch.Tensor: + + from aiter import ActivationType + from aiter.fused_moe_bf16_asm import asm_moe_tkw1 + + activation = \ + ActivationType.Gelu if activation_str == "gelu" else ActivationType.Silu + + return asm_moe_tkw1(hidden_states, + w1, + w2, + topk_weight, + topk_ids, + fc1_scale=fc1_scale, + fc2_scale=fc2_scale, + fc1_smooth_scale=fc1_smooth_scale, + fc2_smooth_scale=fc2_smooth_scale, + a16=a16, + per_tensor_quant_scale=per_tensor_quant_scale, + expert_mask=expert_mask, + activation=activation) + + +def rocm_aiter_asm_moe_tkw1_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, topk_ids: torch.Tensor, - use_fp8_w8a8: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None, + fc1_scale: Optional[torch.Tensor] = None, + fc2_scale: Optional[torch.Tensor] = None, + fc1_smooth_scale: Optional[torch.Tensor] = None, + fc2_smooth_scale: Optional[torch.Tensor] = None, + a16: bool = False, + per_tensor_quant_scale: Optional[torch.Tensor] = None, expert_mask: Optional[torch.Tensor] = None, - **kwagrs # Ignore additional keyword arguments -) -> torch.Tensor: + activation_str: str = "silu") -> torch.Tensor: + return torch.empty_like(hidden_states) + + +def rocm_aiter_ck_moe_impl(hidden_states: torch.Tensor, w1: torch.Tensor, + w2: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor) -> torch.Tensor: + from aiter import ck_moe + return ck_moe(hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids) + - import aiter as rocm_aiter +def rocm_aiter_ck_moe_fake(hidden_states: torch.Tensor, w1: torch.Tensor, + w2: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl( + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + hidden_states_dtype: torch.dtype, + expert_mask: torch.Tensor, + a1: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a1_scale: torch.Tensor, + block_shape: List[int], + smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor: + from aiter import fmoe_fp8_blockscale_g1u1 + from aiter.fused_moe_bf16_asm import moe_sorting_ck + + topk = topk_ids.shape[1] + model_dim = w1.shape[-1] + local_E = E = w1.shape[0] + if expert_mask is not None: + E = expert_mask.numel() + + ( + sorted_token_ids, + sorted_weight_buf, + sorted_expert_ids, + num_valid_ids, + out_asm, + ) = moe_sorting_ck(topk_ids, + topk_weights, + E, + model_dim, + hidden_states_dtype, + expert_mask=expert_mask) + + fmoe_fp8_blockscale_g1u1(out_asm, a1, w1, w2, sorted_token_ids, + sorted_weight_buf, sorted_expert_ids, + num_valid_ids, topk, w1_scale.view(local_E, -1), + w2_scale.view(local_E, -1), + a1_scale.t().contiguous(), *block_shape, + smooth_scale) + + return out_asm + + +def rocm_aiter_fmoe_fp8_blockscale_g1u1_fake( + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + hidden_states_dtype: torch.dtype, + expert_mask: torch.Tensor, + a1: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a1_scale: torch.Tensor, + block_shape: List[int], + smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor: + + return torch.empty_like(a1, dtype=torch.bf16) + + +def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + fc1_scale: Optional[torch.Tensor] = None, + fc2_scale: Optional[torch.Tensor] = None, + fc1_smooth_scale: Optional[torch.Tensor] = None, + fc2_smooth_scale: Optional[torch.Tensor] = None, + a16: bool = False, + activation: str = "silu") -> torch.Tensor: import aiter.fused_moe_bf16_asm as rocm_aiter_asm_fmoe + from aiter import ActivationType + + assert activation in ["silu", "gelu"], "The given activation:" \ + f" {activation}" \ + " is not supported in" \ + " AITER." + if activation == "silu": + aiter_activation = ActivationType.Silu + else: + aiter_activation = ActivationType.Gelu + + return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weight=topk_weight, + topk_ids=topk_ids, + fc1_scale=fc1_scale, + fc2_scale=fc2_scale, + fc1_smooth_scale=fc1_smooth_scale, + fc2_smooth_scale=fc2_smooth_scale, + a16=a16, + activation=aiter_activation) + + +def rocm_aiter_asm_moe_fake(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + fc1_scale: Optional[torch.Tensor] = None, + fc2_scale: Optional[torch.Tensor] = None, + fc1_smooth_scale: Optional[torch.Tensor] = None, + fc2_smooth_scale: Optional[torch.Tensor] = None, + a16: bool = False, + activation: str = "silu") -> torch.Tensor: + return torch.empty_like(hidden_states) + + +def rocm_aiter_topk_softmax_impl(topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool) -> None: + from aiter import topk_softmax + topk_softmax(topk_weights, topk_indices, token_expert_indices, + gating_output, renormalize) + + +def rocm_aiter_topk_softmax_fake(topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool) -> None: + pass + + +if current_platform.is_rocm(): + + direct_register_custom_op( + op_name="rocm_aiter_asm_moe_tkw1", + op_func=rocm_aiter_asm_moe_tkw1_impl, + mutates_args=[], + fake_impl=rocm_aiter_asm_moe_tkw1_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_ck_moe", + op_func=rocm_aiter_ck_moe_impl, + mutates_args=[], + fake_impl=rocm_aiter_ck_moe_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_fmoe_fp8_blockscale_g1u1", + op_func=rocm_aiter_fmoe_fp8_blockscale_g1u1_impl, + mutates_args=[], + fake_impl=rocm_aiter_fmoe_fp8_blockscale_g1u1_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_asm_moe", + op_func=rocm_aiter_asm_moe_impl, + mutates_args=[], + fake_impl=rocm_aiter_asm_moe_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_topk_softmax", + op_func=rocm_aiter_topk_softmax_impl, + mutates_args=["topk_weights", "topk_indices", "token_expert_indices"], + fake_impl=rocm_aiter_topk_softmax_fake, + dispatch_key=current_platform.dispatch_key, + ) + + +def rocm_aiter_fused_experts(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, + allow_deep_gemm: bool = False) -> torch.Tensor: from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8) - if envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE and use_fp8_w8a8: + # All AITER Fused MoE kernels are expecting the following datatypes + topk_weights = topk_weights.to(torch.float32) + topk_ids = topk_ids.to(torch.int32) + + # w8a8 block-scaled + if block_shape is not None and use_fp8_w8a8: + assert not apply_router_weight_on_input, ( + "apply_router_weight_on_input is not supported for block scaled moe" + ) assert w1_scale is not None assert w2_scale is not None - local_E = E = w1.shape[0] - if expert_mask is not None: - E = expert_mask.numel() - - topk = topk_ids.shape[1] - model_dim = w1.shape[-1] - dtype = hidden_states.dtype # The default block sizes are 128 in AITER. - if block_shape is None: - block_shape = [128, 128] - - scale_blk_k = block_shape[1] - - ( - sorted_token_ids, - sorted_weight_buf, - sorted_expert_ids, - num_valid_ids, - out_asm, - ) = rocm_aiter_asm_fmoe.moe_sorting_ck(topk_ids, - topk_weights, - E, - model_dim, - dtype, - expert_mask=expert_mask) - - a1, a1_scale = per_token_group_quant_fp8(hidden_states, scale_blk_k) - rocm_aiter.fmoe_fp8_blockscale_g1u1( - out_asm, - a1, + block_shape = [128, 128] if block_shape is None else block_shape + + a1, a1_scale = per_token_group_quant_fp8(hidden_states, block_shape[1]) + + return torch.ops.vllm.rocm_aiter_fmoe_fp8_blockscale_g1u1( + topk_ids, topk_weights, hidden_states.dtype, expert_map, a1, w1, + w2, w1_scale, w2_scale, a1_scale, block_shape, None) + + # w8a8 per-channel quantization + elif per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8: + # AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input` + # This applies topk_weights on the GEMM output of the first FC layer + # rather than the second FC. + assert (topk_weights.dim() == 2 + ), "`topk_weights` should be in shape (num_tokens, topk)" + assert topk_weights.shape[-1] == 1, ( + "Only support topk=1 when" + " `apply_router_weight_on_input` is True") + + return torch.ops.vllm.rocm_aiter_asm_moe_tkw1( + hidden_states, w1, w2, - sorted_token_ids, - sorted_weight_buf, - sorted_expert_ids, - num_valid_ids, - topk, - w1_scale.view(local_E, -1), - w2_scale.view(local_E, -1), - a1_scale.t().contiguous(), - block_shape[0], - block_shape[1], - None, - ) - return out_asm - + topk_weights, + topk_ids, + fc1_scale=w1_scale, + fc2_scale=w2_scale, + fc1_smooth_scale=None, + fc2_smooth_scale=None, + a16=False, + per_tensor_quant_scale=None, + expert_mask=expert_map, + activation_str=activation) + + # w8a8 per-tensor activation per-tensor weight elif use_fp8_w8a8: - return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states, - w1=w1, - w2=w2, - topk_weight=topk_weights, - topk_ids=topk_ids, - fc1_scale=w1_scale, - fc2_scale=w2_scale, - fc1_smooth_scale=None, - fc2_smooth_scale=None, - a16=False) - - return rocm_aiter.ck_moe(hidden_states=hidden_states, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids) + assert not apply_router_weight_on_input, ( + "apply_router_weight_on_input is not supported for fp8_w8a8") + return torch.ops.vllm.rocm_aiter_asm_moe(hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weight=topk_weights, + topk_ids=topk_ids, + fc1_scale=w1_scale, + fc2_scale=w2_scale, + fc1_smooth_scale=None, + fc2_smooth_scale=None, + a16=False, + activation=activation) + if apply_router_weight_on_input: + assert (topk_weights.dim() == 2 + ), "`topk_weights` should be in shape (num_tokens, topk)" + _, topk = topk_weights.shape + assert ( + topk == 1 + ), "Only support topk=1 when `apply_router_weight_on_input` is True" + + hidden_states = hidden_states * topk_weights.to(hidden_states.dtype) + topk_ids = topk_ids.to(torch.int32) + topk_weights = torch.ones_like(topk_weights, dtype=torch.float32) + + # w16a16 fallback to rocm_aiter_ck_moe w16a16 + return torch.ops.vllm.rocm_aiter_ck_moe(hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids) def rocm_aiter_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor, token_expert_indices: torch.Tensor, gating_output: torch.Tensor, - renormalize: bool) -> tuple[torch.Tensor, ...]: - import aiter as rocm_aiter - rocm_aiter.topk_softmax(topk_weights, topk_indices, token_expert_indices, - gating_output, renormalize) - + renormalize: bool) -> Tuple[torch.Tensor, ...]: + torch.ops.vllm.rocm_aiter_topk_softmax(topk_weights, topk_indices, + token_expert_indices, gating_output, + renormalize) return topk_weights, topk_indices -def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]: +def shuffle_weights(*tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]: """ Applies shuffle_weight function from AITER to each input tensor and returns them. @@ -129,15 +388,14 @@ def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]: *tensors: Variable number of torch.Tensor objects. Returns: - A tuple of shuffled tensors. + A Tuple of shuffled tensors. """ from aiter.ops.shuffle import shuffle_weight - return tuple(shuffle_weight(tensor) for tensor in tensors) def expand_weights(*tensors: torch.Tensor, - expansion_dims: list[int]) -> tuple[torch.Tensor, ...]: + expansion_dims: list[int]) -> Tuple[torch.Tensor, ...]: """ Expands the dimensions of input tensors. @@ -147,7 +405,7 @@ def expand_weights(*tensors: torch.Tensor, corresponding to each tensor. Returns: - A tuple of tensors with expanded dimensions. + A Tuple of tensors with expanded dimensions. """ assert len(tensors) == len(expansion_dims), \ diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 402670b606b60d638453662969b9abbd49632b9f..a4ae8978f86a1ef4e6bb04c005b5de987fce3423 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -185,7 +185,8 @@ class RMSNorm(CustomOp): x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - from vllm_hpu_extension.ops import HPUFusedRMSNorm + from vllm_hpu_extension.kernels import rms_norm + HPUFusedRMSNorm = rms_norm() if HPUFusedRMSNorm is None: return self.forward_native(x, residual) if residual is not None: diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 0172aa80998ced4e982e6456104a754be24c095b..ea4195d1aac60778cebd3770ee8eddeba1346607 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -6,7 +6,6 @@ from typing import Any, Literal, Optional, Union import torch import torch.nn as nn -import torch.nn.functional as F from torch.nn.parameter import Parameter, UninitializedParameter from vllm.distributed import (divide, get_tensor_model_parallel_rank, @@ -17,6 +16,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.utils import dispatch_unquantized_gemm # yapf: disable from vllm.model_executor.parameter import (BasevLLMParameter, BlockQuantScaleParameter, @@ -34,6 +34,8 @@ logger = init_logger(__name__) WEIGHT_LOADER_V2_SUPPORTED = [ "CompressedTensorsLinearMethod", + "BitBLASLinearMethod", + "GPTQBitBLASLinearMethod", "AWQMarlinLinearMethod", "AWQLinearMethod", "GPTQMarlinLinearMethod", @@ -54,6 +56,15 @@ WEIGHT_LOADER_V2_SUPPORTED = [ ] +def adjust_bitblas_shard(param, shard_size, shard_offset): + bitblas_tile_size = getattr(param, "bitblas_tile_size", None) + if bitblas_tile_size is not None: + return (shard_size // bitblas_tile_size, + shard_offset // bitblas_tile_size) + + return shard_size, shard_offset + + def adjust_marlin_shard(param, shard_size, shard_offset): marlin_tile_size = getattr(param, "marlin_tile_size", None) if marlin_tile_size is None: @@ -208,7 +219,7 @@ class UnquantizedLinearMethod(LinearMethodBase): else: return torch.matmul(x, layer.weight) else: - return F.linear(x, layer.weight, bias) + return dispatch_unquantized_gemm()(x, layer.weight, bias) class LinearBase(torch.nn.Module): @@ -635,6 +646,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear): shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) + shard_size, shard_offset = adjust_bitblas_shard( + param, shard_size, shard_offset) + if use_bitsandbytes_4bit: index = list(itertools.accumulate([0] + self.output_sizes)) orig_offsets = { @@ -666,6 +680,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) + shard_size, shard_offset = adjust_bitblas_shard( + param, shard_size, shard_offset) use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) @@ -936,6 +952,15 @@ class QKVParallelLinear(ColumnParallelLinear): shard_offset = self._get_shard_offset_mapping(loaded_shard_id) shard_size = self._get_shard_size_mapping(loaded_shard_id) + # Note(simon): This is needed for Qwen3's fp8 quantization. + if isinstance(param, BlockQuantScaleParameter): + assert self.quant_method is not None + assert hasattr(self.quant_method, "quant_config") + weight_block_size = self.quant_method.quant_config.weight_block_size + block_n, _ = weight_block_size[0], weight_block_size[1] + shard_offset = (shard_offset + block_n - 1) // block_n + shard_size = (shard_size + block_n - 1) // block_n + param.load_qkv_weight(loaded_weight=loaded_weight, num_heads=self.num_kv_head_replicas, shard_id=loaded_shard_id, diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index b31b980fbe84a5b1de45ceac0c702498bc1dfe05..9fbad9d2f91e559461353cc6acce4258410ed446 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -10,8 +10,10 @@ from packaging import version from vllm import _custom_ops as ops from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.triton_utils import HAS_TRITON -TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") +TRITON3 = HAS_TRITON and (version.parse(triton.__version__) + >= version.parse("3.0.0")) if TRITON3: diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 92904937dbab9df9d9d5e023f8f0ea1d196c2e8e..a58b4e6b7f6a8ae7068b90632835b4b8b98fa264 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, List, Type +from typing import Literal, Type, get_args from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -QUANTIZATION_METHODS: List[str] = [ +QuantizationMethods = Literal[ "aqlm", "awq", "deepspeedfp", @@ -15,12 +15,12 @@ QUANTIZATION_METHODS: List[str] = [ "fbgemm_fp8", "modelopt", "nvfp4", - # The order of gptq methods is important for config.py iteration over - # override_quantization_method(..) "marlin", + "bitblas", "gguf", "gptq_marlin_24", "gptq_marlin", + "gptq_bitblas", "awq_marlin", "gptq", "compressed-tensors", @@ -36,6 +36,7 @@ QUANTIZATION_METHODS: List[str] = [ "blockwise_int8", "w8a8_int8" ] +QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) # The customized quantization methods which will be added to this dict. _CUSTOMIZED_METHOD_TO_QUANT_CONFIG = {} @@ -87,6 +88,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: from .aqlm import AQLMConfig from .awq import AWQConfig from .awq_marlin import AWQMarlinConfig + from .bitblas import BitBLASConfig from .bitsandbytes import BitsAndBytesConfig from .compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensorsConfig) @@ -96,6 +98,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: from .fp8 import Fp8Config from .gguf import GGUFConfig from .gptq import GPTQConfig + from .gptq_bitblas import GPTQBitBLASConfig from .gptq_marlin import GPTQMarlinConfig from .gptq_marlin_24 import GPTQMarlin24Config from .hqq_marlin import HQQMarlinConfig @@ -111,7 +114,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: from .blockwise_int8 import BlockInt8Config from .w8a8_int8 import W8A8Int8Config - method_to_config: Dict[str, Type[QuantizationConfig]] = { + method_to_config: dict[str, Type[QuantizationConfig]] = { "aqlm": AQLMConfig, "awq": AWQConfig, "deepspeedfp": DeepSpeedFPConfig, @@ -120,12 +123,12 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: "fbgemm_fp8": FBGEMMFp8Config, "modelopt": ModelOptFp8Config, "nvfp4": ModelOptNvFp4Config, - # The order of gptq methods is important for config.py iteration over - # override_quantization_method(..) "marlin": MarlinConfig, + "bitblas": BitBLASConfig, "gguf": GGUFConfig, "gptq_marlin_24": GPTQMarlin24Config, "gptq_marlin": GPTQMarlinConfig, + "gptq_bitblas": GPTQBitBLASConfig, "awq_marlin": AWQMarlinConfig, "gptq": GPTQConfig, "compressed-tensors": CompressedTensorsConfig, @@ -150,6 +153,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: __all__ = [ "QuantizationConfig", + "QuantizationMethods", "get_quantization_config", "QUANTIZATION_METHODS", -] +] \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index cb1d5400f3a077a3b6f8ac28cf26e4ae616e7184..ef4a7765d61efe0a5cd80111f124341c2e8b95e6 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -17,14 +17,13 @@ from vllm.model_executor.layers.quantization.awq import (AWQConfig, is_layer_skipped_awq) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, - check_marlin_supports_layer, marlin_make_empty_g_idx, - marlin_make_workspace, marlin_moe_permute_scales, marlin_permute_scales, - moe_awq_to_marlin_zero_points, verify_marlin_supported, - verify_marlin_supports_shape) + check_marlin_supports_layer, check_moe_marlin_supports_layer, + marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, + marlin_permute_scales, moe_awq_to_marlin_zero_points, + verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import (GroupQuantScaleParameter, PackedvLLMParameter) @@ -136,12 +135,15 @@ class AWQMarlinConfig(QuantizationConfig): self.full_config).get_quant_method(layer, prefix) return AWQMarlinLinearMethod(self) elif isinstance(layer, FusedMoE): - if layer.local_num_experts > 32: - # For MoEs with many experts the moe_wna16 kernel is faster + from vllm.model_executor.layers.quantization.moe_wna16 import ( + MoeWNA16Config) + if not check_moe_marlin_supports_layer(layer, self.group_size): + logger.warning_one( + f"Layer '{prefix}' is not supported by AWQMoeMarlin. " + "Falling back to Moe WNA16 kernels.") return MoeWNA16Config.from_config( self.full_config).get_quant_method(layer, prefix) - else: - return AWQMoEMethod(self) + return AWQMoEMethod(self) return None @classmethod @@ -391,6 +393,13 @@ class AWQMoEMethod(FusedMoEMethodBase): layer.register_parameter("w2_qzeros", w2_qzeros) set_weight_attrs(w2_qzeros, extra_weight_attrs) + device = layer.w13_qweight.device + sms = torch.cuda.get_device_properties(device).multi_processor_count + layer.workspace = torch.zeros((sms * 4, ), + dtype=torch.int, + device=device, + requires_grad=False) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: num_experts = layer.w13_qweight.shape[0] device = layer.w13_qweight.device @@ -473,10 +482,7 @@ class AWQMoEMethod(FusedMoEMethodBase): activation: str = "silu", ) -> torch.Tensor: assert activation == "silu", "Only SiLU activation is supported." - if expert_map is not None: - raise NotImplementedError( - "Expert Parallelism is not supported for " - "fused Marlin MoE method.") + if apply_router_weight_on_input: raise NotImplementedError( "Apply router weight on input is not supported for" @@ -503,7 +509,10 @@ class AWQMoEMethod(FusedMoEMethodBase): router_logits, topk_weights, topk_ids, + global_num_experts=global_num_experts, + expert_map=expert_map, w1_zeros=layer.w13_qzeros, w2_zeros=layer.w2_qzeros, + workspace=layer.workspace, num_bits=self.quant_config.weight_bits, ) diff --git a/vllm/model_executor/layers/quantization/bitblas.py b/vllm/model_executor/layers/quantization/bitblas.py new file mode 100644 index 0000000000000000000000000000000000000000..3eaaa6c252ced36c0a17621e8c6e70afa479d73a --- /dev/null +++ b/vllm/model_executor/layers/quantization/bitblas.py @@ -0,0 +1,459 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Dict, List, Optional + +import torch + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( + BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_NUM_BITS, + BITBLAS_SUPPORTED_SYM, MINIMUM_BITBLAS_VERSION) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter) +from vllm.model_executor.utils import set_weight_attrs + +logger = init_logger(__name__) + + +class BitBLASConfig(QuantizationConfig): + """Config class for BitBLAS. + + Reference: https://github.com/Microsoft/BitBLAS + """ + TORCH_DTYPE = torch.float16 + STORAGE_DTYPE = "int8" # assume int8 storage + TORCH_STORAGE_DTYPE = getattr(torch, STORAGE_DTYPE) + # "original" or "rescale" or "quantized", + # gptq_with_bitblas prefer "quantized implementation" + ZEROS_MODE = "quantized" + + def __init__( + self, + weight_bits: int, + group_size: Optional[int], + desc_act: Optional[bool], + is_sym: Optional[bool], + quant_method: Optional[str], + lm_head_quantized: bool, + ) -> None: + try: + import bitblas + if bitblas.__version__ < MINIMUM_BITBLAS_VERSION: + raise ImportError( + "bitblas version is wrong. Please " + f"install bitblas>={MINIMUM_BITBLAS_VERSION}") + except ImportError as e: + bitblas_import_exception = e + raise ValueError( + "Trying to use the bitblas backend, but could not import" + f"with the following error: {bitblas_import_exception}. " + "Please install bitblas through the following command: " + f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`" + ) from bitblas_import_exception + + if desc_act and group_size == -1: + # In this case, act_order == True is the same as act_order == False + # (since we have only one group per output channel) + desc_act = False + + self.weight_bits = weight_bits + self.group_size = group_size + self.desc_act = desc_act + self.is_sym = is_sym + self.quant_method = quant_method + self.lm_head_quantized = lm_head_quantized + + # Verify + if self.weight_bits not in BITBLAS_SUPPORTED_NUM_BITS: + raise ValueError( + f"BitBLAS does not support weight_bits = {self.weight_bits}. " + f"Only weight_bits = {BITBLAS_SUPPORTED_NUM_BITS} " + "are supported.") + + if self.is_sym not in BITBLAS_SUPPORTED_SYM: + raise ValueError( + f"BitBLAS does not support is_sym = {self.is_sym}. " + f"Only sym = {BITBLAS_SUPPORTED_SYM} are supported.") + + storage_dtype = self.STORAGE_DTYPE + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + + self.storage_dtype = storage_dtype + self.storage_torch_dtype = self.TORCH_STORAGE_DTYPE + # 4 Bits packed into 32 bit datatype. + self.pack_factor = storage_nbit // weight_bits + self.nbits = weight_bits + + # Zeros type for the quantized weights. + self.zeros_mode = self.ZEROS_MODE + + def __repr__(self) -> str: + return (f"BitBLASConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act}, " + f"is_sym={self.is_sym}, " + f"quant_method={self.quant_method})") + + @classmethod + def get_name(cls) -> str: + return "bitblas" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 70 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @staticmethod + def get_from_keys(config: Dict[str, Any], + keys: List[str], + default: Any = None) -> Any: + """Get a value from the model's quantization config.""" + for key in keys: + if key in config: + return config[key] + return default + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "BitBLASConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"], -1) + desc_act = cls.get_from_keys(config, ["desc_act"], False) + is_sym = cls.get_from_keys(config, ["sym"], False) + quant_method = cls.get_from_keys(config, ["quant_method"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + return cls(weight_bits, group_size, desc_act, is_sym, quant_method, + lm_head_quantized) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + # compat: autogptq >=0.8.0 use checkpoint_format: str + # compat: autogptq <=0.7.1 is_bitblas_format: bool + is_bitblas_format = (hf_quant_cfg.get("checkpoint_format") == "bitblas" + or hf_quant_cfg.get("is_bitblas_format", False)) + + is_valid_user_quant = (user_quant is None or user_quant == "gptq" + or user_quant == "bitblas") + + if is_bitblas_format and is_valid_user_quant: + msg = ("The model is serialized in {} format. Using {} kernel.". + format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + return None + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["BitBLASLinearMethod"]: + if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) + and self.lm_head_quantized): + return BitBLASLinearMethod(self) + return None + + +class BitBLASLinearMethod(LinearMethodBase): + """Linear method for BitBLAS. + + Args: + quant_config: The BitBLAS quantization config. + """ + # USE BITBLAS_OPTIMIZE_FEATURES_CONTIGUOUS + # Instead of BITBLAS_OPTIMIZE_FEATURES + # If you want to high contiguous batching + # performance + OPT_FEATURES = BITBLAS_OPTIMIZE_FEATURES + ENABLE_TUNING = True + BITBLAS_DTYPES = { + torch.float32: "float32", + torch.float16: "float16", + torch.bfloat16: "bfloat16", + torch.half: "float16", + torch.int8: "int8", + } + + def __init__(self, quant_config: BitBLASConfig): + self.quant_config = quant_config + + def create_weights_gptq( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + """Creates quantized weights for use in linear operations. + + The function initializes and returns a dictionary containing quantized + weights, scales, and zeros + for performing quantized matrix multiplication operations. + + Args: + input_size_per_partition: The size of the input partition. + output_size_per_partition: The size of the output partition. + input_size: The total size of the input (unused). + output_size: The total size of the output (unused). + params_dtype: + The data type of the parameters (expected to be torch.float16). + + Returns: + A dictionary containing the quantized weights ('qweight'), + scales ('scales'), and zeros ('zeros'). + + Raises: + ValueError: If `params_dtype` is not `torch.float16` or if the + input size per partition is not divisible by the group size in + `quant_config`. + """ + del input_size, output_size # Unused arguments. + weight_loader = extra_weight_attrs["weight_loader"] + + if params_dtype not in self.quant_config.get_supported_act_dtypes(): + raise ValueError("Parameter data type must be torch.float16, " + f"but got {params_dtype}") + group_size = self.quant_config.group_size + if group_size is None: + group_size = -1 + # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) + if (group_size != -1 and input_size_per_partition % group_size != 0): + raise ValueError( + f"Input size per partition ({input_size_per_partition}) must " + f"be divisible by group size ({group_size}).") + + # Initialize or retrieve the BitBLAS matrix multiplication operator. + self._configure_bitblas_matmul( + input_size_per_partition, + output_size_per_partition, + params_dtype=params_dtype, + enable_tuning=self.ENABLE_TUNING, + bias=False, + layout="nt", + bits=self.quant_config.weight_bits, + ) + + # Initialize quantized weights with dimensions + # Quantized 4Bit weights packed. + qweight = PackedvLLMParameter( + data=torch.empty( + self.bitblas_matmul.retrieve_weight_shape(), + device="cuda", + dtype=self.quant_config.storage_torch_dtype, + requires_grad=False, + ), + input_dim=1, + output_dim=0, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + bitblas_tile_size=(self.bitblas_matmul.retrieve_weight_shape()[-2] + if self.bitblas_matmul.propagate_b else None), + weight_loader=weight_loader, + ) + + # Compute the number of input groups for channel-wise quantization. + input_groups = (1 if group_size == -1 else input_size_per_partition // + group_size) + + # Initialize scales and zeros for the quantized weights. + weight_scale_args = { + "data": + torch.empty( + output_size_per_partition, + input_groups, + device="cuda", + dtype=params_dtype, + ), + "weight_loader": + weight_loader + } + if input_groups == 1: + scales = ChannelQuantScaleParameter(output_dim=0, + **weight_scale_args) + else: + scales = GroupQuantScaleParameter(output_dim=0, + input_dim=1, + **weight_scale_args) + + if self.quant_config.zeros_mode == "quantized": + zeros = PackedvLLMParameter( + data=torch.empty( + input_groups, + output_size_per_partition // self.quant_config.pack_factor, + device="cuda", + dtype=self.quant_config.storage_torch_dtype, + requires_grad=False, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader, + ) + + else: + zeros = BasevLLMParameter( + torch.empty(output_size_per_partition, + input_groups, + device="cuda", + dtype=params_dtype), + weight_loader=weight_loader, + ) + # Set attributes to indicate how scales and zeros are applied. + set_weight_attrs(zeros, { + "input_dim": None if input_groups == 1 else 1, + "output_dim": 0, + }) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("scales", scales) + layer.register_parameter("zeros", zeros) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + if self.quant_config.quant_method == "gptq": + return self.create_weights_gptq(layer, input_size_per_partition, + output_partition_sizes, input_size, + output_size, params_dtype, + **extra_weight_attrs) + else: + raise ValueError( + f"Unsupported quant_method {self.quant_config.quant_method}") + + def _configure_bitblas_matmul( + self, + infeatures, + outfeatures, + params_dtype, + enable_tuning, + bias, + layout, + bits, + out_dtype="float16", + ): + from bitblas import MatmulConfig + bitblas_dtype = self.BITBLAS_DTYPES[params_dtype] + + with_scaling = False + with_zeros = False + group_size = self.quant_config.group_size + zeros_mode = self.quant_config.zeros_mode + if self.quant_config.quant_method == "gptq": + with_scaling = True + with_zeros = True + W_dtype = f"uint{bits}" + if self.quant_config.is_sym: + with_zeros = False + W_dtype = f"int{bits}" + else: + raise ValueError( + f"Unsupported quant_method {self.quant_config.quant_method}") + + matmul_config = MatmulConfig( + N=outfeatures, + K=infeatures, + A_dtype=bitblas_dtype, + W_dtype=W_dtype, + out_dtype=out_dtype, + accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype, + storage_dtype=self.quant_config.STORAGE_DTYPE, + with_scaling=with_scaling, + with_zeros=with_zeros, + group_size=group_size, + with_bias=bias, + layout=layout, + zeros_mode=zeros_mode, + ) + self.bitblas_matmul = self._get_or_create_bitblas_operator( + matmul_config, enable_tuning) + + def _get_or_create_bitblas_operator(self, config, enable_tuning): + from bitblas import Matmul, auto_detect_nvidia_target + from bitblas.cache import get_database_path, global_operator_cache + BITBLAS_DATABASE_PATH = get_database_path() + BITBLAS_TARGET = auto_detect_nvidia_target() + if global_operator_cache.size() == 0: + global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, + BITBLAS_TARGET) + + bitblas_matmul = global_operator_cache.get(config) + if bitblas_matmul is None: + bitblas_matmul = Matmul(config, + target=BITBLAS_TARGET, + enable_tuning=False) + if enable_tuning: + TUNING_MESSAGE = (f"BitBLAS Operator {config} is tuning ...") + logger.info(TUNING_MESSAGE) + bitblas_matmul.hardware_aware_finetune(topk=20) + global_operator_cache.add(config, bitblas_matmul) + global_operator_cache.save_into_database( + BITBLAS_DATABASE_PATH, BITBLAS_TARGET) + TUNED_MESSAGE = ( + f"BitBLAS Operator {config} tuned and saved to database.") + logger.info(TUNED_MESSAGE) + else: + _message = f"BitBLAS Operator {config} created." + logger.info(_message) + else: + _message = ( + f"BitBLAS Operator {config} found in global_operator_cache.") + logger.info(_message) + return bitblas_matmul + + def apply_gptq( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qweight = layer.qweight + scales = layer.scales + qzeros = layer.zeros + + x_2d = x.view(-1, x.shape[-1]) + + if self.quant_config.is_sym: + output_2d = self.bitblas_matmul(x_2d, qweight, scales) + else: + output_2d = self.bitblas_matmul(x_2d, qweight, scales, qzeros) + + output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) + + if bias is not None: + output.add_(bias) # In-place add + + return output + + def apply( + self, + *args: Any, + **kwargs: Any, + ) -> torch.Tensor: + if self.quant_config.quant_method == "gptq": + return self.apply_gptq(*args, **kwargs) + else: + raise ValueError( + f"Unsupported quant_method {self.quant_config.quant_method}") diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 6cd08fae7e6091e39b9f18a84b68d3a28a9d261d..fd148b4f0dd953d3ed1a5ae38f50332157b6c7c1 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -73,7 +73,7 @@ class CompressedTensorsConfig(QuantizationConfig): return 70 def get_name(self) -> str: - return "compressed_tensors" + return "compressed-tensors" def get_quant_method( self, @@ -303,14 +303,12 @@ class CompressedTensorsConfig(QuantizationConfig): def _is_wNa16_group_channel(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: input_quant_none = input_quant is None - is_symmetric = weight_quant.symmetric is_channel_group = ( weight_quant.strategy == QuantizationStrategy.CHANNEL.value or weight_quant.strategy == QuantizationStrategy.GROUP.value) is_static = not weight_quant.dynamic - return (is_channel_group and input_quant_none and is_symmetric - and is_static) + return (is_channel_group and input_quant_none and is_static) def _get_scheme_from_parts( self, weight_quant: BaseModel, @@ -320,6 +318,7 @@ class CompressedTensorsConfig(QuantizationConfig): if self._is_wNa16_group_channel(weight_quant, input_quant): if (self.quant_format == CompressionFormat.marlin_24.value and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS): + assert weight_quant.symmetric return CompressedTensorsW4A16Sparse24( strategy=weight_quant.strategy, num_bits=weight_quant.num_bits, @@ -329,6 +328,7 @@ class CompressedTensorsConfig(QuantizationConfig): return CompressedTensorsWNA16( num_bits=weight_quant.num_bits, strategy=weight_quant.strategy, + symmetric=weight_quant.symmetric, group_size=weight_quant.group_size, actorder=weight_quant.actorder) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 628724c5b7d6766ba087beec2e66e99482e46f31..721e36af2b28ff07f2a939501cf641148a83f340 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -67,7 +67,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): else: return CompressedTensorsWNA16MarlinMoEMethod(quant_config) elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) - and layer.activation == "silu" and layer.expert_map is None): + and layer.activation == "silu"): return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config) elif quant_config._is_fp8_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Fp8MoEMethod(quant_config) @@ -250,6 +250,28 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + is_rocm_aiter_moe_enabled) + + # Property to determine if AITER is used + if is_rocm_aiter_moe_enabled(): + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501 + rocm_aiter_fused_experts, shuffle_weights) + + # reshaping weights is required for aiter moe kernel. + shuffled_w13, shuffled_w2 = shuffle_weights( + layer.w13_weight.data, layer.w2_weight.data) + + layer.w13_weight = torch.nn.Parameter(shuffled_w13, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(shuffled_w2, + requires_grad=False) + + self.fused_experts_func = rocm_aiter_fused_experts + else: + from vllm.model_executor.layers.fused_moe import fused_experts + self.fused_experts_func = fused_experts + def apply( self, layer: torch.nn.Module, @@ -268,7 +290,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe import fused_experts topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -282,10 +303,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) - return fused_experts( - x, - layer.w13_weight, - layer.w2_weight, + return self.fused_experts_func( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, @@ -489,8 +510,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod): ) -> torch.Tensor: assert activation == "silu" - assert global_num_experts == layer.w13_weight.shape[0] - assert expert_map is None topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -521,6 +540,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod): a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, out_dtype=x.dtype, + expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input, ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 38df09ff39373fc33e1fe799c5feca2ecc2b4c25..3535dd3f3f14727ba51a9a3e790c271d0672c451 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -12,11 +12,15 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( MPLinearLayerConfig, choose_mp_linear_kernel) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( marlin_repeat_scales_on_all_ranks) +# yapf conflicts with isort for this block +# yapf: disable from vllm.model_executor.parameter import (BasevLLMParameter, ChannelQuantScaleParameter, GroupQuantScaleParameter, + PackedColumnParameter, PackedvLLMParameter, RowvLLMParameter) +# yapf: enable from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -26,6 +30,7 @@ WNA16_SUPPORTED_TYPES_MAP = { 4: scalar_types.uint4b8, 8: scalar_types.uint8b128 } +WNA16_ZP_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4, 8: scalar_types.uint8} WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys()) @@ -36,10 +41,12 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): strategy: str, num_bits: int, group_size: Optional[int] = None, + symmetric: Optional[bool] = True, actorder: Optional[ActivationOrdering] = None): self.pack_factor = 32 // num_bits self.strategy = strategy + self.symmetric = symmetric self.group_size = -1 if group_size is None else group_size self.has_g_idx = actorder == ActivationOrdering.GROUP @@ -53,7 +60,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): f"Unsupported num_bits = {num_bits}. " f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}") - self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits] + self.quant_type = (WNA16_ZP_SUPPORTED_TYPES_MAP[num_bits] + if not self.symmetric else + WNA16_SUPPORTED_TYPES_MAP[num_bits]) @classmethod def get_min_capability(cls) -> int: @@ -75,7 +84,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): weight_type=self.quant_type, act_type=params_dtype, group_size=self.group_size, - zero_points=False, + zero_points=not self.symmetric, has_g_idx=self.has_g_idx ) @@ -120,13 +129,37 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): dtype=params_dtype, ) } + + zeros_args = { + "weight_loader": + weight_loader, + "data": + torch.zeros( + output_size_per_partition // self.pack_factor, + scales_and_zp_size, + dtype=torch.int32, + ) + } + if not partition_scales: weight_scale = ChannelQuantScaleParameter(output_dim=0, **weight_scale_args) + + if not self.symmetric: + qzeros = PackedColumnParameter(output_dim=0, + packed_dim=0, + packed_factor=self.pack_factor, + **zeros_args) else: weight_scale = GroupQuantScaleParameter(output_dim=0, input_dim=1, **weight_scale_args) + if not self.symmetric: + qzeros = PackedvLLMParameter(input_dim=1, + output_dim=0, + packed_dim=0, + packed_factor=self.pack_factor, + **zeros_args) # A 2D array defining the original shape of the weights # before packing @@ -138,6 +171,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_shape", weight_shape) + if not self.symmetric: + layer.register_parameter("weight_zero_point", qzeros) + # group index (for activation reordering) if self.has_g_idx: weight_g_idx = RowvLLMParameter(data=torch.empty( @@ -151,7 +187,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): self.kernel = kernel_type(mp_linear_kernel_config, w_q_param_name="weight_packed", w_s_param_name="weight_scale", - w_zp_param_name=None, + w_zp_param_name="weight_zero_point", w_gidx_param_name="weight_g_idx") # Checkpoints are serialized in compressed-tensors format, which is diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index b7327f47733b33727a3005555d71560f3a8bc4c2..01056c37b86c8d20500d0116f1cbfb1ebc276693 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -140,6 +140,11 @@ class Fp8Config(QuantizationConfig): return name.replace(".k_proj.output_scale", ".attn.k_scale") if name.endswith(".output_scale") and ".v_proj" in name: return name.replace(".v_proj.output_scale", ".attn.v_scale") + if name.endswith(".output_scale") and ".q_proj" in name: + return name.replace(".q_proj.output_scale", ".attn.q_scale") + if name.endswith("self_attn.prob_output_scale"): + return name.replace(".prob_output_scale", ".attn.prob_scale") + # If no matches, return None return None @@ -575,8 +580,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): def process_weights_after_loading(self, layer: Module) -> None: # Lazy import to avoid importing triton too early. from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - expand_weights, is_rocm_aiter_block_scaled_moe_enabled, - is_rocm_aiter_moe_enabled, shuffle_weights) + expand_weights, is_rocm_aiter_moe_enabled, shuffle_weights) # TODO (rob): refactor block quant into separate class. if self.block_quant: @@ -603,7 +607,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.w2_weight = Parameter(w2_weight, requires_grad=False) layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv, requires_grad=False) - if is_rocm_aiter_block_scaled_moe_enabled(): + if is_rocm_aiter_moe_enabled(): # reshaping weights is required for aiter moe kernel. shuffled_w13, shuffled_w2 = shuffle_weights( layer.w13_weight.data, layer.w2_weight.data) diff --git a/vllm/model_executor/layers/quantization/gptq_bitblas.py b/vllm/model_executor/layers/quantization/gptq_bitblas.py new file mode 100644 index 0000000000000000000000000000000000000000..88cada4c61b83811df854a4c883e9498f41bae10 --- /dev/null +++ b/vllm/model_executor/layers/quantization/gptq_bitblas.py @@ -0,0 +1,438 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Dict, List, Optional, Set + +import torch +from torch.nn.parameter import Parameter + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( + BitBLASLinearKernel, MPLinearLayerConfig) +from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( + BITBLAS_SUPPORTED_NUM_BITS as GPTQ_BITBLAS_SUPPORTED_NUM_BITS) +from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( + BITBLAS_SUPPORTED_SYM as GPTQ_BITBLAS_SUPPORTED_SYM) +from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( + MINIMUM_BITBLAS_VERSION, bitblas_repeat_scales_on_all_ranks, + check_bitblas_supported, verify_bitblas_supported) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + RowvLLMParameter) +from vllm.scalar_type import scalar_types + +logger = init_logger(__name__) + + +class GPTQBitBLASConfig(QuantizationConfig): + """Config class for GPTQ BitBLAS""" + + # (num_bits, is_sym) -> quant_type + TYPE_MAP = { + (4, True): scalar_types.uint4b8, + (8, True): scalar_types.uint8b128, + } + + TORCH_DTYPE = torch.float16 + GPTQ_CKPT_STORAGE_DTYPE = ( + "int32" # GPTQ Default Checkpoints use int32 as storage dtype + ) + GPTQ_BITBLAS_STORAGE_DTYPE = "int8" # BitBLAS uses int8 as storage dtype + TORCH_BITBLAS_STORAGE_DTYPE = getattr(torch, GPTQ_BITBLAS_STORAGE_DTYPE) + # "original" or "rescale" or "quantized", + # the gptq_bitblas prefer "quantized" + ZEROS_MODE = "quantized" + + def __init__( + self, + weight_bits: int, + group_size: int, + desc_act: bool, + is_sym: bool, + quant_method: Optional[str], + lm_head_quantized: bool, + ) -> None: + + try: + import bitblas + if bitblas.__version__ < MINIMUM_BITBLAS_VERSION: + raise ImportError( + "bitblas version is wrong. Please " + f"install bitblas>={MINIMUM_BITBLAS_VERSION}") + except ImportError as e: + bitblas_import_exception = e + raise ValueError( + "Trying to use the bitblas backend, but could not import" + f"with the following error: {bitblas_import_exception}. " + "Please install bitblas through the following command: " + f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`" + ) from bitblas_import_exception + + if desc_act and group_size == -1: + # In this case, act_order == True is the same as act_order == False + # (since we have only one group per output channel) + desc_act = False + + self.weight_bits = weight_bits + self.group_size = group_size + self.desc_act = desc_act + self.is_sym = is_sym + self.quant_method = quant_method + self.lm_head_quantized = lm_head_quantized + + # Verify + if self.weight_bits not in GPTQ_BITBLAS_SUPPORTED_NUM_BITS: + raise ValueError( + f"BitBLAS does not support weight_bits = {self.weight_bits}. " + f"Only weight_bits = {GPTQ_BITBLAS_SUPPORTED_NUM_BITS} " + "are supported.") + + if self.is_sym not in GPTQ_BITBLAS_SUPPORTED_SYM: + raise ValueError( + f"BitBLAS does not support is_sym = {self.is_sym}. " + f"Only sym = {GPTQ_BITBLAS_SUPPORTED_SYM} are supported.") + + self.storage_dtype = self.GPTQ_BITBLAS_STORAGE_DTYPE + + storage_nbit = int("".join(c for c in self.GPTQ_CKPT_STORAGE_DTYPE + if c.isdigit())) + + # 4 Bits packed into 32 bit datatype. + self.pack_factor = storage_nbit // weight_bits + self.nbits = weight_bits + + # Zeros type for the quantized weights. + self.zeros_mode = self.ZEROS_MODE + + if (weight_bits, is_sym) not in self.TYPE_MAP: + raise ValueError("Unsupported quantization config: " + f"bits={weight_bits}, sym={is_sym}") + + self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)] + + def __repr__(self) -> str: + return (f"GPTQBitBLASConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act})" + f"is_sym={self.is_sym}, " + f"quant_method={self.quant_method})") + + @classmethod + def get_name(cls) -> str: + return "gptq_bitblas" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "GPTQBitBLASConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + desc_act = cls.get_from_keys(config, ["desc_act"]) + is_sym = cls.get_from_keys(config, ["sym"]) + quant_method = cls.get_from_keys(config, ["quant_method"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + return cls(weight_bits, group_size, desc_act, is_sym, quant_method, + lm_head_quantized) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + can_convert = cls.is_gptq_bitblas_compatible(hf_quant_cfg) + + is_valid_user_quant = (user_quant is None or user_quant == "bitblas" + or user_quant == "gptq_bitblas") + + if can_convert and is_valid_user_quant: + msg = ("The model is convertible to {} during runtime." + " Using {} kernel.".format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + if can_convert and user_quant == "gptq": + logger.info("Detected that the model can run with gptq_bitblas" + ", however you specified quantization=gptq explicitly," + " so forcing gptq. Use quantization=gptq_bitblas for" + " faster inference") + return None + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["GPTQBitBLASLinearMethod"]: + if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) + and self.lm_head_quantized): + return GPTQBitBLASLinearMethod(self) + return None + + @property + def torch_storage_dtype(self) -> torch.dtype: + return self.TORCH_BITBLAS_STORAGE_DTYPE + + @classmethod + def is_gptq_bitblas_compatible(cls, quant_config: Dict[str, Any]): + # Extract data from quant config. + num_bits = quant_config.get("bits") + group_size = quant_config.get("group_size") + sym = quant_config.get("sym") + desc_act = quant_config.get("desc_act") + + # If we cannot find the info needed in the config, cannot convert. + if (num_bits is None or group_size is None or sym is None + or desc_act is None): + return False + + if (num_bits, sym) not in cls.TYPE_MAP: + return False + + # If the capability of the device is too low, cannot convert. + major, minor = torch.cuda.get_device_capability() + device_capability = major * 10 + minor + if device_capability < cls.get_min_capability(): + return False + + # Otherwise, can convert if model satisfies bitblas constraints. + return check_bitblas_supported(quant_type=cls.TYPE_MAP[(num_bits, + sym)], + group_size=group_size) + + +class GPTQBitBLASLinearMethod(LinearMethodBase): + """Linear method for GPTQ BitBLAS. + + Args: + quant_config: The GPTQ BitBLAS quantization config. + """ + + kernel_type = BitBLASLinearKernel + _kernel_backends_being_used: Set[str] = set() + + def __init__(self, quant_config: GPTQBitBLASConfig) -> None: + self.quant_config = quant_config + # Verify supported on platform. + verify_bitblas_supported(quant_type=self.quant_config.quant_type, + group_size=self.quant_config.group_size) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + """Creates quantized weights for use in linear operations. + + The function initializes and returns a dictionary containing + quantized weights, scales, and zeros + for performing quantized matrix multiplication operations. + + Args: + input_size_per_partition: The size of the input partition. + output_partition_sizes: The size of the output partition. + input_size: The total size of the input (unused). + output_size: The total size of the output (unused). + params_dtype: + The data type of the parameters (expected to be torch.float16). + + Returns: + A dictionary containing the quantized weights ('qweight'), + scales ('scales'), and zeros ('zeros'). + + Raises: + ValueError: If `params_dtype` is not `torch.float16` or + if the input size per partition is not divisible by the + group size in `quant_config`. + """ + if params_dtype != torch.float16: + raise ValueError("Parameter data type must be torch.float16, " + f"but got {params_dtype}") + + # Normalize group_size + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + + if input_size_per_partition % group_size != 0: + raise ValueError( + f"Input size per partition ({input_size_per_partition}) must " + f"be divisible by group size ({self.quant_config.group_size})." + ) + + kernel_type = self.kernel_type + # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) + + is_row_parallel = input_size != input_size_per_partition + weight_loader = extra_weight_attrs.get("weight_loader") + + mp_linear_kernel_config = MPLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=\ + (input_size_per_partition, output_size_per_partition), + weight_type=self.quant_config.quant_type, + act_type=params_dtype, + group_size=self.quant_config.group_size, + zero_points=False, + has_g_idx=self.quant_config.desc_act + ) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for GPTQBitBLASLinearMethod", + kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) + + # Normalize group_size + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + + # Determine sharding + if bitblas_repeat_scales_on_all_ranks(self.quant_config.desc_act, + self.quant_config.group_size, + is_row_parallel): + # By setting scale_dim == None, weight_loader will + # repeat the scales on each GPU in TP>1 case. + scales_and_zp_input_dim = None + scales_and_zp_size = input_size // group_size + else: + # By setting scale_dim == 0, weight_loader will + # shard the scales in TP>1 case. + scales_and_zp_input_dim = 0 + scales_and_zp_size = input_size_per_partition // group_size + + # Init buffers + # Quantized weights + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.pack_factor, + output_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=0, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader) + + # Activation order + # Ignore warning from fused linear layers such as QKVParallelLinear. + g_idx = RowvLLMParameter(data=torch.empty( + input_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader) + + # Scales + scales = Parameter( + torch.empty( + scales_and_zp_size, + output_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs( + scales, + { + **extra_weight_attrs, + "input_dim": scales_and_zp_input_dim, + "output_dim": 1, + }, + ) + + # Quantized zero-points + qzeros_args = { + "data": + torch.empty( + scales_and_zp_size, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + "weight_loader": + weight_loader + } + weight_scale_args = { + "data": + torch.empty( + scales_and_zp_size, + output_size_per_partition, + dtype=params_dtype, + ), + "weight_loader": + weight_loader + } + + if scales_and_zp_input_dim is None: + scales = ChannelQuantScaleParameter(output_dim=1, + **weight_scale_args) + qzeros = PackedColumnParameter( + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args) + + else: + scales = GroupQuantScaleParameter(output_dim=1, + input_dim=0, + **weight_scale_args) + qzeros = PackedvLLMParameter( + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("g_idx", g_idx) + layer.register_parameter("scales", scales) + layer.register_parameter("qzeros", qzeros) + + self.kernel = kernel_type( + mp_linear_kernel_config, + w_q_param_name="qweight", + w_s_param_name="scales", + w_zp_param_name="qzeros", + w_gidx_param_name="g_idx", + bitblas_quant_config=self.quant_config, + ) + + # Initialize or retrieve the BitBLAS matrix multiplication operator. + self.kernel.configure_bitblas_matmul( + input_size_per_partition, + output_size_per_partition, + params_dtype=params_dtype, + bias=False, + ) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + out = self.kernel.apply_gptq_bitblas_linear(layer, x) + if bias is not None: + out.add_(bias) + return out diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 0615bb4ab4df7f04634538eb5ad7964e13a659e1..52cd0a5b697577260451c51b0f47b267765f4714 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -15,13 +15,13 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( MPLinearLayerConfig, choose_mp_linear_kernel) -from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.gptq_utils import ( get_linear_quant_method) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - check_marlin_supported, marlin_moe_permute_scales, - marlin_repeat_scales_on_all_ranks, verify_marlin_supported) + check_marlin_supported, check_moe_marlin_supports_layer, + marlin_moe_permute_scales, marlin_repeat_scales_on_all_ranks, + verify_marlin_supported) from vllm.model_executor.parameter import (ChannelQuantScaleParameter, GroupQuantScaleParameter, PackedColumnParameter, @@ -153,12 +153,15 @@ class GPTQMarlinConfig(QuantizationConfig): def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: if isinstance(layer, FusedMoE): - if layer.local_num_experts > 32: - # For MoEs with many experts the moe_wna16 kernel is faster + from vllm.model_executor.layers.quantization.moe_wna16 import ( + MoeWNA16Config) + if not check_moe_marlin_supports_layer(layer, self.group_size): + logger.warning_one( + f"Layer '{prefix}' is not supported by GPTQMoeMarlin. " + "Falling back to Moe WNA16 kernels.") return MoeWNA16Config.from_config( self.full_config).get_quant_method(layer, prefix) - else: - return GPTQMarlinMoEMethod(self) + return GPTQMarlinMoEMethod(self) return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod) @@ -408,7 +411,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): torch.empty(num_experts, scales_size13, 2 * intermediate_size_per_partition, - dtype=torch.half), + dtype=params_dtype), requires_grad=False, ) layer.register_parameter("w13_scales", w13_scales) @@ -418,7 +421,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): torch.empty(num_experts, scales_size2, hidden_size, - dtype=torch.half), + dtype=params_dtype), requires_grad=False, ) layer.register_parameter("w2_scales", w2_scales) @@ -493,6 +496,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): w2_g_idx_sort_indices) set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) + device = layer.w13_qweight.device + sms = torch.cuda.get_device_properties(device).multi_processor_count + layer.workspace = torch.zeros((sms * 4, ), + dtype=torch.int, + device=device, + requires_grad=False) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Process act_order @@ -601,10 +611,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): "Apply router weight on input is not supported for" "fused Marlin MoE method.") - # The input must currently be float16 - orig_dtype = x.dtype - x = x.half() - topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -626,9 +632,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): router_logits, topk_weights, topk_ids, + global_num_experts=global_num_experts, + expert_map=expert_map, g_idx1=layer.w13_g_idx, g_idx2=layer.w2_g_idx, sort_indices1=layer.w13_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices, num_bits=self.quant_config.quant_type.size_bits, - is_k_full=self.is_k_full).to(orig_dtype) + workspace=layer.workspace, + is_k_full=self.is_k_full) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py index 520e1bc96721c991a619192a8a95a2a4100500d6..d144bb4361045a701d18f137628ee66d601927ec 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -5,6 +5,8 @@ from typing import List, Optional, Type import vllm.envs as envs from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501 AllSparkLinearKernel) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas import ( # noqa: E501 + BitBLASLinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501 ExllamaLinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501 @@ -20,6 +22,7 @@ _POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ MacheteLinearKernel, AllSparkLinearKernel, MarlinLinearKernel, + BitBLASLinearKernel, ExllamaLinearKernel, ] @@ -76,4 +79,4 @@ def choose_mp_linear_kernel( raise ValueError( "Failed to find a kernel that can implement the "\ "WNA16 linear layer. Reasons: \n" - + '\n'.join(failure_reasons)) + + '\n'.join(failure_reasons)) \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py new file mode 100644 index 0000000000000000000000000000000000000000..21452d08b8a1cd34415be1ea917d136843a252d0 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py @@ -0,0 +1,299 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, List, Optional, Tuple + +import torch + +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( + BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_GROUP_SIZES, + MINIMUM_BITBLAS_VERSION, bitblas_make_empty_g_idx, bitblas_sort_g_idx, + check_bitblas_supports_shape, query_bitblas_supported_quant_types, + unpack_gptq_qweight, unpack_gptq_qzeros) + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + +logger = init_logger(__name__) + + +class BitBLASLinearKernel(MPLinearKernel): + + OPT_FEATURES: List[int] = BITBLAS_OPTIMIZE_FEATURES + ENABLE_TUNING: bool = True + MATMUL_LAYOUT: str = "nt" + BITBLAS_DTYPES: Dict[torch.dtype, str] = { + torch.float32: "float32", + torch.float16: "float16", + torch.bfloat16: "bfloat16", + torch.half: "float16", + torch.int8: "int8", + } + bitblas_matmul: object = None + + def __init__( + self, + c: MPLinearLayerConfig, + w_q_param_name: str, + w_s_param_name: str, + w_zp_param_name: Optional[str] = None, + w_gidx_param_name: Optional[str] = None, + bitblas_quant_config: Optional[QuantizationConfig] = None, + ): + self.quant_config = bitblas_quant_config + super().__init__(c, w_q_param_name, w_s_param_name, w_zp_param_name, + w_gidx_param_name) + + def repack_bitblas_from_gptq( + self, + b_q_weight: torch.Tensor, + scales: torch.Tensor, + qzeros: Optional[torch.Tensor] = None, + ): + from bitblas.quantization.utils import general_compress + assert self.bitblas_matmul is not None, "bitblas_matmul is None" + + quant_config = self.quant_config + # qweight in gptq old quant linear stored with + # (outfeatures, infeatures), should be transposed. + qweight = b_q_weight.T.contiguous().view( + quant_config.torch_storage_dtype) # type: ignore[union-attr] + intweight = unpack_gptq_qweight( + qweight, + quant_config.weight_bits).contiguous() # type: ignore[union-attr] + if self.bitblas_matmul.weight_transform is not None: # type: ignore[attr-defined] + qweight = self.bitblas_matmul.weight_transform( # type: ignore[attr-defined] + intweight.cpu()).cuda() + # scales in gptq old quant linear stored with + # (infeatures // group_size, outfeatures), should be transposed. + scales = scales.T.contiguous() + + if qzeros is None: + return qweight, scales, None + + # qzeros should be de-quantized to int zeros. + weight_bits = quant_config.weight_bits # type: ignore[union-attr] + intzeros = unpack_gptq_qzeros(qzeros, weight_bits).T.contiguous() + zeros: Optional[torch.Tensor] = None + zeros_mode = self.bitblas_matmul.config.zeros_mode # type: ignore[attr-defined] + if zeros_mode == "original": + zeros = intzeros.to(torch.float16).contiguous() + elif zeros_mode == "rescale": + assert zeros is not None, "zeros should not be None" + zeros[:, :] = intzeros.to(torch.float16)[:, :] * scales[:, :] + elif zeros_mode == "quantized": + zeros = ( + torch.Tensor( + general_compress( + intzeros.T.contiguous().cpu().numpy(), + weight_bits, + )).to(qweight.device). + to(quant_config.torch_storage_dtype # type: ignore[union-attr] + ).contiguous()) + else: + raise ValueError("Unsupported zeros type: {}".format(zeros_mode)) + + return qweight, scales, zeros + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + + is_bitblas_installed = True + + try: + import bitblas + if bitblas.__version__ < MINIMUM_BITBLAS_VERSION: + raise ImportError( + "bitblas version is wrong. Please " + f"install bitblas>={MINIMUM_BITBLAS_VERSION}") + except ImportError: + is_bitblas_installed = False + + if not is_bitblas_installed: + return False, "bitblas is not installed. Please install bitblas "\ + "by running `pip install bitblas>="\ + f"{MINIMUM_BITBLAS_VERSION}`" + + quant_types = query_bitblas_supported_quant_types(c.zero_points) + if c.weight_type not in quant_types: + return False, (f"Quant type ({c.weight_type}) not supported by" + f" BitBLAS, supported types are: {quant_types}") + + if c.group_size not in BITBLAS_SUPPORTED_GROUP_SIZES: + return False, (f"Group size ({c.group_size}) not supported by " + "BitBLAS, supported group sizes are: " + f"{BITBLAS_SUPPORTED_GROUP_SIZES}") + + return check_bitblas_supports_shape( + c.partition_weight_shape[1], # out_features + c.partition_weight_shape[0], # in_features + c.full_weight_shape[0], # in_features + c.group_size) + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + device = getattr(layer, self.w_q_name).device + c = self.config + quant_config = self.quant_config + + # Default names since bitblas requires empty parameters for these, + # TODO: remove this requirement from bitblas (allow optional tensors) + if self.w_gidx_name is None: + self.w_gidx_name = "g_idx" + if self.w_zp_name is None: + self.w_zp_name = "qzeros" + + if c.has_g_idx: + g_idx, g_idx_sort_indices = bitblas_sort_g_idx( + getattr(layer, self.w_gidx_name)) + self._transform_param(layer, self.w_gidx_name, lambda _: g_idx) + layer.g_idx_sort_indices = g_idx_sort_indices + else: + setattr(layer, self.w_gidx_name, bitblas_make_empty_g_idx(device)) + layer.g_idx_sort_indices = bitblas_make_empty_g_idx(device) + + if c.zero_points: + raise NotImplementedError("Zero points not supported by BitBLAS") + else: + setattr(layer, self.w_zp_name, bitblas_make_empty_g_idx(device)) + + # Repack weights + bitblas_qweight, bitblas_scales, bitblas_qzeros = ( + self.repack_bitblas_from_gptq( + layer.qweight, + layer.scales, + None if quant_config.is_sym else # type: ignore[union-attr] + layer.qzeros, # type: ignore[union-attr] + )) + replace_parameter(layer, self.w_q_name, bitblas_qweight) + replace_parameter(layer, self.w_s_name, bitblas_scales) + if bitblas_qzeros is not None: + replace_parameter(layer, self.w_zp_name, bitblas_qzeros) + + def configure_bitblas_matmul( + self, + infeatures: int, + outfeatures: int, + params_dtype: torch.dtype, + bias: bool, + ) -> None: + enable_tuning = self.ENABLE_TUNING + layout = self.MATMUL_LAYOUT + bits = self.quant_config.weight_bits # type: ignore[union-attr] + self._configure_bitblas_matmul( + infeatures, + outfeatures, + params_dtype, + enable_tuning, + bias, + layout, + bits, + ) + + def _configure_bitblas_matmul( + self, + infeatures, + outfeatures, + params_dtype, + enable_tuning, + bias, + layout, + bits, + ): + from bitblas import MatmulConfig + bitblas_dtype = self.BITBLAS_DTYPES[params_dtype] + quant_config = self.quant_config + with_scaling = False + with_zeros = False + group_size = quant_config.group_size # type: ignore[union-attr] + zeros_mode = quant_config.zeros_mode # type: ignore[union-attr] + if quant_config.quant_method == "gptq": # type: ignore[union-attr] + with_scaling = True + with_zeros = True + W_dtype = f"uint{bits}" + if quant_config.is_sym: # type: ignore[union-attr] + with_zeros = False + W_dtype = f"int{bits}" + else: + raise ValueError( + f"Unsupported quant_method {quant_config.quant_method}" # type: ignore[union-attr] + ) # type: ignore[union-attr] + + matmul_config = MatmulConfig( + M=self.OPT_FEATURES, + N=outfeatures, + K=infeatures, + A_dtype=bitblas_dtype, + W_dtype=W_dtype, + out_dtype=bitblas_dtype, + accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype, + storage_dtype=quant_config. # type: ignore[union-attr] + storage_dtype, # type: ignore[union-attr] + with_scaling=with_scaling, + with_zeros=with_zeros, + group_size=group_size, + with_bias=bias, + layout=layout, + zeros_mode=zeros_mode, + ) + self.bitblas_matmul = self._get_or_create_bitblas_operator( + matmul_config, enable_tuning) + + def _get_or_create_bitblas_operator(self, config, enable_tuning): + from bitblas import Matmul, auto_detect_nvidia_target + from bitblas.cache import get_database_path, global_operator_cache + BITBLAS_DATABASE_PATH = get_database_path() + BITBLAS_TARGET = auto_detect_nvidia_target() + + if global_operator_cache.size() == 0: + global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, + BITBLAS_TARGET) + + bitblas_matmul = global_operator_cache.get(config) + if bitblas_matmul is None: + bitblas_matmul = Matmul(config, + target=BITBLAS_TARGET, + enable_tuning=False) + if enable_tuning: + bitblas_matmul.hardware_aware_finetune(topk=20) + global_operator_cache.add(config, bitblas_matmul) + global_operator_cache.save_into_database( + BITBLAS_DATABASE_PATH, BITBLAS_TARGET) + TUNING_MESSAGE = ( + f"BitBLAS Operator {config} tuned and saved to database.") + logger.info(TUNING_MESSAGE) + else: + _message = f"BitBLAS Operator {config} created without tuning. " + logger.info(_message) + else: + _message = f"BitBLAS Operator {config} retrieved from cache." + logger.info(_message) + return bitblas_matmul + + def apply_gptq_bitblas_linear( + self, + layer: torch.nn.Module, + x: torch.Tensor, + ) -> torch.Tensor: + output_size_per_partition = self.config.partition_weight_shape[1] + out_shape = x.shape[:-1] + (output_size_per_partition, ) + args = [x, layer.qweight, layer.scales] + if self.bitblas_matmul.config.with_zeros: # type: ignore[attr-defined] + args.append(layer.qzeros) + output = self.bitblas_matmul(*args) # type: ignore[operator] + return output.view(out_shape) + + def apply_weights(self, layer, x, bias=None): + NOT_IMPLEMENT_MESSAGE = ( + f"{self.__class__.__name__}.apply_weights is not implemented. " + "Please use BitBLASLinearKernel.apply_gptq_bitblas_linear instead") + raise NotImplementedError(NOT_IMPLEMENT_MESSAGE) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py index 3f0586f6e30d6a02f95a36196b88f96d16a1a15e..b3ffeca4f100ea9c6a8dab96e7c39a6152f73e33 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py @@ -26,17 +26,14 @@ class MacheteLinearKernel(MPLinearKernel): @classmethod def can_implement(cls, c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if c.has_g_idx and\ c.partition_weight_shape[0] != c.full_weight_shape[0]: return False, "Act reordering currently not supported by Machete, "\ "when the input features are partitioned across "\ "devices" - if c.zero_points: - return False, "Zero points currently not supported by "\ - " Compressed Tensors + Machete. (Kernel supports it"\ - " but CompressedTensorsWNA16 does not so support has"\ - " not been added to MacheteWNA16Kernel yet" + return False, "Zero points currently not supported by Machete" if c.weight_type not in query_machete_supported_quant_types( c.zero_points): diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py index e21801cf6a7857ae700c30f2f0d15484993e2044..7bd824ff9e55151a6caa4bebfb136cdfe1840656 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py @@ -9,7 +9,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear, check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx, - query_marlin_supported_quant_types) + marlin_zero_points, query_marlin_supported_quant_types, unpack_cols) from vllm.model_executor.parameter import (BasevLLMParameter, permute_param_layout_) @@ -25,10 +25,6 @@ class MarlinLinearKernel(MPLinearKernel): @classmethod def can_implement(cls, c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: - if c.zero_points: - return False, "Zero points currently not supported by "\ - " MarlinLinearKernel. Will be added when AWQMarlin "\ - "is migrated over to using MPLinearKernel backend" quant_types = query_marlin_supported_quant_types(c.zero_points) if c.weight_type not in quant_types: @@ -67,28 +63,6 @@ class MarlinLinearKernel(MPLinearKernel): if self.w_zp_name is None: self.w_zp_name = "w_zp" - if c.has_g_idx: - g_idx, g_idx_sort_indices = marlin_sort_g_idx( - getattr(layer, self.w_gidx_name)) - self._transform_param(layer, self.w_gidx_name, lambda _: g_idx) - layer.g_idx_sort_indices = g_idx_sort_indices - else: - setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device)) - layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) - - if c.zero_points: - pass - # TODO (lucas): add the following when AWQMarlin is migrated over to - # using MPLinearKernel backend - # self._transform_param(layer, self.w_zp_name, lambda x: \ - # marlin_zero_points( - # x, - # size_k=c.partition_weight_shape[0], - # size_n=c.partition_weight_shape[1], - # num_bits=c.weight_type.size_bits)) - else: - setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device)) - def transform_w_q(x): assert isinstance(x, BasevLLMParameter) permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) @@ -108,6 +82,28 @@ class MarlinLinearKernel(MPLinearKernel): group_size=c.group_size) return x + if c.has_g_idx: + g_idx, g_idx_sort_indices = marlin_sort_g_idx( + getattr(layer, self.w_gidx_name)) + self._transform_param(layer, self.w_gidx_name, lambda _: g_idx) + layer.g_idx_sort_indices = g_idx_sort_indices + else: + setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device)) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + + if c.zero_points: + grouped_k = (c.partition_weight_shape[0] // + c.group_size if c.group_size != -1 else 1) + self._transform_param(layer, self.w_zp_name, lambda x: \ + marlin_zero_points( + unpack_cols(x.t(), c.weight_type.size_bits, + grouped_k, + c.partition_weight_shape[1]), + size_k=grouped_k, + size_n=c.partition_weight_shape[1], + num_bits=c.weight_type.size_bits)) + else: + setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device)) self._transform_param(layer, self.w_q_name, transform_w_q) self._transform_param(layer, self.w_s_name, transform_w_s) @@ -131,5 +127,6 @@ class MarlinLinearKernel(MPLinearKernel): wtype=c.weight_type, input_size_per_partition=c.partition_weight_shape[0], output_size_per_partition=c.partition_weight_shape[1], + has_zp=self.config.zero_points, is_k_full=self.is_k_full, bias=bias) diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index 5d766c2c27ac9b6be139f48ea2e9bcfa64a6f779..5dff8b09693ce4cddc1257f20b9fb1048a3ceee1 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -38,6 +38,9 @@ class BaseKVCacheMethod(QuantizeMethodBase): requires_grad=False) layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) + # Initialize P = softmax(QK^T) scales + layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0), + requires_grad=False) def apply(self, layer: torch.nn.Module) -> torch.Tensor: raise RuntimeError( @@ -97,5 +100,38 @@ class BaseKVCacheMethod(QuantizeMethodBase): "may cause accuracy issues. Please make sure k/v_scale " "scaling factors are available in the fp8 checkpoint.") + if layer.q_scale > 0.0: + q_scale = layer.q_scale + if current_platform.is_fp8_fnuz(): + q_scale *= 2 + layer.calculate_kv_scales = False + else: + q_scale = 1.0 + if layer.prob_scale > 0.0: + prob_scale = layer.prob_scale + if current_platform.is_fp8_fnuz(): + prob_scale *= 2 + else: + prob_scale = 1.0 + + is_singleton_float = lambda x: isinstance(x, float) or isinstance( + x, torch.Tensor) and x.numel() == 1 and x.is_floating_point() + if not is_singleton_float(q_scale) or not is_singleton_float( + prob_scale): + raise ValueError("Only support per-tensor scaling factor" + "for fp8-quantized Q/prob") + + # These are used in the final Attention.forward() + layer._q_scale.copy_(q_scale) + layer._prob_scale.copy_(prob_scale) + if q_scale == 1.0 or prob_scale == 1.0: + logger.warning_once( + f"Using Q scale {q_scale} and prob scale {prob_scale} " + "with fp8 attention. This may cause accuracy issues. " + "Please make sure Q/prob scaling factors are " + "available in the fp8 checkpoint.") + del layer.k_scale del layer.v_scale + del layer.q_scale + del layer.prob_scale diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index ca71da8b736a5e2d435a0dcffd8c8fda9da2ba3c..cf9108ea72c3c61db2d0833ceb9a74d524271433 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import fnmatch -import re from typing import Any, Dict, List, Optional, cast import torch @@ -125,6 +124,13 @@ class QuarkConfig(QuantizationConfig): for q_config in q_configs: q_config["output_tensors"] = None + # In case q_proj output is also quantized, remove the configuration + # to keep qkv consistency. + q_proj_q_config = cast(Dict[str, Any], + layer_quant_config.get("*q_proj")) + if q_proj_q_config is not None: + q_proj_q_config["output_tensors"] = None + return cls(quant_config=config, kv_cache_group=kv_cache_group, kv_cache_config=kv_cache_config, @@ -289,25 +295,14 @@ class QuarkConfig(QuantizationConfig): :param name: param name :return: matching param name for KV cache scale in vLLM """ - if self.kv_cache_group is None or len(self.kv_cache_group) == 0: - return None - - kv_proj_names = [ - re.split(r"[*.]", kv_cache)[-1] for kv_cache in self.kv_cache_group - ] - if name.endswith(".output_scale"): - if len(kv_proj_names) == 1 and kv_proj_names[0] in name: - kv_output_scale_name = "." + kv_proj_names[0] + ".output_scale" - return name.replace(kv_output_scale_name, ".attn.k_scale") - - elif len(kv_proj_names) == 2: - for kv_proj_name in kv_proj_names: - if kv_proj_name in name and kv_proj_name == "k_proj": - return name.replace(".k_proj.output_scale", - ".attn.k_scale") - elif kv_proj_name in name and kv_proj_name == "v_proj": - return name.replace(".v_proj.output_scale", - ".attn.v_scale") + if name.endswith(".output_scale") and ".k_proj" in name: + return name.replace(".k_proj.output_scale", ".attn.k_scale") + if name.endswith(".output_scale") and ".v_proj" in name: + return name.replace(".v_proj.output_scale", ".attn.v_scale") + if name.endswith(".output_scale") and ".q_proj" in name: + return name.replace(".q_proj.output_scale", ".attn.q_scale") + if name.endswith("self_attn.prob_output_scale"): + return name.replace(".prob_output_scale", ".attn.prob_scale") # If no matches, return None return None diff --git a/vllm/model_executor/layers/quantization/utils/bitblas_utils.py b/vllm/model_executor/layers/quantization/utils/bitblas_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5d28d327e8a2f4f4689f94366523aab69719e2ac --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/bitblas_utils.py @@ -0,0 +1,198 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + +import torch + +from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType, scalar_types + +MINIMUM_BITBLAS_VERSION = "0.1.0" + +BITBLAS_MIN_WEIGHT_SIZE_N = 16 +BITBLAS_MIN_WEIGHT_SIZE_K = 16 +GPTQ_BITBLAS_MAX_PARALLEL = 16 + +BITBLAS_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +# For dynamic shape code generation +BITBLAS_OPTIMIZE_FEATURES = [1, 16, 32, 64, 128, 256, 512, 1024] +# If want to enable high performance for contiguous batching +# Please use the following values +BITBLAS_OPTIMIZE_FEATURES_CONTIGUOUS = [16, 32, 64, 128, 256, 512, 1024] + +BITBLAS_SUPPORTED_NUM_BITS = [1, 2, 4, 8] +BITBLAS_SUPPORTED_SYM = [False, True] + + +# Determines the supported quantization types for BitBLAS based on the +# device's capability and whether zero-point (zp) is used. +def query_bitblas_supported_quant_types(has_zp: bool, + device_capability: Optional[int] = None + ): + if device_capability is None: + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) + + if device_capability < 70: + return [] + + if has_zp: + # AWQ style, unsigned + runtime zero-point + return [scalar_types.uint4, scalar_types.uint8] + else: + # GPTQ style, unsigned + symmetric bias + # TODO: once fp8_bitblas is merged into "gptq_bitblas" we should be able + # to add `scalar_types.float8_e4m3fn` here + return [scalar_types.uint4b8, scalar_types.uint8b128] + + +def _check_bitblas_supported( + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]: + + if device_capability is None: + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) + + supported_types = query_bitblas_supported_quant_types( + has_zp, device_capability) + + if quant_type not in supported_types: + return (False, f"BitBLAS does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).") + if (group_size is None or group_size not in BITBLAS_SUPPORTED_GROUP_SIZES): + return (False, f"BitBLAS does not support group_size = {group_size}. " + f"Only group_sizes = {BITBLAS_SUPPORTED_GROUP_SIZES} " + "are supported.") + + return True, None + + +def check_bitblas_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None) -> bool: + cond, _ = _check_bitblas_supported(quant_type, group_size, has_zp, + device_capability) + return cond + + +def verify_bitblas_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False) -> None: + cond, err_msg = _check_bitblas_supported(quant_type, group_size, has_zp) + if not cond: + assert err_msg is not None + raise ValueError(err_msg) + + +def verify_bitblas_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) -> None: + + # Validate output_size_per_partition + if output_size_per_partition % BITBLAS_MIN_WEIGHT_SIZE_N != 0: + raise ValueError(f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {BITBLAS_MIN_WEIGHT_SIZE_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + # Validate input_size_per_partition + if input_size_per_partition % BITBLAS_MIN_WEIGHT_SIZE_K != 0: + raise ValueError(f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {BITBLAS_MIN_WEIGHT_SIZE_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + if (group_size < input_size + and input_size_per_partition % group_size != 0): + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition}" + f" is not divisible by group_size = {group_size}." + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + +def check_bitblas_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) \ + -> Tuple[bool, Optional[str]]: + try: + verify_bitblas_supports_shape(output_size_per_partition, + input_size_per_partition, input_size, + group_size) + except ValueError as e: + return False, e.__str__() + return True, None + + +def bitblas_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: + return (not act_order) or (act_order and not is_row_parallel) + + +def bitblas_repeat_scales_on_all_ranks(act_order: bool, group_size: int, + is_row_parallel: bool) -> bool: + # Need to repeat scales on every rank if act_ordering or + # channelwise and RowParallelLinear + is_channelwise = group_size == -1 + return act_order or (is_channelwise and is_row_parallel) + + +def bitblas_make_empty_g_idx(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) + + +def bitblas_make_empty_zp(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) + + +def bitblas_sort_g_idx( + g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) + return g_idx[g_idx_sort_indices], g_idx_sort_indices + + +def unpack_gptq_qzeros(qzeros, bits, is_gptq_v2=False) -> torch.Tensor: + qzeros = qzeros.view(torch.int32) + elems_per_int32 = 32 // bits + unpacked_zeros = torch.zeros( + (qzeros.shape[0], qzeros.shape[1] * elems_per_int32), + dtype=torch.int8, + device=qzeros.device, + requires_grad=False, + ) + + for col in range(unpacked_zeros.shape[1]): + i = col % elems_per_int32 + unpacked_zeros[:, col] = (qzeros[:, col // elems_per_int32] >> + (bits * i)) & 0xF + if not is_gptq_v2: + return unpacked_zeros + 1 + return unpacked_zeros + + +def unpack_gptq_qweight(qweight, bits): + qweight = qweight.view(torch.int8) + elems_per_int8 = 8 // bits + unpacked_weight = torch.zeros( + (qweight.shape[0], qweight.shape[1] * elems_per_int8), + dtype=torch.int8, + device=qweight.device, + requires_grad=False, + ) + for col in range(unpacked_weight.shape[1]): + i = col % elems_per_int8 + unpacked_weight[:, col] = (qweight[:, col // elems_per_int8] >> + (bits * i)) + + return torch.bitwise_and(unpacked_weight, 2**bits - 1) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 5b2e3ca2c799dbdeb87ef23e16b695f3169969fd..4a190480d35b6c7effa3b90ac9f650267f8f2cae 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -151,6 +151,19 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) \ group_size=group_size)[0] +def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \ + -> bool: + hidden_size = layer.hidden_size + intermediate_size_per_partition = layer.intermediate_size_per_partition + + # gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size) + # down: (n, k) = (hidden_size, intermediate_size_per_partition) + # moe marlin requires n % 128 == 0 and k % 64 == 0 + return hidden_size % 128 == 0 and \ + intermediate_size_per_partition % max(64, group_size) == 0 and \ + group_size in [-1, 32, 64, 128] + + def marlin_make_workspace(output_size_per_partition: int, device: torch.device) -> torch.Tensor: max_workspace_size = (output_size_per_partition // @@ -319,6 +332,7 @@ def apply_gptq_marlin_linear( wtype: ScalarType, output_size_per_partition: int, input_size_per_partition: int, + has_zp: bool, is_k_full: bool, bias: Optional[torch.Tensor] = None, use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: @@ -343,8 +357,8 @@ def apply_gptq_marlin_linear( size_n=output_size_per_partition, size_k=input_size_per_partition, is_k_full=is_k_full, - has_zp=False, use_atomic_add=use_atomic_add, + has_zp=has_zp, use_fp32_reduce=use_fp32_reduce, is_zp_float=False) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index c0a4c4dc461a89ebabb27a43c90c34e702374e6b..e6576afa396e0f9d804170feaeb2b0826bd26d62 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch from vllm import _custom_ops as ops +from vllm import envs from vllm.config import CompilationLevel, get_current_vllm_config from vllm.platforms import current_platform from vllm.utils import W8a8GetCacheJSON @@ -19,6 +20,7 @@ W8A8_TRITONJSON=W8a8GetCacheJSON() # The condition is determined once as the operations # are time consuming. USE_ROWWISE_TORCH_SCALED_MM = (current_platform.is_rocm() + and torch.__version__[0:3] >= "2.7" and current_platform.has_device_capability(94)) def sparse_cutlass_supported() -> bool: @@ -132,6 +134,160 @@ def maybe_create_device_identity(): TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) +def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor, + out_dtype: torch.dtype, scale_a: torch.Tensor, + scale_b: torch.Tensor, bias: torch.Tensor, + output_shape: List, **kwargs) -> torch.Tensor: + + # Fused GEMM_DQ + output = ops.cutlass_scaled_mm(qinput, + weight, + out_dtype=out_dtype, + scale_a=scale_a, + scale_b=scale_b, + bias=bias) + return output.view(*output_shape) + + +def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, bias: torch.Tensor, + input_2d: torch.Tensor, + output_shape: List) -> torch.Tensor: + from vllm.platforms.rocm import on_mi250_mi300 + if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300( + ) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0: + output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b, + current_platform.get_cu_count()) + else: + output = torch._scaled_mm(qinput, + weight, + out_dtype=out_dtype, + scale_a=scale_a, + scale_b=scale_b, + bias=bias) + + return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) + + +def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, bias: torch.Tensor, + input_2d: torch.Tensor, + output_shape: List) -> torch.Tensor: + output = torch._scaled_mm(qinput, + weight, + out_dtype=out_dtype, + scale_a=scale_a, + scale_b=scale_b, + bias=bias) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] + + return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) + + +def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, bias: torch.Tensor, + input_2d: torch.Tensor, + output_shape: List) -> torch.Tensor: + # Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM + # when using it. + # For now it has only been validated on ROCm platform. + # fp8 rowwise scaling in torch._scaled_mm is introduced in + # https://github.com/pytorch/pytorch/pull/144432 using + # hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above. + # + # For CUDA platform please validate if the torch._scaled_mm supports + # rowwise scaled GEMM before using it + + # Fused GEMM_DQ Rowwise GEMM + output = torch._scaled_mm(qinput, + weight, + out_dtype=out_dtype, + scale_a=scale_a, + scale_b=scale_b.t(), + bias=bias) + + output = torch.narrow(output, 0, 0, input_2d.shape[0]) + output = output.view(*output_shape) + return output + + +def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, bias: torch.Tensor, + input_2d: torch.Tensor, + output_shape: List, + **kwargs) -> torch.Tensor: + # Use unfused DQ due to limitations with scaled_mm + + # Symmetric quantized GEMM by definition computes the following: + # C = (s_x * X) (s_w * W) + bias + # This is equivalent to dequantizing the weights and activations + # before applying a GEMM. + # + # In order to compute quantized operands, a quantized kernel + # will rewrite the above like so: + # C = s_w * s_x * (X * W) + bias + # + # For the scaled_mm fallback case, we break this down, since it + # does not support s_w being a vector. + + # GEMM + # This computes C = (X * W). + # Output in fp32 to allow subsequent ops to happen in-place + output = torch._scaled_mm(qinput, + weight, + scale_a=TORCH_DEVICE_IDENTITY, + scale_b=TORCH_DEVICE_IDENTITY, + out_dtype=torch.float32) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] + # Unpad (undo num_token_padding) + output = torch.narrow(output, 0, 0, input_2d.shape[0]) + x_scale = torch.narrow(scale_a, 0, 0, input_2d.shape[0]) + + # DQ + # C = sw * sx * (X * W) + bias + output = output * x_scale * scale_b.t() + if bias is not None: + output = output + bias + return output.to(out_dtype).view(*output_shape) + + +def dispatch_w8a8_scaled_mm( + cutlass_fp8_supported: bool, per_tensor_weights: bool, + per_tensor_activations: bool, use_per_token_if_dynamic: Optional[bool] +) -> Callable[..., torch.Tensor]: + + if cutlass_fp8_supported: + return cutlass_w8a8_scaled_mm + if per_tensor_weights and per_tensor_activations: + if current_platform.is_rocm(): + return rocm_per_tensor_w8a8_scaled_mm + return torch_per_tensor_w8a8_scaled_mm + # torch.scaled_mm supports per tensor weights + activations only + # so fallback to naive if per channel or per token + if (use_per_token_if_dynamic and not per_tensor_weights + and not per_tensor_activations and USE_ROWWISE_TORCH_SCALED_MM): + return torch_per_token_w8a8_scaled_mm + return torch_channelwise_w8a8_scaled_mm + + # TODO(luka): follow similar pattern for marlin and block-fp8-linear # https://github.com/vllm-project/vllm/issues/14397 class Fp8LinearOp: @@ -157,7 +313,8 @@ class Fp8LinearOp: if pad_output is None: config = get_current_vllm_config().compilation_config pad_output = config.level < CompilationLevel.PIECEWISE - self.output_padding = 17 if pad_output else None + self.output_padding = 17 if ( + pad_output and not current_platform.is_rocm()) else None def apply( self, @@ -196,18 +353,6 @@ class Fp8LinearOp: input_scale, scale_ub=input_scale_ub, use_per_token_if_dynamic=use_per_token_if_dynamic) - - # Fused GEMM_DQ - output = ops.cutlass_scaled_mm(qinput, - weight, - out_dtype=out_dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias) - return output.view(*output_shape) - - # torch.scaled_mm supports per tensor weights + activations only - # so fallback to naive if per channel or per token else: if input.dtype != current_platform.fp8_dtype(): # Maybe apply padding to output, see comment in __init__ @@ -219,84 +364,21 @@ class Fp8LinearOp: else: qinput, x_scale = input_2d, input_scale - per_tensor_weights = (weight_scale.numel() == 1) - per_tensor_activations = (x_scale.numel() == 1) - - if per_tensor_weights and per_tensor_activations: - # Fused GEMM_DQ - output = torch._scaled_mm(qinput, - weight, - out_dtype=out_dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias) - # A fix for discrepancy in scaled_mm which returns tuple - # for torch < 2.5 and a single value in torch >= 2.5 - if type(output) is tuple and len(output) == 2: - output = output[0] - - return torch.narrow(output, 0, 0, - input_2d.shape[0]).view(*output_shape) - - elif (use_per_token_if_dynamic and not per_tensor_weights - and not per_tensor_activations - and USE_ROWWISE_TORCH_SCALED_MM): - # For now validated on ROCm platform - # fp8 rowwise scaling in torch._scaled_mm is introduced in - # https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt - # and ROCm 6.3, which only exists in torch 2.7 and above. - # For CUDA platform please validate if the - # torch._scaled_mm support rowwise scaled GEMM - # Fused GEMM_DQ Rowwise GEMM - output = torch._scaled_mm(qinput, - weight, - out_dtype=out_dtype, - scale_a=x_scale, - scale_b=weight_scale.t(), - bias=bias) - - output = torch.narrow(output, 0, 0, input_2d.shape[0]) - output = output.view(*output_shape) - return output - - else: - # Fallback for channelwise case, where we use unfused DQ - # due to limitations with scaled_mm - - # Symmetric quantized GEMM by definition computes the following: - # C = (s_x * X) (s_w * W) + bias - # This is equivalent to dequantizing the weights and activations - # before applying a GEMM. - # - # In order to compute quantized operands, a quantized kernel - # will rewrite the above like so: - # C = s_w * s_x * (X * W) + bias - # - # For the scaled_mm fallback case, we break this down, since it - # does not support s_w being a vector. - - # GEMM - # This computes C = (X * W). - # Output in fp32 to allow subsequent ops to happen in-place - output = torch._scaled_mm(qinput, - weight, - scale_a=TORCH_DEVICE_IDENTITY, - scale_b=TORCH_DEVICE_IDENTITY, - out_dtype=torch.float32) - # A fix for discrepancy in scaled_mm which returns tuple - # for torch < 2.5 and a single value in torch >= 2.5 - if type(output) is tuple and len(output) == 2: - output = output[0] - # Unpad (undo num_token_padding) - output = torch.narrow(output, 0, 0, input_2d.shape[0]) - x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0]) - - # DQ - # C = sw * sx * (X * W) + bias - output = output * x_scale * weight_scale.t() - if bias is not None: - output = output + bias - return output.to(dtype=input.dtype).view(*output_shape) + per_tensor_weights = (weight_scale.numel() == 1) + per_tensor_activations = (x_scale.numel() == 1) + + w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm( + self.cutlass_fp8_supported, per_tensor_weights, + per_tensor_activations, use_per_token_if_dynamic) + + return w8a8_scaled_mm_func(qinput=qinput, + weight=weight, + out_dtype=out_dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias, + input_2d=input_2d, + output_shape=output_shape) def apply_int8_linear( diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 624ed63ab8b4bb5f41463acd3210c08d4db29f93..c5970c71c539f4a4ec2eccfc9024a46d6bb04270 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -46,20 +46,12 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: return x.flatten(-2) -def _apply_rotary_emb( +def _apply_rotary_emb_torch( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, is_neox_style: bool, ) -> torch.Tensor: - """ - Args: - x: [num_tokens, num_heads, head_size] - cos: [num_tokens, head_size // 2] - sin: [num_tokens, head_size // 2] - is_neox_style: Whether to use the Neox-style or GPT-J-style rotary - positional embeddings. - """ cos = cos.unsqueeze(-2).to(x.dtype) sin = sin.unsqueeze(-2).to(x.dtype) if is_neox_style: @@ -75,6 +67,24 @@ def _apply_rotary_emb( return torch.stack((o1, o2), dim=-1).flatten(-2) +def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, + is_neox_style: bool) -> torch.Tensor: + """ + Args: + x: [num_tokens, num_heads, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. + """ + if current_platform.is_cuda_alike(): + from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb + return apply_rotary_emb(x.unsqueeze(0), cos, sin, + not is_neox_style).squeeze(0) + else: + return _apply_rotary_emb_torch(x, cos, sin, is_neox_style) + + @CustomOp.register("rotary_embedding") class RotaryEmbedding(CustomOp): """Original rotary positional embedding.""" @@ -141,14 +151,16 @@ class RotaryEmbedding(CustomOp): query = query.view(num_tokens, -1, self.head_size) query_rot = query[..., :self.rotary_dim] query_pass = query[..., self.rotary_dim:] - query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query_rot = _apply_rotary_emb_torch(query_rot, cos, sin, + self.is_neox_style) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) key_shape = key.shape key = key.view(num_tokens, -1, self.head_size) key_rot = key[..., :self.rotary_dim] key_pass = key[..., self.rotary_dim:] - key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key_rot = _apply_rotary_emb_torch(key_rot, cos, sin, + self.is_neox_style) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key @@ -988,8 +1000,9 @@ class MRotaryEmbedding(RotaryEmbedding): key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key - @staticmethod + @classmethod def get_input_positions( + cls, input_tokens: List[int], hf_config: PretrainedConfig, image_grid_thw: Optional[Union[List[List[int]], torch.Tensor]], @@ -997,6 +1010,8 @@ class MRotaryEmbedding(RotaryEmbedding): second_per_grid_ts: Optional[List[float]], context_len: int = 0, seq_len: Optional[int] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, ) -> Tuple[List[List[int]], int]: """Get mrope input positions and delta value.""" @@ -1006,7 +1021,7 @@ class MRotaryEmbedding(RotaryEmbedding): second_per_grid_ts llm_positions, mrope_position_delta = \ - MRotaryEmbedding.get_input_positions_tensor( + cls.get_input_positions_tensor( input_tokens=input_tokens, hf_config=hf_config, image_grid_thw=image_grid_thw, @@ -1014,12 +1029,52 @@ class MRotaryEmbedding(RotaryEmbedding): second_per_grid_ts=second_per_grid_ts, context_len=context_len, seq_len=seq_len, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, ) return llm_positions.tolist(), mrope_position_delta - @staticmethod + @classmethod def get_input_positions_tensor( + cls, + input_tokens: List[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[List[List[int]], torch.Tensor], + video_grid_thw: Union[List[List[int]], torch.Tensor], + second_per_grid_ts: List[float], + context_len: int = 0, + seq_len: Optional[int] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + ) -> Tuple[torch.Tensor, int]: + from vllm.transformers_utils.config import thinker_uses_mrope + if thinker_uses_mrope(hf_config): + return cls._omni_get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + context_len=context_len, + seq_len=seq_len, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + else: + return cls._vl_get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + context_len=context_len, + seq_len=seq_len, + ) + + @classmethod + def _vl_get_input_positions_tensor( + cls, input_tokens: List[int], hf_config: PretrainedConfig, image_grid_thw: Union[List[List[int]], torch.Tensor], @@ -1037,11 +1092,6 @@ class MRotaryEmbedding(RotaryEmbedding): tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0) - if isinstance(image_grid_thw, torch.Tensor): - image_grid_thw = image_grid_thw.tolist() - if isinstance(video_grid_thw, torch.Tensor): - video_grid_thw = video_grid_thw.tolist() - input_tokens_tensor = torch.tensor(input_tokens) vision_start_indices = torch.argwhere( input_tokens_tensor == vision_start_token_id).squeeze(1) @@ -1121,6 +1171,224 @@ class MRotaryEmbedding(RotaryEmbedding): return llm_positions, mrope_position_delta + @classmethod + def _omni_get_input_positions_tensor( + cls, + input_tokens: List[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[List[List[int]], torch.Tensor], + video_grid_thw: Union[List[List[int]], torch.Tensor], + second_per_grid_ts: Optional[List[float]] = None, + context_len: int = 0, + seq_len: Optional[int] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + ) -> Tuple[torch.Tensor, int]: + """Get mrope input positions and delta value (Qwen2.5-Omni version). + + Differences from MRotaryEmbedding: + 1. Add audio support (and related `audio_feature_lengths`). + 2. Add `use_audio_in_video` option to read audio from video inputs. + In this case, audio and vision position ids will be split into + chunks and interleaved. + + Example: + + (V_i are vision position ids, A_i are audio position ids) + + |V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|... + |vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |... + """ + + # TODO(fyabc): refactor and share more code with + # _vl_get_input_positions_tensor. + + thinker_config = hf_config.thinker_config + audio_token_id = thinker_config.audio_token_index + image_token_id = thinker_config.image_token_index + video_token_id = thinker_config.video_token_index + audio_start_token_id = thinker_config.audio_start_token_id + audio_end_token_id = thinker_config.audio_end_token_id + vision_start_token_id = thinker_config.vision_start_token_id + vision_end_token_id = thinker_config.vision_end_token_id + seconds_per_chunk = thinker_config.seconds_per_chunk + spatial_merge_size = thinker_config.vision_config.spatial_merge_size + tokens_per_second = getattr(thinker_config.vision_config, + "tokens_per_second", 25) + + if isinstance(image_grid_thw, list): + image_grid_thw = torch.tensor(image_grid_thw) + if isinstance(video_grid_thw, list): + video_grid_thw = torch.tensor(video_grid_thw) + + src_item = input_tokens + audio_seqlens = audio_feature_lengths + if not second_per_grid_ts: + second_per_grid_ts = [1] * video_grid_thw.shape[0] + audio_idx = 0 + video_idx = 0 + image_idx = 0 + new_src_item: list[int] = [] + llm_pos_ids_list: list[torch.Tensor] = [] + + idx = 0 + while idx < len(src_item): + new_src_item_len = len(new_src_item) + start_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + if src_item[idx] not in [ + audio_token_id, video_token_id, image_token_id + ]: + if use_audio_in_video and idx > 0: + if src_item[idx] == vision_end_token_id and \ + src_item[idx - 1] == audio_end_token_id: + # processing the <|audio_eos|> before <|vision_eos|> + start_idx -= 1 + elif src_item[idx] == audio_start_token_id and \ + src_item[idx - 1] == vision_start_token_id: + # processing the <|audio_bos|> after <|vision_eos|> + start_idx -= 1 + new_src_item.append(src_item[idx]) + llm_pos_ids = torch.tensor([start_idx], + dtype=torch.long).expand(3, -1) + llm_pos_ids_list.append(llm_pos_ids) + elif src_item[idx] == audio_token_id: + assert audio_seqlens is not None + audio_seqlen = audio_seqlens[audio_idx] + place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + new_src_item.extend([audio_token_id] * place_num) + llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx + llm_pos_ids_list.append(llm_pos_ids) + audio_idx += 1 + elif src_item[idx] == image_token_id: + grid_t = image_grid_thw[image_idx][0] + grid_hs = image_grid_thw[:, 1] + grid_ws = image_grid_thw[:, 2] + t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long() + llm_pos_ids = cls._get_llm_pos_ids_for_vision( + start_idx, image_idx, spatial_merge_size, t_index, grid_hs, + grid_ws) + llm_pos_ids_list.append(llm_pos_ids) + vision_seqlen = image_grid_thw[image_idx].prod() // ( + spatial_merge_size**2) + new_src_item.extend([image_token_id] * vision_seqlen) + image_idx += 1 + elif src_item[idx] == video_token_id and not use_audio_in_video: + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = (torch.arange(grid_t) * + second_per_grid_ts[video_idx] * + tokens_per_second).long() + llm_pos_ids = cls._get_llm_pos_ids_for_vision( + start_idx, video_idx, spatial_merge_size, t_index, grid_hs, + grid_ws) + llm_pos_ids_list.append(llm_pos_ids) + vision_seqlen = video_grid_thw[video_idx].prod() // ( + spatial_merge_size**2) + new_src_item.extend([video_token_id] * vision_seqlen) + video_idx += 1 + else: + # read audio from video + assert audio_seqlens is not None + audio_seqlen = audio_seqlens[audio_idx] + vision_seqlen = video_grid_thw[video_idx].prod() // ( + spatial_merge_size**2) + grid_t = video_grid_thw[video_idx][0] + grid_h = video_grid_thw[video_idx][1] + grid_w = video_grid_thw[video_idx][2] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) + t_index = (torch.arange(grid_t) * + second_per_grid_ts[video_idx] * + tokens_per_second).long() + t_index_split_chunk = cls._split_list_into_ranges( + t_index, t_ntoken_per_chunk) + place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2 + pure_audio_len = place_num - 2 + added_audio_len = 0 + audio_llm_pos_ids_list: List[torch.Tensor] = [] + for t_chunk in t_index_split_chunk: + vision_ntoken_per_chunk = len( + t_chunk) * grid_h * grid_w // (spatial_merge_size**2) + new_src_item.extend([video_token_id] * + vision_ntoken_per_chunk) + vision_llm_pos_ids_list = cls._get_llm_pos_ids_for_vision( + start_idx, video_idx, spatial_merge_size, t_chunk, + grid_hs, grid_ws).split(1, dim=1) + llm_pos_ids_list.extend(vision_llm_pos_ids_list) + new_src_item.extend( + min(t_ntoken_per_chunk, pure_audio_len - + added_audio_len) * [audio_token_id]) + audio_start_idx = start_idx if len( + audio_llm_pos_ids_list + ) == 0 else audio_llm_pos_ids_list[-1][0].item() + 1 + if min(t_ntoken_per_chunk, + pure_audio_len - added_audio_len) > 0: + audio_llm_pos_ids_list = (torch.arange( + min(t_ntoken_per_chunk, pure_audio_len - + added_audio_len)).expand(3, -1) + + audio_start_idx).split(1, + dim=1) + else: + audio_llm_pos_ids_list = [] + added_audio_len += min(t_ntoken_per_chunk, + pure_audio_len - added_audio_len) + llm_pos_ids_list.extend(audio_llm_pos_ids_list) + if added_audio_len < pure_audio_len: + new_src_item.extend( + (pure_audio_len - added_audio_len) * [audio_token_id]) + audio_llm_pos_ids_list = ( + torch.arange(pure_audio_len - added_audio_len).expand( + 3, -1) + llm_pos_ids_list[-1].max() + 1).split( + 1, dim=1) + llm_pos_ids_list.extend(audio_llm_pos_ids_list) + audio_idx += 1 + video_idx += 1 + # move to the next token + idx += len(new_src_item) - new_src_item_len + + llm_positions = torch.cat(llm_pos_ids_list, dim=1) + mrope_position_delta = torch.cat(llm_pos_ids_list, + dim=1).max() + 1 - len(src_item) + llm_positions = llm_positions[:, context_len:seq_len] + + return llm_positions, mrope_position_delta + + @staticmethod + def _get_llm_pos_ids_for_vision( + start_idx: int, + vision_idx: int, + spatial_merge_size: int, + t_index: List[int], + grid_hs: torch.Tensor, + grid_ws: torch.Tensor, + ) -> torch.Tensor: + llm_pos_ids_list = [] + llm_grid_h = grid_hs[vision_idx] // spatial_merge_size + llm_grid_w = grid_ws[vision_idx] // spatial_merge_size + h_index = (torch.arange(llm_grid_h).view(1, -1, 1).expand( + len(t_index), -1, llm_grid_w).flatten()) + w_index = (torch.arange(llm_grid_w).view(1, 1, -1).expand( + len(t_index), llm_grid_h, -1).flatten()) + t_index_tensor = torch.Tensor(t_index).to(llm_grid_h.device).view( + -1, 1).expand(-1, llm_grid_h * llm_grid_w).long().flatten() + _llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index]) + llm_pos_ids_list.append(_llm_pos_ids + start_idx) + llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) + return llm_pos_ids + + @staticmethod + def _split_list_into_ranges(lst: torch.Tensor, + interval: int) -> List[List[int]]: + ranges: List[List[int]] = [[] + for _ in range((max(lst) // interval) + 1)] + for num in lst: + index = num // interval + ranges[index].append(num) + return ranges + @staticmethod def get_next_input_positions( mrope_position_delta: int, @@ -1144,6 +1412,58 @@ class MRotaryEmbedding(RotaryEmbedding): mrope_position_delta + seq_len, ).expand(3, -1) + @classmethod + def omni_get_updates_use_audio_in_video( + cls, + thinker_config: PretrainedConfig, + audio_len: int, + video_grid_thw: Union[List[int], torch.Tensor], + video_second_per_grid_t: float, + ) -> List[int]: + """Get video prompt updates when `use_audio_in_video` is True. + + In this case, audio and vision update ids will be split into + chunks and interleaved (details in `_omni_get_input_positions_tensor`). + + <|video_bos|><|VIDEO|><|video_eos|> => + <|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|> + """ + + audio_token_id = thinker_config.audio_token_index + video_token_id = thinker_config.video_token_index + audio_start_token_id = thinker_config.audio_start_token_id + audio_end_token_id = thinker_config.audio_end_token_id + seconds_per_chunk = thinker_config.seconds_per_chunk + spatial_merge_size = thinker_config.vision_config.spatial_merge_size + tokens_per_second = getattr(thinker_config.vision_config, + "tokens_per_second", 25) + + grid_t = video_grid_thw[0] + grid_h = video_grid_thw[1] + grid_w = video_grid_thw[2] + t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) + t_index = (torch.arange(grid_t) * video_second_per_grid_t * + tokens_per_second).long() + t_index_split_chunk = cls._split_list_into_ranges( + t_index, t_ntoken_per_chunk) + + updates = [audio_start_token_id] + added_audio_len = 0 + for t_chunk in t_index_split_chunk: + vision_ntoken_per_chunk = len(t_chunk) * grid_h * grid_w // ( + spatial_merge_size**2) + updates.extend([video_token_id] * vision_ntoken_per_chunk) + + audio_chunk_size = min(t_ntoken_per_chunk, + audio_len - added_audio_len) + updates.extend(audio_chunk_size * [audio_token_id]) + added_audio_len += audio_chunk_size + if added_audio_len < audio_len: + updates.extend((audio_len - added_audio_len) * [audio_token_id]) + updates.extend([audio_end_token_id]) + + return updates + _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index a9ef973917e19fdf0b43813e0abb62e09171f464..cc4a69a26f274f644cd22a9cab544db73277f9bf 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -1,9 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 """Utility methods for model layers.""" -from typing import Tuple +from typing import Callable, Optional, Tuple import torch +from vllm import _custom_ops as ops +from vllm import envs +from vllm.platforms import current_platform + def get_token_bin_counts_and_mask( tokens: torch.Tensor, @@ -47,12 +51,49 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, output_tokens_tensor, vocab_size, num_seqs) repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat( 1, vocab_size) - logits[logits > 0] /= torch.where(prompt_mask | output_mask, - repetition_penalties, 1.0)[logits > 0] - logits[logits <= 0] *= torch.where(prompt_mask | output_mask, - repetition_penalties, 1.0)[logits <= 0] + + # If token appears in prompt or output, apply, otherwise use 1.0 for no-op. + penalties = torch.where(prompt_mask | output_mask, repetition_penalties, + 1.0) + + # If logits are positive, divide by penalty, otherwise multiply by penalty. + scaling = torch.where(logits > 0, 1.0 / penalties, penalties) + logits *= scaling + # We follow the definition in OpenAI API. # Refer to https://platform.openai.com/docs/api-reference/parameter-details logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts logits -= presence_penalties.unsqueeze(dim=1) * output_mask return logits + + +def rocm_unquantized_gemm(x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None): + from vllm.platforms.rocm import on_mi250_mi300 + k = weight.shape[1] + use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300() and \ + x.dtype in [torch.float16, torch.bfloat16] \ + and k % 8 == 0 and bias is None) + + if use_skinny is not True: + return torch.nn.functional.linear(x, weight, bias) + + x_view = x.view(-1, x.size(-1)) + n = x_view.shape[0] + m = weight.shape[0] + cu_count = current_platform.get_cu_count() + + if m > 8 and 0 < n < 4: + out = ops.wvSplitK(weight, x_view, cu_count) + return out.view(*x.shape[:-1], weight.shape[0]) + elif m % 4 == 0 and n == 1 and k <= 8192: + out = ops.LLMM1(weight, x_view, 4) + return out.view(*x.shape[:-1], weight.shape[0]) + return torch.nn.functional.linear(x, weight, bias) + + +def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]: + # if current_platform.is_rocm(): + # return rocm_unquantized_gemm + return torch.nn.functional.linear diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 69e827ebf4ea0092bf2da42f3ee8a91cab0f70c7..64283b66679e32181fa5e54e4c61328824ed540a 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -13,6 +13,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding) +from vllm.model_executor.layers.utils import dispatch_unquantized_gemm from vllm.model_executor.parameter import BasevLLMParameter from vllm.model_executor.utils import set_weight_attrs @@ -55,8 +56,8 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): else: return torch.matmul(x, layer.weight) else: - return F.linear(x, layer.weight, bias) - + return dispatch_unquantized_gemm()(x, layer.weight, bias) + def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: return F.embedding(input_, layer.weight) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 3d9ebd9919b1ef65a31a0e67279e44d6386e8463..8a5e9346fce4be673a5802b1844f56d49e75bbd0 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -613,8 +613,12 @@ class ShardedStateLoader(BaseModelLoader): DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors" - def __init__(self, load_config: LoadConfig): + def __init__(self, + load_config: LoadConfig, + runai_model_streamer: bool = False): super().__init__(load_config) + + self.runai_model_streamer = runai_model_streamer extra_config = ({} if load_config.model_loader_extra_config is None else load_config.model_loader_extra_config.copy()) self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN) @@ -661,7 +665,7 @@ class ShardedStateLoader(BaseModelLoader): def _prepare_weights(self, model_name_or_path: str, revision: Optional[str]): - if os.path.isdir(model_name_or_path): + if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path): return model_name_or_path else: allow_patterns = ["*.safetensors"] @@ -680,12 +684,13 @@ class ShardedStateLoader(BaseModelLoader): device_config = vllm_config.device_config model_config = vllm_config.model_config target_device = torch.device(device_config.device) - from safetensors.torch import safe_open from vllm.distributed import get_tensor_model_parallel_rank - local_model_path = self._prepare_weights(model_config.model, - model_config.revision) + model_weights = model_config.model + if hasattr(model_config, "model_weights"): + model_weights = model_config.model_weights + local_model_path = model_weights with set_default_torch_dtype(model_config.dtype): with target_device: @@ -697,40 +702,56 @@ class ShardedStateLoader(BaseModelLoader): local_model_path, self.pattern.format(rank=rank, part="*"), ) - filepaths = glob.glob(pattern) + + filepaths = [] + if is_s3(local_model_path): + file_pattern = f"*{self.pattern.format(rank=rank, part=' * ')}" + filepaths = s3_glob(path=local_model_path, + allow_pattern=[file_pattern]) + else: + filepaths = glob.glob(pattern) if not filepaths: # TODO: support un-sharded checkpoints too raise ValueError( f"Could not find checkpoint files '{pattern}', only " f"pre-sharded checkpoints are currently supported!") state_dict = self._filter_subtensors(model.state_dict()) - for path in filepaths: - with safe_open(path, framework="pt") as f: - for key in f.keys(): # noqa: SIM118 - tensor = f.get_tensor(key) - # If loading with LoRA enabled, additional padding may - # be added to certain parameters. We only load into a - # narrowed view of the parameter data. - param_data = state_dict[key].data - param_shape = state_dict[key].shape - for dim, size in enumerate(tensor.shape): - if size < param_shape[dim]: - param_data = param_data.narrow(dim, 0, size) - if tensor.shape != param_shape: - logger.warning( - "loading tensor of shape %s into " - "parameter '%s' of shape %s", - tensor.shape, - key, - param_shape, - ) - param_data.copy_(tensor) - state_dict.pop(key) + for key, tensor in self.iterate_over_files(filepaths): + # If loading with LoRA enabled, additional padding may + # be added to certain parameters. We only load into a + # narrowed view of the parameter data. + param_data = state_dict[key].data + param_shape = state_dict[key].shape + for dim, size in enumerate(tensor.shape): + if size < param_shape[dim]: + param_data = param_data.narrow(dim, 0, size) + if tensor.shape != param_shape: + logger.warning( + "loading tensor of shape %s into " + "parameter '%s' of shape %s", + tensor.shape, + key, + param_shape, + ) + param_data.copy_(tensor) + state_dict.pop(key) if state_dict: raise ValueError( f"Missing keys {tuple(state_dict)} in loaded state!") return model.eval() + def iterate_over_files( + self, paths) -> Generator[Tuple[str, torch.Tensor], None, None]: + if self.runai_model_streamer: + yield from runai_safetensors_weights_iterator(paths, True) + else: + from safetensors.torch import safe_open + for path in paths: + with safe_open(path, framework="pt") as f: + for key in f.keys(): # noqa: SIM118 + tensor = f.get_tensor(key) + yield key, tensor + @staticmethod def save_model( model: torch.nn.Module, @@ -1517,4 +1538,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: if load_config.load_format == LoadFormat.RUNAI_STREAMER: return RunaiModelStreamerLoader(load_config) + if load_config.load_format == LoadFormat.RUNAI_STREAMER_SHARDED: + return ShardedStateLoader(load_config, runai_model_streamer=True) + return DefaultModelLoader(load_config) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 6a781b5e5ed24b611ca21b8765320f001cb60c44..b9c44cdb3097b0550d36b237a56bd2b6b570a14c 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -31,15 +31,6 @@ def set_default_torch_dtype(dtype: torch.dtype): torch.set_default_dtype(old_dtype) -def is_transformers_impl_compatible( - arch: str, - module: Optional["transformers.PreTrainedModel"] = None) -> bool: - mod = module or getattr(transformers, arch, None) - if mod is None: - return False - return mod.is_backend_compatible() - - def resolve_transformers_arch(model_config: ModelConfig, architectures: list[str]): for i, arch in enumerate(architectures): @@ -56,20 +47,32 @@ def resolve_transformers_arch(model_config: ModelConfig, # "AutoModelFor": "--", # }, auto_modules = { - name: get_class_from_dynamic_module(module, model_config.model) + name: + get_class_from_dynamic_module(module, + model_config.model, + revision=model_config.revision) for name, module in sorted(auto_map.items(), key=lambda x: x[0]) } - custom_model_module = auto_modules.get("AutoModel") + model_module = getattr(transformers, arch, None) + if model_module is None: + if "AutoModel" not in auto_map: + raise ValueError( + f"Cannot find model module. '{arch}' is not a registered " + "model in the Transformers library (only relevant if the " + "model is meant to be in Transformers) and 'AutoModel' is " + "not present in the model config's 'auto_map' (relevant " + "if the model is custom).") + model_module = auto_modules["AutoModel"] # TODO(Isotr0py): Further clean up these raises. # perhaps handled them in _ModelRegistry._raise_for_unsupported? if model_config.model_impl == ModelImpl.TRANSFORMERS: - if not is_transformers_impl_compatible(arch, custom_model_module): + if not model_module.is_backend_compatible(): raise ValueError( f"The Transformers implementation of {arch} is not " "compatible with vLLM.") architectures[i] = "TransformersForCausalLM" if model_config.model_impl == ModelImpl.AUTO: - if not is_transformers_impl_compatible(arch, custom_model_module): + if not model_module.is_backend_compatible(): raise ValueError( f"{arch} has no vLLM implementation and the Transformers " "implementation is not compatible with vLLM. Try setting " @@ -132,10 +135,10 @@ def get_model_architecture( architectures = ["QuantMixtralForCausalLM"] vllm_supported_archs = ModelRegistry.get_supported_archs() - is_vllm_supported = any(arch in vllm_supported_archs - for arch in architectures) - if (not is_vllm_supported - or model_config.model_impl == ModelImpl.TRANSFORMERS): + vllm_not_supported = not any(arch in vllm_supported_archs + for arch in architectures) + if (model_config.model_impl == ModelImpl.TRANSFORMERS or + model_config.model_impl != ModelImpl.VLLM and vllm_not_supported): architectures = resolve_transformers_arch(model_config, architectures) model_cls, arch = ModelRegistry.resolve_model_cls(architectures) diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 065715cbde4eeeb3093a6f91ec2219555e5cfc1b..dfe8f20c70d6249c5291b45e5ea05bca7494936b 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -24,7 +24,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.deepspeedfp import ( DeepSpeedFPConfig, DeepSpeedFPParameter) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -435,7 +434,6 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant): self.unpadded_vocab_size = config.vocab_size self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -462,14 +460,6 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant): sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index edf67c860e9770ba77e1e925f4935398371d8dde..7c716efab8ef1aeb1210fc813a546ccda1587915 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -15,11 +15,10 @@ from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.sampler import (SamplerOutput, - SamplingMetadata, get_sampler) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargs) @@ -527,7 +526,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.vocab_size, logit_scale) - self.sampler = get_sampler() def _validate_image_sizes( self, images: List[torch.Tensor]) -> List[torch.Tensor]: @@ -653,14 +651,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py index 8700f24d2bd256cd41689a3fecb1e199af7cb512..d152287e8fa397ceee19548824cf1da055bb0deb 100644 --- a/vllm/model_executor/models/aya_vision.py +++ b/vllm/model_executor/models/aya_vision.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 Adapted from # https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision -from functools import cached_property from typing import (Iterable, Literal, Mapping, Optional, Sequence, Set, Tuple, TypedDict, Union, cast) @@ -17,7 +16,6 @@ from transformers.models.got_ocr2.image_processing_got_ocr2 import ( from vllm.config import VllmConfig from vllm.jsontree import json_map_leaves -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargs @@ -461,17 +459,3 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states, sampling_metadata) - - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index e9fa372bd1ebf748af745aff02afc3deeacbcaac..61dcd938655ed831d28f7ce2fc3ffa7ab55058b1 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -43,7 +43,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -504,7 +503,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -532,14 +530,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index dfb8f49cc0145ce5cb99e5033961376b66c900bc..16dac6123d663e86fb28f0133d7e1e2e4363caf7 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -24,7 +24,6 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import ( MambaMixer2, extra_groups_for_head_shards) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -462,7 +461,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -538,14 +536,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 04d6cde555e28d36401959259be0b3e8c2dbb10c..bcfbe92c3a11e0cfcfea9fb978f593efd97e03aa 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -37,7 +37,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -791,7 +790,6 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant): self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = get_sampler() def forward( self, @@ -828,14 +826,6 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant): sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - stacked_params_mapping = { "q_proj": { "param_name": "qkv_proj", diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index e1d77646f47e8381c8c98f14b1efe97477230440..76a529c93343fe5c73674c25fed3760154523023 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -11,8 +11,10 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, PoolerConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.forward_context import get_forward_context -from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.activation import (get_act_and_mul_fn, + get_act_fn) from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler, @@ -108,6 +110,7 @@ class BertEncoder(nn.Module): def __init__(self, vllm_config: VllmConfig, + bias: bool = True, rotary_kwargs: Optional[dict] = None, prefix: str = ""): super().__init__() @@ -118,6 +121,7 @@ class BertEncoder(nn.Module): BertLayer(config=config, cache_config=cache_config, quant_config=quant_config, + bias=bias, rotary_kwargs=rotary_kwargs, prefix=f"{prefix}.layer.{layer_idx}") for layer_idx in range(config.num_hidden_layers) @@ -139,6 +143,7 @@ class BertLayer(nn.Module): config: BertConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + bias: bool = True, rotary_kwargs: Optional[dict] = None, prefix: str = ""): super().__init__() @@ -149,19 +154,31 @@ class BertLayer(nn.Module): layer_norm_eps=config.layer_norm_eps, cache_config=cache_config, quant_config=quant_config, + bias=bias, rotary_kwargs=rotary_kwargs, prefix=f"{prefix}.attention") - self.intermediate = BertIntermediate( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - prefix=f"{prefix}.intermediate") + if config.hidden_act in ["silu", "gelu_and_mul"]: + self.intermediate = BertGatedIntermediate( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.intermediate") + else: + self.intermediate = BertIntermediate( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.intermediate") self.output = BertOutput(hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, layer_norm_eps=config.layer_norm_eps, + bias=bias, quant_config=quant_config, prefix=f"{prefix}.output") @@ -181,6 +198,7 @@ class BertAttention(nn.Module): layer_norm_eps: float, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + bias: bool = True, rotary_kwargs: Optional[dict] = None, prefix: str = "", ): @@ -190,11 +208,13 @@ class BertAttention(nn.Module): num_attention_heads=num_attention_heads, cache_config=cache_config, quant_config=quant_config, + bias=bias, rotary_kwargs=rotary_kwargs, prefix=f"{prefix}.output") self.output = BertSelfOutput(hidden_size=hidden_size, layer_norm_eps=layer_norm_eps, + bias=bias, quant_config=quant_config, prefix=f"{prefix}.output") @@ -215,6 +235,7 @@ class BertSelfAttention(nn.Module): num_attention_heads: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + bias: bool = True, rotary_kwargs: Optional[dict] = None, prefix: str = "", ): @@ -240,7 +261,7 @@ class BertSelfAttention(nn.Module): head_size=self.head_dim, total_num_heads=self.total_num_heads, total_num_kv_heads=self.total_num_kv_heads, - bias=True, + bias=bias, quant_config=quant_config, prefix=f"{prefix}.qkv_proj") @@ -278,12 +299,13 @@ class BertSelfOutput(nn.Module): def __init__(self, hidden_size: int, layer_norm_eps: float, + bias: bool = True, quant_config: Optional[QuantizationConfig] = None, prefix: str = ""): super().__init__() self.dense = RowParallelLinear(input_size=hidden_size, output_size=hidden_size, - bias=True, + bias=bias, quant_config=quant_config, prefix=f"{prefix}.dense") self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) @@ -301,12 +323,13 @@ class BertIntermediate(nn.Module): hidden_size: int, intermediate_size: int, hidden_act: str, + bias: bool = True, quant_config: Optional[QuantizationConfig] = None, prefix: str = ""): super().__init__() self.dense = ColumnParallelLinear(input_size=hidden_size, output_size=intermediate_size, - bias=True, + bias=bias, quant_config=quant_config, prefix=f"{prefix}.dense") self.intermediate_act_fn = get_act_fn(hidden_act) @@ -317,19 +340,46 @@ class BertIntermediate(nn.Module): return hidden_states +class BertGatedIntermediate(nn.Module): + # for NomciBert and GteModel + + def __init__(self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + bias: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.act_fn = get_act_and_mul_fn(hidden_act) + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(hidden_states) + hidden_states = self.act_fn(gate_up) + return hidden_states + + class BertOutput(nn.Module): def __init__(self, hidden_size: int, intermediate_size: int, layer_norm_eps: float, + bias: bool = True, quant_config: Optional[QuantizationConfig] = None, prefix: str = ""): super().__init__() self.dense = RowParallelLinear(input_size=intermediate_size, output_size=hidden_size, - bias=True, + bias=bias, quant_config=quant_config, prefix=f"{prefix}.dense") @@ -343,19 +393,32 @@ class BertOutput(nn.Module): class BertModel(nn.Module, SupportsQuant): - packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]} + packed_modules_mapping = { + "qkv_proj": ["query", "key", "value"], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", embedding_class: type = BertEmbedding, + bias: bool = True, rotary_kwargs: Optional[dict] = None, add_pooling_layer: bool = False): super().__init__() + """ + For BertModel, all linear layers have bias. + For NomicBertModel, all linear layers do not have bias. + """ + config = vllm_config.model_config.hf_config self.embeddings = embedding_class(config) self.encoder = BertEncoder(vllm_config=vllm_config, + bias=bias, rotary_kwargs=rotary_kwargs, prefix=f"{prefix}.encoder") self.pooler = BertPooler(config) if add_pooling_layer else None @@ -387,6 +450,8 @@ class BertModel(nn.Module, SupportsQuant): ("qkv_proj", "query", "q"), ("qkv_proj", "key", "k"), ("qkv_proj", "value", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) @@ -546,3 +611,115 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, inputs_embeds=inputs_embeds, intermediate_tensors=intermediate_tensors, token_type_ids=token_type_ids) + + +class NomicBertEmbeddingModel(BertEmbeddingModel): + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + "emb_ln": "embeddings.LayerNorm", + "layers": "layer", + "attn.Wqkv": "attention.self.qkv_proj", + "attn.out_proj": "attention.output.dense", + 'norm1': "attention.output.LayerNorm", + 'mlp.fc11': "intermediate.up_proj", + 'mlp.fc12': "intermediate.gate_proj", + 'mlp.fc2': "output.dense", + 'norm2': "output.LayerNorm", + }) + + def _build_model(self, + vllm_config: VllmConfig, + prefix: str = "") -> BertModel: + config = vllm_config.model_config.hf_config + + assert config.__class__.__name__ == "NomicBertConfig" + assert config.activation_function == "swiglu" + + # Assume NomicBertModel all linear layers do not have bias + assert not config.mlp_fc1_bias + assert not config.mlp_fc2_bias + assert not config.qkv_proj_bias + + config.layer_norm_eps = config.layer_norm_epsilon + config.position_embedding_type = "rotary" + config.intermediate_size = config.n_inner + config.hidden_act = "silu" + config.hidden_size = config.n_embd + config.num_hidden_layers = config.n_layer + + head_dim = config.hidden_size // config.num_attention_heads + rotary_kwargs = { + "head_size": head_dim, + "rotary_dim": getattr(config, "rotary_emb_dim", head_dim), + "max_position": config.max_trained_positions, + "base": config.rotary_emb_base, + "rope_scaling": { + "rope_type": "dynamic", + "factor": config.rotary_scaling_factor + } + } + + return BertModel(vllm_config=vllm_config, + prefix=prefix, + bias=False, + rotary_kwargs=rotary_kwargs, + embedding_class=BertEmbedding) + + +class GteEmbeddingModel(BertEmbeddingModel): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + "attention.qkv_proj": "attention.self.qkv_proj", + "attention.o_proj": "attention.output.dense", + 'attn_ln': "attention.output.LayerNorm", + 'mlp.down_proj': "output.dense", + 'mlp_ln': "output.LayerNorm", + }) + + def _build_model(self, + vllm_config: VllmConfig, + prefix: str = "") -> BertModel: + config = vllm_config.model_config.hf_config + + assert config.__class__.__name__ == "GteConfig" + assert config.position_embedding_type == "rope" + assert config.hidden_act == "gelu" + + config.position_embedding_type = "rotary" + config.hidden_act = "gelu_and_mul" + + head_dim = config.hidden_size // config.num_attention_heads + rotary_kwargs = { + "head_size": head_dim, + "rotary_dim": getattr(config, "rotary_emb_dim", head_dim), + "max_position": config.max_position_embeddings, + "base": config.rope_theta, + } + + model = BertModel(vllm_config=vllm_config, + prefix=prefix, + rotary_kwargs=rotary_kwargs, + embedding_class=BertEmbedding) + + # GteModel only gate_up_proj does not have bias. + # Hack method learned from vllm/model_executor/models/glm.py + for layer in model.encoder.layer: + layer.intermediate.gate_up_proj.bias = None + layer.intermediate.skip_bias_add = True + return model + + def split_up_gate_proj(self, weights: Iterable[Tuple[str, torch.Tensor]]): + n = "mlp.up_gate_proj" + for name, weight in weights: + if n in name: + up, gate = weight.chunk(2, dim=0) + yield name.replace(n, "intermediate.up_proj"), up + yield name.replace(n, "intermediate.gate_proj"), gate + else: + yield name, weight + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + weights = self.hf_to_vllm_mapper.apply(weights) + weights = self.split_up_gate_proj(weights) + self.model.load_weights(weights) diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index a6f00f9997730c797b13228abac922a3885adca0..eed49e74ac9f2f6163b8a807308bec217213f3bf 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property from typing import Literal, Optional, Set, Tuple, TypedDict, Union import torch @@ -12,7 +11,6 @@ from transformers import (BatchFeature, Blip2Config, Blip2QFormerConfig, from vllm.config import CacheConfig, VllmConfig from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, @@ -62,6 +60,7 @@ class Blip2QFormerMultiHeadAttention(nn.Module): quant_config: Optional[QuantizationConfig], cache_config: Optional[CacheConfig], is_cross_attention: bool = False, + prefix: str = "", ) -> None: super().__init__() @@ -141,7 +140,7 @@ class Blip2QFormerMultiHeadAttention(nn.Module): class Blip2QFormerSelfOutput(nn.Module): - def __init__(self, config: Blip2QFormerConfig) -> None: + def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None: super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) @@ -169,6 +168,7 @@ class Blip2QFormerAttention(nn.Module): quant_config: Optional[QuantizationConfig], cache_config: Optional[CacheConfig], is_cross_attention: bool = False, + prefix: str = "", ) -> None: super().__init__() @@ -177,9 +177,10 @@ class Blip2QFormerAttention(nn.Module): quant_config=quant_config, cache_config=cache_config, is_cross_attention=is_cross_attention, + prefix=f"{prefix}.attention", ) - self.output = Blip2QFormerSelfOutput(config) + self.output = Blip2QFormerSelfOutput(config, prefix=f"{prefix}.output") def forward( self, @@ -197,7 +198,7 @@ class Blip2QFormerAttention(nn.Module): class Blip2QFormerIntermediate(nn.Module): - def __init__(self, config: Blip2QFormerConfig) -> None: + def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None: super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) @@ -211,7 +212,7 @@ class Blip2QFormerIntermediate(nn.Module): class Blip2QFormerOutput(nn.Module): - def __init__(self, config: Blip2QFormerConfig) -> None: + def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None: super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) @@ -239,6 +240,7 @@ class Blip2QFormerLayer(nn.Module): quant_config: Optional[QuantizationConfig], cache_config: Optional[CacheConfig], layer_idx: int, + prefix: str = "", ) -> None: super().__init__() @@ -246,7 +248,8 @@ class Blip2QFormerLayer(nn.Module): self.seq_len_dim = 1 self.attention = Blip2QFormerAttention(config, quant_config=quant_config, - cache_config=cache_config) + cache_config=cache_config, + prefix=f"{prefix}.attention") self.layer_idx = layer_idx @@ -255,13 +258,16 @@ class Blip2QFormerLayer(nn.Module): config, quant_config=quant_config, cache_config=cache_config, - is_cross_attention=True) + is_cross_attention=True, + prefix=f"{prefix}.crossattention") self.has_cross_attention = True else: self.has_cross_attention = False - self.intermediate_query = Blip2QFormerIntermediate(config) - self.output_query = Blip2QFormerOutput(config) + self.intermediate_query = Blip2QFormerIntermediate( + config, prefix=f"{prefix}.intermediate_query") + self.output_query = Blip2QFormerOutput(config, + prefix=f"{prefix}.output_query") def forward( self, @@ -327,6 +333,7 @@ class Blip2QFormerEncoder(nn.Module): *, quant_config: Optional[QuantizationConfig], cache_config: Optional[CacheConfig], + prefix: str = "", ) -> None: super().__init__() @@ -336,7 +343,8 @@ class Blip2QFormerEncoder(nn.Module): Blip2QFormerLayer(config, quant_config=quant_config, cache_config=cache_config, - layer_idx=layer_idx) + layer_idx=layer_idx, + prefix=f"{prefix}.layer.{layer_idx}") for layer_idx in range(config.num_hidden_layers) ]) @@ -367,6 +375,7 @@ class Blip2QFormerModel(nn.Module): *, quant_config: Optional[QuantizationConfig], cache_config: Optional[CacheConfig], + prefix: str = "", ) -> None: super().__init__() @@ -378,7 +387,8 @@ class Blip2QFormerModel(nn.Module): self.encoder = Blip2QFormerEncoder(config, quant_config=quant_config, - cache_config=cache_config) + cache_config=cache_config, + prefix=f"{prefix}.encoder") def forward( self, @@ -513,7 +523,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, self.qformer = Blip2QFormerModel(config.qformer_config, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.qformer") self.language_projection = nn.Linear( config.qformer_config.hidden_size, @@ -530,13 +541,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size expected_dims = (3, h, w) @@ -649,7 +653,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, - ) -> Union[SamplerOutput, IntermediateTensors]: + ) -> IntermediateTensors: """Run forward pass for BLIP-2. One key thing to understand is the `input_ids` already accounts for the @@ -707,13 +711,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 73ddef125061bc4523f2288ad12704515cc77a31..124e193738a9996f00dea13391963d22e886c37d 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -37,7 +37,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -308,8 +307,6 @@ class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant): self.config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) - - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) @@ -345,14 +342,6 @@ class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant): sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 0ad5e89df2e256dbf130fd0bd661b9a506054e4a..e2c275300f8c1955b6ece4f68c9766ab2d97f3d5 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -22,7 +22,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -950,7 +949,6 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -1054,14 +1052,6 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 20d4aae394b253743554c7233bbec24adb3fca4f..209b2e97cc9ef67bc9d3c0f64c5dbd554a28a27f 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -23,7 +23,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -494,8 +493,6 @@ class ChatGLMBaseModel(nn.Module): self.transformer.embedding.weight) self.lm_head = self.transformer.output_layer self.logits_processor = LogitsProcessor(config.padded_vocab_size) - - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) @@ -508,14 +505,6 @@ class ChatGLMBaseModel(nn.Module): sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index bb8d9bf8a03c500ebcc88195e0c3796aa7c70974..25b1d5a1955f5ec40f623f221230eac6005c1247 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -38,7 +38,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -89,6 +88,7 @@ class CohereMLP(nn.Module): self, config: CohereConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -99,12 +99,14 @@ class CohereMLP(nn.Module): [self.intermediate_size] * 2, bias=False, quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", ) self.down_proj = RowParallelLinear( self.intermediate_size, self.hidden_size, bias=False, quant_config=quant_config, + prefix=f"{prefix}.down_proj", ) self.act_fn = SiluAndMul() @@ -158,12 +160,14 @@ class CohereAttention(nn.Module): self.total_num_kv_heads, bias=False, quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, self.hidden_size, bias=False, quant_config=quant_config, + prefix=f"{prefix}.o_proj", ) self.rotary_emb = get_rope( self.head_dim, @@ -244,7 +248,9 @@ class CohereDecoderLayer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.self_attn") - self.mlp = CohereMLP(config, quant_config=quant_config) + self.mlp = CohereMLP(config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") self.input_layernorm = LayerNorm(param_shape=(config.hidden_size), eps=config.layer_norm_eps) @@ -365,7 +371,6 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant): scale=config.logit_scale) self.model = CohereModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -399,14 +404,6 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant): return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index b66529860bc23ca80443bd0061bd6fbb07afd803..40c0a73f52d50c5141948a3823d7d0de5ba14111 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -16,7 +16,6 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -390,7 +389,6 @@ class DbrxForCausalLM(nn.Module, SupportsPP): ) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) @@ -417,14 +415,6 @@ class DbrxForCausalLM(nn.Module, SupportsPP): sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: expert_params_mapping = [( diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 5e036d049a8a58f79839589da0daabe901d71590..c6421143dd6855785c20fd9b55fe83b3d421bd33 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -43,7 +43,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -453,7 +452,6 @@ class DeepseekForCausalLM(nn.Module, SupportsPP): if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -480,14 +478,6 @@ class DeepseekForCausalLM(nn.Module, SupportsPP): sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 6f0e7cd31fa51733a8c79e05496376fb245782a7..cad9146c3bba194f36bffee6948dc9610528073d 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -13,7 +13,6 @@ from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -165,10 +164,9 @@ class DeepSeekMTP(nn.Module): self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config, prefix=maybe_prefix( prefix, "model")) - - self.sampler = get_sampler() self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' + def forward( self, input_ids: torch.Tensor, @@ -192,14 +190,6 @@ class DeepSeekMTP(nn.Module): return self.model.compute_logits(hidden_states, sampling_metadata, spec_step_idx) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 40105a9bd2368eef5fb699d95cf02c3b104f82ec..0361f6dff119d3abb23b43323fae5032f251bb37 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -47,7 +47,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -712,7 +711,6 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' @@ -741,14 +739,6 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def make_empty_intermediate_tensors( self, batch_size: int, dtype: torch.dtype, device: torch.device) -> IntermediateTensors: diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index c3dbadb292769ccdd3e26e425370cde8eb610671..ac136698ee174ad546ac9f01920d909997b74087 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -4,7 +4,6 @@ """Inference-only Deepseek-VL2 model compatible with HuggingFace weights.""" import math from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union import torch @@ -16,7 +15,6 @@ from transformers import BatchFeature from vllm.config import VllmConfig from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, @@ -393,13 +391,6 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): model = model.to(dtype=torch.get_default_dtype()) return model - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - def _validate_pixel_values( self, data: Union[torch.Tensor, List[torch.Tensor]] ) -> Union[torch.Tensor, List[torch.Tensor]]: @@ -647,13 +638,6 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index 3e4a5040b7c895aade6875163b5d28ace8c28a3d..4ff1e785494f7a1726575bd6ef73a9b8ee1c4c3e 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -9,7 +9,6 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -131,10 +130,6 @@ class EAGLE(nn.Module): # checkpoint file has token_map tensor. self.token_map = None - @property - def sampler(self): - return self.model.sampler - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.model.get_input_embeddings(input_ids) @@ -188,14 +183,6 @@ class EAGLE(nn.Module): return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # This implementation is incompitable with https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B # due to missing lm_head weights and its config being that of a diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 553c524ebc370ff9879044b273a09eaf3d471a07..4a6490cd127a59c2a4153b8bf51677ccad3b6348 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -510,8 +509,6 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): else: self.lm_head = PPMissingLayer() - self.sampler = get_sampler() - self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) @@ -538,14 +535,6 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 4398e0ba9b1f2ce5e96f9507e2a89e8ca05ce9e3..77330bdc68ccb5c4ef723a8cdfb6b5db80352da4 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -42,7 +42,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -527,7 +526,6 @@ class FalconForCausalLM(nn.Module, SupportsPP): quant_config=quant_config, ) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) @@ -554,14 +552,6 @@ class FalconForCausalLM(nn.Module, SupportsPP): sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py index 359cc7f377310bafa024b1ed7941f88327fa87c2..d1a36c3f481a19b08f7f6db3996a2e15c11f3995 100644 --- a/vllm/model_executor/models/florence2.py +++ b/vllm/model_executor/models/florence2.py @@ -3,7 +3,6 @@ import math from collections import OrderedDict from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union import torch @@ -14,7 +13,6 @@ from transformers import BartTokenizer, BatchFeature, PretrainedConfig from vllm.config import VllmConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.bart import (BartDecoder, BartEncoder, BartParallelLMHead, @@ -673,7 +671,6 @@ class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only): self.logits_processor = LogitsProcessor(self.vocab_size, config.vocab_size) - self.sampler = get_sampler() def forward( self, @@ -716,11 +713,6 @@ class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only): sampling_metadata) return logits - def sample(self, logits: torch.Tensor, - sampling_metadata: SamplingMetadata) -> SamplerOutput: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ @@ -929,12 +921,6 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal, raise NotImplementedError( 'Florence2 only supports COSINE as temporal embedding.') - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - return get_sampler() - def _validate_pixel_values( self, data: Union[torch.Tensor, List[torch.Tensor]] ) -> Union[torch.Tensor, List[torch.Tensor]]: @@ -1110,13 +1096,6 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal, return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> SamplerOutput: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 27cd8d0986a555bbd731acd992e64fdb823c11b1..d6bd6155a447e64b340bc0f5a350b27e6b289111 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -27,7 +27,6 @@ from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor, from vllm.config import VllmConfig from vllm.model_executor.layers.linear import ColumnParallelLinear -from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.models.persimmon import PersimmonForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -270,10 +269,6 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - @property - def sampler(self): - return self.language_model.sampler - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.patch_size @@ -387,14 +382,6 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): self.language_model.lm_head, hidden_states, sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.language_model.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 92d99883c7743a264327588d4dee088b0ebae3e5..c1cc0df11178d5e1abaed8068480f0c2932523a1 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -35,7 +35,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -388,7 +387,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.model = GemmaModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -415,14 +413,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index d125c666f3cd1301edb0121f1bc11554cf8b2078..7fb2e9948c068834739ade7ce3d3fcd59359ec26 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -34,7 +34,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -146,8 +145,8 @@ class Gemma2Attention(nn.Module): # reference: # https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa layer_idx = extract_layer_index(prefix) - use_sliding_window = (layer_idx % 2 == 0 and - config.interleaved_sliding_window is not None) + use_sliding_window = (layer_idx % 2 == 0 and getattr( + config, "interleaved_sliding_window", None) is not None) sliding_window = config.interleaved_sliding_window if \ use_sliding_window else None self.attn = Attention(self.num_heads, @@ -388,7 +387,6 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): prefix=maybe_prefix(prefix, "model")) self.logits_processor = LogitsProcessor( config.vocab_size, soft_cap=config.final_logit_softcapping) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -415,14 +413,6 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index fb8eccc55078aa3d8faff7a303218e0f832d2245..4e0d4f84ca6bd5514ad4cf9e567f74e2e36f74e3 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -34,7 +34,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -147,7 +146,9 @@ class Gemma3Attention(nn.Module): # TODO(woosuk): Add reference to the original HF implementation. layer_idx = extract_layer_index(prefix) - self.is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern) + self.is_sliding = (getattr( + config, "interleaved_sliding_window", None) is not None and bool( + (layer_idx + 1) % config.sliding_window_pattern)) # Initialize the rotary embedding. if self.is_sliding: # Local attention. Override the values in config.json. @@ -493,7 +494,6 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): prefix=maybe_prefix(prefix, "model")) self.logits_processor = LogitsProcessor( config.vocab_size, soft_cap=config.final_logit_softcapping) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -521,14 +521,6 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index e5a3d6762fff2a33af053567ba9c6b44ba65a85a..65c177f8c5ade6f9e675e71498df4498171896ce 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import math from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union +from typing import Any, Literal, Optional, Set, Tuple, TypedDict import torch from torch import nn @@ -12,7 +12,6 @@ import vllm.envs as envs from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import GemmaRMSNorm -from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -479,7 +478,8 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, self.config = config self.quant_config = quant_config self.multimodal_config = multimodal_config - self.sliding_window = config.text_config.interleaved_sliding_window + self.sliding_window = getattr(config.text_config, + "interleaved_sliding_window", None) self.vision_tower = SiglipVisionModel(config.vision_config, quant_config, @@ -503,10 +503,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, def dtype(self): return next(self.parameters()).dtype - @property - def sampler(self): - return self.language_model.sampler - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size expected_dims = (3, h, w) @@ -607,7 +603,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, - **kwargs: object) -> Union[SamplerOutput, IntermediateTensors]: + **kwargs: object) -> IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None @@ -685,13 +681,14 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask) global_attn_masks.append(global_attn_mask) - # Create a local causal mask with sliding window (1024). - local_attn_mask = torch.ones_like(global_attn_mask) - local_attn_mask = torch.tril(local_attn_mask, - diagonal=-self.sliding_window) - local_attn_mask = torch.where(local_attn_mask == 0, - global_attn_mask, float("-inf")) - local_attn_masks.append(local_attn_mask) + if self.sliding_window is not None: + # Create a local causal mask with sliding window (1024). + local_attn_mask = torch.ones_like(global_attn_mask) + local_attn_mask = torch.tril(local_attn_mask, + diagonal=-self.sliding_window) + local_attn_mask = torch.where(local_attn_mask == 0, + global_attn_mask, float("-inf")) + local_attn_masks.append(local_attn_mask) kwargs["global_attn_masks"] = global_attn_masks kwargs["local_attn_masks"] = local_attn_masks return kwargs @@ -704,13 +701,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/glm4.py b/vllm/model_executor/models/glm4.py index 306caa6647fd02193ff8ecb47955dabad5579247..49d75908fa9b375ad976239dbb889547514e18a0 100644 --- a/vllm/model_executor/models/glm4.py +++ b/vllm/model_executor/models/glm4.py @@ -37,7 +37,6 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors @@ -200,8 +199,8 @@ class Glm4DecoderLayer(nn.Module): hidden_states = self.post_self_attn_layernorm(hidden_states) # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) hidden_states = self.mlp(hidden_states) hidden_states = self.post_mlp_layernorm(hidden_states) @@ -267,7 +266,6 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -295,14 +293,6 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 776c03f652bdccf0871115e51df81727262dc141..e3219333915e94c03799955f72258eb2ae8721ca 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -35,7 +35,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -255,7 +254,6 @@ class GPT2LMHeadModel(nn.Module, SupportsPP): self.lm_head = self.lm_head.tie_weights(self.transformer.wte) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) @@ -282,14 +280,6 @@ class GPT2LMHeadModel(nn.Module, SupportsPP): sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 43f3d4f6dc9ccb3c953a3cd0f922f14fa2161fcb..def6b1544d8c2d606df1912bcd5430109ece1a29 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -35,7 +35,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -43,7 +42,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -244,6 +243,30 @@ class GPTBigCodeModel(nn.Module): hidden_states = self.ln_f(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if ".attn.bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + # TODO (@robertgshaw2-neuralmagic): move to fp8 linear method + if "c_attn.input_scale" in name or "c_attn.weight_scale" in name: + weight_loader(param, loaded_weight, 'q') + weight_loader(param, loaded_weight, 'k') + weight_loader(param, loaded_weight, 'v') + else: + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = {"c_attn": ["c_attn"]} @@ -278,7 +301,6 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) @@ -305,36 +327,10 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "lm_head.weight" in name: - continue - if ".attn.bias" in name: - # Skip attention mask. - # NOTE: "c_attn.bias" should not be skipped. - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - # TODO (@robertgshaw2-neuralmagic): move to fp8 linear method - if "c_attn.input_scale" in name or "c_attn.weight_scale" in name: - weight_loader(param, loaded_weight, 'q') - weight_loader(param, loaded_weight, 'k') - weight_loader(param, loaded_weight, 'v') - else: - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."]), + ) + return loader.load_weights(weights) \ No newline at end of file diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 752aec0b223dd545c564ac65eda48066e671a751..3db96fb8e187cf5227b42160398bce9dcfbe9984 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -34,7 +34,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -43,7 +42,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -188,6 +187,7 @@ class GPTJModel(nn.Module): quant_config = vllm_config.quant_config self.config = config + self.quant_config = quant_config self.embed_dim = config.n_embd self.wte = VocabParallelEmbedding( config.vocab_size, @@ -228,61 +228,6 @@ class GPTJModel(nn.Module): hidden_states = self.ln_f(hidden_states) return hidden_states - -class GPTJForCausalLM(nn.Module, SupportsPP): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.config = config - self.quant_config = quant_config - assert not config.tie_word_embeddings - self.transformer = GPTJModel(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "transformer")) - self.lm_head = ParallelLMHead( - config.vocab_size, - config.n_embd, - bias=True, - quant_config=quant_config, - ) - self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() - self.make_empty_intermediate_tensors = ( - self.transformer.make_empty_intermediate_tensors) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.transformer.get_input_embeddings(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, - intermediate_tensors, inputs_embeds) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata, self.lm_head.bias) - return logits - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ @@ -339,3 +284,54 @@ class GPTJForCausalLM(nn.Module, SupportsPP): weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params + + +class GPTJForCausalLM(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + assert not config.tie_word_embeddings + self.transformer = GPTJModel(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "transformer")) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.n_embd, + bias=True, + quant_config=quant_config, + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata, self.lm_head.bias) + return logits + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) \ No newline at end of file diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 9c69aa846e579dfa07ca646f3c8df66bee70e898..de8ff7b04559d9a6ba239ac39b2384f49ffb88d7 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -36,7 +36,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -356,7 +355,6 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP): if self.config.tie_word_embeddings: self.embed_out.weight = self.gpt_neox.embed_in.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.gpt_neox.make_empty_intermediate_tensors) @@ -383,14 +381,6 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP): sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index 3bd6332c11ca00bb4e7c7f517bd3898522fb3f24..0696a7245c22427bfd842735cf7e7533e873e6bc 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -441,8 +440,6 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): else: self.lm_head = PPMissingLayer() - self.sampler = get_sampler() - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -464,14 +461,6 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def make_empty_intermediate_tensors( self, batch_size: int, dtype: torch.dtype, device: torch.device) -> IntermediateTensors: diff --git a/vllm/model_executor/models/granite_speech.py b/vllm/model_executor/models/granite_speech.py new file mode 100644 index 0000000000000000000000000000000000000000..b43b59da6d1118dc8fd2a80e73ee32b1ad802741 --- /dev/null +++ b/vllm/model_executor/models/granite_speech.py @@ -0,0 +1,777 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2025 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only IBM Granite speeech model.""" +import math +from typing import Iterable, Mapping, Optional, Set, Tuple, TypedDict, Union + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import BatchFeature, PretrainedConfig + +from vllm.config import CacheConfig, VllmConfig +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import get_sampler +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs) +from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, + MultiModalDataParser) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors + +from .blip2 import Blip2QFormerModel +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsPP) +from .utils import (AutoWeightsLoader, embed_multimodal, + init_vllm_registered_model, maybe_prefix) + + +### Audio Input +class GraniteSpeechAudioInputs(TypedDict): + + input_features: torch.Tensor + """Shape: `(bsz, num_features, 160)`""" + + input_features_mask: torch.Tensor + """Shape: `(bsz, num_features)`""" + + audio_embed_sizes: list[int] + """List of length `bsz`""" + + +class GraniteSpeechMultiModalProcessingInfo(BaseProcessingInfo): + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"audio": 1} + + # There is no limit to the maximum number of audio tokens that can be + # encoded as features; we pick ~5000 as a number that is probably higher + # than we would expect to encounter. The sequence of length + # get_max_audio_len() produces get_max_audio_tokens(). + def get_max_audio_tokens(self): + return 5001 + + def get_max_audio_len(self): + return 8000000 + + +### Input Processing & Multimodal utils +class GraniteSpeechMultiModalProcessor( + BaseMultiModalProcessor[GraniteSpeechMultiModalProcessingInfo]): + + def _get_data_parser(self) -> MultiModalDataParser: + feature_extractor = self.info.get_hf_processor().audio_processor + sampling_rate = feature_extractor.melspec_kwargs["sample_rate"] + return MultiModalDataParser(target_sr=sampling_rate) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + input_features=MultiModalFieldConfig.batched("audio"), + audio_embed_sizes=MultiModalFieldConfig.batched("audio"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + feature_extractor = processor.audio_processor + vocab = tokenizer.get_vocab() + + # Use getattr with default to be compatible with transformers<4.48 + audio_token = getattr(processor, "audio_token", "<|audio|>") + audio_token_id = vocab[audio_token] + + def get_replacement(item_idx: int): + audios = mm_items.get_items("audio", AudioProcessorItems) + audio = audios.get(item_idx) + audio_length = audio.shape[-1] + num_projector_features = feature_extractor._get_num_audio_features( + [audio_length])[0] + return [audio_token_id] * num_projector_features + + return [ + PromptReplacement( + modality="audio", + target=[audio_token_id], + replacement=get_replacement, + ) + ] + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + mm_data = dict(mm_data) + audios = mm_data.pop("audios", []) + + if audios: + # GraniteSpeechFeatureExtractor accepts "audio" + mm_data["audio"] = audios + + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + if "audio" in mm_data: + # Calculate the number of audio tokens per entry in the batch; + # This is used to split the batch back out after padding. + audio_token_index = self.info.get_hf_config().audio_token_index + processed_outputs["audio_embed_sizes"] = [ + torch.sum(indices == audio_token_index).item() + for indices in processed_outputs["input_ids"] + ] + + return processed_outputs + + +class GraniteSpeechDummyInputsBuilder( + BaseDummyInputsBuilder[GraniteSpeechMultiModalProcessingInfo]): + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_audios = mm_counts.get("audio", 0) + return { + "audio": + self._get_dummy_audios( + length=self.info.get_max_audio_len(), + num_audios=num_audios, + ) + } + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_audios = mm_counts.get("audio", 0) + hf_processor = self.info.get_hf_processor() + audio_token = getattr(hf_processor, "audio_token", "<|audio|>") + return audio_token * num_audios + + +### QFormer Projector +class GraniteSpeechEncoderProjector(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: CacheConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = config.projector_config.hidden_size + self.downsample_rate = config.downsample_rate + self.window_size = config.window_size + self.num_queries = config.window_size // config.downsample_rate + + self.query = nn.Parameter( + torch.zeros(1, self.num_queries, + config.projector_config.hidden_size)) + + # NOTE - this is implemented generically in transformers, + # but for now we create the QFormer model directly since + # all existing models use this for the projector. + self.qformer = Blip2QFormerModel( + config.projector_config, + quant_config=quant_config, + cache_config=cache_config, + prefix=f"{prefix}.qformer", + ) + self.linear = nn.Linear(config.projector_config.hidden_size, + config.text_config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, dim = hidden_states.size() + nblocks = math.ceil(seq_len / self.window_size) + pad = nblocks * self.window_size - seq_len + hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad), + "constant", 0) + hidden_states = hidden_states.view(batch_size * nblocks, + self.window_size, dim) + + last_hidden_state = self.qformer( + query_embeds=self.query.data, + encoder_hidden_states=hidden_states, + ) + + query_proj = self.linear( + last_hidden_state.view( + batch_size, + nblocks * self.window_size // self.downsample_rate, + -1, + )) + return query_proj + + +# Encoder - conformer is adapted from: https://github.com/lucidrains/conformer.git +# NOTE - it would be nice to see if we can align this with other models using +# conformer in vLLM, e.g., phi4mm audio. +class GraniteSpeechConformerFeedForward(nn.Module): + """Feedforward module for conformer encoder blocks.""" + + def __init__(self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.pre_norm = nn.LayerNorm(config.hidden_dim) + + self.up_proj = ColumnParallelLinear( + input_size=config.hidden_dim, + output_size=config.hidden_dim * config.feedforward_mult, + quant_config=quant_config, + prefix=f"{prefix}.up_proj", + ) + self.silu = nn.SiLU() + + self.down_proj = RowParallelLinear( + input_size=config.hidden_dim * config.feedforward_mult, + output_size=config.hidden_dim, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.pre_norm(hidden_states) + hidden_states, _ = self.up_proj(hidden_states) + hidden_states = self.silu(hidden_states) + hidden_states, _ = self.down_proj(hidden_states) + return hidden_states + + +class GraniteSpeechConformerAttention(nn.Module): + """Attention for conformer blocks using Shaw's relative positional + embeddings. See the following [paper](https://arxiv.org/pdf/1803.02155) + for more details. + """ + + def __init__(self, config: PretrainedConfig, prefix: str = ""): + super().__init__() + + inner_dim = config.dim_head * config.num_heads + self.max_pos_emb = config.max_pos_emb + self.context_size = config.context_size + self.num_heads = config.num_heads + self.dim_head = config.dim_head + self.scale = self.dim_head**-0.5 + self.pre_norm = nn.LayerNorm(config.hidden_dim) + self.to_q = nn.Linear(config.hidden_dim, inner_dim, bias=False) + self.to_kv = nn.Linear(config.hidden_dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, config.hidden_dim) + self.rel_pos_emb = nn.Embedding(2 * self.max_pos_emb + 1, + self.dim_head) + + if self.context_size <= 0 or self.context_size > self.max_pos_emb: + raise ValueError( + "Context size is either less than 0 or exceeds the max_pos_emb" + ) + + def forward(self, hidden_states: torch.Tensor, + attention_dists: torch.Tensor) -> torch.Tensor: + hidden_states = self.pre_norm(hidden_states) + bsz, num_features, _ = hidden_states.shape + + num_blocks = math.ceil(num_features / self.context_size) + remainder = num_features % self.context_size + if remainder > 0: + # right padding to reach block size + hidden_states = torch.nn.functional.pad( + hidden_states, (0, 0, 0, self.context_size - remainder)) + + # NOTE: would be nice to try to use qkvparallellinear + # here for this block attention implementation if possible + query_states = self.to_q(hidden_states) + key_states, value_states = self.to_kv(hidden_states).chunk(2, dim=-1) + + query_states = query_states.reshape(bsz, num_blocks, self.context_size, + self.num_heads, + -1).transpose(2, 3) + key_states = key_states.reshape(bsz, num_blocks, self.context_size, + self.num_heads, -1).transpose(2, 3) + value_states = value_states.reshape(bsz, num_blocks, self.context_size, + self.num_heads, + -1).transpose(2, 3) + + # shaw's relative positional embedding + dist = attention_dists.to(hidden_states.device) + rel_pos_emb = self.rel_pos_emb(dist) + rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + + list(rel_pos_emb.shape)) + pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, + dim=-1) * self.scale + + if remainder > 0: + # masked attention in the extended block + mask = torch.ones(self.context_size, + self.context_size, + dtype=bool, + device=hidden_states.device) + mask[:remainder, :remainder] = 0 + mask_value = -torch.finfo(pos_attn.dtype).max + pos_attn[:, -1, :].masked_fill_(mask, mask_value) + + with torch.nn.attention.sdpa_kernel( + torch.nn.attention.SDPBackend.MATH): + out = F.scaled_dot_product_attention(query_states, + key_states, + value_states, + attn_mask=pos_attn, + scale=self.scale) + out = out.transpose(2, 3).reshape(bsz, hidden_states.shape[1], -1) + return self.to_out(out[:, :num_features, :]) + + +class GraniteSpeechConformerDepthWiseConv1d(nn.Module): + """Wrapper for padded 1D pointwise convolution.""" + + def __init__(self, + chan_in: int, + chan_out: int, + kernel_size: int, + prefix: str = ""): + super().__init__() + # Padding for the 1D conv is symmetric or close (i.e., offset by one). + pad = kernel_size // 2 + pad_offset = (kernel_size + 1) % 2 + self.padding = (pad, pad - pad_offset) + + self.conv = nn.Conv1d(chan_in, + chan_out, + kernel_size, + groups=chan_in, + bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = F.pad(hidden_states, self.padding) + return self.conv(hidden_states) + + +class GraniteSpeechConformerConvModule(nn.Module): + """Conformer conv module consisting of several 1D/depthwise 1D + convolutional layers. + """ + + def __init__(self, config: PretrainedConfig, prefix: str = ""): + super().__init__() + inner_dim = config.hidden_dim * config.conv_expansion_factor + + self.norm = nn.LayerNorm(config.hidden_dim) + self.up_conv = nn.Conv1d(config.hidden_dim, inner_dim * 2, 1) + self.glu = nn.GLU(dim=1) + self.depth_conv = GraniteSpeechConformerDepthWiseConv1d( + inner_dim, + inner_dim, + kernel_size=config.conv_kernel_size, + prefix=f"{prefix}.depth_conv", + ) + self.silu = nn.SiLU() + self.batch_norm = nn.BatchNorm1d(inner_dim) + self.down_conv = nn.Conv1d(inner_dim, config.hidden_dim, 1) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.norm(hidden_states) + hidden_states = self.up_conv(hidden_states.permute(0, 2, 1)) + hidden_states = self.glu(hidden_states) + hidden_states = self.depth_conv(hidden_states) + hidden_states = self.silu(self.batch_norm(hidden_states)) + hidden_states = self.down_conv(hidden_states).permute(0, 2, 1) + return hidden_states + + +class GraniteSpeechConformerBlock(nn.Module): + """Conformer block, consisting largely of linear layers, + attention, and convolutional layers.""" + + def __init__(self, config: PretrainedConfig, prefix: str = ""): + super().__init__() + self.ff1 = GraniteSpeechConformerFeedForward(config, + prefix=f"{prefix}.ff1") + self.attn = GraniteSpeechConformerAttention(config, + prefix=f"{prefix}.attn") + self.conv = GraniteSpeechConformerConvModule(config, + prefix=f"{prefix}.conv") + self.ff2 = GraniteSpeechConformerFeedForward(config, + prefix=f"{prefix}.ff2") + self.post_norm = nn.LayerNorm(config.hidden_dim) + + def forward(self, hidden_states: torch.Tensor, + attention_dists: torch.Tensor) -> torch.Tensor: + hidden_states = 0.5 * self.ff1(hidden_states) + hidden_states + hidden_states = self.attn( + hidden_states, attention_dists=attention_dists) + hidden_states + hidden_states = self.conv(hidden_states) + hidden_states + hidden_states = 0.5 * self.ff2(hidden_states) + hidden_states + hidden_states = self.post_norm(hidden_states) + return hidden_states + + +class GraniteSpeechCTCEncoder(nn.Module): + """CTC Encoder comprising conformer blocks and additional linear layers.""" + + def __init__(self, + config: PretrainedConfig, + prefix: str, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.config = config + + # Precompute clamped relative positional encoding distances + seq = torch.arange(config.context_size) + relpos_dist = seq.view(-1, 1) - seq.view(1, -1) + self.attention_dists = torch.clamp( + relpos_dist, -config.context_size, + config.context_size) + config.max_pos_emb + + self.input_linear = nn.Linear(config.input_dim, + config.hidden_dim, + bias=True) + self.layers = nn.ModuleList([ + GraniteSpeechConformerBlock( + config, + prefix=f"{prefix}.layers.{idx}", + ) for idx in range(config.num_layers) + ]) + + self.out = ColumnParallelLinear( + input_size=config.hidden_dim, + output_size=config.output_dim, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.out", + ) + + self.out_mid = RowParallelLinear( + input_size=config.output_dim, + output_size=config.hidden_dim, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.out_mid", + ) + self.softmax = nn.Softmax(dim=-1) + self.num_layers = config.num_layers + + def forward(self, hidden_states: torch.Tensor): + hidden_states = self.input_linear(hidden_states) + for idx, layer in enumerate(self.layers, start=1): + hidden_states = layer(hidden_states, + attention_dists=self.attention_dists) + + if idx == self.num_layers // 2: + hidden_states_mid = hidden_states.clone() + hidden_states_mid, _ = self.out(hidden_states_mid) + hidden_states_mid = self.softmax(hidden_states_mid) + hidden_states_mid, _ = self.out_mid(hidden_states_mid) + hidden_states += hidden_states_mid + return hidden_states + + +@MULTIMODAL_REGISTRY.register_processor( + GraniteSpeechMultiModalProcessor, + info=GraniteSpeechMultiModalProcessingInfo, + dummy_inputs=GraniteSpeechDummyInputsBuilder) +class GraniteSpeechForConditionalGeneration( + nn.Module, + SupportsMultiModal, + SupportsPP, + SupportsLoRA, +): + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + cache_config = vllm_config.cache_config + + self.config = config + self.quant_config = quant_config + self.cache_config = cache_config + self.sampler = get_sampler() + + # The language model is typically a Granite LLM + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + + # Conformer encoder + self.encoder = GraniteSpeechCTCEncoder( + config=config.encoder_config, + quant_config=quant_config, + prefix=f"{prefix}.encoder", + ) + + # Blip2 QFormer + self.projector = GraniteSpeechEncoderProjector( + config=config, + quant_config=quant_config, + cache_config=cache_config, + prefix=f"{prefix}.projector", + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + def _parse_and_validate_audio_input( + self, + **kwargs: object, + ) -> Optional[GraniteSpeechAudioInputs]: + input_features = kwargs.pop("input_features", None) + input_features_mask = kwargs.pop("input_features_mask", None) + audio_embed_sizes = kwargs.pop("audio_embed_sizes", None) + if input_features is None: + return None + + # If we have a batch of variable feature length audio clips, we need + # to mask the features; usually we would get an input_features_mask + # from the processor, but we handle rebuilding it here since + # vLLM generally processes everything independently + batches. + if input_features_mask is None: + input_features_mask = self._build_input_features_mask( + audio_embed_sizes) + + if not isinstance(input_features, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio input features. " + f"Got type: {type(input_features)}") + + if input_features_mask is not None and not isinstance( + input_features_mask, torch.Tensor): + raise ValueError("Incorrect type of audio input features mask. " + f"Got type: {type(input_features_mask)}") + + if isinstance(input_features, torch.Tensor): + # Granite speech currently only allows one audio token per instance + # and features are already unsqueezed in the processor, so one + # instance will have shape [1, {num_features}, 160]. As such, + # input features will usually be of shape + # [bsz, 1, num_features, 160], which we squeeze to be 3D here. + if len(input_features.shape) == 4: + input_features = input_features.squeeze(1) + if len(input_features.shape) != 3: + raise ValueError( + "Squeezed input features should be 3D but are of shape " + f"{input_features.shape}") + input_features = input_features.to( + self.encoder.input_linear.weight.dtype) + + else: + # Otherwise we have a list of tensors, which are almost certainly + # differing in their respective numbers of audio features; + # stack them into a 3D tensor of size [bsz, most_num_features, 160]. + input_features = self._pad_and_stack_input_features( + input_features, ).to(self.encoder.input_linear.weight.dtype) + + return GraniteSpeechAudioInputs( + input_features=input_features, + input_features_mask=input_features_mask, + audio_embed_sizes=audio_embed_sizes.flatten().tolist(), + ) + + def _build_input_features_mask( + self, + audio_embed_sizes: torch.Tensor, + ) -> torch.Tensor: + """Calculate the input features mask, which will generally be used + to mask the the padded features for all entries in the batch except + for those with the most audio features. + + Args: + audio_embed_sizes: torch.Tensor + Tensor of num features in each seq in the batch. + Returns: + torch.Tensor: Mask of shape (bsz, num_features) to be applied to + the audio features prior to splitting the audio embeddings. + """ + most_audio_features = torch.max(audio_embed_sizes).item() + mask_indices = torch.arange( + most_audio_features, + device=audio_embed_sizes.device, + ).view(1, -1) + input_features_mask = mask_indices < audio_embed_sizes.view(-1, 1) + return input_features_mask + + def _pad_and_stack_input_features( + self, + input_features: list[torch.Tensor], + ) -> torch.Tensor: + """Given a list of input features of varying length, pad them to the + same length and stack them into a torch.Tensor. + + NOTE: Usually, padding is done in the input processor/feature extractor + and zero padded prior to the computation of the Mel features; the + resulting values are only constant within a batch and generally nonzero + (i.e., slightly negative nums); we should validate that this is okay + since we don't use a feature attention mask, but the more important + thing is that we apply the input_features_mask with variable len + batches. + + Args: + input_features: list[torch.Tensor] + Input features to be coerced into a tensor. + Returns: + torch.Tensor: Tensor of shape [bsz, num_features, 160], where + num_features is the max number of features of any entry in the + batch. + """ + # Input features are of shape [bsz, num_features, 160] + feat_lens = [feats.shape[1] for feats in input_features] + padding = [max(feat_lens) - length for length in feat_lens] + # TODO (Alex) - Validate that it's okay to zero pad like this; + # in transformers we zero pad prior to calculating the speech features, + # so the value is not zero and is dependent on the batched features. + padded = [ + torch.nn.functional.pad(feats, (0, 0, 0, pad, 0, 0)) + for feats, pad in zip(input_features, padding) + ] + stacked_features = torch.cat(padded, dim=0).to(input_features[0]) + return stacked_features + + def _process_audio_input( + self, + audio_input: GraniteSpeechAudioInputs, + ) -> tuple[torch.Tensor]: + """Compute the audio features to be merged into the LLM embeddings. + + Args: + audio_input: GraniteSpeechAudioInputs + Audio inputs object containing Mel features, an input features + mask, and the (flattened) number of audio tokens per instance. + Returns: + tuple[torch.Tensor]: List of length bsz. + """ + # TODO (Alex) - support embedding inputs + encoder_embeds = self.encoder(audio_input["input_features"]) + # [bsz, , 4096] + projected_embeds = self.projector(encoder_embeds) + # Apply mask on variable length audio features + masked_embeds = projected_embeds[audio_input["input_features_mask"]] + # Split variable length features into a tuple + return torch.split(masked_embeds, audio_input["audio_embed_sizes"]) + + def get_multimodal_embeddings( + self, + **kwargs: object, + ) -> Optional[MultiModalEmbeddings]: + """Compute the audio embeddings if audio inputs are present.""" + audio_input = self._parse_and_validate_audio_input(**kwargs) + if audio_input is None: + return None + audio_features = self._process_audio_input(audio_input) + return audio_features + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + """Compute the merged LLM / audio embeddings.""" + if multimodal_embeddings is None: + return self.language_model.get_input_embeddings(input_ids) + + inputs_embeds = embed_multimodal( + input_ids, + self.config.audio_token_index, + self.language_model.model.get_input_embeddings, + multimodal_embeddings, + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + audio_embeds = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, audio_embeds) + input_ids = None + + model_output = self.language_model(input_ids, positions, + intermediate_tensors, inputs_embeds) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits( + hidden_states, + sampling_metadata, + ) + + def load_weights( + self, + weights: Iterable[Tuple[str, torch.Tensor]], + ) -> Set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + def get_mm_mapping(self) -> MultiModelKeys: + """Get the module prefix in multimodal models.""" + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="projector", + tower_model="encoder", + ) diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 367722126e569be0ca3fe7901771271b96ed9da2..7fff14cb9f120eb4f47056dd1fe8caebb94396f0 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -391,8 +390,6 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): scale=1 / self.config.logits_scaling) - self.sampler = get_sampler() - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -428,14 +425,6 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): device=device), }) - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/granitemoeshared.py b/vllm/model_executor/models/granitemoeshared.py index cf8c969e118fe3d2a7cb567bc5d90219f069a413..4e660cbf667b2a61672ccdc911b3809ca55b94e2 100644 --- a/vllm/model_executor/models/granitemoeshared.py +++ b/vllm/model_executor/models/granitemoeshared.py @@ -20,7 +20,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -295,8 +294,6 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP): scale=1 / self.config.logits_scaling) - self.sampler = get_sampler() - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -332,14 +329,6 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP): device=device), }) - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 2984f224128642c4a3e98365ac815f58b01076b7..e4692c45808878229074cc8e6f6ee30fe196e335 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -170,7 +170,8 @@ class GritLMPooler(nn.Module): mean_embeddings = sum_embeddings / num_non_instruction_tokens.unsqueeze( 1) - pooled_data = self.head(mean_embeddings) + pooled_data = self.head(mean_embeddings, + pooling_metadata=pooling_metadata) pooled_outputs = [ PoolingSequenceGroupOutput(data) for data in pooled_data diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index ef96257ba4bbc9dbc0f5904414c2eff6da232ae6..c48cb157084d4d6556b806b477fa37d563c34dbe 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -39,7 +39,6 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -521,7 +520,6 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): config.vocab_size, self.output_multiplier_scale) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -551,14 +549,6 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: skip_prefixes = ["rotary_emb.inv_freq"] diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index c31870461b4c215717653eac04005deb213c9edd..961954c2b584ffc96e1290a6867550e4a8f9fa69 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -28,7 +28,6 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -603,7 +602,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, if self.config.text_config.tie_word_embeddings: self.lm_head.weight = self.model.text_model.wte.weight self.logits_processor = LogitsProcessor(config.text_config.vocab_size) - self.sampler = get_sampler() def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size @@ -754,14 +752,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index 22c9287509ed7897fef4129ffd81a623604af660..f141dcf3cd4fc4d29bf2d9f891df57b365203c26 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -13,7 +13,6 @@ from vllm.utils import supports_kw if TYPE_CHECKING: from vllm.config import VllmConfig from vllm.model_executor.layers.pooler import PoolerOutput - from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -103,14 +102,6 @@ class VllmModelForTextGeneration(VllmModel[T], Protocol[T]): """Return `None` if TP rank > 0.""" ... - def sample( - self, - logits: T, - sampling_metadata: "SamplingMetadata", - ) -> "SamplerOutput": - """Only called on TP rank 0.""" - ... - @overload def is_text_generation_model( diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 520b85c0cdfbc97fd6f8d299c46a30cf46ace36a..c3d7cbfcddbb9fa1abf5d1632fbf3ac89487b1f1 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -23,7 +23,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -336,7 +335,6 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): if self.config.tie_word_embeddings: self.output.weight = self.model.tok_embeddings.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -363,14 +361,6 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ @@ -423,7 +413,7 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM): prefix=prefix, model_type=model_type) - for attr in ("output", "logits_processor", "sampler"): + for attr in ("output", "logits_processor"): delattr(self, attr) config = vllm_config.model_config.hf_config diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 8f5f454cbf6076849636245fe133b5a356d066e3..23b92ad2bbf664c7d70fca106cd30496d5023485 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -8,7 +8,6 @@ # -------------------------------------------------------- from abc import ABC, abstractmethod from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property from typing import Literal, Optional, Set, Tuple, TypedDict, TypeVar, Union import torch @@ -20,7 +19,6 @@ from transformers import BatchEncoding, PretrainedConfig, TensorType from vllm.config import VllmConfig from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.awq import AWQConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.models.intern_vit import (InternVisionModel, InternVisionPatchModel) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -698,13 +696,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): (llm_quant_config is not None): quant_config.modules_to_not_convert.append("vision_model") - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - def _init_vision_model( self, config: PretrainedConfig, @@ -903,7 +894,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, - ) -> Union[SamplerOutput, IntermediateTensors]: + ) -> IntermediateTensors: if intermediate_tensors is not None: input_ids = None @@ -941,13 +932,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: # unused modules appear in OpenGVLab/InternVideo2_5_Chat_8B diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 78fe6588eddcea84f2b83bbf64d2f8753b6ada83..e1e3f0f199c5fd399298c63541596c717172ad6c 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -36,7 +36,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -308,7 +307,6 @@ class JAISLMHeadModel(nn.Module, SupportsPP): config.mup_width_scale) self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size, scale=self.output_logits_scale) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) @@ -335,14 +333,6 @@ class JAISLMHeadModel(nn.Module, SupportsPP): sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 6fabc8228e18771c1f0e59881dce100192a196b1..46335c2b3930f4649b495ddc12a7157f752201cd 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -19,7 +19,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -409,7 +408,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -466,14 +464,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py new file mode 100644 index 0000000000000000000000000000000000000000..0629266860fd3d6f84cfaf930d1e7e78da60a85a --- /dev/null +++ b/vllm/model_executor/models/kimi_vl.py @@ -0,0 +1,577 @@ +# SPDX-License-Identifier: Apache-2.0 +# ruff: noqa: E501 +# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/modeling_kimi_vl.py +# Copyright 2025 The Moonshot AI Team, DeepSeek-AI, and HuggingFace Inc. team. All rights reserved. +# +# The code is based on llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py), but modified for KimiVL. +# +# Licensing Information: +# - Code derived from llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py) is licensed under the Apache License, Version 2.0. +# - Other parts of the code are licensed under the MIT License. +# +# Apache License, Version 2.0: +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MIT License: +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import copy +import math +from collections.abc import Mapping +from dataclasses import dataclass +from typing import (Any, Iterable, List, Literal, Optional, Sequence, Tuple, + TypedDict, Union) + +import torch +from torch import nn +from transformers import BatchFeature +from transformers.activations import GELUActivation + +from vllm.config import VllmConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model +from vllm.model_executor.models.interfaces import SupportsMultiModal +from vllm.model_executor.models.moonvit import MoonVitPretrainedModel +from vllm.model_executor.models.utils import merge_multimodal_embeddings +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs, NestedTensors) +from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, + MultiModalDataItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig +from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config + +from .utils import is_pp_missing_parameter, maybe_prefix + + +# For dummy input only +@dataclass +class MaxImageTokenMeta: + width: int = 1024 + height: int = 1024 + + +class KimiVLMultiModalProjector(nn.Module): + + def __init__(self, config: KimiVLConfig): + super().__init__() + + self.hidden_size = (config.vision_config.hidden_size * + config.vision_config.merge_kernel_size[0] * + config.vision_config.merge_kernel_size[1]) + + self.pre_norm = torch.nn.LayerNorm(config.vision_config.hidden_size, + eps=1e-5) + self.linear_1 = nn.Linear(self.hidden_size, + self.hidden_size, + bias=True) + self.act = GELUActivation() + self.linear_2 = nn.Linear(self.hidden_size, + config.text_config.hidden_size, + bias=True) + + def forward(self, image_features: torch.Tensor) -> torch.Tensor: + hidden_states = self.pre_norm(image_features).view( + -1, self.hidden_size) + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class KimiVLImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values: Union[torch.Tensor, List[torch.Tensor]] + """ + Shape:`(num_patches, num_channels, patch_size, patch_size)` + """ + + image_grid_hws: torch.Tensor + """Shape:`(num_images, 2)`""" + + +# TODO: support embeds too +# We only support pixel input for kimi-vl now +KimiVLImageInputs = KimiVLImagePixelInputs + + +class KimiVLProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(KimiVLConfig) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + hf_processor = self.get_hf_processor() + patch_size = hf_processor.image_processor.patch_size + kernel_size = hf_processor.image_processor.merge_kernel_size + in_token_limit = hf_processor.image_processor.in_token_limit + height = image_height + width = image_width + assert isinstance(height, + int), f"height must be int, current height {height}" + assert isinstance(width, + int), f"width must be int, current width {width}" + assert kernel_size is not None, "kernel_size must be specified" + + if (width // patch_size) * (height // patch_size) > in_token_limit: + scale = math.sqrt(in_token_limit / ((width // patch_size) * + (height // patch_size))) + new_w, new_h = int(width * scale), int(height * scale) + width, height = new_w, new_h + + kernel_height, kernel_width = kernel_size + + pad_height = (kernel_height * patch_size - height % + (kernel_height * patch_size)) % (kernel_height * + patch_size) + pad_width = (kernel_width * patch_size - width % + (kernel_width * patch_size)) % (kernel_width * patch_size) + + # Calculate new dimensions after padding and patching + token_height = (height + pad_height) // (kernel_size[0] * patch_size) + token_width = (width + pad_width) // (kernel_size[1] * patch_size) + return int(token_height * token_width) + + @property + def image_token_id(self) -> int: + return self.get_hf_config().media_placeholder_token_id + + +class KimiVLDummyInputsBuilder(BaseDummyInputsBuilder[KimiVLProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + image_token = processor.image_token + + return image_token * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + return { + "image": + self._get_dummy_images(width=MaxImageTokenMeta.width, + height=MaxImageTokenMeta.height, + num_images=num_images) + } + + +class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]): + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + image_grid_hws = hf_inputs.get("image_grid_hws", torch.empty((0, 2))) + image_grid_sizes = image_grid_hws.prod(-1) + + # pixel_values is merged as a single large tensor + # image_grid_hws is shapes for each subtensor in pixel_values + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_grid_sizes), + image_grid_hws=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + image_token_id = self.info.image_token_id + + def get_replacement(item_idx: int): + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems)) + + if isinstance(images, ImageEmbeddingItems): + num_image_tokens = images.get_feature_size(item_idx) + else: + image_size = images.get_image_size(item_idx) + num_image_tokens = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + ) + + return [image_token_id] * num_image_tokens + + return [ + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=get_replacement, + ), + ] + + +@MULTIMODAL_REGISTRY.register_processor(KimiVLMultiModalProcessor, + info=KimiVLProcessingInfo, + dummy_inputs=KimiVLDummyInputsBuilder) +class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal): + + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: + super().__init__() + model_config = vllm_config.model_config + config: KimiVLConfig = model_config.hf_config + self.config = config + quant_config = vllm_config.quant_config + + assert isinstance(config.vision_config, MoonViTConfig) + + self.vision_tower = MoonVitPretrainedModel(config.vision_config) + + self.multi_modal_projector = KimiVLMultiModalProjector(config=config) + + self.quant_config = quant_config + sub_vllm_config = copy.deepcopy(vllm_config) + sub_vllm_config.model_config.hf_config = sub_vllm_config.model_config.hf_config.text_config + self.language_model = DeepseekV2Model( + vllm_config=sub_vllm_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + self.unpadded_vocab_size = config.text_config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.text_config.hidden_size, + org_num_embeddings=self.config.text_config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE) + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, logit_scale) + self.media_placeholder: int = self.config.media_placeholder_token_id + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_world_size = get_tensor_model_parallel_world_size() + + # ref: qwen2_vl.py + def _validate_and_reshape_mm_tensor(self, mm_input: object, + name: str) -> torch.Tensor: + if not isinstance(mm_input, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of {name}. " + f"Got type: {type(mm_input)}") + if isinstance(mm_input, torch.Tensor): + if mm_input.ndim == 2: + return mm_input + if mm_input.ndim != 3: + raise ValueError(f"{name} should be 2D or batched 3D tensor. " + f"Got ndim: {mm_input.ndim} " + f"(shape={mm_input.shape})") + return mm_input.reshape(-1, mm_input.shape[-1]) + else: + return torch.concat(mm_input) + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[KimiVLImageInputs]: + # image input type must be pixel values now + pixel_values = kwargs.pop("pixel_values", None) + image_grid_hws = kwargs.pop("image_grid_hws", None) + + if pixel_values is None: + return None + + image_grid_hws = self._validate_and_reshape_mm_tensor( + image_grid_hws, "image grid hws") + # pixel_values may have complex shapes + num_channels = 3 + patch_size = self.config.vision_config.patch_size + if isinstance(pixel_values, list): + pixel_values = torch.cat([ + x.reshape(-1, num_channels, patch_size, patch_size) + for x in pixel_values + ]) + else: + pixel_values = pixel_values.reshape(-1, num_channels, patch_size, + patch_size) + pixel_values = pixel_values.to(self.vision_tower.dtype) + # image_grid_hws.shape = (N, 2) + assert image_grid_hws.ndim == 2, f"unexpected shape for image_grid_hws: {image_grid_hws.shape}" + + return KimiVLImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values, + image_grid_hws=image_grid_hws, + ) + + # perform vt on processored pixel_values + @torch.inference_mode() + def _process_image_pixels(self, + inputs: KimiVLImagePixelInputs) -> torch.Tensor: + assert self.vision_tower is not None + + pixel_values = inputs["pixel_values"] + image_grid_hws = inputs["image_grid_hws"] + return self.vision_tower(pixel_values, image_grid_hws) + + def _process_image_input(self, + image_input: KimiVLImageInputs) -> torch.Tensor: + assert image_input["type"] == "pixel_values" + image_features = self._process_image_pixels(image_input) + assert isinstance(image_features, list) + lengths = [x.shape[0] for x in image_features] + return self.multi_modal_projector( + torch.cat(image_features)).split(lengths) + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings(self, + **kwargs: object) -> Optional[NestedTensors]: + # Validate the multimodal input keyword arguments + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + + # Run multimodal inputs through encoder and projector + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + + # `get_input_embeddings` should already be implemented for the language + # model as one of the requirements of basic vLLM model implementation. + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + placeholder_token_id=self.config.media_placeholder_token_id) + + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> IntermediateTensors: + if intermediate_tensors is not None: + inputs_embeds = None + # NOTE: In v1, inputs_embeds is always generated at model runner from + # `get_multimodal_embeddings` and `get_input_embeddings`, this + # condition is only for v0 compatibility. + elif inputs_embeds is None: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + inputs_embeds = None + else: + inputs_embeds = self.get_input_embeddings(input_ids) + image_embeds = self._process_image_input(image_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + image_embeds, + placeholder_token_id=self.config. + media_placeholder_token_id, + ) + input_ids = None + + hidden_states = self.language_model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + **kwargs) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata, **kwargs) + return logits + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + config = self.config.text_config + _KEYS_TO_MODIFY_MAPPING = { + "language_model.lm_head": "lm_head", + "language_model.model": "language_model", + } + # only doing this for language model part for now. + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + if not config.use_mla: + stacked_params_mapping += [ + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + ] + if getattr(config, "n_routed_experts", None): + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=config.n_routed_experts) + else: + expert_params_mapping = [] + + params_dict = dict(self.named_parameters()) + for args in weights: + name, loaded_weight = args[:2] + kwargs = args[2] if len(args) > 2 else {} + if "rotary_emb.inv_freq" in name: + continue + + spec_layer = get_spec_layer_idx_from_weight_name(config, name) + if spec_layer is not None: + continue # skip spec decode layers for main model + + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in name: + name = name.replace(key_to_modify, new_key) + use_default_weight_loading = False + if "vision" in name: + if self.vision_tower is not None: + # We only do sharding for language model and + # not vision model for now. + use_default_weight_loading = True + else: + for (param_name, weight_name, + shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id, **kwargs) + break + else: + for idx, (param_name, weight_name, expert_id, + shard_id) in enumerate(expert_params_mapping): + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + expert_id=expert_id, + shard_id=shard_id, + **kwargs) + break + else: + use_default_weight_loading = True + if use_default_weight_loading: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight, **kwargs) + + +def get_spec_layer_idx_from_weight_name(config: DeepseekV2Config, + weight_name: str) -> Optional[int]: + if hasattr(config, + "num_nextn_predict_layers") and (config.num_nextn_predict_layers + > 0): + layer_idx = config.num_hidden_layers + for i in range(config.num_nextn_predict_layers): + if weight_name.startswith(f"model.layers.{layer_idx+i}."): + return layer_idx + i + return None diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index c8aeb38763085dd15ce08343525084d89b98f059..a30e2027a443118de33bf4b32e09662d0449bb04 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -44,7 +44,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -139,8 +138,8 @@ class LlamaAttention(nn.Module): self.head_dim = getattr(config, "head_dim", self.hidden_size // self.total_num_heads) # Phi models introduced a partial_rotary_factor parameter in the config - partial_rotary_factor = getattr(config, "partial_rotary_factor", 1) - self.rotary_dim = int(partial_rotary_factor * self.head_dim) + self.partial_rotary_factor = getattr(config, "partial_rotary_factor", + 1) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -172,11 +171,12 @@ class LlamaAttention(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.rotary_dim, + rotary_dim=self.head_dim, max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, is_neox_style=is_neox_style, + partial_rotary_factor=self.partial_rotary_factor, ) if hasattr(config, "interleaved_sliding_window"): @@ -346,6 +346,8 @@ class LlamaModel(nn.Module): else: self.norm = PPMissingLayer() + self.aux_hidden_state_layers: tuple[int] = tuple() + self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) @@ -372,7 +374,8 @@ class LlamaModel(nn.Module): positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor, + list[torch.Tensor]]]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -384,7 +387,11 @@ class LlamaModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + aux_hidden_states = [] + for idx, layer in enumerate( + self.layers[self.start_layer:self.end_layer]): + if idx in self.aux_hidden_state_layers: + aux_hidden_states.append(hidden_states + residual) hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: @@ -394,6 +401,9 @@ class LlamaModel(nn.Module): }) hidden_states, _ = self.norm(hidden_states, residual) + + if len(aux_hidden_states) > 0: + return hidden_states, aux_hidden_states return hidden_states def load_weights(self, weights: Iterable[Tuple[str, @@ -679,11 +689,16 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): else: self.lm_head = PPMissingLayer() - self.sampler = get_sampler() - self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]: + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + def _init_model(self, vllm_config: VllmConfig, prefix: str = "", @@ -715,11 +730,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): sampling_metadata) return logits - def sample(self, logits: torch.Tensor, - sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 51efbfe202f0b44f65c827d88bd97b15aede6846..e5d1a671f5d6f88eb1360b6d73cd32ac5dc4dd03 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -51,8 +51,8 @@ class Llama4MoE(nn.Module): renormalize: bool, ) -> Tuple[torch.Tensor, torch.Tensor]: router_scores, router_indices = fast_topk(gating_output, topk, dim=-1) - router_scores = torch.sigmoid(router_scores.float()).to( - hidden_states.dtype) + # psuedo-standard is that the router scores are floats + router_scores = torch.sigmoid(router_scores.float()) return (router_scores, router_indices.to(torch.int32)) def __init__(self, diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index 28ad6128c4f19f98b9c167670ac009de862c13f4..56e53ac2b8158ee854004d632063ecf91a4bd48f 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -70,7 +70,7 @@ class LlamaModel(nn.Module): input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor]: input_embeds = self.embed_tokens(input_ids) hidden_states = self.fc( torch.cat((input_embeds, hidden_states), dim=-1)) @@ -82,7 +82,8 @@ class LlamaModel(nn.Module): hidden_states, residual, ) - return hidden_states + residual + hidden_states = hidden_states + residual + return hidden_states, hidden_states def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: @@ -132,7 +133,7 @@ class EagleLlamaForCausalLM(LlamaForCausalLM): input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor]: return self.model(input_ids, positions, hidden_states) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py new file mode 100644 index 0000000000000000000000000000000000000000..0b18e4a8fe2f328123b25ca34a0bb83780e1007e --- /dev/null +++ b/vllm/model_executor/models/llama_eagle3.py @@ -0,0 +1,232 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Iterable, Optional, Set, Tuple + +import torch +import torch.nn as nn +from transformers import LlamaConfig + +from vllm.config import ModelConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import QKVParallelLinear +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.llama import (LlamaDecoderLayer, + LlamaForCausalLM) +from vllm.v1.sample.metadata import SamplingMetadata + +from .utils import AutoWeightsLoader, maybe_prefix + +logger = init_logger(__name__) + + +class LlamaDecoderLayer(LlamaDecoderLayer): + + def __init__( + self, + config: LlamaConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config, quant_config=quant_config, prefix=prefix) + + # override qkv + self.self_attn.qkv_proj = QKVParallelLinear( + 2 * self.hidden_size, + self.self_attn.head_dim, + self.self_attn.total_num_heads, + self.self_attn.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "qkv_proj"), + ) + + self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + embeds: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + + residual = hidden_states + embeds = self.input_layernorm(embeds) + hidden_states = self.hidden_norm(hidden_states) + + hidden_states = torch.cat([embeds, hidden_states], dim=-1) + # Self Attention + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + # Fully Connected + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +class LlamaModel(nn.Module): + + def __init__( + self, + *, + model_config: ModelConfig, + start_layer_id: int = 0, + prefix: str = "", + ) -> None: + super().__init__() + self.config = model_config.hf_config + self.vocab_size = self.config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "embed_tokens"), + ) + self.layers = nn.ModuleList([ + LlamaDecoderLayer( + self.config, + prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"), + ) + ]) + if hasattr(self.config, "target_hidden_size"): + self.fc = torch.nn.Linear(self.config.target_hidden_size * 3, + self.config.hidden_size, + bias=False) + else: + self.fc = torch.nn.Linear(self.config.hidden_size * 3, + self.config.hidden_size, + bias=False) + self.norm = RMSNorm( + self.config.hidden_size, + eps=self.config.rms_norm_eps, + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + input_embeds = self.embed_tokens(input_ids) + if (hidden_states.shape[-1] != input_embeds.shape[-1]): + hidden_states = self.fc(hidden_states) + + residual = None + hidden_states, residual = self.layers[0]( + positions, + input_embeds, + hidden_states, + residual, + ) + + hidden_states, hidden_prenorm = self.norm(hidden_states, residual) + return hidden_states, hidden_prenorm + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if 'midlayer.' in name: + name = name.replace('midlayer.', 'layers.0.') + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Eagle3LlamaForCausalLM(LlamaForCausalLM): + + def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0): + nn.Module.__init__(self) + self.config = model_config.hf_config + self.model = LlamaModel(model_config=model_config, + start_layer_id=start_layer_id, + prefix="model") + + logit_scale = getattr(self.config, "logit_scale", 1.0) + self.lm_head = ParallelLMHead( + self.config.draft_vocab_size, + self.config.hidden_size, + org_num_embeddings=self.config.draft_vocab_size, + padding_size=(DEFAULT_VOCAB_PADDING_SIZE), + prefix="") + self.logits_processor = LogitsProcessor(self.config.draft_vocab_size, + scale=logit_scale) + self.draft_id_to_target_id = nn.Parameter( + torch.zeros((self.config.draft_vocab_size), + dtype=torch.long).type(torch.LongTensor), + requires_grad=False, + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + return self.model(input_ids, positions, hidden_states) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + base = torch.arange(self.config.draft_vocab_size, device=logits.device) + targets = base + self.draft_id_to_target_id + logits_new = logits.new_full(( + logits.shape[0], + self.config.vocab_size, + ), float('-inf')) + logits_new[:, targets] = logits + return logits_new + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader( + self, + skip_prefixes=None, + ) + + model_weights = {} + for name, loaded_weight in weights: + if "t2d" in name: + continue + if "d2t" in name: + name = name.replace("d2t", "draft_id_to_target_id") + elif "lm_head" not in name: + name = "model." + name + model_weights[name] = loaded_weight + + return loader.load_weights(model_weights.items()) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index fbd212d1700447afefcb0d9f287054a759c6653b..8862b2679f934219d0f8b61ea64751175e8d359c 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -2,7 +2,6 @@ from abc import abstractmethod from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict, TypeVar, Union, cast) @@ -23,7 +22,6 @@ from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, @@ -546,13 +544,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size expected_dims = (3, h, w) @@ -763,13 +754,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 9c4d0e1fc275e2751d3d36cb1967d59bbc868c7b..c646c0f03d1ebfb2152dcb7da4f0a2adddb4c8f3 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 from abc import abstractmethod -from functools import cached_property from typing import (Final, Iterable, List, Literal, Mapping, Optional, Protocol, Set, Tuple, TypedDict, TypeVar, Union) @@ -13,7 +12,6 @@ from transformers.models.llava_next.modeling_llava_next import ( from typing_extensions import NotRequired from vllm.config import VllmConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalFieldConfig @@ -250,13 +248,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: expected_dims = (2, ) @@ -585,13 +576,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 0221c6b237cbb3712c29255ea6dc9886b5455147..a5ff189cfdb50be8577dd066acd0076622bee141 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -2,7 +2,6 @@ import math from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union import torch @@ -12,7 +11,6 @@ from transformers import (BatchFeature, LlavaNextVideoConfig, from vllm.config import VllmConfig from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -301,13 +299,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, self.make_empty_intermediate_tensors = ( self.language_model.model.make_empty_intermediate_tensors) - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - def _validate_video_pixel_values( self, data: Union[torch.Tensor, List[torch.Tensor]] ) -> Union[torch.Tensor, List[torch.Tensor]]: @@ -469,13 +460,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 60d32c924694caf26eaf436caf4556e33fcabf09..5c2b388e403dfa548ea05ca0e4f18a42f6f32304 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -2,7 +2,6 @@ import math from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple, TypedDict, Union) @@ -16,7 +15,6 @@ from typing_extensions import NotRequired from vllm.config import VllmConfig from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, @@ -455,13 +453,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, self.make_empty_intermediate_tensors = ( self.language_model.model.make_empty_intermediate_tensors) - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: expected_dims = (2, ) @@ -583,21 +574,21 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: - modalities = {} + mm_input_by_modality = {} # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key in ("pixel_values", - "image_embeds") and "images" not in modalities: - modalities["images"] = self._parse_and_validate_image_input( - **kwargs) - if input_key in ("pixel_values_videos", - "video_embeds") and "videos" not in modalities: - modalities["videos"] = self._parse_and_validate_video_input( - **kwargs) + if input_key in ("pixel_values", "image_embeds" + ) and "image" not in mm_input_by_modality: + mm_input_by_modality[ + "image"] = self._parse_and_validate_image_input(**kwargs) + if input_key in ("pixel_values_videos", "video_embeds" + ) and "video" not in mm_input_by_modality: + mm_input_by_modality[ + "video"] = self._parse_and_validate_video_input(**kwargs) - return modalities + return mm_input_by_modality def _select_image_features(self, image_features: torch.Tensor, *, strategy: str) -> torch.Tensor: @@ -848,8 +839,9 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: - modalities = self._parse_and_validate_multimodal_inputs(**kwargs) - if not modalities: + mm_input_by_modality = self._parse_and_validate_multimodal_inputs( + **kwargs) + if not mm_input_by_modality: return None # The result multimodal_embeddings is tuple of tensors, with each @@ -858,14 +850,13 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, # NOTE: It is important to iterate over the keys in this dictionary # to preserve the order of the modalities. - for modality in modalities: - if modality == "images": - image_input = modalities["images"] - vision_embeddings = self._process_image_input(image_input) + for modality in mm_input_by_modality: + multimodal_input = mm_input_by_modality[modality] + if modality == "image": + vision_embeddings = self._process_image_input(multimodal_input) multimodal_embeddings += tuple(vision_embeddings) - if modality == "videos": - video_input = modalities["videos"] - video_embeddings = self._process_video_pixels(video_input) + if modality == "video": + video_embeddings = self._process_video_pixels(multimodal_input) multimodal_embeddings += tuple(video_embeddings) return multimodal_embeddings @@ -957,13 +948,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 7a525ad8e494f5cfc83b96cf89d1ed8f14d3c4ca..af78ece66bbed8d910d52cde962a98f32909e280 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -14,7 +14,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -27,7 +26,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -154,6 +153,26 @@ class MambaModel(nn.Module): return hidden_states + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "A_log" in name: + name = name.replace("A_log", "A") + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP, SupportsV0Only): @@ -193,7 +212,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.backbone.make_empty_intermediate_tensors) @@ -247,30 +265,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP, sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "A_log" in name: - name = name.replace("A_log", "A") - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index 526dec46ff29a049806fb37aad66b44fad7cec15..78303733f6bb54f22f8429c47eb012a965d4fa30 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -19,7 +19,6 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import ( MambaMixer2, extra_groups_for_head_shards) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -208,7 +207,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.backbone.make_empty_intermediate_tensors) @@ -282,14 +280,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree, sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters()) diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index cf03396a9ca9970cebc31f744812ab7638ebbffe..866dc3f466e7967e7dbb57904e5efcca491db21b 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -45,7 +45,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -553,7 +552,6 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.logits_processor = LogitsProcessor(unpadded_vocab_size, config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -584,14 +582,6 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 1a91cf9bab478a560fd082c7548a8ca022b1f917..65a26eadd5c81348461a1fb7eb6ec576850ab047 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -25,7 +25,7 @@ import math from collections import defaultdict from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property, partial +from functools import partial from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict, Union) @@ -40,7 +40,6 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2, get_2d_sincos_pos_embed) -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.minicpm import MiniCPMForCausalLM @@ -758,13 +757,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): self.make_empty_intermediate_tensors = ( self.llm.make_empty_intermediate_tensors) - @cached_property - def sampler(self): - if hasattr(self.llm, "sampler"): - return self.llm.sampler - - return get_sampler() - def _parse_and_validate_vision_input( self, modality: str, @@ -946,14 +938,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ) -> Optional[torch.Tensor]: return self.llm.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 7562aa678d5abd1aaf2afc4adf5c239e1884de2c..74be08159cd8fffedcedb41636357e576d43383d 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -33,7 +33,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -994,7 +993,6 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.config.vocab_size) - self.sampler = Sampler() else: self.lm_head = PPMissingLayer() @@ -1030,16 +1028,6 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid, return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ): - - next_tokens = self.sampler(logits, sampling_metadata) - - return next_tokens - def make_empty_intermediate_tensors( self, batch_size: int, dtype: torch.dtype, device: torch.device) -> IntermediateTensors: diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index 8b1a1d68fc3fa9b135c67e86caea1441269ce840..f8e9e318136734a82a01683ce0eef911cd3580c8 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -2,7 +2,6 @@ from abc import abstractmethod from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict, TypeVar, Union) @@ -19,7 +18,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, @@ -274,6 +272,9 @@ class Mistral3MultiModalProcessor( vision_config = hf_config.vision_config assert isinstance(vision_config, PixtralVisionConfig) + # Need to sneak in spatial_merge_size for Mistral3 + vision_config.spatial_merge_size = getattr(hf_config, + "spatial_merge_size", 1) encoder_info = PixtralHFEncoderInfo(vision_config) def get_replacement(item_idx: int): @@ -435,13 +436,6 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size expected_dims = (3, h, w) @@ -598,13 +592,6 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal, return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 1bc5acf9b3d2d5e7de5301cbbd3268556b38cc47..5e1543c7c285ec9de73e99e47b080a38f115a3a9 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -42,7 +42,6 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -489,7 +488,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -516,14 +514,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self, skip_prefixes=["rotary_emb.inv_freq"]) diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 96eb925cf894d103c74a4ad8d5e28d7e809e0a8b..7c022a5b8f689aa8d5278300cfa69bde1fe5d90a 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -42,7 +42,6 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -372,7 +371,6 @@ class MixtralForCausalLM(nn.Module, SupportsPP): if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -399,14 +397,6 @@ class MixtralForCausalLM(nn.Module, SupportsPP): sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 7bfb3ada6bb46606d9a2707b0b19b46b8cd4dae9..0c1d61c01f910d0504949bfbbb214fdb16fb854b 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -47,7 +47,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -1211,7 +1210,6 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, ) self.logits_processor = LogitsProcessor(config.output_hidden_states, config.text_config.vocab_size) - self.sampler = get_sampler() def compute_logits( self, @@ -1222,14 +1220,6 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, hidden_states, sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def unpack_data(self, image_data: Union[List[torch.Tensor], torch.Tensor], padding_value=0) -> torch.Tensor: diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 817035a0ec19ec24ba91279adc73248f716eface..acfaacf743a0aad54d1118ddb6ef9b869bca1861 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -17,7 +17,6 @@ # limitations under the License. import math from collections.abc import Iterable, Mapping -from functools import cached_property from itertools import tee from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union @@ -38,7 +37,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.loader import _initialize_model from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -672,9 +670,9 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, self.config, None, prefix=maybe_prefix(prefix, "multi_modal_projector")) - self.language_model = _initialize_model( - vllm_config=vllm_config.with_hf_config(config.text_config), + vllm_config=vllm_config.with_hf_config(config.text_config, + ["LlamaForCausalLM"]), prefix=maybe_prefix(prefix, "language_model"), model_class=Llama4ForCausalLM, ) @@ -682,13 +680,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[Llama4ImagePatchInputs]: # num_images, 1, num_chunks, channel, image_size, image_size @@ -785,10 +776,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample(self, logits: torch.Tensor, - sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def separate_weights( self, weights: Iterable[Tuple[str, torch.Tensor]], @@ -824,7 +811,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, # language_model is an Llama4ForCausalLM instance. We load it's # using llama4's load_weights routine. language_model_weights, other_weights = self.separate_weights( - weights, prefix="language_model.model.") + weights, prefix="language_model.") loader = AutoWeightsLoader(self) loaded_language_model_params = loader.load_weights( language_model_weights) diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py new file mode 100644 index 0000000000000000000000000000000000000000..2190241f0ba3ca8b8a447bc91d9c7ef2200d17dd --- /dev/null +++ b/vllm/model_executor/models/modernbert.py @@ -0,0 +1,325 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Iterable, Optional, Set, Tuple + +import torch +from torch import nn +from transformers import ModernBertConfig + +from vllm.attention import Attention, AttentionType +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.linear import (QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.pooler import CrossEncodingPooler +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.sequence import IntermediateTensors, PoolerOutput + +from .interfaces import SupportsCrossEncoding +from .utils import WeightsMapper, maybe_prefix + + +class ModernBertEmbeddings(nn.Module): + + def __init__(self, config: ModernBertConfig): + + super().__init__() + self.config = config + self.tok_embeddings = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.norm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps, + bias=config.norm_bias) + + def forward( + self, + input_ids: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if inputs_embeds: + return self.norm(inputs_embeds) + else: + inputs_embeds = self.tok_embeddings(input_ids) + embeddings = self.norm(inputs_embeds) + return embeddings + + +class ModernBertRotaryEmbedding(RotaryEmbedding): + + def __init__(self, config: ModernBertConfig, head_size: int, dim: int, + base: float): + super().__init__( + head_size=head_size, + rotary_dim=dim, + max_position_embeddings=config.max_position_embeddings, + base=base, + is_neox_style=True, + dtype=torch.float16) + self.config = config + + +class ModernBertAttention(nn.Module): + + def __init__(self, + config: ModernBertConfig, + layer_id: Optional[int] = None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.layer_id = layer_id + self.deterministic_flash_attn = config.deterministic_flash_attn + self.num_heads = config.num_attention_heads + assert self.num_heads % tp_size == 0 + self.head_dim = config.hidden_size // config.num_attention_heads + self.all_head_size = self.head_dim * self.num_heads + self.scaling = self.head_dim**-0.5 + self.Wqkv = QKVParallelLinear( + config.hidden_size, + self.head_dim, + self.num_heads, + bias=config.attention_bias, + ) + + if layer_id % config.global_attn_every_n_layers != 0: + self.local_attention = (config.local_attention // 2, + config.local_attention // 2) + else: + self.local_attention = (-1, -1) + + rope_theta = config.global_rope_theta + if self.local_attention != ( + -1, -1) and config.local_rope_theta is not None: + rope_theta = config.local_rope_theta + self.rotary_emb = ModernBertRotaryEmbedding(config=config, + head_size=self.head_dim, + dim=self.head_dim, + base=rope_theta) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + prefix=f"{layer_id}.attn", + attn_type=AttentionType.ENCODER_ONLY) + self.Wo = RowParallelLinear(config.hidden_size, + config.hidden_size, + bias=config.attention_bias) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + qkv, _ = self.Wqkv(hidden_states) + q, k, v = qkv.split([self.all_head_size] * 3, dim=-1) + q, k = self.rotary_emb(position_ids, q, k) + attn_outputs = self.attn(q, k, v) + hidden_states = attn_outputs + hidden_states, _ = self.Wo(hidden_states) + return hidden_states + + +class ModernBertMLP(nn.Module): + + def __init__(self, config: ModernBertConfig): + super().__init__() + self.config = config + self.Wi = nn.Linear(config.hidden_size, + int(config.intermediate_size) * 2, + bias=config.mlp_bias) + self.act = nn.GELU() + self.Wo = RowParallelLinear(config.intermediate_size, + config.hidden_size, + bias=config.mlp_bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input, gate = self.Wi(hidden_states).chunk(2, dim=-1) + return self.Wo(self.act(input) * gate)[0] + + +class ModernBertLayer(nn.Module): + + def __init__(self, + config: ModernBertConfig, + prefix: str = "", + layer_id: Optional[int] = None): + super().__init__() + self.config = config + if layer_id == 0: + self.attn_norm = nn.Identity() + else: + self.attn_norm = nn.LayerNorm(config.hidden_size, + eps=config.norm_eps, + bias=config.norm_bias) + self.attn = ModernBertAttention(config=config, layer_id=layer_id) + self.mlp_norm = nn.LayerNorm(config.hidden_size, + eps=config.norm_eps, + bias=config.norm_bias) + self.mlp = ModernBertMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + ): + attn_outputs = self.attn(self.attn_norm(hidden_states), + position_ids=position_ids) + hidden_states = hidden_states + attn_outputs + mlp_output = self.mlp(self.mlp_norm(hidden_states)) + hidden_states = hidden_states + mlp_output + return hidden_states + + +class ModernBertEncoderLayer(nn.Module): + + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.layers = nn.ModuleList([ + ModernBertLayer(config=config, layer_id=layer_id) + for layer_id in range(config.num_hidden_layers) + ]) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + for i, layer in enumerate(self.layers): + hidden_states = layer(hidden_states, position_ids) + return hidden_states + + +@support_torch_compile +class ModernBertModel(nn.Module): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={"layers.": "encoder_layer.layers."}) + + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + config = vllm_config.model_config.hf_config + self.config = config + self.embeddings = ModernBertEmbeddings(config) + self.encoder_layer = ModernBertEncoderLayer(vllm_config) + self.final_norm = nn.LayerNorm(config.hidden_size, + eps=config.norm_eps, + bias=config.norm_bias) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + weights = self.hf_to_vllm_mapper.apply(weights) + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embeddings(input_ids=input_ids, + inputs_embeds=inputs_embeds) + + outputs = self.encoder_layer( + hidden_states=hidden_states, + position_ids=position_ids, + ) + norm_outputs = self.final_norm(outputs) + return norm_outputs + + +class ModernBertPooler(nn.Module): + + def __init__(self, config: ModernBertConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size, + config.classifier_bias) + self.act = nn.GELU() + self.norm = nn.LayerNorm(config.hidden_size, + eps=config.norm_eps, + bias=config.norm_bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + pooled_output = hidden_states + pooled_output = pooled_output.mean(dim=0, keepdim=False) + pooled_output = self.norm(self.act(self.dense(pooled_output))) + return pooled_output + + +class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.config = config + self.model = ModernBertModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "modernbert")) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + self._pooler = CrossEncodingPooler(config, self.classifier, + ModernBertPooler(config)) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + self_weights = [] + + def weight_filter(): + for name, weight in weights: + if name.startswith("model."): + yield name[len("model."):], weight + else: + self_weights.append((name, weight)) + + self.model.load_weights(weight_filter()) + + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in self_weights: + if name.startswith("classifier"): + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + if name.startswith("head"): + param = params_dict["_pooler.pooler." + name[len("head") + 1:]] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def forward( + self, + input_ids: Optional[torch.LongTensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.model( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + position_ids=positions, + ) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index d75845b45e73339a68077c49d552a0bf142b9b64..46147a333b06e8ecbf88d9a974fb63ff01dc0c13 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -35,7 +35,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -1394,7 +1393,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, self.logits_processor = LogitsProcessor(config.embedding_size or config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -1506,7 +1504,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, - ) -> SamplerOutput: + ) -> torch.Tensor: if intermediate_tensors is not None: inputs_embeds = None @@ -1532,14 +1530,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/moonvit.py b/vllm/model_executor/models/moonvit.py new file mode 100644 index 0000000000000000000000000000000000000000..c367d90f847b6209c32ac7510de95f75dbdc4e3e --- /dev/null +++ b/vllm/model_executor/models/moonvit.py @@ -0,0 +1,628 @@ +# SPDX-License-Identifier: Apache-2.0 +# ruff: noqa: E501 +# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/modeling_kimi_vl.py +# This file is meant to be used in kimi_vl.py only +# Copyright 2025 The Moonshot AI Team, DeepSeek-AI, and HuggingFace Inc. team. All rights reserved. +# +# The code is based on llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py), but modified for KimiVL. +# +# Licensing Information: +# - Code derived from llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py) is licensed under the Apache License, Version 2.0. +# - Other parts of the code are licensed under the MIT License. +# +# Apache License, Version 2.0: +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MIT License: +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +import math +from copy import deepcopy +from functools import cached_property +from typing import List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.activations import ACT2FN, PytorchGELUTanh +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import is_flash_attn_2_available + +from vllm.transformers_utils.configs.moonvit import MoonViTConfig + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_varlen_func +else: + flash_attn_varlen_func = None + + +def multihead_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + q_cu_seqlens: Optional[torch.Tensor] = None, + k_cu_seqlens: Optional[torch.Tensor] = None, +): + """Multi-head attention using flash attention 2. + + Args: + q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim), + or (tot_seqlens, num_heads, head_dim) if packing. + q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q. + The first element should be 0 and the last element should be q.shape[0]. + k_cu_seqlens (torch.Tensor): cumulative sequence lengths of k. + The first element should be 0 and the last element should be k.shape[0]. + + Returns: + output: shape (batch_size, seqlen, dim) or (tot_seqlens, dim) if packing, + where dim = num_heads * head_dim + """ + # Unified format legal check + assert q.dim() == k.dim() == v.dim() == 3, "q, k, v must have 3 dims" + assert q_cu_seqlens[-1] == q.shape[ + 0], "q_cu_seqlens must sum to q.shape[0]" + assert (k_cu_seqlens[-1] == k.shape[0] == + v.shape[0]), "k_cu_seqlens must sum to k.shape[0]" + assert q.dtype in [ + torch.bfloat16, + torch.float16, + ], f"unsupported dtype {q.dtype} for multihead attn" + + max_seqlen_q = (q_cu_seqlens[1:] - q_cu_seqlens[:-1]).max().item() + max_seqlen_k = (k_cu_seqlens[1:] - k_cu_seqlens[:-1]).max().item() + attn_out = flash_attn_varlen_func( + q, + k, + v, + q_cu_seqlens, + k_cu_seqlens, + max_seqlen_q, + max_seqlen_k, + causal=False, + ) + attn_out = attn_out.flatten(start_dim=-2) + + return attn_out + + +def sdpa_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + q_cu_seqlens: Optional[torch.Tensor] = None, + k_cu_seqlens: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """SDPA attention. + + Args: + q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim), + or (tot_seqlens, num_heads, head_dim) if packing. + """ + seq_length = q.shape[0] + attention_mask = torch.zeros([1, seq_length, seq_length], + device=q.device, + dtype=torch.bool) + for i in range(1, len(q_cu_seqlens)): + attention_mask[ + ..., + q_cu_seqlens[i - 1]:q_cu_seqlens[i], + q_cu_seqlens[i - 1]:q_cu_seqlens[i], + ] = True + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_output = F.scaled_dot_product_attention(q, + k, + v, + attention_mask, + dropout_p=0.0) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + return attn_output + + +VL_VISION_ATTENTION_FUNCTIONS = { + "flash_attention_2": multihead_attention, + "sdpa": sdpa_attention, +} + + +def _apply_rope_input_validation(x, freqs_cis): + assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape) + assert x.shape[:-2] == freqs_cis.shape[:-1], (x.shape, freqs_cis.shape) + assert x.shape[-1] == 2 * freqs_cis.shape[-1], (x.shape, freqs_cis.shape) + assert freqs_cis.dtype == torch.complex64, freqs_cis.dtype + + +def apply_rope(xq: torch.Tensor, xk: torch.Tensor, + freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Args: (The leading dimensions of all inputs should be the same) + xq: query, tensor of shape (..., num_heads, head_dim) + xk: key, tensor of shape (..., num_heads, head_dim) + freqs_cis: tensor of shape (..., head_dim/2), dtype=torch.complex64. It contains the precomputed cis(freqs) for each position in the 2D grid. + Returns: + xq_out, xk_out: tensors of shape (..., num_heads, head_dim) + """ + _apply_rope_input_validation(xq, freqs_cis) + _apply_rope_input_validation(xk, freqs_cis) + + freqs_cis = freqs_cis.unsqueeze(-2) # ..., 1, head_dim/2 + # ..., num_heads, head_dim/2 + xq_ = torch.view_as_complex(xq.float().view(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().view(*xq.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten( + -2) # ..., num_heads, head_dim + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten( + -2) # ..., num_heads, head_dim + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class Learnable2DInterpPosEmb(nn.Module): + + def __init__(self, + height: int, + width: int, + dim: int, + interpolation_mode: str = "bicubic") -> None: + super().__init__() + self.height = height + self.width = width + self.interpolation_mode = interpolation_mode + self.weight = nn.Parameter(torch.empty(height, width, dim)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.normal_(self.weight) + + def forward(self, x: torch.Tensor, grid_hws: torch.Tensor) -> torch.Tensor: + pos_embs = [] + for shape in grid_hws.tolist(): + if shape == self.weight.shape[:-1]: + pos_embs.append(self.weight.flatten(end_dim=1)) + else: + pos_embs.append( + F.interpolate( + self.weight.permute((2, 0, 1)).unsqueeze(0), + size=shape, + mode=self.interpolation_mode, + ).squeeze(0).permute((1, 2, 0)).flatten(end_dim=1)) + out = x + torch.cat(pos_embs) + return out + + +class MoonVisionPatchEmbed(nn.Module): + + def __init__( + self, + out_dim: int, + in_dim: int = 3, + patch_size: Union[int, Tuple[int, int]] = (14, 14), + pos_emb_height: int = 14, + pos_emb_width: int = 14, + ): + super().__init__() + assert isinstance( + patch_size, + (int, Sequence)), f"Invalid patch_size type: {type(patch_size)}" + if isinstance(patch_size, int): + patch_size = (patch_size, patch_size) + assert (len(patch_size) == 2 + ), f"Expected patch_size to be a tuple of 2, got {patch_size}" + self.patch_size = patch_size + + self.proj = nn.Conv2d(in_dim, + out_dim, + kernel_size=patch_size, + stride=patch_size) + + self.pos_emb = Learnable2DInterpPosEmb(height=pos_emb_height, + width=pos_emb_width, + dim=out_dim) + + def forward(self, x: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor: + """ + Args: + x (L, Channels): input tensor + grid_hw (N, 2): grid height and width + + Returns: + (L, Cout) tensor + """ + x = self.proj(x).view(x.size(0), -1) + # apply positional embedding + x = self.pos_emb(x, grid_hw) + return x + + +class Rope2DPosEmb(nn.Module): + """2D rotary position embedding with multi-resolution support. + + This class is intended to be used in the following way: + 1. Before training, create an instance of Rope2DPosEmb. This instance will hold the precomputed cis. + 2. Before each forward pass, call `get_freqs_cis_by_*` to get the `freqs_cis` tensor for this iteration. + 3. During the forward pass, pass the `freqs_cis` tensor to each attention layer, and call `apply` just before each attention operation. + The rope is shared across all attention layers and all heads. + + Refs: + - RoFormer: https://arxiv.org/abs/2104.09864 + - VisionLLaMA: https://arxiv.org/abs/2403.00522 + - https://github.com/Meituan-AutoML/VisionLLaMA/blob/main/dit/models.py + + Args: + dim (int): usually the multi-head attention dimension, should be divisible by 4 (TODO: relax this constraint if needed) + max_height (int): the maximum height of the 2D grid + max_width (int): the maximum width of the 2D grid + theta_base (float): the base of the theta + device (str): the device to store the precomputed cis + """ + + def __init__(self, + dim: int, + max_height: int, + max_width: int, + theta_base=10000, + device="cuda"): + super().__init__() + self.dim = dim + assert self.dim % 4 == 0, "dim must be divisible by 4" + self.max_height = max_height + self.max_width = max_width + self.theta_base = theta_base + self.device = device + + def extra_repr(self): + return f"dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}" + + @cached_property + def precomputed_freqs_cis(self) -> torch.Tensor: + """Calculate the cis(freqs) for each position in the 2D grid. + + Return: complex tensor of shape (max_height, max_width, dim//2) and value: + height axis: ret[h, w, 2*i] = cis(h * theta_base**(-4*i/dim)) + weight axis: ret[h, w, 2*i+1] = cis(w * theta_base**(-4*i/dim)) with (i in [0, dim//4)) + note: `cis` is a mathematical notation defined by cis x = cos x + i sin x, + """ + N = self.max_height * self.max_width + flat_pos = torch.arange(0, N).float().to(self.device) + x_pos = flat_pos % self.max_width + y_pos = flat_pos // self.max_width + dim_range = (torch.arange(0, self.dim, + 4)[:(self.dim // 4)].float().to(self.device) + ) # C/4 + freqs = 1.0 / (self.theta_base**(dim_range / self.dim)) + x_freqs = torch.outer(x_pos, freqs).float() # N, C/4 + y_freqs = torch.outer(y_pos, freqs).float() # N, C/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) # N, C/4 + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) # N, C/4 + # N, C/4, 2 + freqs_cis = torch.cat( + [x_cis.unsqueeze(dim=-1), + y_cis.unsqueeze(dim=-1)], dim=-1) + # max_height, max_width, C/2 + freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1) + return freqs_cis + + def get_freqs_cis_by_seqlens(self, grid_hws: torch.Tensor) -> torch.Tensor: + """ + Args: + grid_hws (torch.Tensor): containing list of (height, width) or (t, height, width) tuples. + Returns: + freqs_cis: tensor of shape (sum(t * height * width), dim//2) + """ + shapes = grid_hws.tolist() + assert all(1 <= h <= self.max_height and 1 <= w <= self.max_width + for h, w in shapes), ( + shapes, + self.max_height, + self.max_width, + ) + freqs_cis = torch.cat( + [ + self.precomputed_freqs_cis[:h, :w].reshape(-1, self.dim // 2) + for h, w in shapes + ], + dim=0, + ) + return freqs_cis + + def get_freqs_cis_by_idx(self, pos_idx: torch.Tensor, + pos_idx_mask: torch.Tensor) -> torch.Tensor: + """ + Args: + pos_idx: tensor of shape (..., 2), It contains the (h, w) position indices of each 2D token. + pos_idx_mask: a mask of shape (...), the leading dimensions should be the same as pos_idx. + Rope will only be applied to the tokens with True mask. `freqs_cis` for the tokens with False mask with be ones. + Return: + freqs_cis: tensor of shape (..., dim//2) + """ + assert (pos_idx.shape[:-1] == pos_idx_mask.shape + and pos_idx.shape[-1] == 2 and pos_idx.ndim + == pos_idx_mask.ndim + 1), (pos_idx.shape, pos_idx_mask.shape) + assert pos_idx_mask.dtype == torch.bool, pos_idx_mask.dtype + + shp = pos_idx_mask.shape + (self.dim // 2, ) # ..., head_dim/2 + freqs_cis = torch.ones(shp, dtype=torch.complex64, + device=self.device) # ..., head_dim/2 + freqs_cis[pos_idx_mask] = self.precomputed_freqs_cis[pos_idx[ + ..., 0][pos_idx_mask], pos_idx[..., 1][pos_idx_mask]] + return freqs_cis + + +class MLP2(nn.Module): + """ + Args: + dims: [in_dim, hidden_dim, out_dim] + bias: whether to use bias in linear layer. + """ + + def __init__(self, dims: list[int], activation, bias=True): + super().__init__() + assert len(dims) == 3 + self.fc0 = nn.Linear(dims[0], dims[1], bias=bias) + self.fc1 = nn.Linear(dims[1], dims[2], bias=bias) + self.activation = activation + for m in [self.fc0, self.fc1]: + nn.init.trunc_normal_(m.weight, std=math.sqrt(2 / m.in_features)) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc0(x) + x = self.activation(x) + return self.fc1(x) + + +class MoonVitEncoderLayer(nn.Module): + + def __init__( + self, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + *, + attn_implementation: str = "sdpa", + activation=F.gelu, + attn_bias: bool = False, + ): + super().__init__() + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads + self.attn_implementation = attn_implementation + # use fa2 in vllm by default + if is_flash_attn_2_available(): + self.attn_implementation = "flash_attention_2" + + self.norm0 = nn.LayerNorm(hidden_dim) + self.norm1 = nn.LayerNorm(hidden_dim) + self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim], activation) + self.wqkv = nn.Linear(hidden_dim, hidden_dim * 3, bias=attn_bias) + self.wo = nn.Linear(hidden_dim, hidden_dim, bias=attn_bias) + + def attention_qkvpacked( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rope_freqs_cis: Optional[torch.Tensor] = None, + ): + """ + Args: + x (torch.Tensor): (batch_size, seqlen, hidden_dim) + cu_seqlens (torch.Tensor): + """ + xqkv = self.wqkv(x) + + qkv_shape = xqkv.size()[:-1] + ( + 3, + self.num_heads, + self.hidden_size_per_attention_head, + ) + # xqkv: (batch_size, seqlen, 3, nheads, headdim) + xqkv = xqkv.view(*qkv_shape) + xq, xk, xv = torch.unbind(xqkv, dim=-3) + + xq, xk = apply_rope(xq, xk, rope_freqs_cis) + + attn_func = VL_VISION_ATTENTION_FUNCTIONS[self.attn_implementation] + attn_out = attn_func(xq, + xk, + xv, + q_cu_seqlens=cu_seqlens, + k_cu_seqlens=cu_seqlens) + + attn_out = self.wo(attn_out) + return attn_out + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rope_freqs_cis: Union[torch.Tensor, None] = None, + ) -> torch.Tensor: + """ + Args: + hidden_states: non-packed (B, N, D) or packed (L, D). if non-packed, seqlens should be None, if packed, seqlens should be set + + Returns: + output: same shape of input, non-packed (B, N, D) for non-packed input, (L, D) for packed input + """ + residual = hidden_states + hidden_states = self.norm0(hidden_states) + attn_out = self.attention_qkvpacked(hidden_states, + cu_seqlens, + rope_freqs_cis=rope_freqs_cis) + hidden_states = residual + attn_out + + residual = hidden_states + hidden_states = self.mlp(self.norm1(hidden_states)) + hidden_states = residual + hidden_states + return hidden_states + + +class MoonVitEncoder(nn.Module): + + def __init__( + self, + hidden_dim: int, + num_layers: int, + block_cfg: dict, + ) -> None: + super().__init__() + + self.rope_2d = Rope2DPosEmb( + block_cfg["hidden_dim"] // block_cfg["num_heads"], 512, 512) + self.blocks = nn.ModuleList( + [MoonVitEncoderLayer(**block_cfg) for _ in range(num_layers)]) + self.final_layernorm = nn.LayerNorm(hidden_dim) + + def forward(self, hidden_states: torch.Tensor, + grid_hw: torch.Tensor) -> torch.Tensor: + rope_freqs_cis = self.rope_2d.get_freqs_cis_by_seqlens( + grid_hws=grid_hw) + + lengths = torch.cat(( + torch.zeros(1, device=hidden_states.device, dtype=grid_hw.dtype), + grid_hw[:, 0] * grid_hw[:, 1], + )) + cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32) + + for _, block in enumerate(self.blocks): + hidden_states = block(hidden_states, + cu_seqlens, + rope_freqs_cis=rope_freqs_cis) + + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states + + +def patch_merger( + x: torch.Tensor, + grid_hw: torch.Tensor, + merge_kernel_size: list[int, int] = (2, 2), +) -> List[torch.Tensor]: + d_model = x.size(-1) + + outputs = [] + pre_sum = 0 + for x_shape in grid_hw.tolist(): + height, width = x_shape[0], x_shape[1] + # Get the current sequence + seq = x[pre_sum:pre_sum + height * width] + # Reshape along self.merge_kernel_size and concat to the last dimension + kernel_height, kernel_width = merge_kernel_size + new_height, new_width = height // kernel_height, width // kernel_width + reshaped_seq = seq.view(new_height, kernel_height, new_width, + kernel_width, d_model) + reshaped_seq = reshaped_seq.permute(0, 2, 1, 3, 4).contiguous() + padded_seq = reshaped_seq.view(new_height * new_width, + kernel_height * kernel_width, -1) + outputs.append(padded_seq) + pre_sum += height * width + + return outputs + + +class MoonVitVLProjector(nn.Module): + + def __init__( + self, + in_channels: int, + merge_kernel_size: list[int, int], + hidden_act: str = "gelu", + ln_eps: float = 1e-5, + out_dim: int = 4096, + ): + super().__init__() + self.hidden_size = in_channels * merge_kernel_size[ + 0] * merge_kernel_size[1] + + self.pre_norm = nn.nn.LayerNorm(in_channels, eps=ln_eps) + self.linear_1 = nn.Linear(self.hidden_size, + self.hidden_size, + bias=True) + self.act = ACT2FN[hidden_act] + self.linear_2 = nn.Linear(self.hidden_size, out_dim, bias=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.pre_norm(hidden_states).view(-1, self.hidden_size) + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class MoonVitPretrainedModel(PreTrainedModel): + config_class = MoonViTConfig + model_type = "moonvit" + _no_split_modules = ["PackingTransformer"] + _supports_flash_attn_2 = True + _supports_sdpa = True + + def __init__(self, config: MoonViTConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + config = deepcopy(config) + self.merge_kernel_size = config.merge_kernel_size + self.patch_size = config.patch_size + self.patch_embed = MoonVisionPatchEmbed( + out_dim=config.hidden_size, + patch_size=config.patch_size, + pos_emb_height=config.init_pos_emb_height, + pos_emb_width=config.init_pos_emb_width, + ) + + self.encoder = MoonVitEncoder( + hidden_dim=config.hidden_size, + num_layers=config.num_hidden_layers, + block_cfg={ + "num_heads": config.num_attention_heads, + "hidden_dim": config.hidden_size, + "mlp_dim": config.intermediate_size, + "activation": PytorchGELUTanh(), + "attn_bias": True, + "attn_implementation": config._attn_implementation, + }, + ) + + def forward(self, pixel_values: torch.Tensor, + grid_hw: torch.Tensor) -> torch.Tensor: + """ + Args: + pixel_values (torch.Tensor): The input pixel values. + grid_hw (torch.Tensor): The grid height and width. + + Returns: + torch.Tensor: The output tokens. + """ + hidden_states = self.patch_embed(pixel_values, grid_hw) + hidden_states = self.encoder(hidden_states, grid_hw) + hidden_states = patch_merger(hidden_states, + grid_hw, + merge_kernel_size=self.merge_kernel_size) + return hidden_states diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index b30f3ee37997f3df656edbc755b37759a9cc1a8e..77bd794058cdad76898882c5a4004716a9ec05db 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -18,7 +18,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -298,7 +297,6 @@ class MPTForCausalLM(nn.Module, SupportsPP): prefix=maybe_prefix(prefix, "transformer")) self.lm_head = self.transformer.wte self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) @@ -325,14 +323,6 @@ class MPTForCausalLM(nn.Module, SupportsPP): sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index 0ea296b2f93d1c642ec981ff6d0de75735271691..5208c0796c8d2760518c7e39ca6a29b7187a6da7 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -38,7 +38,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -416,8 +415,6 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): else: self.lm_head = PPMissingLayer() - self.sampler = get_sampler() - self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -444,14 +441,6 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ diff --git a/vllm/model_executor/models/nemotron_nas.py b/vllm/model_executor/models/nemotron_nas.py index 5c9b04cab180ac1c6e36038f54fc607b51e0c1c1..2649994968765ce819134247727e2a666d821c2b 100644 --- a/vllm/model_executor/models/nemotron_nas.py +++ b/vllm/model_executor/models/nemotron_nas.py @@ -34,7 +34,6 @@ from vllm.distributed import get_pp_group from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -408,8 +407,6 @@ class DeciLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, HasNoOps): else: self.lm_head = PPMissingLayer() - self.sampler = get_sampler() - self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -439,11 +436,6 @@ class DeciLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, HasNoOps): sampling_metadata) return logits - def sample(self, logits: torch.Tensor, - sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 4a341c97d6cdf3df5dc7cefdd9f1c0febeca7261..0781ca168f84085401ab90090f6b4a4bc379d1bf 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -39,7 +39,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -309,7 +308,6 @@ class OlmoForCausalLM(nn.Module, SupportsPP): quant_config=quant_config, ) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -340,14 +338,6 @@ class OlmoForCausalLM(nn.Module, SupportsPP): sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index f9427cdadf7a281989b76fe0d9a3f340031ef507..44beae5726dc09971a02a259b1a35cb6aac47c9f 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -28,6 +28,7 @@ from typing import Iterable, Optional, Tuple, Union import torch from torch import nn +from transformers import Olmo2Config from vllm.attention import Attention from vllm.config import VllmConfig @@ -42,7 +43,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -52,7 +52,6 @@ from vllm.model_executor.models.utils import ( make_layers, maybe_prefix) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.configs.olmo2 import Olmo2Config class Olmo2Attention(nn.Module): @@ -339,7 +338,6 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP): prefix=maybe_prefix(prefix, "lm_head"), ) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -367,14 +365,6 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP): sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index 6cf3f1f82645cc02e3e817956aa11bbe6223f0fb..9bed29d0132f2ccc716c9cf7b7532d1e36a88479 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -31,7 +31,6 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -39,7 +38,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -255,7 +254,7 @@ class OlmoeModel(nn.Module): quant_config = vllm_config.quant_config self.vocab_size = config.vocab_size - + self.config = config self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, @@ -308,56 +307,6 @@ class OlmoeModel(nn.Module): hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - -class OlmoeForCausalLM(nn.Module, SupportsPP): - - fall_back_to_pt_during_load = False - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.config = config - self.quant_config = quant_config - self.model = OlmoeModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) - self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() - - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) - return hidden_states - - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ @@ -380,8 +329,6 @@ class OlmoeForCausalLM(nn.Module, SupportsPP): params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue for (param_name, weight_name, shard_id) in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: @@ -453,3 +400,50 @@ class OlmoeForCausalLM(nn.Module, SupportsPP): weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params + + +class OlmoeForCausalLM(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = OlmoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + self.logits_processor = LogitsProcessor(config.vocab_size) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=["rotary_emb.inv_freq"], + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 4a12f36d90e84f59f4473e5007da586dde6ae262..d258eddae25d4de0e7258b20df0f8a8e9de694f9 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -35,7 +35,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -43,7 +42,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -313,6 +312,43 @@ class OPTModel(nn.Module): intermediate_tensors, inputs_embeds=inputs_embeds) + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class OPTForCausalLM(nn.Module, SupportsPP): packed_modules_mapping = { @@ -320,6 +356,10 @@ class OPTForCausalLM(nn.Module, SupportsPP): "gate_up_proj": ["gate_proj", "up_proj"] } + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ + "decoder.": "model.decoder.", + }) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -334,7 +374,6 @@ class OPTForCausalLM(nn.Module, SupportsPP): self.lm_head = ParallelLMHead(config.vocab_size, config.word_embed_proj_dim) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -361,52 +400,11 @@ class OPTForCausalLM(nn.Module, SupportsPP): sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "lm_head.weight" in name and self.config.tie_word_embeddings: - continue - if name.startswith("decoder."): - name = "model." + name - - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head.weight"] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 0b42666e02d61f9b6dd44dc8e92c3afe67777f48..8d9c000750d78dd55b8639279a3e44dcbeed33f1 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -22,7 +22,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -30,7 +29,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -260,6 +259,45 @@ class OrionModel(nn.Module): hidden_states = self.norm(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class OrionForCausalLM(nn.Module, SupportsPP): @@ -277,7 +315,6 @@ class OrionForCausalLM(nn.Module, SupportsPP): if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -304,56 +341,16 @@ class OrionForCausalLM(nn.Module, SupportsPP): sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + loader = AutoWeightsLoader( + self, + skip_prefixes=([ + "rotary_emb.inv_freq", # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + "rotary_emb.cos_cached", + "rotary_emb.sin_cached" + ]), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 6c1bd499f63989e865cd556fba7568d488d854af..8699ae52622d57c78511186786880757e11df3c6 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -8,7 +8,6 @@ from transformers import BatchFeature, PaliGemmaConfig from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, @@ -260,10 +259,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - @property - def sampler(self): - return self.language_model.sampler - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size expected_dims = (3, h, w) @@ -369,7 +364,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, - **kwargs: object) -> Union[SamplerOutput, IntermediateTensors]: + **kwargs: object) -> IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None @@ -396,13 +391,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index db8d170a8c91b12ac89bd456b4966e9018c1cb0e..eacf02433b573b71c2674765c1c3c38364281136 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -38,7 +38,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -46,7 +45,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -116,9 +115,10 @@ class PersimmonAttention(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=int(self.partial_rotary_factor * self.head_dim), + rotary_dim=self.head_dim, max_position=self.max_position_embeddings, base=self.rope_theta, + partial_rotary_factor=self.partial_rotary_factor, ) self.scaling = self.head_dim**-0.5 self.attn = Attention(self.num_heads, @@ -221,7 +221,7 @@ class PersimmonModel(nn.Module): quant_config = vllm_config.quant_config self.vocab_size = config.vocab_size - + self.config = config self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.start_layer, self.end_layer, self.layers = make_layers( @@ -260,6 +260,38 @@ class PersimmonModel(nn.Module): hidden_states = self.final_layernorm(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + + if "query_key_value" in name: + # copy from vllm/model_executor/models/bloom.py + # NOTE: Persimmon's fused QKV's output_dim has the shape of + # (num_heads * 3 * head_size), while the + # required shape is (3 * num_heads * head_size). + # Thus, we need weight conversion. + output_dim = getattr(param, "output_dim", None) + num_heads = self.config.num_attention_heads + if output_dim is not None: + loaded_weight_shape = loaded_weight.shape + loaded_weight = loaded_weight.view( + loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + + loaded_weight_shape[output_dim + 1:]) + loaded_weight = loaded_weight.transpose( + output_dim, output_dim + 1) + loaded_weight = loaded_weight.reshape(loaded_weight_shape) + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class PersimmonForCausalLM(nn.Module, SupportsPP): @@ -274,7 +306,6 @@ class PersimmonForCausalLM(nn.Module, SupportsPP): config.hidden_size, bias=False) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -305,49 +336,7 @@ class PersimmonForCausalLM(nn.Module, SupportsPP): sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - - if "query_key_value" in name: - # copy from vllm/model_executor/models/bloom.py - # NOTE: Persimmon's fused QKV's output_dim has the shape of - # (num_heads * 3 * head_size), while the - # required shape is (3 * num_heads * head_size). - # Thus, we need weight conversion. - output_dim = getattr(param, "output_dim", None) - num_heads = self.config.num_attention_heads - if output_dim is not None: - loaded_weight_shape = loaded_weight.shape - loaded_weight = loaded_weight.view( - loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + - loaded_weight_shape[output_dim + 1:]) - loaded_weight = loaded_weight.transpose( - output_dim, output_dim + 1) - loaded_weight = loaded_weight.reshape(loaded_weight_shape) - - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index fdf7734595a5415007780eb14938a0accc92b1e1..fc2b108bad97be0c5ec4fdc11b8b7299d21b3359 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -53,7 +53,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -322,7 +321,6 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP): bias=True, quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -350,14 +348,6 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP): sampling_metadata, self.lm_head.bias) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index 33984f54ae27143c5d780b5b3c51d79f9a54b8a7..338e87b4285fbeb126a49dde3e075df88c5fedca 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -17,7 +17,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -26,7 +25,7 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -353,10 +352,29 @@ class Phi3SmallModel(nn.Module): hidden_states = self.final_layernorm(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class Phi3SmallForCausalLM(nn.Module, SupportsPP): _tied_weights_keys = ["lm_head.weight"] + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_suffix={"rotary_emb.inv_freq": None}) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -377,7 +395,6 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP): if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -418,6 +435,7 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP): sampling_metadata) if self.dummy_token_indices is not None and logits is not None: logits.index_fill_(-1, self.dummy_token_indices, -torch.inf) + logits = logits / self.mup_width_multiplier return logits def forward( @@ -436,33 +454,10 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP): output_hidden_states = output_hidden_states return output_hidden_states - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - - next_tokens = self.sampler(logits / self.mup_width_multiplier, - sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - - params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - if "lm_head.weight" in name and self.config.tie_word_embeddings: - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head.weight"] + if self.config.tie_word_embeddings else None)) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 7f41ad2359df631fe002bbfb709d0c432baf0f10..a1442251b99284b2f7d252fbbe5fbc65e17ea6b9 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -16,7 +16,6 @@ # limitations under the License. import re from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property from typing import Any, List, Literal, Optional, Set, Tuple, TypedDict, Union import torch @@ -27,7 +26,6 @@ from transformers import (BatchFeature, CLIPVisionConfig, PretrainedConfig, from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -327,7 +325,7 @@ class Phi3VProcessingInfo(BaseProcessingInfo): *, image_width: int, image_height: int, - processor: Optional[ProcessorMixin], + processor: Optional[ProcessorMixin] = None, ) -> int: if processor is None: processor = self.get_hf_processor() @@ -555,13 +553,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: expected_dims = (2, ) @@ -716,13 +707,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index ec19797f887540e304954e7a1b9a095bb7adc311..6035994f433646ed816a96bc5f44b2c3c62cf9b6 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -1,41 +1,41 @@ # SPDX-License-Identifier: Apache-2.0 import math -import re -from functools import lru_cache -from typing import (Dict, Iterable, List, Literal, Mapping, Optional, Tuple, - TypedDict, Union) +from collections.abc import Iterable, Mapping, Sequence +from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union import numpy as np -import scipy.signal import torch import torch.nn as nn -import torchvision.transforms as T -from PIL import Image -from transformers import PretrainedConfig, SiglipVisionConfig -from transformers.utils import logging +from transformers import (BatchFeature, PretrainedConfig, ProcessorMixin, + SequenceFeatureExtractor, SiglipVisionConfig) from vllm.config import VllmConfig from vllm.distributed import get_pp_group -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - InputContext) -from vllm.inputs.data import TokenInputs, token_inputs from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors -from vllm.sequence import IntermediateTensors, SequenceData -from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs, NestedTensors) +from vllm.multimodal.parse import (AudioProcessorItems, ImageEmbeddingItems, + ImageProcessorItems, ImageSize, + MultiModalDataItems, MultiModalDataParser) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.utils import is_list_of from .idefics2_vision_model import Idefics2VisionTransformer -from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsV0Only +from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal from .phi4mm_audio import AudioEmbedding -from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix, + merge_multimodal_embeddings) # <|endoftext10|> (see vocab.json in hf model) _IMAGE_PLACEHOLDER_TOKEN_ID = 200010 @@ -43,115 +43,19 @@ _IMAGE_PLACEHOLDER_TOKEN_ID = 200010 _AUDIO_PLACEHOLDER_TOKEN_ID = 200011 _AUDIO_MAX_SOUNDFILE_SIZE = 241_000 -DUMMY_SAMPLING_FREQUENCY = 16_000 # kHz - -DYNAMIC_HD = 16 -AUDIO_TOKEN_PATTERN = r"<\|audio_(\d+)\|>" -IMAGE_TOKEN_PATTERN = r"<\|image_(\d+)\|>" SIGLIP_NAME = "siglip-so400m-patch14-448" VISION_ENCODER_TO_PROCESSING_CONFIG = { 'siglip-so400m-patch14-448': { - 'dynamic_hd': 16, 'vit_image_size': 448, 'vit_patch_size': 14, 'token_compression_factor': 2, }, } -logger = logging.get_logger(__name__) -# This is a workaround to prevent text (user input) + audio + image -# from being used in the same prompt. -# It includes token ids for "/n" and tokens in added_tokens_decoder -# from the tokenizer_confg.json file. -NON_USER_INPUT_TOKENS = { - 198, 200010, 200011, 199999, 200018, 200019, 200020, 200021, 200022, - 200023, 200024, 200025, 200026, 200027, 200028 -} -def get_max_dummy_image(ctx: InputContext): - hf_config = ctx.get_hf_config() - vision_encoder_name = hf_config.img_processor - if vision_encoder_name is None: - vision_encoder_name = SIGLIP_NAME - prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name] - dynamic_hd_size = prepro_config['dynamic_hd'] - vit_image_size = prepro_config['vit_image_size'] - - max_side = vit_image_size * dynamic_hd_size - dummy_image = dummy_image_for_phi4mm(vit_image_size, max_side) - return dummy_image - - -# image token length -def get_max_phi4mm_image_tokens(ctx: InputContext): - dummy_image = get_max_dummy_image(ctx) - - hf_config = ctx.get_hf_config() - vision_encoder_name = hf_config.img_processor - if vision_encoder_name is None: - vision_encoder_name = SIGLIP_NAME - prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name] - dynamic_hd_size = prepro_config['dynamic_hd'] - vit_image_size = prepro_config['vit_image_size'] - vit_patch_size = prepro_config['vit_patch_size'] - token_compression_factor = prepro_config['token_compression_factor'] - - image_num_tokens = _compute_num_image_tokens(dummy_image, dynamic_hd_size, - vit_image_size, - vit_patch_size, - token_compression_factor) - return image_num_tokens - - -def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, - image_size): - best_ratio_diff = float('inf') - best_ratio = (1, 1) - area = width * height - for ratio in target_ratios: - target_aspect_ratio = ratio[0] / ratio[1] - ratio_diff = abs(aspect_ratio - target_aspect_ratio) - if ratio_diff < best_ratio_diff: - best_ratio_diff = ratio_diff - best_ratio = ratio - elif ratio_diff == best_ratio_diff: - if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: - best_ratio = ratio - return best_ratio - - -def _find_target_aspect_ratio(image, image_size, max_num, min_num): - orig_width, orig_height = image.size - - w_crop_num = math.ceil(orig_width / float(image_size)) - h_crop_num = math.ceil(orig_height / float(image_size)) - if w_crop_num * h_crop_num > max_num: - aspect_ratio = orig_width / orig_height - - # calculate the existing image aspect ratio - target_ratios = set((i, j) for i in range(1, max_num + 1) - for j in range(1, max_num + 1) - if i * j <= max_num and i * j >= min_num) - target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) - - # find the closest aspect ratio to the target - target_aspect_ratio = find_closest_aspect_ratio( - aspect_ratio, target_ratios, orig_width, orig_height, image_size) - - # calculate the target width and height - target_width = image_size * target_aspect_ratio[0] - target_height = image_size * target_aspect_ratio[1] - logger.debug("target_aspect_ratio: %s", target_aspect_ratio) - else: - target_width = image_size * w_crop_num - target_height = image_size * h_crop_num - target_aspect_ratio = (w_crop_num, h_crop_num) - return target_aspect_ratio, target_height, target_width - - -def _get_padding_size(image, target_height, target_width): - orig_width, orig_height = image.size +def _get_padding_size(orig_width: int, orig_height: int, target_height: int, + target_width: int): ratio_width = target_width / orig_width ratio_height = target_height / orig_height @@ -164,181 +68,6 @@ def _get_padding_size(image, target_height, target_width): return padding_height, padding_width -def dynamic_preprocess(image, - min_num=1, - max_num=12, - image_size=384, - mask_size=27): - target_aspect_ratio, target_height, target_width =\ - _find_target_aspect_ratio( - image, image_size, max_num, min_num) - padding_height, padding_width = _get_padding_size(image, target_height, - target_width) - - # Calculate the ratio - orig_width, orig_height = image.size - ratio_width = target_width / orig_width - ratio_height = target_height / orig_height - if ratio_width < ratio_height: - new_size = (target_width, int(orig_height * ratio_width)) - else: - new_size = (int(orig_width * ratio_height), target_height) - - attention_mask = torch.ones((int(mask_size * target_aspect_ratio[1]), - int(mask_size * target_aspect_ratio[0]))) - if padding_width >= 14: - attention_mask[:, -math.floor(padding_width / 14):] = 0 - if padding_height >= 14: - attention_mask[-math.floor(padding_height / 14):, :] = 0 - assert attention_mask.sum( - ) > 0, f'attention mask is empty {attention_mask}' - - if min(new_size[1], target_height) < 10 or min(new_size[0], - target_width) < 10: - raise ValueError(f'the aspect ratio is very extreme {new_size}') - - image = T.functional.resize( - image, - [new_size[1], new_size[0]], - ) - - resized_img = T.functional.pad(image, - [0, 0, padding_width, padding_height], - fill=[255, 255, 255]) - - return resized_img, attention_mask - - -def pad_to_max_num_crops(images, max_crops=5): - """ - images: B x 3 x H x W, B<=max_crops - """ - B, _, H, W = images.shape - if max_crops > B: - pad = torch.zeros(max_crops - B, - 3, - H, - W, - dtype=images.dtype, - device=images.device) - images = torch.cat([images, pad], dim=0) - return images - - -def pad_mask_to_max_num_crops(masks, max_crops=5): - B, H, W = masks.shape - if max_crops > B: - pad = torch.ones(max_crops - B, - H, - W, - dtype=masks.dtype, - device=masks.device) - masks = torch.cat([masks, pad], dim=0) - return masks - - -def preprocess(images, dynamic_hd_size, vit_resolution, vit_patch_size): - - # Basic settings. - img_processor = T.Compose([ - T.ToTensor(), - T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), - ]) - # Dynamic HD - base_resolution = vit_resolution - images = [image.convert('RGB') for image in images] - # cover 384 and 448 resolution - mask_resolution = base_resolution // vit_patch_size - elems, image_attention_masks = [], [] - for im in images: - elem, attention_mask = dynamic_preprocess(im, - max_num=dynamic_hd_size, - image_size=base_resolution, - mask_size=mask_resolution) - elems.append(elem) - image_attention_masks.append(attention_mask) - hd_images = [img_processor(im) for im in elems] - global_image = [ - torch.nn.functional.interpolate( - im.unsqueeze(0).float(), - size=(base_resolution, base_resolution), - mode='bicubic', - ).to(im.dtype) for im in hd_images - ] - shapes = [[im.size(1), im.size(2)] for im in hd_images] - mask_shapes = [[mask.size(0), mask.size(1)] - for mask in image_attention_masks] - global_attention_mask = [ - torch.ones((1, mask_resolution, mask_resolution)) for _ in hd_images - ] - hd_images_reshape = [ - im.reshape(1, 3, h // base_resolution, base_resolution, - w // base_resolution, base_resolution).permute( - 0, 2, 4, 1, 3, 5).reshape(-1, 3, base_resolution, - base_resolution).contiguous() - for im, (h, w) in zip(hd_images, shapes) - ] - attention_masks_reshape = [ - mask.reshape(1, h // mask_resolution, mask_resolution, - w // mask_resolution, mask_resolution).permute( - 0, 1, 3, 2, 4).reshape(-1, mask_resolution, - mask_resolution).contiguous() - for mask, (h, w) in zip(image_attention_masks, mask_shapes) - ] - # NOTE token compression is hard coded here, and odd numbers seems to fail - downsample_attention_masks = [ - mask[:, 0::2, - 0::2].reshape(1, h // mask_resolution, w // mask_resolution, - mask_resolution // 2 + mask_resolution % 2, - mask_resolution // 2 + mask_resolution % 2).permute( - 0, 1, 3, 2, 4) - for mask, (h, w) in zip(attention_masks_reshape, mask_shapes) - ] - downsample_attention_masks = [ - mask.reshape(mask.size(1) * mask.size(2), - mask.size(3) * mask.size(4)) - for mask in downsample_attention_masks - ] - # NOTE hard coded number of tokens - num_img_tokens = [ - 256 + 1 + int(mask.sum().item()) + int(mask[:, 0].sum().item()) + 16 - for mask in downsample_attention_masks - ] - - hd_images_reshape = [ - torch.cat([_global_image] + [_im], dim=0) - for _global_image, _im in zip(global_image, hd_images_reshape) - ] - hd_masks_reshape = [ - torch.cat([_global_mask] + [_mask], - dim=0) for _global_mask, _mask in zip( - global_attention_mask, attention_masks_reshape) - ] - max_crops = max([img.size(0) for img in hd_images_reshape]) - image_transformed = [ - pad_to_max_num_crops(im, max_crops) for im in hd_images_reshape - ] - image_transformed = torch.stack(image_transformed, dim=0) - mask_transformed = [ - pad_mask_to_max_num_crops(mask, max_crops) \ - for mask in hd_masks_reshape - ] - mask_transformed = torch.stack(mask_transformed, dim=0) - - returned_input_image_embeds = image_transformed - returned_image_sizes = torch.tensor(shapes, dtype=torch.long) - returned_image_attention_mask = mask_transformed - returned_num_img_tokens = num_img_tokens - - data = { - "pixel_values": returned_input_image_embeds, - "image_sizes": returned_image_sizes, - "image_attention_mask": returned_image_attention_mask, - "num_img_tokens": returned_num_img_tokens, - } - return data - - def get_navit_vision_model(layer_idx: int = -1, **kwargs): vision_config = { "hidden_size": 1152, @@ -492,7 +221,7 @@ class Phi4MMImageEncoder(nn.Module): def forward(self, pixel_values: torch.FloatTensor, image_sizes: torch.Tensor, - image_attention_mask: torch.Tensor) -> torch.FloatTensor: + image_attention_mask: torch.Tensor) -> list[torch.FloatTensor]: """ process image and return vision embeddings. @@ -656,785 +385,505 @@ class Phi4MMImageEncoder(nn.Module): for _output_img in output_imgs: img_feature_proj = self.img_projection( _output_img.to(target_device).to(target_dtype)) - img_set_tensor.append(img_feature_proj) + img_set_tensor.append(img_feature_proj.squeeze(0)) return img_set_tensor -class Phi4MMAudioFeatureInputs(TypedDict): - type: Literal["audio_features"] - data: Tuple[NestedTensors] - """Shape: `((batch_size, num_audios, 80, M), )""" - - -class Phi4MMAudioEmbeddingInputs(TypedDict): - type: Literal["audio_embeds"] - data: NestedTensors - """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)""" - - -Phi4MMAudioInputs = Union[Phi4MMAudioFeatureInputs, Phi4MMAudioEmbeddingInputs] - - -def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None): - """Create a Mel filter-bank the same as SpeechLib FbankFC. - - Args: - sample_rate (int): Sample rate in Hz. number > 0 [scalar] - n_fft (int): FFT size. int > 0 [scalar] - n_mel (int): Mel filter size. int > 0 [scalar] - fmin (float): lowest frequency (in Hz). If None use 0.0. - float >= 0 [scalar] - fmax: highest frequency (in Hz). If None use sample_rate / 2. - float >= 0 [scalar] - - Returns - out (numpy.ndarray): Mel transform matrix - [shape=(n_mels, 1 + n_fft/2)] +class Phi4MMImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: Union[torch.Tensor, List[torch.Tensor]] """ + Shape: + `(batch_size * num_images, 1 + num_patches, num_channels, height, width)` - bank_width = int(n_fft // 2 + 1) - if fmax is None: - fmax = sample_rate / 2 - if fmin is None: - fmin = 0 - assert fmin >= 0, "fmin cannot be negative" - assert (fmin < fmax <= - sample_rate / 2), "fmax must be between (fmin, samplerate / 2]" - - def mel(f): - return 1127.0 * np.log(1.0 + f / 700.0) - - def bin2mel(fft_bin): - return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0)) - - def f2bin(f): - return int((f * n_fft / sample_rate) + 0.5) - - # Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1] - klo = f2bin(fmin) + 1 - khi = f2bin(fmax) - - khi = max(khi, klo) - - # Spec 2: SpeechLib uses triangles in Mel space - mlo = mel(fmin) - mhi = mel(fmax) - m_centers = np.linspace(mlo, mhi, n_mels + 2) - ms = (mhi - mlo) / (n_mels + 1) - - matrix = np.zeros((n_mels, bank_width), dtype=np.float32) - for m in range(0, n_mels): - left = m_centers[m] - center = m_centers[m + 1] - right = m_centers[m + 2] - for fft_bin in range(klo, khi): - mbin = bin2mel(fft_bin) - if left < mbin < right: - matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms - - return matrix - - -class LogFbankProcessor: - - def __init__(self): - - self._eightk_method = "fillzero" - self._mel = speechlib_mel(16000, 512, 80, fmin=None, fmax=7690).T - - self._hamming400 = np.hamming(400) # for 16k audio - self._hamming200 = np.hamming(200) # for 8k audio + Note that `num_patches` may be different per batch and image, + in which case the data is passed as a list instead of a batched tensor. + """ - def extract_spectrogram(self, wav, fs): - """Extract spectrogram features from waveform. - Args: - wav (1D array): waveform of the input - fs (int): sampling rate of the waveform, 16000 or 8000. - If fs=8000, the waveform will be resampled to 16000Hz. - Output: - log_fbank (2D array): a TxD matrix of log Mel filterbank features. - D=80, and T is the number of frames. - """ - if wav.ndim > 1: - wav = np.squeeze(wav) + image_sizes: torch.Tensor + """ + Shape: `(batch_size * num_images, 2)` - # by default, we extract the mean if stereo - if len(wav.shape) == 2: - wav = wav.mean(1) + This should be in `(height, width)` format. + """ - # Resample to 16000 or 8000 if needed - if fs > 16000: - wav = scipy.signal.resample_poly(wav, 1, fs // 16000) - fs = 16000 - elif 8000 < fs < 16000: - wav = scipy.signal.resample_poly(wav, 1, fs // 8000) - fs = 8000 - elif fs < 8000: - raise RuntimeError(f"Unsupported sample rate {fs}") - - if fs == 8000: - if self._eightk_method == "resample": - # Input audio is 8 kHz. Convert to 16 kHz before feature - # extraction - wav = scipy.signal.resample_poly(wav, 2, 1) - fs = 16000 - # Do nothing here for fillzero method - elif fs != 16000: - # Input audio is not a supported sample rate. - raise RuntimeError( - f"Input data using an unsupported sample rate: {fs}") - - preemphasis = 0.97 - - if fs == 8000: - n_fft = 256 - win_length = 200 - hop_length = 80 - fft_window = self._hamming200 - elif fs == 16000: - n_fft = 512 - win_length = 400 - hop_length = 160 - fft_window = self._hamming400 - - # Spec 1: SpeechLib cut remaining sample insufficient for a hop - n_batch = (wav.shape[0] - win_length) // hop_length + 1 - # Here we don't use stride_tricks since the input array may not satisfy - # memory layout requirement and we need writeable output - # Here we only use list of views before copy to destination - # so it is more efficient than broadcasting - y_frames = np.array( - [ - wav[_stride:_stride + win_length] - for _stride in range(0, hop_length * n_batch, hop_length) - ], - dtype=np.float32, - ) + num_img_tokens: list[int] + """Shape: `(batch_size * num_images)`""" - # Spec 2: SpeechLib applies preemphasis within each batch - y_frames_prev = np.roll(y_frames, 1, axis=1) - y_frames_prev[:, 0] = y_frames_prev[:, 1] - y_frames = (y_frames - preemphasis * y_frames_prev) * 32768 + image_attention_mask: torch.Tensor + """Shape: `(batch_size * num_images, H_mask, W_mask)`""" - S = np.fft.rfft(fft_window * y_frames, n=n_fft, - axis=1).astype(np.complex64) - if fs == 8000: - # Need to pad the output to look like 16 kHz data but with zeros in - # the 4 to 8 kHz bins. - frames, bins = S.shape - padarray = np.zeros((frames, bins)) - S = np.concatenate((S[:, 0:-1], padarray), - axis=1) # Nyquist bin gets set to zero +class Phi4MMImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: Union[torch.Tensor, List[torch.Tensor]] + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` - spec = np.abs(S).astype(np.float32) - return spec + `hidden_size` must match the hidden size of language model backbone. + """ - def extract_features(self, wav, fs): - """Extract log filterbank features from waveform. - Args: - wav (1D array): waveform of the input - fs (int): sampling rate of the waveform, 16000 or 8000. - If fs=8000, the waveform will be resampled to 16000Hz. - Output: - log_fbank (2D array): a TxD matrix of log Mel filterbank features. - D=80, and T is the number of frames. - """ - spec = self.extract_spectrogram(wav, fs) - spec_power = spec**2 - fbank_power = np.clip(spec_power.dot(self._mel), 1.0, None) - log_fbank = np.log(fbank_power).astype(np.float32) +class Phi4MMAudioFeatureInputs(TypedDict): + type: Literal["audio_features"] + data: Union[torch.Tensor, List[torch.Tensor]] + """Shape: `(batch_size * num_audios, 80, M)""" - return log_fbank +class Phi4MMAudioEmbeddingInputs(TypedDict): + type: Literal["audio_embeds"] + data: NestedTensors + """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)""" -@lru_cache -def audio_feature_extractor() -> LogFbankProcessor: - # Creates an instance of the audio processor, needed to extract the - # the audio features from the sound file - # LRU cache ensures that we only make one copy - return LogFbankProcessor() +Phi4MMImageInput = Union[Phi4MMImagePixelInputs, Phi4MMImageEmbeddingInputs] +Phi4MMAudioInputs = Union[Phi4MMAudioFeatureInputs, Phi4MMAudioEmbeddingInputs] -def _compute_num_image_tokens(image, dynamic_hd_size, vit_image_size, - vit_patch_size, token_compression_factor): - """ - compute the number of tokens an image is expected to take up considering - the image encoder architecture and exclude output features containing - only padding pixels - for siglip, vit_image_size=448, vit_patch_size=14, so output will be - 32x32 feature map - NOTE right now, Phi4MM uses hard-coded token_compression_factor=2 - """ - assert vit_image_size % vit_patch_size == 0, \ - "vit_image_size must be divisible by vit_patch_size" - assert vit_image_size // vit_patch_size % token_compression_factor == 0, \ - "vit_image_size // vit_patch_size must be divisible by "\ - "token_compression_factor" - - target_aspect_ratio, target_height, target_width = ( - _find_target_aspect_ratio(image, - vit_image_size, - dynamic_hd_size, - min_num=1)) - assert target_aspect_ratio[ - 0] * vit_image_size == target_width, \ - f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}" - assert target_aspect_ratio[ - 1] * vit_image_size == target_height, \ - f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}" - assert (target_height % vit_image_size == 0 - and target_width % vit_image_size == 0) - - padding_height, padding_width = _get_padding_size(image, target_height, - target_width) - assert padding_width == 0 or padding_height == 0, \ - "padding_width or padding_height must be 0" - - target_feat_width = target_width // vit_patch_size - target_feat_height = target_height // vit_patch_size - if padding_width >= vit_patch_size: - assert padding_height == 0, "padding_height not 0" - non_pad_feat_width = target_feat_width - math.floor( - padding_width / vit_patch_size) - non_pad_feat_height = target_feat_height - elif padding_height >= vit_patch_size: - assert padding_width == 0, "padding_width not 0" - non_pad_feat_height = target_feat_height - math.floor( - padding_height / vit_patch_size) - non_pad_feat_width = target_feat_width - else: - # small padding shorter than a vit patch - non_pad_feat_width = target_feat_width - non_pad_feat_height = target_feat_height - - feat_width = non_pad_feat_width // token_compression_factor - feat_height = non_pad_feat_height // token_compression_factor - # NOTE it's possible that the non-padding feature is not divisible - if non_pad_feat_width % token_compression_factor != 0: - feat_width += 1 - if non_pad_feat_height % token_compression_factor != 0: - feat_height += 1 - num_hd_patch_tokens = feat_width * feat_height - num_hd_newline_tokens = feat_height - vit_feature_size = vit_image_size // vit_patch_size - num_global_image_tokens = (vit_feature_size // token_compression_factor)**2 - num_sep_tokens = 1 - num_global_image_newline_tokens = \ - vit_feature_size // token_compression_factor - - return (num_global_image_tokens + num_sep_tokens + num_hd_patch_tokens + - num_hd_newline_tokens + num_global_image_newline_tokens) - - -def compute_logfbank_output_size(wav_length: int, fs: int) -> Tuple[int, int]: +def cat_with_pad(tensors, dim, padding_value=0): """ - Compute the output size of the `extract_features` method. - - Args: - wav_length (int): Length of the input waveform in samples. - fs (int): Sampling rate of the waveform, either 16000 or 8000. - - Returns: - tuple (int, int): Output size as (T, D), where: - T: Number of time frames. - D: Number of Mel filterbank bins (80). + cat along dim, while pad to max for all other dims """ + ndim = tensors[0].dim() + assert all( + t.dim() == ndim for t in + tensors[1:]), "All tensors must have the same number of dimensions" - # Resample to 16000 or 8000 if needed - if fs > 16000: - wav_length //= fs // 16000 - fs = 16000 - elif 8000 <= fs < 16000: - # We'll resample to 16K from 8K - wav_length *= 2 - fs = 16000 - elif fs < 8000: - raise RuntimeError(f"Unsupported sample rate {fs}") - - # Spectrogram parameters for 16 kHz - win_length = 400 # Frame length in samples - hop_length = 160 # Frame shift in samples - mel_bins = 80 # Number of mel filterbank bins - - # Calculate number of frames (T) - T = (wav_length - win_length) // hop_length + 1 - if T < 1: - raise ValueError("Waveform too short for given parameters.") - - # Return time frames (T) and mel bins (D) - return T, mel_bins - + out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)] + out_size[dim] = sum(t.shape[dim] for t in tensors) + output = tensors[0].new_full(out_size, padding_value) -def _get_audio_embed_sizes(audios, ctx: InputContext): - """ - Get the audio embedding sizes for each audio file. + index = 0 + for t in tensors: + # Create a slice list where every dimension except dim is full slice + slices = [slice(0, t.shape[d]) for d in range(ndim)] + # Update only the concat dimension slice + slices[dim] = slice(index, index + t.shape[dim]) - Args: - audios (List[Tuple[np.ndarray, int]]): List of audio files as tuples of - waveform and sample rate. - ctx (InputContext): Input context. + output[slices] = t + index += t.shape[dim] - Returns: - List[int]: List of audio embedding sizes. - """ - audio_embed_sizes = [] - for audio in audios: - audio_data, sf = audio - audio_frames, _ = compute_logfbank_output_size(len(audio_data), sf) - audio_embed_size = _compute_audio_embed_size(ctx.get_hf_config(), - audio_frames) - audio_embed_sizes.append(audio_embed_size) - return audio_embed_sizes + return output -def _get_audio_id_to_input_ids(audios, ctx: InputContext, prompt_str=""): - """ - The following will search for `<|audio_{idx}|>` tokens and - return a mapping of audio placeholder tokens to audio placeholder token ids - based on the size of the audio embeddings. +class Phi4MMProcessingInfo(BaseProcessingInfo): - Args: - audios (List[Tuple[np.ndarray, int]]): List of audio files as tuples of - waveform and sample rate. - ctx (InputContext): Input context. - prompt_str (str): The prompt string. + def get_hf_processor( + self, + *, + dynamic_hd: Optional[int] = None, + **kwargs: object, + ) -> ProcessorMixin: + if dynamic_hd is not None: + kwargs["dynamic_hd"] = dynamic_hd - Returns: - Dict[str, List[int]]: Mapping of audio placeholder tokens to audio - placeholder token ids. + return self.ctx.get_hf_processor(**kwargs) - """ - if len(audios) == 0: - return {} - - audio_embed_sizes = _get_audio_embed_sizes(audios, ctx) - audio_ids = re.findall(AUDIO_TOKEN_PATTERN, prompt_str) - audio_ids = [int(audio_id) for audio_id in audio_ids] - assert len(audio_ids) == len( - audio_embed_sizes - ), "Number of audio tokens and audio features do not match" - assert tuple(audio_ids) == tuple(range(1, - len(audio_ids) + - 1)), "Audio ids are not in order!" - audio_id_to_input_ids = { - f"<|audio_{audio_id}|>": - [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size - for audio_id, audio_embed_size in zip(audio_ids, audio_embed_sizes) - } + @property + def image_tokens(self) -> list[str]: + return [f"<|image_{i+1}|>" for i in range(100)] - return audio_id_to_input_ids - - -def _count_image_tokens(images, ctx: InputContext): - hf_config = ctx.get_hf_config() - vision_encoder_name = hf_config.img_processor - if vision_encoder_name is None: - vision_encoder_name = SIGLIP_NAME - prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name] - dynamic_hd_size = prepro_config['dynamic_hd'] - vit_image_size = prepro_config['vit_image_size'] - vit_patch_size = prepro_config['vit_patch_size'] - token_compression_factor = prepro_config['token_compression_factor'] - - image_token_counts = [ - _compute_num_image_tokens(image, dynamic_hd_size, vit_image_size, - vit_patch_size, token_compression_factor) - for image in images - ] - return image_token_counts - - -def _get_image_id_to_input_ids(images, prompt, ctx: InputContext): - if len(images) == 0: - return {} - - image_ids = re.findall(IMAGE_TOKEN_PATTERN, prompt) - image_ids = [int(image_id) for image_id in image_ids] - assert len(image_ids) == len( - set(image_ids)), "Duplicate image tokens in prompt" - assert len(images) == len( - image_ids), "Number of images and image tokens in prompt do not match" - - # NOTE the following assertion is not strictly necessary - assert tuple(image_ids) == tuple(range(1, - len(image_ids) + - 1)), "Image ids are not in order" - - image_token_counts = _count_image_tokens(images, ctx) - image_id_to_input_ids = { - f"<|image_{image_id}|>": [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_tokens - for image_id, num_tokens in zip(image_ids, image_token_counts) - } - return image_id_to_input_ids + @property + def audio_tokens(self) -> list[str]: + return [f"<|audio_{i+1}|>" for i in range(100)] + def get_dynamic_hd( + self, + processor: Optional[ProcessorMixin] = None, + ) -> int: + if processor is None: + processor = self.get_hf_processor() + image_processor = processor.image_processor + return image_processor.dynamic_hd -def input_processor_for_phi4mm(ctx: InputContext, - inputs: DecoderOnlyInputs) -> TokenInputs: - """ - Implements the input processor, which transforms the input prompt ids - to include the audio placeholder token. This will become the `input_ids` - in `forward` for the model. + def get_feature_extractor(self) -> SequenceFeatureExtractor: + return self.get_hf_processor().audio_processor - Args: - ctx (InputContext): Input context. - inputs (DecoderOnlyInputs): The inputs (e.g. prompt, prompt_token_ids) - to process. + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"audio": None, "image": None} - Returns: - TokenInputs: Processed inputs - """ - multi_modal_data = inputs.get("multi_modal_data") - if (multi_modal_data is None or - ("audio" not in multi_modal_data and "image" not in multi_modal_data)): - # pure text input, so no need to do pre-processing - return inputs - - prompt_str = inputs.get("prompt") - prompt_token_ids = inputs.get("prompt_token_ids") - # for offline_inference, we will get str input and we parse MM special - # tokens from it - # (ignore prompt_token_ids) - # for OAI server, we will get prompt_token_ids, where MM special tokens - # are already parsed - - if 'audio' in multi_modal_data: - audios = multi_modal_data["audio"] - - if not isinstance(audios, list): - audios = [audios] - if prompt_str is not None: - audio_id_to_input_ids = _get_audio_id_to_input_ids( - audios, ctx, prompt_str=prompt_str) - audio_embed_sizes = [] - elif prompt_token_ids is not None: - audio_id_to_input_ids = {} - audio_embed_sizes = _get_audio_embed_sizes(audios, ctx) - else: - audio_id_to_input_ids = {} - audio_embed_sizes = [] - - if 'image' in multi_modal_data: - # PIL Image or list of PIL Images - images = multi_modal_data["image"] - if not isinstance(images, list): - images = [images] - if prompt_str is not None: - image_id_to_input_ids = _get_image_id_to_input_ids( - images, prompt_str, ctx) - image_token_counts = [] - elif prompt_token_ids is not None: - image_id_to_input_ids = {} - image_token_counts = _count_image_tokens(images, ctx) - else: - image_id_to_input_ids = {} - image_token_counts = [] - - # Handle the case where the prompt is a string and we need to manually - # tokenize it. - # In this case, the `audio_id_to_input_ids` dict will be mapping from - # an audio placeholder - # string (e.g. `<|audio_1|>`) to the audio placeholder tokens for the - # given audio length. - if prompt_str: - pattern = r"(<\|image_\d+\|>|<\|audio_\d+\|>)" - prompt_chunk_strings = re.split(pattern, prompt_str) - prompt_chunk_strings = [s for s in prompt_chunk_strings if s != ""] - - # Create the new input_ids with the placeholder image and audio - # tokens inserted - tokenizer = cached_tokenizer_from_config(ctx.model_config) - input_ids = [] - has_imag, has_audio, has_user_text_input = False, False, False - for prompt_chunk_string in prompt_chunk_strings: - if re.match(IMAGE_TOKEN_PATTERN, prompt_chunk_string): - input_ids.extend(image_id_to_input_ids[prompt_chunk_string]) - has_imag = True - elif re.match(AUDIO_TOKEN_PATTERN, prompt_chunk_string): - input_ids.extend(audio_id_to_input_ids[prompt_chunk_string]) - has_audio = True - else: - curr_token_ids = tokenizer(prompt_chunk_string).input_ids - if not has_user_text_input: - for token_id in curr_token_ids: - if token_id not in NON_USER_INPUT_TOKENS: - has_user_text_input = True - break - input_ids.extend(curr_token_ids) - if has_audio and has_imag and has_user_text_input: - raise ValueError( - "Phi4MMForCausalLM does not support text + audio + image" + - " inputs in the same prompt") - # Handle the case where the prompt is already tokenized - else: - assert prompt_token_ids is not None, \ - "If string prompt isn't provided, prompt_token_ids must be" - - i = 0 - input_ids = prompt_token_ids - # only needed for later assertion - img_cnt, audio_cnt, user_text_input_cnt = 0, 0, 0 - image_token_count_iter = iter(image_token_counts) - audio_embed_size_iter = iter(audio_embed_sizes) - while i < len(input_ids): - token_id = input_ids[i] - if token_id == _AUDIO_PLACEHOLDER_TOKEN_ID: - token_count = next(audio_embed_size_iter) - audio_cnt += 1 - elif token_id == _IMAGE_PLACEHOLDER_TOKEN_ID: - token_count = next(image_token_count_iter) - img_cnt += 1 - else: - user_text_input_cnt += 1 if token_id not in \ - NON_USER_INPUT_TOKENS else 0 - i += 1 - continue - tokens = [token_id] * token_count - input_ids = input_ids[:i] + tokens + input_ids[i + 1:] - i += token_count - - if audio_cnt > 0 and img_cnt > 0 and user_text_input_cnt > 0: - raise ValueError( - "Phi4MMForCausalLM does not support text + audio + image" + - " inputs in the same prompt") - # If the below assertion fails, it might be that input pure-text - # messages contain image/audio special tokens literally - # (<|endoftext10|>, <|endoftext11|>). - assert (img_cnt == len(image_token_counts)), ( - f"Number of image tokens in prompt_token_ids ({img_cnt}) " - f"does not match number of images ({len(image_token_counts)})") - assert (audio_cnt == len(audio_embed_sizes)), ( - f"Number of audio tokens in prompt_token_ids ({audio_cnt}) " - f"does not match number of audios ({len(audio_embed_sizes)})") - - # NOTE: Create a defensive copy of the original inputs - return token_inputs( - prompt_token_ids=input_ids, - prompt=prompt_str, - multi_modal_data=multi_modal_data, - ) + def _find_target_aspect_ratio( + self, + orig_width: int, + orig_height: int, + image_size: int, + max_num: int, + min_num: int, + ): + w_crop_num = math.ceil(orig_width / float(image_size)) + h_crop_num = math.ceil(orig_height / float(image_size)) + if w_crop_num * h_crop_num > max_num: + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set((i, j) for i in range(1, max_num + 1) + for j in range(1, max_num + 1) + if i * j <= max_num and i * j >= min_num) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + image_processor = self.get_hf_processor().image_processor + target_aspect_ratio = image_processor.find_closest_aspect_ratio( + aspect_ratio, + target_ratios, + orig_width, + orig_height, + image_size, + ) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + else: + target_width = image_size * w_crop_num + target_height = image_size * h_crop_num + target_aspect_ratio = (w_crop_num, h_crop_num) + return target_aspect_ratio, target_height, target_width + def _compute_num_image_tokens( + self, + orig_width: int, + orig_height: int, + dynamic_hd_size: int, + vit_image_size: int, + vit_patch_size: int, + token_compression_factor: int = 2, + ): + """ + compute the number of tokens an image is expected to take up considering + the image encoder architecture and exclude output features containing + only padding pixels -def _compute_audio_embed_size(hf_config, audio_frames): - """ - Compute the audio embedding size based on the audio frames and - compression rate. - """ - compression_rate = hf_config.embd_layer['audio_embd_layer'][ - 'compression_rate'] - # NOTE: this is a hard-coded value but might be configurable in the future - qformer_compression_rate = 1 - integer = audio_frames // compression_rate - remainder = audio_frames % compression_rate + for siglip, vit_image_size=448, vit_patch_size=14, so output will be + 32x32 feature map + NOTE right now, Phi4MM uses hard-coded token_compression_factor=2 + """ + assert vit_image_size % vit_patch_size == 0, ( + "vit_image_size must be divisible by vit_patch_size") + assert (vit_image_size // vit_patch_size % + token_compression_factor == 0), ( + "vit_image_size // vit_patch_size must be divisible by " + "token_compression_factor") + + target_aspect_ratio, target_height, target_width = ( + self._find_target_aspect_ratio(orig_width, + orig_height, + vit_image_size, + dynamic_hd_size, + min_num=1)) + assert target_aspect_ratio[0] * vit_image_size == target_width, ( + f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}") + assert target_aspect_ratio[1] * vit_image_size == target_height, ( + f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}") + assert (target_height % vit_image_size == 0 + and target_width % vit_image_size == 0) + + padding_height, padding_width = _get_padding_size( + orig_width, orig_height, target_height, target_width) + assert padding_width == 0 or padding_height == 0, \ + "padding_width or padding_height must be 0" + + target_feat_width = target_width // vit_patch_size + target_feat_height = target_height // vit_patch_size + if padding_width >= vit_patch_size: + assert padding_height == 0, "padding_height not 0" + non_pad_feat_width = target_feat_width - math.floor( + padding_width / vit_patch_size) + non_pad_feat_height = target_feat_height + elif padding_height >= vit_patch_size: + assert padding_width == 0, "padding_width not 0" + non_pad_feat_height = target_feat_height - math.floor( + padding_height / vit_patch_size) + non_pad_feat_width = target_feat_width + else: + # small padding shorter than a vit patch + non_pad_feat_width = target_feat_width + non_pad_feat_height = target_feat_height + + feat_width = non_pad_feat_width // token_compression_factor + feat_height = non_pad_feat_height // token_compression_factor + # NOTE it's possible that the non-padding feature is not divisible + if non_pad_feat_width % token_compression_factor != 0: + feat_width += 1 + if non_pad_feat_height % token_compression_factor != 0: + feat_height += 1 + num_hd_patch_tokens = feat_width * feat_height + num_hd_newline_tokens = feat_height + vit_feature_size = vit_image_size // vit_patch_size + num_global_image_tokens = (vit_feature_size // + token_compression_factor)**2 + num_sep_tokens = 1 + num_global_image_newline_tokens = \ + vit_feature_size // token_compression_factor + + return (num_global_image_tokens + num_sep_tokens + + num_hd_patch_tokens + num_hd_newline_tokens + + num_global_image_newline_tokens) + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + processor: Optional[ProcessorMixin] = None, + ) -> int: + hf_config = self.get_hf_config() + vision_encoder_name = hf_config.img_processor + if vision_encoder_name is None: + vision_encoder_name = SIGLIP_NAME + prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[ + vision_encoder_name] + vit_image_size = prepro_config['vit_image_size'] + vit_patch_size = prepro_config['vit_patch_size'] + token_compression_factor = prepro_config['token_compression_factor'] + + dynamic_hd_size = self.get_dynamic_hd(processor=processor) + + image_num_tokens = self._compute_num_image_tokens( + image_width, + image_height, + dynamic_hd_size=dynamic_hd_size, + vit_image_size=vit_image_size, + vit_patch_size=vit_patch_size, + token_compression_factor=token_compression_factor, + ) - result = integer if remainder == 0 else integer + 1 + return image_num_tokens - integer = result // qformer_compression_rate - remainder = result % qformer_compression_rate - result = integer if remainder == 0 else integer + 1 # qformer compression + def get_image_size_with_most_features( + self, + processor: Optional[ProcessorMixin] = None, + ) -> ImageSize: + hf_config = self.get_hf_config() + vision_encoder_name = hf_config.img_processor + if vision_encoder_name is None: + vision_encoder_name = SIGLIP_NAME + prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[ + vision_encoder_name] + vit_image_size = prepro_config['vit_image_size'] + + max_side = vit_image_size * self.get_dynamic_hd(processor=processor) + return ImageSize(height=max_side, width=vit_image_size) + + def get_audio_num_frames(self, audio_len: int, sr: float) -> int: + """ + Compute the output size of the `extract_features` method. - return result + Args: + audio_len (int): Length of the input waveform in samples. + sr (float): Sampling rate of the waveform, either 16000 or 8000. + Returns: + tuple (int, int): Output size as (T, D), where: + T: Number of time frames. + D: Number of Mel filterbank bins (80). + """ -def get_max_phi4mm_audio_tokens(ctx: InputContext) -> int: - return 10000 + # Resample to 16000 or 8000 if needed + if sr > 16000: + audio_len //= sr // 16000 + elif 8000 <= sr < 16000: + # We'll resample to 16K from 8K + audio_len *= 2 + elif sr < 8000: + raise RuntimeError(f"Unsupported sample rate {sr}") + + # Spectrogram parameters for 16 kHz + win_length = 400 # Frame length in samples + hop_length = 160 # Frame shift in samples + + # Calculate number of frames (T) + num_frames = (audio_len - win_length) // hop_length + 1 + if num_frames < 1: + raise ValueError("Waveform too short for given parameters.") + + # Return time frames (T) + return num_frames + + def _compute_audio_embed_size(self, audio_frames: int) -> int: + """ + Compute the audio embedding size based on the audio frames and + compression rate. + """ + hf_config = self.get_hf_config() + compression_rate = hf_config.embd_layer['audio_embd_layer'][ + 'compression_rate'] + # NOTE: this is a hard-coded value but might be configurable + # in the future + qformer_compression_rate = 1 + integer = audio_frames // compression_rate + remainder = audio_frames % compression_rate + result = integer if remainder == 0 else integer + 1 -def dummy_audio_for_phi4mm(audio_count: int) -> dict: - """ - Create dummy audio data for the Phi4MM model, which is used for profiling. + integer = result // qformer_compression_rate + remainder = result % qformer_compression_rate + # qformer compression + result = integer if remainder == 0 else integer + 1 - Args: - audio_count (int): Number of audio samples. + return result - Returns: - dict: Dummy audio data. - """ - dummy_audio = np.full((_AUDIO_MAX_SOUNDFILE_SIZE, ), 0.0) - return [(dummy_audio, DUMMY_SAMPLING_FREQUENCY)] * audio_count +class Phi4MMDummyInputsBuilder(BaseDummyInputsBuilder[Phi4MMProcessingInfo]): -def dummy_image_for_phi4mm(width: int, height: int): - image = Image.new('RGB', (width, height), color='black') - return image + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_audios = mm_counts.get("audio", 0) + num_images = mm_counts.get("image", 0) + image_tokens: list[str] = self.info.image_tokens[:num_images] + audio_tokens: list[str] = self.info.audio_tokens[:num_audios] -def dummy_data_for_phi4mm(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]) -> DummyData: - """ - Create dummy sequence (input_ids) and audio data for the Phi4MM model, - which is used for profiling. + return "".join(image_tokens + audio_tokens) - In this case, the sequence data is a bunch of 0s with a number of audio - tokens that correspond to the audio embed size of the - _AUDIO_MAX_SOUNDFILE_SIZE. + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_audios = mm_counts.get("audio", 0) + num_images = mm_counts.get("image", 0) - Args: - ctx (InputContext): Input context. - seq_len (int): Length of the sequence. - mm_counts (Mapping[str, int]): Multi-modal counts. + target_width, target_height = \ + self.info.get_image_size_with_most_features() - Returns: - Tuple: Dummy sequence data and dummy audio data. - """ - audio_count = mm_counts["audio"] - audio_frames, _ = compute_logfbank_output_size(_AUDIO_MAX_SOUNDFILE_SIZE, - DUMMY_SAMPLING_FREQUENCY) - audio_feature_size = _compute_audio_embed_size(ctx.get_hf_config(), - audio_frames) - - image_count = mm_counts["image"] - dummy_image = get_max_dummy_image(ctx) - max_image_tokens = get_max_phi4mm_image_tokens(ctx) - total_image_tokens = image_count * max_image_tokens - - if seq_len - audio_feature_size * audio_count - total_image_tokens < 0: - raise RuntimeError( - f"Phi4MM cannot process {audio_count} audios and {image_count}" - f"images in a prompt, please increase max_model_len to be at" - f" larger than " - f"{audio_feature_size * audio_count + total_image_tokens}" - " or reduce audio/image limit by --limit-mm-per-prompt.") - - if audio_feature_size * audio_count > total_image_tokens: - seq_data = SequenceData.from_prompt_token_counts( - (_AUDIO_PLACEHOLDER_TOKEN_ID, audio_feature_size * audio_count), - (0, seq_len - audio_feature_size * audio_count), - ) mm_data = { - "audio": dummy_audio_for_phi4mm(audio_count), + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + "audio": + self._get_dummy_audios(length=_AUDIO_MAX_SOUNDFILE_SIZE, + num_audios=num_audios), } - else: - seq_data = SequenceData.from_prompt_token_counts( - (_IMAGE_PLACEHOLDER_TOKEN_ID, total_image_tokens), - (0, seq_len - total_image_tokens), - ) - mm_data = { - "image": [dummy_image] * image_count, - } - return DummyData(seq_data, mm_data) + return mm_data -def input_mapper_for_phi4mm_audio(ctx: InputContext, - data: object) -> MultiModalKwargs: - """ - This function is used to create the MultiModalKwargs for the Phi4MM - (audio) model. - Specifically, for audio, we extract the audio features from the sound - file and create pairs of audio features and audio embed lengths (the - latter of which is used to repeat the audio placeholder token in the - input prompt IDs). - These pairs are used, downstream, in `_audio_features_to_embeddings` - (via `_process_audio_input`). - - Note that the incoming audio data (each entry in `data`) is a tuple of - the audio data and the sampling frequency (e.g. from soundfile.read). - - Args: - ctx (InputContext): Input context. - data (object): Audio data. - - Returns: - MultiModalKwargs: Multi-modal inputs. - """ - if not isinstance(data, list): - data = [data] - - if len(data) == 0: - return MultiModalKwargs() - - audio_features = [] - for audio_input in data: - if not isinstance(audio_input, tuple): - raise NotImplementedError( - f"Unsupported data type: {type(audio_input)}") - - audio, sf = audio_input - feature_extractor = audio_feature_extractor() - single_audio_features = feature_extractor.extract_features(audio, sf) - feat_stride = (1 if not hasattr(feature_extractor, "stride") else - feature_extractor.stride) - audio_frames = len(single_audio_features) * feat_stride - single_audio_embed_size = _compute_audio_embed_size( - ctx.get_hf_config(), audio_frames) - single_audio_feature_audio_len_pair = ( - single_audio_features, - [single_audio_embed_size], - ) - audio_features.append(single_audio_feature_audio_len_pair) - return MultiModalKwargs({"audio_features": audio_features}) - - -def input_mapper_for_phi4mm_image(ctx: InputContext, data: object): - if not isinstance(data, list): - data = [data] - # data: list of PIL images - if len(data) == 0: - return MultiModalKwargs() - hf_config = ctx.get_hf_config() - vision_encoder_name = hf_config.img_processor - if vision_encoder_name is None: - vision_encoder_name = SIGLIP_NAME - prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name] - dynamic_hd_size = prepro_config['dynamic_hd'] - vit_image_size = prepro_config['vit_image_size'] - vit_patch_size = prepro_config['vit_patch_size'] - - image_input_dict = preprocess(data, dynamic_hd_size, vit_image_size, - vit_patch_size) - return MultiModalKwargs({ - "pixel_values": - image_input_dict["pixel_values"], - "image_sizes": - image_input_dict["image_sizes"], - "image_attention_mask": - image_input_dict["image_attention_mask"], - "num_img_tokens": - image_input_dict["num_img_tokens"], - }) +class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): -def cat_with_pad(tensors, dim, padding_value=0): - """ - cat along dim, while pad to max for all other dims - """ - ndim = tensors[0].dim() - assert all( - t.dim() == ndim for t in - tensors[1:]), "All tensors must have the same number of dimensions" + def _get_data_parser(self) -> MultiModalDataParser: + feature_extractor = self.info.get_feature_extractor() + return MultiModalDataParser(target_sr=feature_extractor.sampling_rate, + audio_resample_method="scipy") - out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)] - out_size[dim] = sum(t.shape[dim] for t in tensors) - output = tensors[0].new_full(out_size, padding_value) + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + if not mm_data: + prompt_ids = self.info.get_tokenizer().encode(prompt) + prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + sr = self.info.get_feature_extractor().sampling_rate + if (audio_data := mm_data.get("audios", [])): + mm_data['audios'] = [(data, sr) for data in audio_data] + + processed_outputs = super()._call_hf_processor(prompt, mm_data, + mm_kwargs) + + num_img_tokens = [ + self.info.get_num_image_tokens(image_width=img_size[0], + image_height=img_size[1]) + for img_size in processed_outputs["image_sizes"] + ] + processed_outputs["num_img_tokens"] = num_img_tokens - index = 0 - for t in tensors: - # Create a slice list where every dimension except dim is full slice - slices = [slice(0, t.shape[d]) for d in range(ndim)] - # Update only the concat dimension slice - slices[dim] = slice(index, index + t.shape[dim]) + audio_features = processed_outputs['input_audio_embeds'] + feature_sizes = [ + self.info.get_audio_num_frames(len(audio), sr) + for audio in audio_data + ] + processed_outputs['input_audio_embeds'] = [ + audio_features[idx, :size] + for idx, size in enumerate(feature_sizes) + ] - output[slices] = t - index += t.shape[dim] + return processed_outputs - return output + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + input_image_embeds=MultiModalFieldConfig.batched("image"), + image_attention_mask=MultiModalFieldConfig.batched("image"), + image_sizes=MultiModalFieldConfig.batched("image"), + num_img_tokens=MultiModalFieldConfig.batched("image"), + input_audio_embeds=MultiModalFieldConfig.batched("audio"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + image_tokens: list[str] = self.info.image_tokens # type: ignore + audio_tokens: list[str] = self.info.audio_tokens # type: ignore + feature_extractor = self.info.get_feature_extractor() + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + def get_image_replacement_phi4mm(item_idx: int): + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems)) + + if isinstance(images, ImageEmbeddingItems): + num_image_tokens = images.get_feature_size(item_idx) + else: + image_size = images.get_image_size(item_idx) + num_image_tokens = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + processor=hf_processor, + ) + + image_tokens = [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_image_tokens + + return image_tokens + + def get_audio_replacement_phi4mm(item_idx: int): + audios = mm_items.get_items("audio", AudioProcessorItems) + # TODO(Isotr0py): support embedding inputs + audio_len = audios.get_audio_length(item_idx) + audio_frames = self.info.get_audio_num_frames( + audio_len, feature_extractor.sampling_rate) + audio_embed_size = self.info._compute_audio_embed_size( + audio_frames) + + audio_tokens = [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size + + return audio_tokens + + num_images = mm_items.get_count("image", strict=False) + num_audios = mm_items.get_count("audio", strict=False) + + image_repl = [ + PromptReplacement( + modality="image", + target=image_token, + replacement=get_image_replacement_phi4mm, + ) for image_token in image_tokens[:num_images] + ] + audio_repl = [ + PromptReplacement( + modality="audio", + target=audio_token, + replacement=get_audio_replacement_phi4mm, + ) for audio_token in audio_tokens[:num_audios] + ] + return image_repl + audio_repl -@MULTIMODAL_REGISTRY.register_input_mapper("audio", - input_mapper_for_phi4mm_audio) -@MULTIMODAL_REGISTRY.register_input_mapper("image", - input_mapper_for_phi4mm_image) -@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( - "audio", get_max_phi4mm_audio_tokens) -@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( - "image", get_max_phi4mm_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi4mm) -@INPUT_REGISTRY.register_input_processor(input_processor_for_phi4mm) -class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal, - SupportsV0Only): +@MULTIMODAL_REGISTRY.register_processor( + Phi4MMMultiModalProcessor, + info=Phi4MMProcessingInfo, + dummy_inputs=Phi4MMDummyInputsBuilder, +) +class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): """ Implements the Phi-4-multimodal-instruct model in vLLM. """ @@ -1518,48 +967,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal, logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) - self.sampler = Sampler() - - def _audio_features_to_embeddings( - self, - input_ids: torch.Tensor, - input_features: List[torch.Tensor], - audio_input_sizes: torch.Tensor, - audio_projection_mode: str, - ) -> torch.Tensor: - """ - Convert audio features to embeddings, which are used as input to the - model (via `inputs_embeds`). - - Args: - input_ids (torch.Tensor): Input IDs (the prompt in this case). - input_features (list[torch.Tensor]): Input features (the audio - embeddings). - audio_input_sizes (list[torch.Tensor]): Audio input sizes (the - audio embed lengths to use for padding the audio placeholder token - in the input prompt IDs). - """ - # The audio projection can either be a single linear or Sequential, - # so handle both cases - if isinstance(self.embed_tokens_extend.audio_projection, - nn.Sequential): - target_dtype = self.embed_tokens_extend.audio_projection[ - 0].bias.dtype - else: - target_dtype = self.embed_tokens_extend.audio_projection.bias.dtype - - audio_input = [ - input.unsqueeze(0).to(target_dtype) for input in input_features - ] - kwargs = { - "wte": self.model.embed_tokens, - 'audio_projection_mode': audio_projection_mode - } - audio_embeddings = self.embed_tokens_extend(input_ids, audio_input, - audio_input_sizes, - **kwargs) - audio_embeddings = audio_embeddings.to(target_dtype) - return audio_embeddings def _parse_and_validate_audio_input( self, **kwargs: object) -> Optional[Phi4MMAudioInputs]: @@ -1574,7 +981,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal, Returns: Optional[Phi4MMAudioInputs]: Parsed and validated audio inputs. """ - audio_features = kwargs.pop("audio_features", None) + audio_features = kwargs.pop("input_audio_embeds", None) audio_embeds = kwargs.pop("audio_embeds", None) if audio_features is None and audio_embeds is None: @@ -1586,7 +993,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal, f"Got type: {type(audio_features)}") return Phi4MMAudioFeatureInputs(type="audio_features", - data=audio_features) + data=flatten_bn(audio_features)) if audio_embeds is not None: if not isinstance(audio_embeds, (torch.Tensor, list)): @@ -1598,8 +1005,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal, raise AssertionError("This line should be unreachable.") - def _process_audio_input(self, input_ids: torch.Tensor, - audio_input: Phi4MMAudioInputs, + def _process_audio_input(self, audio_input: Phi4MMAudioInputs, audio_projection_mode: str) -> NestedTensors: """ Create the audio embeddings from the audio input, where the audio input @@ -1607,8 +1013,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal, created by `input_mapper_for_phi4mm_audio`. Args: - input_ids (torch.Tensor): Input IDs (the prompt in this case, - before the audio token replication). audio_input (Phi4MMAudioInputs): Audio input. Returns: @@ -1620,21 +1024,20 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal, audio_features = audio_input["data"] # (e.g. multiple examples) and the second dim is the multi-audio dim # (e.g. multiple audios in the same example) - audio_feature = [i[0] for j in audio_features for i in j] - audio_feature_len = [i[1].item() for j in audio_features for i in j] - # Add the batch dim via `squeeze` - return self._audio_features_to_embeddings( - input_ids.unsqueeze(0), - audio_feature, - audio_feature_len, - audio_projection_mode, - ).squeeze(0) + dtype = next(self.embed_tokens_extend.parameters()).dtype + audio_embeds = [ + self.embed_tokens_extend( + features.to(dtype), + audio_projection_mode=audio_projection_mode, + ) for features in audio_features + ] + return audio_embeds def _parse_and_validate_image_input(self, **kwargs: object) -> Optional[Dict]: - pixel_values: Optional[Dict] = kwargs.get("pixel_values") - if pixel_values is None: + input_image_embeds: NestedTensors = kwargs.get("input_image_embeds") + if input_image_embeds is None: return None image_sizes = kwargs.get("image_sizes") @@ -1643,23 +1046,24 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal, assert image_sizes is not None and image_attention_mask is not None\ and num_img_tokens is not None, "Missing image inputs" - if isinstance(pixel_values, list): - assert pixel_values[0].dim() == 5, "Incorrect image inputs" + if is_list_of(input_image_embeds, torch.Tensor): + assert all(p.dim() == 5 + for p in input_image_embeds), "Incorrect image inputs" # list len is batch_size. # each tensor has dimension: num_img_per_example, num_hd_patches, # channels, height, width. # need to pad along num_hd_patches. # mask size num_img_per_prompt, num_hd_patches, feat_h, heat_w. - pixel_values = cat_with_pad(pixel_values, dim=0) - elif isinstance(pixel_values, torch.Tensor): + input_image_embeds = cat_with_pad(input_image_embeds, dim=0) + elif isinstance(input_image_embeds, torch.Tensor): # dimension: batch_size, num_img_per_example, num_hd_patches, # channels, height, width. # we flatten first 2 dims to make it a single large batch for # SigLIP Encoder. - assert pixel_values.dim() == 6, "Incorrect image inputs" - pixel_values = pixel_values.flatten(0, 1) + assert input_image_embeds.dim() == 6, "Incorrect image inputs" + input_image_embeds = input_image_embeds.flatten(0, 1) else: - raise ValueError("Incorrect pixel_values inputs") + raise ValueError("Incorrect input_image_embeds inputs") if isinstance(image_attention_mask, list): image_attention_mask = cat_with_pad(image_attention_mask, dim=0) @@ -1685,80 +1089,140 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal, else: raise ValueError("Incorrect image_attention_mask inputs") - return { - 'pixel_values': pixel_values, - 'image_sizes': image_sizes, - 'image_attention_mask': image_attention_mask, - 'num_img_tokens': num_img_tokens, - } + return Phi4MMImagePixelInputs( + type="pixel_values", + data=input_image_embeds, + image_sizes=image_sizes, + image_attention_mask=image_attention_mask, + num_img_tokens=num_img_tokens, + ) - def merge_image_features_to_inputs_embeds( + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("input_image_embeds", + "image_embeds") and "images" not in modalities: + modalities["images"] = self._parse_and_validate_image_input( + **kwargs) + if input_key in ("input_audio_embeds", + "audio_embeds") and "audios" not in modalities: + modalities["audios"] = self._parse_and_validate_audio_input( + **kwargs) + + return modalities + + def _process_image_input( + self, image_input: Phi4MMImagePixelInputs) -> list[torch.Tensor]: + if image_input["type"] == "image_embeds": + image_embeds = image_input["image_embeds"].type(self.visual.dtype) + else: + dtype = next(self.vision_encoder.parameters()).dtype + pixel_values = image_input['data'].to(dtype) + image_sizes = image_input['image_sizes'] + image_attention_mask = image_input['image_attention_mask'] + image_embeds = self.vision_encoder(pixel_values, image_sizes, + image_attention_mask) + return image_embeds + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: + return None + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + audio_projection_mode = 'speech' + for modality in modalities: + # make sure process images first + if modality == "images": + audio_projection_mode = "vision" + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings += tuple(vision_embeddings) + if modality == "audios": + audio_input = modalities["audios"] + audio_embeddings = self._process_audio_input( + audio_input, audio_projection_mode=audio_projection_mode) + multimodal_embeddings += tuple(audio_embeddings) + + return multimodal_embeddings + + def get_input_embeddings( self, input_ids: torch.Tensor, - inputs_embeds: torch.Tensor, - image_set_tensors: List[torch.Tensor], - ): - position_tuple = (input_ids == _IMAGE_PLACEHOLDER_TOKEN_ID).nonzero( - as_tuple=True) - - assert all([t.shape[0] == 1 for t in image_set_tensors - ]), 'img_set_tensor should have shape (1, N_tokens, C)' - # Shape: (merged_N_tokens, C) - image_set_tensor = torch.cat(image_set_tensors, dim=1).squeeze(0) - image_set_tensor = image_set_tensor.to(inputs_embeds.dtype).to( - inputs_embeds.device) - merged_embeds = inputs_embeds.index_put( - indices=position_tuple, - values=image_set_tensor, - accumulate=False, - ) - return merged_embeds + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.model.embed_tokens(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + [_IMAGE_PLACEHOLDER_TOKEN_ID, _AUDIO_PLACEHOLDER_TOKEN_ID]) + return inputs_embeds + + def get_input_embeddings_v0( + self, + input_ids: torch.Tensor, + image_input: Optional[Phi4MMImagePixelInputs] = None, + audio_input: Optional[Phi4MMAudioFeatureInputs] = None, + ) -> torch.Tensor: + audio_projection_mode = 'speech' + inputs_embeds = self.get_input_embeddings(input_ids) + if image_input is not None: + image_embeds = self._process_image_input(image_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + image_embeds, + placeholder_token_id=_IMAGE_PLACEHOLDER_TOKEN_ID, + ) + audio_projection_mode = 'vision' + + if audio_input is not None: + audio_embeds = self._process_audio_input( + audio_input, audio_projection_mode=audio_projection_mode) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + audio_embeds, + placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN_ID, + ) + return inputs_embeds def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> torch.Tensor: if intermediate_tensors is not None: - input_ids = None inputs_embeds = None - else: - # Each entry in this is a pair of audio_features and audio_embed - # lengths + + # NOTE: In v1, inputs_embeds is always generated at model runner from + # `get_multimodal_embeddings` and `get_input_embeddings`, this + # condition is only for v0 compatibility. + elif inputs_embeds is None: + image_input = self._parse_and_validate_image_input(**kwargs) audio_input = self._parse_and_validate_audio_input(**kwargs) - image_inputs = self._parse_and_validate_image_input(**kwargs) - - has_audio = audio_input is not None - has_image = image_inputs is not None - - if has_audio: - audio_projection_mode = 'vision' if has_image else 'speech' - inputs_embeds = self._process_audio_input( - input_ids, audio_input, audio_projection_mode) - - if has_image: - dtype = self.vision_encoder.img_processor.embeddings.\ - patch_embedding.weight.dtype - pixel_values = image_inputs['pixel_values'].to(dtype) - image_sizes = image_inputs['image_sizes'] - image_attention_mask = image_inputs['image_attention_mask'] - image_set_tensors = self.vision_encoder( - pixel_values, image_sizes, image_attention_mask) - if not has_audio: - inputs_embeds = self.model.embed_tokens(input_ids) - - inputs_embeds = self.merge_image_features_to_inputs_embeds( - input_ids, inputs_embeds, image_set_tensors) - - if has_image or has_audio: - # multi-modal input, we have set inputs_embeds properly in - # previous steps - input_ids = None - else: - # text-only, we keep using original input_ids + + if image_input is None and audio_input is None: inputs_embeds = None + else: + inputs_embeds = self.get_input_embeddings_v0( + input_ids, + image_input=image_input, + audio_input=audio_input) + input_ids = None hidden_states = self.model( input_ids, @@ -1778,14 +1242,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal, sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> None: weights = ((name, data) for name, data in weights diff --git a/vllm/model_executor/models/phi4mm_audio.py b/vllm/model_executor/models/phi4mm_audio.py index db90848f98099dfc7a3cf5f566368984cba3f14a..34a7a73d057aee1297ed41226feecdd68cd7f523 100644 --- a/vllm/model_executor/models/phi4mm_audio.py +++ b/vllm/model_executor/models/phi4mm_audio.py @@ -1159,8 +1159,11 @@ class AudioEmbedding(nn.Module): input_embeds: torch.FloatTensor, audio_attention_mask: torch.Tensor = None, audio_projection_mode: str = "speech", - ): - + ) -> torch.FloatTensor: + """ + arguments: + input_embeds: audio features (B, T, D) B: num audios in a sequence + """ if self.freeze_audio_processor: with torch.no_grad(): audio_features, masks = self.encoder(input_embeds, @@ -1210,62 +1213,20 @@ class AudioEmbedding(nn.Module): def forward( self, - input_ids: torch.LongTensor, - input_embeds: torch.FloatTensor, - audio_embed_sizes, - **kwargs, + audio_features: torch.FloatTensor, + audio_attention_mask: torch.Tensor = None, + audio_projection_mode: str = "speech", ) -> torch.FloatTensor: """ arguments: - input_ids: input text ids (B, U) - input_embeds: audio features (B, T, D) B: num audios in a sequence + audio_features: audio features (T, D) + + returns: + audio_embeds: audio embeddings (num_audio_tokens, hidden_dim) """ - assert input_embeds is not None and len(input_embeds) == len( - audio_embed_sizes) - - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - - with torch.no_grad(): - positions = (input_ids == _AUDIO_PLACEHOLDER_TOKEN_ID).nonzero( - as_tuple=False) - - if not isinstance(input_embeds, list): - input_embeds = [input_embeds] - - audio_projection_mode = kwargs.get("audio_projection_mode", "speech") - audio_set_tensor = [ - self.get_audio_features( - input_embed, audio_projection_mode=audio_projection_mode) - for input_embed in input_embeds - ] - - with torch.no_grad(): - input_ids.clamp_min_(0).clamp_max_(self.vocab_size) - - if "wte" in kwargs: - # we use the token embedding layer from the huggingface model, this - # is REQUIRED to make sure we are using the loaded weights. - hidden_states = kwargs["wte"](input_ids) - else: - # otherwise, we use token embedding in pretrained mixformer from - # phi team - hidden_states = self.wte(input_ids) - - if len(positions.tolist()) > 0: - assert sum(audio_embed_sizes) == len( - positions - ), "please ensure the encoder outputs have the same length as"\ - " defined in input_ids!" - idx = 0 - for i in range(len(audio_embed_sizes)): - cnt = audio_embed_sizes[i] - assert audio_set_tensor[i].shape[0] == 1 - hidden_states[ - positions[idx, 0], - positions[idx, 1]:positions[idx, 1] + cnt, - ] = (audio_set_tensor[i][0, :audio_embed_sizes[i], :].to( - hidden_states.dtype).to(hidden_states.device)) - idx += cnt - - return hidden_states + audio_embeds = self.get_audio_features( + audio_features.unsqueeze(0), + audio_attention_mask=audio_attention_mask, + audio_projection_mode=audio_projection_mode, + ) + return audio_embeds.squeeze(0) diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 381a33d98b9cb39a530932536e51d5acb93068a8..2dc55e4c352e3b5e16f9c97596fa0aadc6de14c0 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -40,7 +40,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -634,7 +633,6 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -659,14 +657,6 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP): sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 38e140a91ecf5f2e7dee681de631ec1082a98b6e..73fd80146955ee136a962f5ece9d62d5cf019047 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -28,7 +28,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs @@ -331,13 +330,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[PixtralImagePixelInputs]: images = kwargs.pop("images", None) @@ -441,13 +433,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def is_vision_encoder_weights(weight: Tuple[str, torch.Tensor]): @@ -926,9 +911,8 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]): return self.vision_config.image_size def get_patch_size(self) -> int: - spatial_merge_size = getattr(self.vision_config, "spatial_merge_size", - 1) - return (self.vision_config.patch_size * spatial_merge_size) + return (self.vision_config.patch_size * + self.vision_config.spatial_merge_size) def get_patch_grid_length(self) -> int: image_size, patch_size = self.get_image_size(), self.get_patch_size() diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py new file mode 100644 index 0000000000000000000000000000000000000000..790c48ccd216656a14d4849799cf32f29fbf3170 --- /dev/null +++ b/vllm/model_executor/models/plamo2.py @@ -0,0 +1,736 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Inference-only PLaMo2 model.""" +import math +from typing import Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import PretrainedConfig, PreTrainedModel + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.layer import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + selective_scan_fn, selective_state_update) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + composed_weight_loader, default_weight_loader, sharded_weight_loader) +from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, + SupportsV0Only) +from vllm.model_executor.models.mamba_cache import (MambaCacheManager, + MambaCacheParams) +from vllm.model_executor.models.utils import maybe_prefix +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.sequence import IntermediateTensors +from vllm.utils import LayerBlockType + + +# Only used for type hinting. +class Plamo2Config(PretrainedConfig): # type: ignore + model_type: str = "plamo2" + + hidden_size: int + num_hidden_layers: int + rms_norm_eps: float + # Attention + num_attention_heads: int + hidden_size_per_head: int + num_key_value_heads: int + # Mamba + mamba_d_state: int + mamba_d_conv: int + mamba_num_heads: int + mamba_step: int + # MLP + intermediate_size: int + # Tokenizer + vocab_size: int + + +class Plamo2PreTrainedModel(PreTrainedModel): # type: ignore + + def _init_weights(self, module: torch.nn.Module) -> None: + std = 0.02 + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +def get_initial_dt_bias(num_heads: int) -> torch.Tensor: + dt_min = 0.001 + dt_max = 0.1 + dt = torch.exp( + torch.rand(num_heads) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min)) + dt = torch.clamp(dt, 1e-4) + inv_dt = dt + torch.log(-torch.expm1(-dt)) + return inv_dt + + +def is_mamba(config: Plamo2Config, i: int) -> bool: + assert config.mamba_step > 1 + + if config.num_hidden_layers <= (config.mamba_step // 2): + # use attention in last layer + return i != config.num_hidden_layers - 1 + return (i % config.mamba_step) != (config.mamba_step // 2) + + +# TODO(Shinichi): Replace this with RMSNorm. +def _rms_norm(hidden_states: torch.Tensor, weight: torch.Tensor, + eps: float) -> torch.Tensor: + input_shape = hidden_states.shape + hidden_states = hidden_states.reshape(input_shape[:-1] + weight.shape) + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + eps) + hidden_states = hidden_states.to(input_dtype) + hidden_states = weight * hidden_states + return hidden_states.reshape(input_shape) + + +def _swiglu(h: torch.Tensor) -> torch.Tensor: + h0, h1 = h.chunk(2, dim=-1) + return torch.nn.functional.silu(h0) * h1 + + +# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer +class Plamo2MambaMixer(nn.Module): + # TODO(Shinichi): Rebase on Mamba2 implementation. + + def __init__(self, + config: Plamo2Config, + cache_config: CacheConfig, + quant_config: QuantizationConfig, + max_model_len: int, + prefix: str = "", + **kwargs) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.ssm_state_size = config.mamba_d_state + self.conv_kernel_size = config.mamba_d_conv + self.intermediate_size = (config.mamba_num_heads * + config.hidden_size_per_head) + self.hidden_size_per_head = config.hidden_size_per_head + self.num_heads = config.mamba_num_heads + self.time_step_rank = max(64, self.hidden_size // 16) + self.use_conv_bias = False + self.use_bias = False + self.conv1d = ColumnParallelLinear( + input_size=self.conv_kernel_size, + output_size=self.intermediate_size, + bias=self.use_conv_bias, + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + self.in_proj = MergedColumnParallelLinear( + self.hidden_size, + [self.intermediate_size] * 2, + bias=self.use_bias, + prefix=f"{prefix}.in_proj", + ) + # selective projection used to make dt, B and C input dependent + self.bcdt_proj = RowParallelLinear( + self.intermediate_size, + self.time_step_rank + self.ssm_state_size * 2, + bias=False, + prefix=f"{prefix}.bcdt_proj", + ) + # time step projection (discretization) - + # In the forward we need to apply dt_proj without the bias, + # as the bias is added in the selective scan kernel. + self.dt_proj = ColumnParallelLinear( + self.time_step_rank, + self.num_heads, + bias=False, + prefix=f"{prefix}.dt_proj", + ) + self.dt_bias = torch.nn.Parameter(get_initial_dt_bias(self.num_heads)) + + tp_size = get_tensor_model_parallel_world_size() + self.A = nn.Parameter( + torch.empty( + self.intermediate_size // tp_size, + self.ssm_state_size, + dtype=torch.float32, + )) + self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size)) + + set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)}) + a_weight_loader = composed_weight_loader( + sharded_weight_loader(0), lambda x: -torch.exp(x.float())) + set_weight_attrs(self.A, {"weight_loader": a_weight_loader}) + + self.out_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=self.use_bias, + input_is_parallel=True, + prefix=f"{prefix}.out_proj", + ) + # The activation function is fixed to SiLU. + self.activation = "silu" + + self.dt_norm = RMSNorm(self.time_step_rank, eps=config.rms_norm_eps) + self.B_norm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) + self.C_norm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + mamba_cache_params: MambaCacheParams, + **kwargs, + ) -> torch.Tensor: + + attn_metadata: AttentionMetadata = get_forward_context().attn_metadata + + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states)[0] + # Reshaping the projected states as in modeling_plamo.py. + length = len(hidden_states) + projected_states = projected_states.reshape(length, self.num_heads, -1) + gate, hidden_states = torch.split( + projected_states, + [self.hidden_size_per_head, self.hidden_size_per_head], + dim=-1) + hidden_states = hidden_states.reshape(length, -1).transpose(0, 1) + gate = gate.reshape(length, -1).transpose(0, 1) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + + if attn_metadata.query_start_loc is not None \ + and attn_metadata.context_lens_tensor is not None: + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + hidden_states = causal_conv1d_fn( + hidden_states, + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=mamba_cache_params.conv_state, + has_initial_state=attn_metadata.context_lens_tensor > 0, + cache_indices=mamba_cache_params.state_indices_tensor, + query_start_loc=attn_metadata.query_start_loc) + else: + hidden_states = causal_conv1d_update( + hidden_states.transpose(0, 1), + mamba_cache_params.conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=mamba_cache_params.state_indices_tensor) + hidden_states = hidden_states.transpose(0, 1) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + ssm_parameters = self.bcdt_proj(hidden_states.transpose(-2, -1))[0] + + # Splitting the ssm_parameters as in modeling_plamo.py. + B, C, time_step = torch.split( + ssm_parameters, + [self.ssm_state_size, self.ssm_state_size, self.time_step_rank], + dim=-1, + ) + time_step = self.dt_norm(time_step.contiguous()) + B = self.B_norm(B.contiguous()) + C = self.C_norm(C.contiguous()) + + discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = (self.dt_bias.float() if hasattr( + self.dt_proj, "bias") else None) + + # Broadcasting as in modeling_plamo.py. + discrete_time_step = discrete_time_step.transpose( + 0, 1)[..., None].expand(-1, -1, self.hidden_size_per_head) + discrete_time_step = discrete_time_step.reshape( + -1, self.intermediate_size).transpose(0, 1) + time_proj_bias = time_proj_bias[..., + None].expand(-1, + self.hidden_size_per_head) + time_proj_bias = time_proj_bias.reshape(self.intermediate_size) + + if attn_metadata.query_start_loc is not None \ + and attn_metadata.context_lens_tensor is not None: + scan_outputs = selective_scan_fn( + hidden_states, + mamba_cache_params.ssm_state, + discrete_time_step, + self.A, + B.transpose(-2, -1), + C.transpose(-2, -1), + self.D.float(), + gate, + time_proj_bias, + delta_softplus=True, + cache_indices=mamba_cache_params.state_indices_tensor, + has_initial_state=attn_metadata.context_lens_tensor > 0, + query_start_loc=attn_metadata.query_start_loc) + else: + scan_outputs = selective_state_update( + mamba_cache_params.ssm_state, + hidden_states.transpose(0, 1), + discrete_time_step.transpose(0, 1), + self.A, + B, + C, + self.D, + gate.transpose(0, 1), + time_proj_bias, + dt_softplus=True, + state_batch_indices=mamba_cache_params.state_indices_tensor) + scan_outputs = scan_outputs.transpose(0, 1) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_outputs.transpose(-2, + -1))[0] + return contextualized_states + + +class DenseMLP(nn.Module): + + def __init__( + self, + config: Plamo2Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_up_proj = MergedColumnParallelLinear( + self.hidden_size, [self.intermediate_size] * 2, + bias=False, + prefix=f"{prefix}.gate_up_proj", + quant_config=quant_config) + self.down_proj = RowParallelLinear(self.intermediate_size, + self.hidden_size, + bias=False, + prefix=f"{prefix}.down_proj", + quant_config=quant_config) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + h = self.gate_up_proj(hidden_states)[0] + h = _swiglu(h) + output, _ = self.down_proj(h) + return output # type: ignore + + +class Plamo2AttentionMixer(nn.Module): + + def __init__(self, + config: Plamo2Config, + cache_config: CacheConfig, + quant_config: QuantizationConfig, + max_model_len: int | None = None, + prefix: str = "", + **kwargs) -> None: + super().__init__() + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = config.hidden_size_per_head + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + quant_config=quant_config) + + self.rope_theta = config.rope_theta if hasattr(config, + "rope_theta") else 10000 + self.rope_scaling = config.rope_scaling if hasattr( + config, "rope_scaling") else None + + assert max_model_len is not None, "max_model_len must be provided" + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_model_len, + base=self.rope_theta, + rope_scaling=self.rope_scaling, + ) + self.q_weight = torch.nn.Parameter( + torch.ones((self.num_heads, config.hidden_size_per_head))) + self.k_weight = torch.nn.Parameter( + torch.ones((self.num_kv_heads, config.hidden_size_per_head))) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn", + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + **kwargs, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = _rms_norm(q, self.q_weight, 1e-6) + k = _rms_norm(k, self.k_weight, 1e-6) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class Plamo2DecoderLayer(nn.Module): + + def __init__(self, + vllm_config: VllmConfig, + layer_idx: int, + max_model_len: int | None = None, + prefix: str = "", + **kwargs) -> None: + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + max_model_len = vllm_config.scheduler_config.max_model_len + + self.is_mamba = is_mamba(config, layer_idx) + if self.is_mamba: + self.mixer = Plamo2MambaMixer(config=config, + cache_config=cache_config, + quant_config=quant_config, + max_model_len=max_model_len, + prefix=f"{prefix}.mixer") + else: + self.mixer = Plamo2AttentionMixer(config=config, + cache_config=cache_config, + quant_config=quant_config, + max_model_len=max_model_len, + prefix=f"{prefix}.mixer") + + self.mlp = DenseMLP(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + self.pre_mixer_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_mixer_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_mlp_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_mlp_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + mamba_cache_params: MambaCacheParams, + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.pre_mixer_norm(hidden_states) + else: + hidden_states, residual = self.pre_mixer_norm( + hidden_states, residual) + + hidden_states = self.mixer(positions=positions, + hidden_states=hidden_states, + residual=residual, + mamba_cache_params=mamba_cache_params) + hidden_states = self.post_mixer_norm(hidden_states) + # Fully Connected + hidden_states, residual = self.pre_mlp_norm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_mlp_norm(hidden_states) + return hidden_states, residual + + +class Plamo2Decoder(torch.nn.Module): + + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers + + self.layers = nn.ModuleList([ + Plamo2DecoderLayer(vllm_config=vllm_config, + layer_idx=i, + prefix=f"{prefix}.layers.{i}") + for i in range(num_hidden_layers) + ]) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + mamba_cache_params: MambaCacheParams, + ) -> torch.Tensor: + mamba_cache_index = 0 + for layer in self.layers: + layer_mamba_cache_params = None + if layer.is_mamba: + layer_mamba_cache_params = mamba_cache_params.at_layer_idx( + mamba_cache_index) + mamba_cache_index += 1 + + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + mamba_cache_params=layer_mamba_cache_params) + return hidden_states, residual + + +class Plamo2Model(Plamo2PreTrainedModel): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config.model_config.hf_config) + + config = vllm_config.model_config.hf_config + + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + prefix=f"{prefix}.embed_tokens", + ) + self.layers = Plamo2Decoder(vllm_config, prefix=f"{prefix}.layers") + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_init() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + mamba_cache_params: MambaCacheParams, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # TODO(Shinichi): Implement pipeline parallelism. + hidden_states = self.embed_tokens(input_ids) + residual = None + + hidden_states, residual = self.layers( + positions=positions, + hidden_states=hidden_states, + residual=residual, + mamba_cache_params=mamba_cache_params) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, IsHybrid, + SupportsV0Only): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + config = vllm_config.model_config.hf_config + scheduler_config = vllm_config.scheduler_config + assert not vllm_config.cache_config.enable_prefix_caching, \ + "PLaMo2 currently does not support prefix caching" + + super().__init__(config) + self.config = config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.scheduler_config = scheduler_config + + # ModelConfig.get_head_size assumes head_dim is set or calculated as + # hidden_size // num_attention_heads. However, this is not always + # the case for PLaMo2, as indicated by the FIXME comment. + self.config.head_dim = self.config.hidden_size_per_head + + self.model = Plamo2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.vocab_size = self.config.vocab_size + self.unpadded_vocab_size = self.config.vocab_size + num_embeddings = ((self.vocab_size + 15) // 16) * 16 + self.lm_head = ParallelLMHead( + num_embeddings, + self.config.hidden_size, + org_num_embeddings=self.config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + prefix=f"{prefix}.lm_head", + ) + if self.config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) + + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Optional[MambaCacheManager] = None + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + self.config.vocab_size) + + # Initialize weights and apply final processing + self.post_init() + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs): + if self.mamba_cache is None: + num_mamba_layers = self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, LayerBlockType.mamba) + + self.mamba_cache = MambaCacheManager( + self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, + *self._get_mamba_cache_shape()) + + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + + hidden_states = self.model(input_ids, positions, mamba_cache_params, + intermediate_tensors, inputs_embeds) + return hidden_states + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def _get_mamba_cache_shape( + self) -> Tuple[Tuple[int, int], Tuple[int, int]]: + world_size = get_tensor_model_parallel_world_size() + hidden_size = (self.config.mamba_num_heads * + self.config.hidden_size_per_head) + conv_state_shape = ( + hidden_size // world_size, + self.config.mamba_d_conv - 1, + ) + temporal_state_shape = ( + hidden_size // world_size, + self.config.mamba_d_state, + ) + return conv_state_shape, temporal_state_shape + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + + # Both tie_word_embeddings=True and lm_head.weight in the safetensor + # at the same time causes dict key access error. + if name == "lm_head.weight" and self.config.tie_word_embeddings: + assert "lm_head.weight" not in params_dict + continue + + # Update the weight names to be compatible with the vllm version + # of the model. + # Do not change the order of the replacements. + replacements = { + # Rename incompatible weight names. + ".A_log": ".A", + ".B_norm_weight": ".B_norm.weight", + ".C_norm_weight": ".C_norm.weight", + ".dt_norm_weight": ".dt_norm.weight", + } + # Apply replacements based on the defined mappings + for old, new in replacements.items(): + if old in name: + name = name.replace(old, new) + + # Broadcast the loaded weight to match the model's parameter shape. + if ".A" in name: + loaded_weight = loaded_weight[:, None, None].expand( + -1, self.config.hidden_size_per_head, + self.config.mamba_d_state) + loaded_weight = loaded_weight.reshape( + -1, self.config.mamba_d_state) + elif ".D" in name: + loaded_weight = loaded_weight[:, None].expand( + -1, self.config.hidden_size_per_head) + loaded_weight = loaded_weight.reshape(-1) + # Offset parameter with vllm's RMSNorm haven't been supported yet. + if ".pre_mixer_norm" in name: + loaded_weight += 1.0 + elif ".post_mixer_norm" in name: + loaded_weight += 1.0 / 5 + elif ".pre_mlp_norm" in name: + loaded_weight += 1.0 + elif ".post_mlp_norm" in name: + loaded_weight += 1.0 / (5**1.5) + elif "model.norm.weight" in name: + loaded_weight += 1.0 + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index f0d21bca7a4a2f591b4b56b89519ca9eee7c3dda..26eecc337233d2f2b1ace7b7c2187f02dd9b1ea2 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -27,7 +27,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -285,7 +284,6 @@ class QWenBaseModel(nn.Module): if self.config.tie_word_embeddings: self.lm_head.weight = self.transformer.wte.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) @@ -310,14 +308,6 @@ class QWenBaseModel(nn.Module): sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 3ea56c1d0f9aec2914991d63c13601d01a801529..17a871310693e1039ea6ff9868c1ec8d264958d7 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -47,7 +47,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -615,7 +614,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.quant_method = None if quant_config is not None: @@ -653,14 +651,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py new file mode 100644 index 0000000000000000000000000000000000000000..039f528db13bb1eac4ea3a84afdf7b4935216383 --- /dev/null +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -0,0 +1,901 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen2.5-Omni model (thinker part).""" + +from copy import copy +from functools import partial +from typing import (Any, Dict, Iterable, List, Mapping, Optional, Sequence, + Set, Tuple, Union) + +import torch +import torch.nn as nn +from transformers.feature_extraction_utils import BatchFeature +from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import ( + Qwen2_5OmniConfig, Qwen2_5OmniThinkerConfig) +from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( + Qwen2_5OmniAudioEncoder) +from transformers.models.qwen2_5_omni.processing_qwen2_5_omni import ( + Qwen2_5OmniProcessor) +from transformers.models.whisper import WhisperFeatureExtractor + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding +from vllm.model_executor.models.qwen2_5_vl import ( + Qwen2_5_VisionTransformer, Qwen2_5_VLImageEmbeddingInputs, + Qwen2_5_VLImageInputs, Qwen2_5_VLImagePixelInputs, + Qwen2_5_VLProcessingInfo, Qwen2_5_VLVideoEmbeddingInputs, + Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs) +from vllm.model_executor.models.qwen2_audio import ( + Qwen2AudioInputs, Qwen2AudioProcessingInfo, + _get_feat_extract_output_lengths) +from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (ImageItem, ModalityData, + MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs, NestedTensors) +from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems, + ModalityDataItems, MultiModalDataItems, + MultiModalDataParser) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + PlaceholderFeaturesInfo, + PromptReplacement, PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .utils import (AutoWeightsLoader, WeightsMapper, + init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings) + +try: + import flash_attn +except (ImportError, ModuleNotFoundError): + flash_attn = None + +logger = init_logger(__name__) + + +def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]): + audio_feature_lengths = hf_inputs.get("audio_feature_lengths", + torch.empty((0, ))) + + image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) + image_grid_sizes = image_grid_thw.prod(-1) + + video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) + video_grid_sizes = video_grid_thw.prod(-1) + + return dict( + input_audio_features=MultiModalFieldConfig.flat_from_sizes( + "audio", audio_feature_lengths, dim=1), + feature_attention_mask=MultiModalFieldConfig.batched("audio"), + audio_feature_lengths=MultiModalFieldConfig.batched("audio"), + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_grid_sizes), + image_embeds=MultiModalFieldConfig.flat_from_sizes( + "image", image_grid_sizes), + image_grid_thw=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes), + video_embeds=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes), + video_grid_thw=MultiModalFieldConfig.batched("video"), + second_per_grid_ts=MultiModalFieldConfig.batched("video"), + ) + + +class Qwen2_5OmniThinkerMultiModalDataParser(Qwen2VLMultiModalDataParser): + + def _parse_audio_data( + self, + data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], + ) -> ModalityDataItems[Any, Any]: + if isinstance(data, dict): + return DictEmbeddingItems( + data, + modality="audio", + required_fields={ + "input_audio_features", "audio_feature_lengths" + }, + fields_factory=_qwen2_5_omni_thinker_field_config, + ) + + return super()._parse_audio_data(data) + + +class Qwen2_5OmniThinkerProcessingInfo(Qwen2AudioProcessingInfo, + Qwen2_5_VLProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(Qwen2_5OmniConfig).thinker_config + + def get_hf_processor( + self, + *, + sampling_rate: Optional[int] = None, + min_pixels: Optional[int] = None, + max_pixels: Optional[int] = None, + size: Optional[dict[str, int]] = None, + fps: Optional[Union[float, List[float]]] = None, + **kwargs: object, + ) -> Qwen2_5OmniProcessor: + if fps is not None: + kwargs["fps"] = fps + processor = self.ctx.get_hf_processor( + Qwen2_5OmniProcessor, + image_processor=self.get_image_processor(min_pixels=min_pixels, + max_pixels=max_pixels, + size=size), + **kwargs, + ) + if not hasattr(processor, "audio_token"): + processor.audio_token = "<|AUDIO|>" + if not hasattr(processor, "image_token"): + processor.image_token = "<|IMAGE|>" + if not hasattr(processor, "video_token"): + processor.video_token = "<|VIDEO|>" + return processor + + def get_feature_extractor( + self, + *, + sampling_rate: Optional[int] = None, + **kwargs: object, + ): + hf_processor = self.get_hf_processor(sampling_rate=sampling_rate) + feature_extractor = hf_processor.feature_extractor # type: ignore + assert isinstance(feature_extractor, WhisperFeatureExtractor) + return feature_extractor + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"audio": None, "image": None, "video": None} + + +class Qwen2_5OmniThinkerDummyInputsBuilder( + BaseDummyInputsBuilder[Qwen2_5OmniThinkerProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_audios = mm_counts.get("audio", 0) + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + hf_processor = self.info.get_hf_processor() + + audio_token: str = hf_processor.audio_token + image_token: str = hf_processor.image_token + video_token: str = hf_processor.video_token + + return (audio_token * num_audios + image_token * num_images + + video_token * num_videos) + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_audios = mm_counts.get("audio", 0) + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + feature_extractor = self.info.get_feature_extractor() + + target_audio_length = min( + feature_extractor.chunk_length, + 30, + ) * feature_extractor.sampling_rate + target_width, target_height = \ + self.info.get_image_size_with_most_features() + target_num_frames = \ + self.info.get_num_frames_with_most_features(seq_len, mm_counts) + + mm_data = { + "audio": + self._get_dummy_audios(length=target_audio_length, + num_audios=num_audios), + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + "video": + self._get_dummy_videos(width=target_width, + height=target_height, + num_frames=target_num_frames, + num_videos=num_videos), + } + + return mm_data + + +class Qwen2_5OmniThinkerMultiModalProcessor( + BaseMultiModalProcessor[Qwen2_5OmniThinkerProcessingInfo]): + + def _get_data_parser(self) -> MultiModalDataParser: + feature_extractor = self.info.get_feature_extractor() + return Qwen2_5OmniThinkerMultiModalDataParser( + target_sr=feature_extractor.sampling_rate) + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + mm_data = dict(mm_data) + audios = mm_data.pop("audios", []) + + # NOTE: WhisperFeatureExtractor cannot handle empty list of audios + if audios: + # NOTE: Qwen2.5-Omni processor accept "audio" + mm_data["audio"] = audios + mm_kwargs = dict(**mm_kwargs, ) + + hf_inputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + input_features = hf_inputs.pop('input_features', None) + feature_attention_mask = hf_inputs.get('feature_attention_mask', None) + if ('input_audio_features' not in hf_inputs + and input_features is not None): + if feature_attention_mask is not None: + input_features = input_features.permute( + 0, 2, 1)[feature_attention_mask.bool()].permute(1, 0) + hf_inputs['input_audio_features'] = input_features + if ('audio_feature_lengths' not in hf_inputs + and feature_attention_mask is not None): + hf_inputs['audio_feature_lengths'] = feature_attention_mask.sum(-1) + return hf_inputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return _qwen2_5_omni_thinker_field_config(hf_inputs) + + def _maybe_apply_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + prompt_ids: list[int], + mm_kwargs: MultiModalKwargs, + is_update_applied: bool, + ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: + """ + Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`. + """ + unbound_prompt_updates = self._get_prompt_updates( + mm_items, + hf_processor_mm_kwargs, + mm_kwargs, + ) + mm_prompt_updates = self._bind_and_group_updates( + unbound_prompt_updates) + + mm_item_counts = mm_items.get_all_counts() + self._validate_mm_kwargs(mm_kwargs, mm_item_counts) + + use_audio_in_video = hf_processor_mm_kwargs.get( + "use_audio_in_video", False) + + if is_update_applied: + mm_placeholders = self._find_mm_placeholders( + mm_prompt_updates, + prompt_ids, + mm_item_counts, + ) + self._validate_mm_placeholders( + mm_placeholders, + mm_item_counts, + use_audio_in_video=use_audio_in_video) + + tokenizer = self.info.get_tokenizer() + prompt = decode_tokens(tokenizer, prompt_ids) + else: + ( + prompt_ids, + prompt, + mm_placeholders, + ) = self._apply_prompt_updates( + prompt_ids, + mm_prompt_updates, + mm_item_counts, + ) + self._validate_mm_placeholders( + mm_placeholders, + mm_item_counts, + use_audio_in_video=use_audio_in_video) + + tokenizer = self.info.get_tokenizer() + prompt = decode_tokens(tokenizer, prompt_ids) + + if use_audio_in_video: + mm_kwargs["use_audio_in_video"] = True + + return prompt_ids, prompt, mm_placeholders + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + image_processor = self.info.get_image_processor( + **hf_processor_mm_kwargs) + vocab = tokenizer.get_vocab() + + audio_token = processor.audio_token + image_token = processor.image_token + video_token = processor.video_token + audio_token_id = vocab[audio_token] + image_token_id = vocab[image_token] + video_token_id = vocab[video_token] + + audio_feature_lengths = out_mm_kwargs.get("audio_feature_lengths") + feature_attention_mask = out_mm_kwargs.get("feature_attention_mask") + if audio_feature_lengths is None and feature_attention_mask is None: + audio_output_lengths = [] + elif audio_feature_lengths is not None: + _, audio_output_lens = _get_feat_extract_output_lengths( + audio_feature_lengths) + audio_output_lengths = audio_output_lens.tolist() + elif feature_attention_mask is not None: + assert isinstance(feature_attention_mask, torch.Tensor) + _, audio_output_lens = _get_feat_extract_output_lengths( + feature_attention_mask.sum(-1)) + audio_output_lengths = audio_output_lens.tolist() + + # number of audios read from video. + audio_in_video_item_idx = 0 + + def get_replacement_qwen2_audio(item_idx: int): + item_idx += audio_in_video_item_idx + + num_features = audio_output_lengths[item_idx] + if num_features == 0: + audios = mm_items.get_items("audio", AudioProcessorItems) + audio = audios.get(item_idx) + raise ValueError( + f"The audio {audio} (len={len(audio)}) is too short " + "to be represented inside the model") + + return [audio_token_id] * num_features + + def get_replacement_qwen2_vision(item_idx: int, modality: str): + grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx] + assert isinstance(grid_thw, torch.Tensor) + merge_length = image_processor.merge_size**2 + + token_id = image_token_id if modality == "image" else video_token_id + return [token_id] * (int(grid_thw.prod()) // merge_length) + + use_audio_in_video = hf_processor_mm_kwargs.get( + "use_audio_in_video", False) + thinker_config = self.info.get_hf_config() + + def get_replacement_qwen2_use_audio_in_video(item_idx: int): + nonlocal audio_in_video_item_idx + + audio_num_features = audio_output_lengths[audio_in_video_item_idx + + item_idx] + video_grid_thw = out_mm_kwargs["video_grid_thw"][item_idx] + + audio_in_video_item_idx += 1 + + second_per_grid_ts = hf_processor_mm_kwargs.get( + "second_per_grid_ts", None) + if second_per_grid_ts: + video_second_per_grid_t = second_per_grid_ts[item_idx] + else: + video_second_per_grid_t = 1.0 + + return MRotaryEmbedding.omni_get_updates_use_audio_in_video( + thinker_config=thinker_config, + audio_len=audio_num_features, + video_grid_thw=video_grid_thw, + video_second_per_grid_t=video_second_per_grid_t, + ) + + video_replacement_fn = ( + get_replacement_qwen2_use_audio_in_video if use_audio_in_video else + partial(get_replacement_qwen2_vision, modality="video")) + + return [ + PromptReplacement( + modality="audio", + target=audio_token, + replacement=get_replacement_qwen2_audio, + ), + PromptReplacement( + modality="image", + target=image_token, + replacement=partial(get_replacement_qwen2_vision, + modality="image"), + ), + PromptReplacement( + modality="video", + target=video_token, + replacement=video_replacement_fn, + ), + ] + + def _apply_hf_processor_main( + self, + prompt: Union[str, list[int]], + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + *, + enable_hf_prompt_update: bool, + ) -> tuple[list[int], MultiModalKwargs, bool]: + """ + Qwen2.5-Omni reimplements this function to handle text only. + """ + if isinstance(prompt, str): + if enable_hf_prompt_update: + return self._apply_hf_processor_text_mm( + prompt_text=prompt, + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + tokenizer = self.info.get_tokenizer() + prompt_ids = encode_tokens(tokenizer, prompt) + else: + prompt_ids = self._apply_hf_processor_tokens_only(prompt) + + mm_kwargs = self._apply_hf_processor_mm_only( + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + + return prompt_ids, mm_kwargs, False + + def _apply_hf_processor_mm_only( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> MultiModalKwargs: + """ + Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`. + """ + mm_counts = mm_items.get_all_counts() + + use_audio_in_video = hf_processor_mm_kwargs.get( + "use_audio_in_video", False) + if use_audio_in_video and "video" in mm_counts: + assert "audio" in mm_counts + mm_counts["audio"] -= mm_counts["video"] + + _, mm_kwargs, _ = self._apply_hf_processor_text_mm( + prompt_text=self.dummy_inputs.get_dummy_text(mm_counts), + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + + return mm_kwargs + + def _validate_mm_placeholders( + self, + mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]], + mm_item_counts: Mapping[str, int], + use_audio_in_video: bool = False, + ) -> None: + if use_audio_in_video: + mm_item_counts = copy(mm_item_counts) + if "video" in mm_item_counts: + assert "audio" in mm_item_counts + mm_item_counts["audio"] -= mm_item_counts["video"] + super()._validate_mm_placeholders(mm_placeholders, mm_item_counts) + + +class Qwen2_5OmniConditionalGenerationMixin: + + def _validate_and_reshape_mm_tensor(self, + mm_input: object, + name: str, + dim: int = 0) -> torch.Tensor: + if not isinstance(mm_input, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of {name}. " + f"Got type: {type(mm_input)}") + if isinstance(mm_input, torch.Tensor): + return torch.concat(list(mm_input), dim=dim) + else: + return torch.concat(mm_input, dim=dim) + + def _parse_and_validate_audio_input( + self, **kwargs: object) -> Optional[Qwen2AudioInputs]: + input_audio_features = kwargs.pop('input_audio_features', None) + audio_feature_lengths = kwargs.pop('audio_feature_lengths', None) + feature_attention_mask = kwargs.pop('feature_attention_mask', None) + if input_audio_features is None: + return None + input_audio_features = self._validate_and_reshape_mm_tensor( + input_audio_features, 'input_audio_features', dim=1) + if feature_attention_mask is not None: + feature_attention_mask = self._validate_and_reshape_mm_tensor( + feature_attention_mask, 'feature_attention_mask') + if not isinstance(input_audio_features, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio input features. " + f"Got type: {type(input_audio_features)}") + return Qwen2AudioInputs(input_features=input_audio_features, + audio_feature_lengths=audio_feature_lengths, + feature_attention_mask=feature_attention_mask) + + def _parse_and_validate_image_input( + self, + **kwargs: Dict[str, Any], + ) -> Optional[Qwen2_5_VLImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + pixel_values = self._validate_and_reshape_mm_tensor( + pixel_values, "image pixel values") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of image pixel values. " + f"Got type: {type(pixel_values)}") + + return Qwen2_5_VLImagePixelInputs(type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw) + + if image_embeds is not None: + image_embeds = self._validate_and_reshape_mm_tensor( + image_embeds, "image embeds") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(image_embeds, torch.Tensor): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + return Qwen2_5_VLImageEmbeddingInputs( + type="image_embeds", + image_embeds=image_embeds, + image_grid_thw=image_grid_thw) + + def _parse_and_validate_video_input( + self, + **kwargs: Dict[str, Any], + ) -> Optional[Qwen2_5_VLVideoInputs]: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + video_embeds = kwargs.pop("video_embeds", None) + video_grid_thw = kwargs.pop("video_grid_thw", None) + + if pixel_values_videos is None and video_embeds is None: + return None + + if pixel_values_videos is not None: + pixel_values_videos = self._validate_and_reshape_mm_tensor( + pixel_values_videos, "video pixel values") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + return Qwen2_5_VLVideoPixelInputs( + type="pixel_values_videos", + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + ) + + if video_embeds is not None: + video_embeds = self._validate_and_reshape_mm_tensor( + video_embeds, "video embeds") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + if not isinstance(video_embeds, torch.Tensor): + raise ValueError("Incorrect type of video embeddings. " + f"Got type: {type(video_embeds)}") + return Qwen2_5_VLVideoEmbeddingInputs( + type="video_embeds", + video_embeds=video_embeds, + video_grid_thw=video_grid_thw) + + def _process_audio_input( + self, + audio_input: Qwen2AudioInputs, + audio_hashes: List[str] = None, + cached_audio_features: torch.Tensor = None, + ) -> torch.Tensor: + + input_features = audio_input["input_features"] + audio_feature_lengths = audio_input["audio_feature_lengths"] + if input_features.ndim == 3: + assert input_features.shape[0] == 1 + input_features = input_features.squeeze(0) + if audio_feature_lengths.ndim == 2: + assert audio_feature_lengths.shape[ + 0] == 1 or audio_feature_lengths.shape[1] == 1 + if audio_feature_lengths.shape[0] == 1: + audio_feature_lengths = audio_feature_lengths.squeeze(0) + else: + audio_feature_lengths = audio_feature_lengths.squeeze(1) + + audio_feat_lengths, audio_output_lengths = ( + self.audio_tower._get_feat_extract_output_lengths( + audio_feature_lengths)) + + audio_outputs = self.audio_tower( + input_features.to(self.audio_tower.dtype), + feature_lens=audio_feature_lengths, + aftercnn_lens=audio_feat_lengths, + ) + audio_features = audio_outputs.last_hidden_state + return audio_features.split(audio_output_lengths.tolist()) + + def _process_image_input( + self, + image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]: + if image_input["type"] == "image_embeds": + return image_input["image_embeds"].type(self.visual.dtype) + + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + + pixel_values = image_input["pixel_values"].type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + # Split concatenated embeddings for each image item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return image_embeds.split(sizes.tolist()) + + def _process_video_input( + self, + video_input: Qwen2_5_VLVideoInputs, + video_hashes: List[str] = None, + cached_video_embeds: torch.Tensor = None) -> torch.Tensor: + if video_input["type"] == "video_embeds": + return video_input["video_embeds"].type(self.visual.dtype) + + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + + pixel_values_videos = video_input["pixel_values_videos"].type( + self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) + # Split concatenated embeddings for each video item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return video_embeds.split(sizes.tolist()) + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen2_5OmniThinkerMultiModalProcessor, + info=Qwen2_5OmniThinkerProcessingInfo, + dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder, +) +class Qwen2_5OmniThinkerForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsPP, + Qwen2_5OmniConditionalGenerationMixin): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "thinker.lm_head.": "language_model.lm_head.", + "thinker.model.": "language_model.model.", + "thinker.": "", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + thinker_config: Qwen2_5OmniThinkerConfig = ( + vllm_config.model_config.hf_config.thinker_config) + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = thinker_config + self.multimodal_config = multimodal_config + + # force "use_flash_attention_2=True" to audio tower to align + # the results. + if flash_attn is not None: + audio_config = thinker_config.audio_config + audio_config._attn_implementation_autoset = True + audio_config._attn_implementation = "flash_attention_2" + else: + logger.warning( + "flash_attn is not available, the model may not yield the " + "exactly same result as the transformers implementation " + "in the audio tower part.") + + self.audio_tower = Qwen2_5OmniAudioEncoder(thinker_config.audio_config) + self.visual = Qwen2_5_VisionTransformer( + vision_config=thinker_config.vision_config, + norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "visual"), + ) + self.quant_config = quant_config + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "language_model"), + hf_config=thinker_config.text_config, + architectures=["Qwen2ForCausalLM"], + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + mm_input_by_modality = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("pixel_values", "image_embeds" + ) and "image" not in mm_input_by_modality: + mm_input_by_modality[ + "image"] = self._parse_and_validate_image_input(**kwargs) + if input_key in ("pixel_values_videos", "video_embeds" + ) and "video" not in mm_input_by_modality: + mm_input_by_modality[ + "video"] = self._parse_and_validate_video_input(**kwargs) + if input_key in ("input_audio_features" + ) and "audio" not in mm_input_by_modality: + mm_input_by_modality[ + "audio"] = self._parse_and_validate_audio_input(**kwargs) + return mm_input_by_modality + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + + mm_input_by_modality = self._parse_and_validate_multimodal_inputs( + **kwargs) + if not mm_input_by_modality: + return None + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in mm_input_by_modality: + multimodal_input = mm_input_by_modality[modality] + if modality == "image": + vision_embeddings = self._process_image_input(multimodal_input) + multimodal_embeddings += vision_embeddings + if modality == "video": + video_embeddings = self._process_video_input(multimodal_input) + multimodal_embeddings += video_embeddings + if modality == "audio": + audio_embeddings = self._process_audio_input(multimodal_input) + multimodal_embeddings += audio_embeddings + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + + # TODO (ywang96): support overlapping modalitiy embeddings so that + # `use_audio_in_video` will work on V1. + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, [ + self.config.image_token_index, + self.config.video_token_index, + self.config.audio_token_index + ]) + return inputs_embeds + + def get_multimodal_embeddings_v0( + self, **kwargs: object) -> Optional[NestedTensors]: + audio_input = self._parse_and_validate_audio_input(**kwargs) + image_input = self._parse_and_validate_image_input(**kwargs) + video_input = self._parse_and_validate_video_input(**kwargs) + + if audio_input is None and image_input is None and video_input is None: + return None + + multimodal_embeddings: List[Tuple[NestedTensors, str]] = [] + + if audio_input is not None: + audio_embeds = self._process_audio_input(audio_input) + multimodal_embeddings.append((audio_embeds, "audio")) + if image_input is not None: + image_embeds = self._process_image_input(image_input) + multimodal_embeddings.append((image_embeds, "image")) + if video_input is not None: + video_embeds = self._process_video_input(video_input) + multimodal_embeddings.append((video_embeds, "video")) + return multimodal_embeddings + + def get_input_embeddings_v0( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is None: + return inputs_embeds + + for embeddings, modality in multimodal_embeddings: + if modality == "audio": + placeholder_token_id = self.config.audio_token_index + if modality == "image": + placeholder_token_id = self.config.image_token_index + if modality == "video": + placeholder_token_id = self.config.video_token_index + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, embeddings, placeholder_token_id) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + multimodal_embeddings = self.get_multimodal_embeddings_v0(**kwargs) + inputs_embeds = self.get_input_embeddings_v0( + input_ids, multimodal_embeddings) + input_ids = None + + hidden_states = self.language_model.model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=["talker.", "token2wav."], + ) + loaded_weights = loader.load_weights(weights, + mapper=self.hf_to_vllm_mapper) + + return loaded_weights diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 7f62dccfd92e016aed39a9ea6fc04a659d677fcd..52961f23a7bb3d1684c4e2dc0b025e596fa1c257 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -24,7 +24,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2.5-VL model compatible with HuggingFace weights.""" -from functools import cached_property, partial +from functools import partial from typing import (Callable, Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) @@ -38,19 +38,19 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) from vllm.config import VllmConfig -from vllm.distributed import parallel_state, tensor_model_parallel_all_gather +from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinConfig) -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY @@ -200,6 +200,25 @@ class Qwen2_5_VisionMLP(nn.Module): return x_down +def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int): + """All-gather the input tensor interleavely across model parallel group.""" + import torch.distributed as dist + gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)] + dist.all_gather(gathered_tensors, + local_tensor, + group=parallel_state.get_tp_group().device_group) + + gathered_tensors_split = [ + torch.split(tensor, hidden_size // tp_size, -1) + for tensor in gathered_tensors + ] + ordered_tensors = [ + tensor for pair in zip(*gathered_tensors_split) for tensor in pair + ] + result_tensor = torch.cat(ordered_tensors, dim=-1) + return result_tensor + + class Qwen2_5_VisionAttention(nn.Module): def __init__( @@ -219,10 +238,14 @@ class Qwen2_5_VisionAttention(nn.Module): self.num_attention_heads_per_partition = dist_utils.divide( num_heads, self.tp_size) - self.qkv = ColumnParallelLinear(input_size=embed_dim, - output_size=3 * projection_size, - quant_config=quant_config, - prefix=f"{prefix}.qkv") + self.qkv = QKVParallelLinear( + hidden_size=embed_dim, + head_size=self.hidden_size_per_attention_head, + total_num_heads=num_heads, + total_num_kv_heads=num_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv") self.proj = RowParallelLinear(input_size=projection_size, output_size=embed_dim, quant_config=quant_config, @@ -241,7 +264,8 @@ class Qwen2_5_VisionAttention(nn.Module): # [s, b, 3 * head * head_dim] seq_len, bs, _ = qkv.shape if self.tp_size > 1: - qkv = tensor_model_parallel_all_gather(qkv) + qkv = all_gather_interleave(qkv, self.qkv.hidden_size, + self.tp_size) # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] q, k, v = qkv.chunk(3, dim=2) @@ -710,9 +734,9 @@ class Qwen2_5_VisionTransformer(nn.Module): torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), + ("attn.qkv.", "attn.q.", "q"), + ("attn.qkv.", "attn.k.", "k"), + ("attn.qkv.", "attn.v.", "v"), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: Set[str] = set() @@ -867,13 +891,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): # GPTQ configs do not have a list of ignored modules, however AutoGPTQ # seems to avoid vision encoder sections for some models. @@ -1011,20 +1028,20 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, return video_embeds.split(sizes.tolist()) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: - modalities = {} + mm_input_by_modality = {} # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key in ("pixel_values", - "image_embeds") and "images" not in modalities: - modalities["images"] = self._parse_and_validate_image_input( - **kwargs) - if input_key in ("pixel_values_videos", - "video_embeds") and "videos" not in modalities: - modalities["videos"] = self._parse_and_validate_video_input( - **kwargs) - return modalities + if input_key in ("pixel_values", "image_embeds" + ) and "image" not in mm_input_by_modality: + mm_input_by_modality[ + "image"] = self._parse_and_validate_image_input(**kwargs) + if input_key in ("pixel_values_videos", "video_embeds" + ) and "video" not in mm_input_by_modality: + mm_input_by_modality[ + "video"] = self._parse_and_validate_video_input(**kwargs) + return mm_input_by_modality def get_language_model(self) -> torch.nn.Module: return self.language_model @@ -1032,8 +1049,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: - modalities = self._parse_and_validate_multimodal_inputs(**kwargs) - if not modalities: + mm_input_by_modality = self._parse_and_validate_multimodal_inputs( + **kwargs) + if not mm_input_by_modality: return None # The result multimodal_embeddings is tuple of tensors, with each @@ -1042,14 +1060,13 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, # NOTE: It is important to iterate over the keys in this dictionary # to preserve the order of the modalities. - for modality in modalities: - if modality == "images": - image_input = modalities["images"] - vision_embeddings = self._process_image_input(image_input) + for modality in mm_input_by_modality: + multimodal_input = mm_input_by_modality[modality] + if modality == "image": + vision_embeddings = self._process_image_input(multimodal_input) multimodal_embeddings += vision_embeddings - if modality == "videos": - video_input = modalities["videos"] - video_embeddings = self._process_video_input(video_input) + if modality == "video": + video_embeddings = self._process_video_input(multimodal_input) multimodal_embeddings += video_embeddings return multimodal_embeddings @@ -1161,13 +1178,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: @@ -1180,5 +1190,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, """ return MultiModelKeys.from_string_field( language_model="language_model", - connector="visual.", - tower_model="visual.merger.") \ No newline at end of file + connector="visual.merger.", + tower_model="visual.", + ) diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 362dbbb2f7cf7e99c72e63fec8ebe8f45c51c28c..afefe7b9e71fed365ee596bfbccbe40150454964 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -22,7 +22,6 @@ # limitations under the License. """Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property from typing import Any, Optional, Set, Tuple, TypedDict, Union import torch @@ -34,7 +33,6 @@ from transformers.models.qwen2_audio import (Qwen2AudioConfig, from transformers.models.whisper import WhisperFeatureExtractor from vllm.config import VllmConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, @@ -267,13 +265,6 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - def _validate_and_reshape_mm_tensor(self, mm_input: object, name: str) -> torch.Tensor: if not isinstance(mm_input, (torch.Tensor, list)): @@ -405,13 +396,6 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index b46fcac9f9c9b5ca0faa5a2f1bfc8b0586c92784..5ef7a7b5a6907f6887642d79b2be81e25596a7bb 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -47,7 +47,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -553,7 +552,6 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP): if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -580,14 +578,6 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP): sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index b7652a9f4cdd7f5c8b5f4f715448758387af3af6..bc0e1c2741a5fe1aa822728f131756e7f43c5357 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -24,7 +24,7 @@ # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property, partial +from functools import partial from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict, Union) @@ -51,7 +51,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinConfig) -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY @@ -1173,12 +1172,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): # GPTQ configs do not have a list of ignored modules, however AutoGPTQ @@ -1462,13 +1455,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: @@ -1481,5 +1467,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, """ return MultiModelKeys.from_string_field( language_model="language_model", - connector="visual.", - tower_model="visual.merger.") \ No newline at end of file + connector="visual.merger.", + tower_model="visual.", + ) diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index 9c14038e611333bbf1f443000833f8ccfd7342d7..73d2838f461ead1bdb1735e50046ad8b0f2dcf30 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -38,7 +38,6 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors @@ -283,7 +282,6 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -311,14 +309,6 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index f0ef79dfdfe28fa3e6f235f99f0c00e0802c1d44..70f9956e3efc77f7605b5974200703314ff8ddc3 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -44,7 +44,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -494,7 +493,6 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP): if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -521,14 +519,6 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP): sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 02ee3a8574435962e526fd20c2bf1e516c9e0276..79be5b0e65292c1c09457c931b01450b9c82fc05 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -100,6 +100,7 @@ _TEXT_GENERATION_MODELS = { "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), + "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), @@ -122,13 +123,11 @@ _TEXT_GENERATION_MODELS = { _EMBEDDING_MODELS = { # [Text-only] "BertModel": ("bert", "BertEmbeddingModel"), - "RobertaModel": ("roberta", "RobertaEmbeddingModel"), - "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"), - "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"), "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"), "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), "GritLM": ("gritlm", "GritLM"), + "GteModel": ("bert", "GteEmbeddingModel"), "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"), "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501 "LlamaModel": ("llama", "LlamaForCausalLM"), @@ -138,12 +137,16 @@ _EMBEDDING_MODELS = { if arch == "LlamaForCausalLM" }, "MistralModel": ("llama", "LlamaForCausalLM"), + "NomicBertModel": ("bert", "NomicBertEmbeddingModel"), "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"), + "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"), + "RobertaModel": ("roberta", "RobertaEmbeddingModel"), "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), + "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"), # [Multimodal] "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), @@ -162,6 +165,8 @@ _CROSS_ENCODER_MODELS = { "RobertaForSequenceClassification"), "XLMRobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"), + "ModernBertForSequenceClassification": ("modernbert", + "ModernBertForSequenceClassification"), } _MULTIMODAL_MODELS = { @@ -174,10 +179,12 @@ _MULTIMODAL_MODELS = { "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501 "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"), + "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"), # noqa: E501 "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"), "InternVLChatModel": ("internvl", "InternVLChatModel"), "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"), "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), # noqa: E501 + "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501 "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"), "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501 @@ -195,6 +202,7 @@ _MULTIMODAL_MODELS = { "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), # noqa: E501 "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501 + "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501 "UltravoxModel": ("ultravox", "UltravoxModel"), "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"), # [Encoder-decoder] @@ -208,6 +216,7 @@ _MULTIMODAL_MODELS = { _SPECULATIVE_DECODING_MODELS = { "EAGLEModel": ("eagle", "EAGLE"), "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"), + "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "MedusaModel": ("medusa", "Medusa"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py index 19a23162aa840b12cab20cca0a29d8b427d336e8..e78c37b65f874879ef641c455dc42fb55c9609c3 100644 --- a/vllm/model_executor/models/skyworkr1v.py +++ b/vllm/model_executor/models/skyworkr1v.py @@ -8,7 +8,6 @@ # -------------------------------------------------------- from abc import ABC, abstractmethod from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property from typing import Literal, Optional, Set, Tuple, TypedDict, TypeVar, Union import torch @@ -21,7 +20,6 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.awq import AWQConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.models.intern_vit import (InternVisionModel, InternVisionPatchModel) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -699,13 +697,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): (llm_quant_config is not None): quant_config.modules_to_not_convert.append("vision_model") - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - def _init_vision_model( self, config: PretrainedConfig, @@ -908,7 +899,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, - ) -> Union[SamplerOutput, IntermediateTensors]: + ) -> IntermediateTensors: if intermediate_tensors is not None: input_ids = None @@ -946,13 +937,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: skip_prefixes = [ diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index 1cae0a7fe0dcd60ec3ecfcc1dfc30be5ccff2996..f86aff7ba7ef0fbeda6752dee0bd1a39507d3340 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -418,8 +417,6 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): else: self.lm_head = PPMissingLayer() - self.sampler = get_sampler() - self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -440,14 +437,6 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 53f520304abc4fa2338a684ec81a8265858e75b5..1cbda7267e4c621e7885a9d148682be46c8dc069 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -36,7 +36,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -105,9 +104,8 @@ class StablelmAttention(nn.Module): 1, self.total_num_key_value_heads // tp_size) self.head_dim = self.hidden_size // self.total_num_heads self.max_position_embeddings = config.max_position_embeddings - rope_pct = getattr(config, "rope_pct", - getattr(config, "partial_rotary_factor", 1)) - self.rotary_ndims = int(self.head_dim * rope_pct) + self.partial_rotary_factor = getattr( + config, "rope_pct", getattr(config, "partial_rotary_factor", 1)) self.scaling = self.head_dim**-0.5 self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_key_value_heads * self.head_dim @@ -131,9 +129,10 @@ class StablelmAttention(nn.Module): prefix=f"{prefix}.o_proj") self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.rotary_ndims, + rotary_dim=self.head_dim, max_position=self.config.max_position_embeddings, base=self.config.rope_theta, + partial_rotary_factor=self.partial_rotary_factor, ) self.attn = Attention(self.num_heads, self.head_dim, @@ -310,7 +309,6 @@ class StablelmForCausalLM(nn.Module, SupportsPP): if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -337,14 +335,6 @@ class StablelmForCausalLM(nn.Module, SupportsPP): sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 8b9fb7cb7bc6e568b5bc8cf84c81b6977fc9619b..6eebe4c4d61451d9df2cd3009c1dee7e8971acdf 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -36,7 +36,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -317,7 +316,6 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP): ) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -344,14 +342,6 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP): sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index a1f233e04892e7c39c5e4d2b31188195bb9abb11..a37e88a387fdae1904b2e9df95c423ea2ae349b6 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -35,7 +35,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -84,7 +83,7 @@ def replace_linear_class( ) -> Union[ColumnParallelLinear, RowParallelLinear]: """ Replace nn.Linear with one of vLLM's tensor parallel linear classes. - + Args: linear (nn.Linear): `nn.Linear` to be replaced. style (str): Tensor parallel style of the new linear, e.g. "colwise". @@ -396,8 +395,6 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA, else: self.lm_head = PPMissingLayer() - self.sampler = get_sampler() - self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -435,12 +432,6 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA, sampling_metadata) return logits - def sample(self, logits: torch.Tensor, - sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: - - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index cb5ff4ed6365bd4223e0d460b1380a3cdea379a9..bfa48099b74164c82c9c17f3299a2a28d595e8aa 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -3,7 +3,6 @@ # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py """PyTorch Ultravox model.""" from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union import torch @@ -18,7 +17,6 @@ from vllm.config import VllmConfig from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -438,13 +436,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - def get_mm_mapping(self) -> MultiModelKeys: """ Get the module prefix in multimodal models @@ -628,13 +619,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 63e71f26880576ef8f237f51865b5c0b87c296f0..908cd7885aa8324ddfe73c57a41934f2edd12d72 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -21,7 +21,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -669,7 +668,6 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) - self.sampler = Sampler() def forward( self, @@ -724,14 +722,6 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."]) diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index ea21fffaede560a62060d9c785d3fd4e203b98d8..d34033e3ac90009a60ff3f174f1c2836c29c520f 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -31,7 +31,6 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import ( MambaMixer2, extra_groups_for_head_shards) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -870,7 +869,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): # Initialize logits processing and sampling self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = get_sampler() def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: """Convert input token IDs to embeddings. @@ -1004,23 +1002,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - """Sample next tokens from computed logits. - - Args: - logits: Computed logits for next token prediction - sampling_metadata: Metadata for sampling process - - Returns: - Sampled tokens and related sampling information - """ - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 2b1294bf7baa3f72010fdf552173ef198a7736a0..34a0b527b585ed55c34e0954f3231e780fee27c8 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -282,10 +282,12 @@ class PackedColumnParameter(_ColumnvLLMParameter): packed_factor: Union[int, Fraction], packed_dim: int, marlin_tile_size: Optional[int] = None, + bitblas_tile_size: Optional[int] = None, **kwargs): self._packed_factor = packed_factor self._packed_dim = packed_dim self._marlin_tile_size = marlin_tile_size + self._bitblas_tile_size = bitblas_tile_size super().__init__(**kwargs) @property @@ -300,12 +302,17 @@ class PackedColumnParameter(_ColumnvLLMParameter): def marlin_tile_size(self): return self._marlin_tile_size + @property + def bitblas_tile_size(self): + return self._bitblas_tile_size + def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): return _adjust_shard_indexes_for_packing( shard_size=shard_size, shard_offset=shard_offset, packed_factor=self.packed_factor, - marlin_tile_size=self.marlin_tile_size) + marlin_tile_size=self.marlin_tile_size, + bitblas_tile_size=self.bitblas_tile_size) class PackedvLLMParameter(ModelWeightParameter): @@ -323,10 +330,12 @@ class PackedvLLMParameter(ModelWeightParameter): packed_factor: Union[int, Fraction], packed_dim: int, marlin_tile_size: Optional[int] = None, + bitblas_tile_size: Optional[int] = None, **kwargs): self._packed_factor = packed_factor self._packed_dim = packed_dim self._marlin_tile_size = marlin_tile_size + self._bitblas_tile_size = bitblas_tile_size super().__init__(**kwargs) @property @@ -341,12 +350,17 @@ class PackedvLLMParameter(ModelWeightParameter): def marlin_tile_size(self): return self._marlin_tile_size + @property + def bitblas_tile_size(self): + return self._bitblas_tile_size + def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): return _adjust_shard_indexes_for_packing( shard_size=shard_size, shard_offset=shard_offset, packed_factor=self.packed_factor, - marlin_tile_size=self.marlin_tile_size) + marlin_tile_size=self.marlin_tile_size, + bitblas_tile_size=self.bitblas_tile_size) class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter): @@ -421,8 +435,13 @@ def _adjust_shard_indexes_for_marlin(shard_size, shard_offset, return shard_size * marlin_tile_size, shard_offset * marlin_tile_size +def _adjust_shard_indexes_for_bitblas(shard_size, shard_offset, + bitblas_tile_size): + return shard_size // bitblas_tile_size, shard_offset // bitblas_tile_size + + def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor, - marlin_tile_size): + marlin_tile_size, bitblas_tile_size): shard_size = shard_size // packed_factor shard_offset = shard_offset // packed_factor if marlin_tile_size is not None: @@ -430,4 +449,10 @@ def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor, shard_size=shard_size, shard_offset=shard_offset, marlin_tile_size=marlin_tile_size) - return shard_size, shard_offset + elif bitblas_tile_size is not None: + return _adjust_shard_indexes_for_bitblas( + shard_size=shard_size, + shard_offset=shard_offset, + bitblas_tile_size=bitblas_tile_size) + + return shard_size, shard_offset \ No newline at end of file diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index ca48dae3756bcd8faa01a771240e6475a9401358..3d555df036ccc2e8e3a1fe5944c4392f19c9f621 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 - -from .base import MultiModalPlaceholderMap, MultiModalPlugin +from .base import MultiModalPlaceholderMap from .hasher import MultiModalHashDict, MultiModalHasher from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins, MultiModalDataDict, MultiModalKwargs, @@ -26,7 +25,6 @@ __all__ = [ "MultiModalKwargs", "MultiModalPlaceholderDict", "MultiModalPlaceholderMap", - "MultiModalPlugin", "NestedTensors", "MULTIMODAL_REGISTRY", "MultiModalRegistry", diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py index f379ec1682a3c99eeecbda7a08b6f9097882c920..1fd2ab7f87d1f41aa80e2ef68353d578a06fbc73 100644 --- a/vllm/multimodal/audio.py +++ b/vllm/multimodal/audio.py @@ -1,17 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 - import base64 from io import BytesIO from pathlib import Path +from typing import Literal, Optional import numpy as np import numpy.typing as npt -from vllm.inputs.registry import InputContext from vllm.utils import PlaceholderModule -from .base import MediaIO, MultiModalPlugin -from .inputs import AudioItem, ModalityData, MultiModalKwargs +from .base import MediaIO try: import librosa @@ -24,26 +22,7 @@ except ImportError: soundfile = PlaceholderModule("soundfile") # type: ignore[assignment] -class AudioPlugin(MultiModalPlugin): - """Plugin for audio data.""" - - def get_data_key(self) -> str: - return "audio" - - def _default_input_mapper( - self, - ctx: InputContext, - data: ModalityData[AudioItem], - **mm_processor_kwargs, - ) -> MultiModalKwargs: - raise NotImplementedError("There is no default audio input mapper") - - def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: - raise NotImplementedError( - "There is no default maximum multimodal tokens") - - -def resample_audio( +def resample_audio_librosa( audio: npt.NDArray[np.floating], *, orig_sr: float, @@ -52,6 +31,55 @@ def resample_audio( return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr) +def resample_audio_scipy( + audio: npt.NDArray[np.floating], + *, + orig_sr: float, + target_sr: float, +): + # lazy import scipy.signal, otherwise it will crash doc build. + import scipy.signal + + if orig_sr > target_sr: + return scipy.signal.resample_poly(audio, 1, orig_sr // target_sr) + elif orig_sr < target_sr: + return scipy.signal.resample_poly(audio, target_sr // orig_sr, 1) + return audio + + +class AudioResampler: + """Resample audio data to a target sample rate.""" + + def __init__( + self, + target_sr: Optional[float] = None, + method: Literal["librosa", "scipy"] = "librosa", + ): + self.target_sr = target_sr + self.method = method + + def resample( + self, + audio: npt.NDArray[np.floating], + *, + orig_sr: float, + ) -> npt.NDArray[np.floating]: + if self.target_sr is None: + raise RuntimeError("Audio resampling is not supported when " + "`target_sr` is not provided") + if self.method == "librosa": + return resample_audio_librosa(audio, + orig_sr=orig_sr, + target_sr=self.target_sr) + elif self.method == "scipy": + return resample_audio_scipy(audio, + orig_sr=orig_sr, + target_sr=self.target_sr) + else: + raise ValueError(f"Invalid resampling method: {self.method}. " + "Supported methods are 'librosa' and 'scipy'.") + + class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]): def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]: diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index ad95b982499c94cdfd940079eaeaf75733f36165..2f93922fcedb97efb66b3acbf7c17df9354e5c91 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -1,247 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from collections import defaultdict from collections.abc import Sequence from pathlib import Path -from typing import (TYPE_CHECKING, Any, Callable, Generic, NamedTuple, - Optional, TypeVar, Union) - -from torch import nn - -from vllm.inputs import InputContext -from vllm.logger import init_logger -from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides, - resolve_mm_processor_kwargs) +from typing import TYPE_CHECKING, Generic, NamedTuple, TypeVar if TYPE_CHECKING: - from vllm.config import ModelConfig from vllm.sequence import SequenceGroupMetadata -from .inputs import (ModalityData, MultiModalDataDict, MultiModalKwargs, - PlaceholderRange) - -logger = init_logger(__name__) - -MultiModalInputMapper = Callable[[InputContext, ModalityData[object]], - MultiModalKwargs] -""" -Return a dictionary to be passed as keyword arguments to -:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers -and processors in HuggingFace Transformers. - -If the data is not supported, throw :exc:`TypeError`. -""" - -MultiModalTokensCalc = Union[int, Callable[[InputContext], int]] -""" -Calculate the maximum number of multimodal tokens input to the language -model. This does not include tokens that correspond to the input text. -""" +from .inputs import MultiModalKwargs, PlaceholderRange _T = TypeVar("_T") -N = TypeVar("N", bound=type[nn.Module]) - - -class MultiModalPlugin(ABC): - """ - Base class that defines data processing logic for a specific modality. - - In particular, we adopt a registry pattern to dispatch data processing - according to the model being used (considering that different models may - process the same data differently). This registry is in turn used by - :class:`~MultiModalRegistry` which acts at a higher level - (i.e., the modality of the data). - """ - - def __init__(self) -> None: - self._input_mappers = ClassRegistry[nn.Module, MultiModalInputMapper]() - self._max_mm_tokens = ClassRegistry[nn.Module, MultiModalTokensCalc]() - - @abstractmethod - def get_data_key(self) -> str: - """ - Get the data key corresponding to the modality. - """ - raise NotImplementedError - - @abstractmethod - def _default_input_mapper( - self, - ctx: InputContext, - data: ModalityData[Any], - **mm_processor_kwargs, - ) -> MultiModalKwargs: - """ - Return a dictionary to be passed as keyword arguments to - :meth:`~torch.nn.Module.forward`. This is similar in concept to - tokenizers and processors in HuggingFace Transformers. - - If the data is not supported, throw :exc:`TypeError`. - """ - raise NotImplementedError - - def register_input_mapper( - self, - mapper: Optional[MultiModalInputMapper] = None, - ): - """ - Register an input mapper to a model class. - - When the model receives input data that matches the modality served by - this plugin (see :meth:`get_data_key`), the provided function is - invoked to transform the data into a dictionary of model inputs. - - If `None` is provided, then the default input mapper is used instead. - """ - - def wrapper(model_cls: N) -> N: - if self._input_mappers.contains(model_cls, strict=True): - logger.warning( - "Model class %s already has an input mapper " - "registered to %s. It is overwritten by the new one.", - model_cls, - self, - ) - - self._input_mappers[model_cls] = (mapper - or self._default_input_mapper) - - return model_cls - - return wrapper - - def map_input( - self, - model_config: "ModelConfig", - data: ModalityData[Any], - mm_processor_kwargs: Optional[dict[str, Any]], - ) -> MultiModalKwargs: - """ - Transform the data into a dictionary of model inputs using the - input mapper registered for that model. - - The model is identified by ``model_config``. - - Raises: - TypeError: If the data type is not supported. - """ - - # Avoid circular import - from vllm.model_executor.model_loader import get_model_architecture - - model_cls, _ = get_model_architecture(model_config) - - mapper = self._input_mappers.get(model_cls) - - if mapper is None: - raise KeyError(f"No input mapper in {self} is registered for " - f"model class {model_cls.__name__}.") - - if mm_processor_kwargs is None: - mm_processor_kwargs = {} - - # In the case of the default mapper, we have to get resource - # processor through its HuggingFace autoclass; since this goes - # through **kwargs, we can't inspect it the same way, so we allow - # drop mm_processor_kwargs based on signature inspection - # if we're using the default mapper. - # - # This should be safe in general due to the sanitation, since the - # transformers resource should filter unused kwargs anyway. - uses_default_mapper = mapper == self._default_input_mapper - mm_processor_kwargs = resolve_mm_processor_kwargs( - model_config.mm_processor_kwargs, - mm_processor_kwargs, - callable=mapper, - allow_var_kwargs=uses_default_mapper, - ) - return mapper(InputContext(model_config), data, **mm_processor_kwargs) - - @abstractmethod - def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: - """ - Calculate the maximum number of tokens, corresponding to a single - instance of multimodal data, that are passed to the language model. - """ - raise NotImplementedError - - def _validate_max_multimodal_tokens(self, max_mm_tokens: int): - if max_mm_tokens < 1: - raise ValueError("You should set the number of tokens to a " - f"positive integer. Found: {max_mm_tokens}") - - def register_max_multimodal_tokens( - self, - max_mm_tokens: Optional[MultiModalTokensCalc] = None, - ): - """ - Register the maximum number of tokens, corresponding to a single - instance of multimodal data, that are passed to the language model - for a model class. - - If `None` is provided, then the default calculation is used instead. - """ - - def wrapper(model_cls: N) -> N: - if self._max_mm_tokens.contains(model_cls, strict=True): - logger.warning( - "Model class %s already calculates maximum number of " - "tokens in %s. It is overwritten by the new one.", - model_cls, - self, - ) - - if isinstance(max_mm_tokens, int): - self._validate_max_multimodal_tokens(max_mm_tokens) - - self._max_mm_tokens[model_cls] = ( - max_mm_tokens or self._default_max_multimodal_tokens) - - return model_cls - - return wrapper - - def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: - """ - Get the maximum number of multi-modal tokens - for profiling the memory usage of a model. - - If this registry is not applicable to the model, `0` is returned. - - The model is identified by ``model_config``. - """ - # Avoid circular import - from vllm.model_executor.model_loader import get_model_architecture - from vllm.model_executor.models import supports_multimodal - - model_cls, _ = get_model_architecture(model_config) - - if not supports_multimodal(model_cls): - return 0 - - max_mm_tokens = self._max_mm_tokens.get(model_cls) - if max_mm_tokens is None: - return 0 - - if callable(max_mm_tokens): - mm_processor_kwargs = get_allowed_kwarg_only_overrides( - max_mm_tokens, - overrides=model_config.mm_processor_kwargs, - requires_kw_only=False, - allow_var_kwargs=True, - ) - max_mm_tokens = max_mm_tokens(InputContext(model_config), - **mm_processor_kwargs) - - self._validate_max_multimodal_tokens(max_mm_tokens) - - return max_mm_tokens class MultiModalPlaceholderMap: """ Relates multi-modal embeddings to their corresponding placeholders. + + Note: This is only used in V0. """ class IndexMap(NamedTuple): @@ -279,8 +55,7 @@ class MultiModalPlaceholderMap: @classmethod def from_seq_group( cls, seq_group: "SequenceGroupMetadata", positions: range - ) -> tuple[Optional[MultiModalDataDict], dict[str, - "MultiModalPlaceholderMap"]]: + ) -> tuple[MultiModalKwargs, dict[str, "MultiModalPlaceholderMap"]]: """ Returns the multi-modal items that intersect with the portion of a prompt (``seq_group``) represented by ``positions``, as well as a @@ -323,48 +98,24 @@ class MultiModalPlaceholderMap: seq_mm_placeholders = seq_group.multi_modal_placeholders if not seq_mm_data or not seq_mm_placeholders: - return seq_mm_data, {} - - # For merged processor, we directly use mm_kwargs as mm_data - if isinstance(seq_mm_data, MultiModalKwargs): - placeholder_maps = dict[str, MultiModalPlaceholderMap]() - - for modality, placeholders in seq_mm_placeholders.items(): - placeholder_map = MultiModalPlaceholderMap() + return MultiModalKwargs({}), {} - if positions: - placeholder_map.append_items_from_seq_group( - positions, - # Dummy, since we don't care about intersecting items - [None] * len(placeholders), - placeholders, - ) - - placeholder_maps[modality] = placeholder_map - - return seq_mm_data, placeholder_maps - - mm_data = {**seq_mm_data} - placeholder_maps = defaultdict[str, MultiModalPlaceholderMap]( - MultiModalPlaceholderMap) + placeholder_maps = dict[str, MultiModalPlaceholderMap]() for modality, placeholders in seq_mm_placeholders.items(): - mm_items = mm_data.pop(modality) - if not isinstance(mm_items, list): - mm_items = [mm_items] + placeholder_map = MultiModalPlaceholderMap() if positions: - intersecting_items = placeholder_maps[modality] \ - .append_items_from_seq_group( - positions, - mm_items, - placeholders, - ) + placeholder_map.append_items_from_seq_group( + positions, + # Dummy, since we don't care about intersecting items + [None] * len(placeholders), + placeholders, + ) - if intersecting_items: - mm_data[modality] = intersecting_items + placeholder_maps[modality] = placeholder_map - return mm_data, placeholder_maps + return seq_mm_data, placeholder_maps def append_items_from_seq_group( self, @@ -445,8 +196,7 @@ class MultiModalPlaceholderMap: f"The number of source ({len(src_indices)}) and destination " f"indices ({len(dest_indices)}) must be the same.") - return MultiModalPlaceholderMap.IndexMap(src=src_indices, - dest=dest_indices) + return self.IndexMap(src=src_indices, dest=dest_indices) class MediaIO(ABC, Generic[_T]): diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 0c5a84c6508a1c0fc9ad04364a861daa26e3f950..939928bbf108b35b9e39df67287fc6b7a8995537 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -3,89 +3,11 @@ import base64 from io import BytesIO from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional import torch from PIL import Image -from vllm.inputs.registry import InputContext -from vllm.logger import init_logger -from vllm.transformers_utils.processor import cached_get_image_processor -from vllm.utils import is_list_of - -from .base import MediaIO, MultiModalPlugin -from .inputs import ImageItem, ModalityData, MultiModalKwargs - -if TYPE_CHECKING: - from vllm.config import ModelConfig - -logger = init_logger(__name__) - - -class ImagePlugin(MultiModalPlugin): - """Plugin for image data.""" - - def get_data_key(self) -> str: - return "image" - - def _get_hf_image_processor( - self, - model_config: "ModelConfig", - mm_processor_kwargs: Optional[dict[str, Any]] = None, - ): - if mm_processor_kwargs is None: - mm_processor_kwargs = {} - return cached_get_image_processor( - model_config.model, - trust_remote_code=model_config.trust_remote_code, - **mm_processor_kwargs) - - def _default_input_mapper( - self, - ctx: InputContext, - data: ModalityData[ImageItem], - **mm_processor_kwargs, - ) -> MultiModalKwargs: - model_config = ctx.model_config - - # PIL image - if isinstance(data, Image.Image) or is_list_of(data, Image.Image): - image_processor = self._get_hf_image_processor( - model_config, - mm_processor_kwargs, - ) - - if image_processor is None: - raise RuntimeError("No HuggingFace processor is available " - "to process the image object") - try: - # NOTE: It may make sense to forward the mm_processor_kwargs - # here too. For now, to keep it simple, we only allow it be - # used for the initialization call though, just in case the - # signatures of the preprocessor initializer don't match - # preprocess() - batch_data = image_processor \ - .preprocess(data, return_tensors="pt") \ - .data - except Exception: - logger.error( - "Failed to process image (%s) with the default mapper. " - "This is most likely an edge-case with this model's image " - "processor in transformers (type: %s), and not vLLM.", - data, - type(image_processor).__name__) - raise - - return MultiModalKwargs(batch_data) - - # Image embedding - elif isinstance(data, torch.Tensor) or is_list_of(data, torch.Tensor): - return MultiModalKwargs({"image_embeds": data}) - - raise TypeError(f"Invalid image type: {type(data)}") - - def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: - return 3000 +from .base import MediaIO def rescale_image_size(image: Image.Image, diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 53729799b629cf1e85a3771ea373fadf778a29be..6855808e8e44a33b1786538abea699f50f5943f6 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -320,7 +320,8 @@ class MultiModalFlatField(BaseMultiModalField): :func:`MultiModalFieldConfig.flat` :func:`MultiModalFieldConfig.flat_from_sizes` """ - slices: Sequence[slice] + slices: Union[Sequence[slice], Sequence[Sequence[slice]]] + dim: int = 0 def build_elems( self, @@ -329,7 +330,10 @@ class MultiModalFlatField(BaseMultiModalField): data: NestedTensors, ) -> Sequence[MultiModalFieldElem]: field_factory = self._field_factory(modality=modality, key=key) - return [field_factory(data[s]) for s in self.slices] + if not is_list_of(self.slices, slice, check="all"): + assert isinstance(data, torch.Tensor), \ + "torch.Tensor is required for multiple slices" + return [field_factory(data[cast(slice, s)]) for s in self.slices] def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): @@ -338,10 +342,16 @@ class MultiModalFlatField(BaseMultiModalField): # - produce exactly same result as `torch.concat(batch)` # - will achieve zero-copy if the tensor is contiguous return batch[0].contiguous() - first_shape = batch[0].shape - if all(elem.shape[1:] == first_shape[1:] for elem in batch): - return torch.concat(batch) + def _expect_same_shape(tensor: torch.Tensor): + return tensor.shape[:self.dim] + tensor.shape[self.dim + 1:] + + first_shape = _expect_same_shape(batch[0]) + + if all(_expect_same_shape(elem) == first_shape for elem in batch): + return torch.concat(batch, dim=self.dim) + + assert self.dim == 0, "dim == 0 is required for nested list" return [e for elem in batch for e in elem] @@ -398,7 +408,9 @@ class MultiModalFieldConfig: ) @staticmethod - def flat(modality: str, slices: Sequence[slice]): + def flat(modality: str, + slices: Union[Sequence[slice], Sequence[Sequence[slice]]], + dim: int = 0): """ Defines a field where an element in the batch is obtained by slicing along the first dimension of the underlying data. @@ -406,8 +418,10 @@ class MultiModalFieldConfig: Args: modality: The modality of the multi-modal item that uses this keyword argument. - slices: For each multi-modal item, a slice that is used to extract - the data corresponding to it. + slices: For each multi-modal item, a slice (dim=0) or a tuple of + slices (dim>0) that is used to extract the data corresponding + to it. + dim: The dimension to extract data, default to 0. Example: @@ -423,14 +437,33 @@ class MultiModalFieldConfig: Element 1: [AAA] Element 2: [BBBB] Element 3: [CC] + + .. code-block:: + + Given: + slices: [ + (slice(None), slice(0, 3)), + (slice(None), slice(3, 7)), + (slice(None), slice(7, 9))] + dim: 1 + + Input: + Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]] + + Output: + Element 1: [[A],[A],[A]] + Element 2: [[B],[B],[B],[B]] + Element 3: [[C],[C]] """ return MultiModalFieldConfig( - field=MultiModalFlatField(slices=slices), + field=MultiModalFlatField(slices=slices, dim=dim), modality=modality, ) @staticmethod - def flat_from_sizes(modality: str, size_per_item: torch.Tensor): + def flat_from_sizes(modality: str, + size_per_item: torch.Tensor, + dim: int = 0): """ Defines a field where an element in the batch is obtained by slicing along the first dimension of the underlying data. @@ -440,6 +473,7 @@ class MultiModalFieldConfig: keyword argument. slices: For each multi-modal item, the size of the slice that is used to extract the data corresponding to it. + dim: The dimension to slice, default to 0. Example: @@ -455,6 +489,21 @@ class MultiModalFieldConfig: Element 1: [AAA] Element 2: [BBBB] Element 3: [CC] + + + .. code-block:: + + Given: + slices: [3, 4, 2] + dim: 1 + + Input: + Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]] + + Output: + Element 1: [[A],[A],[A]] + Element 2: [[B],[B],[B],[B]] + Element 3: [[C],[C]] See also: :func:`MultiModalFieldConfig.flat` @@ -465,12 +514,11 @@ class MultiModalFieldConfig: f"but found shape: {size_per_item.shape}") slice_idxs = [0, *accumulate(size_per_item)] - slices = [ - slice(slice_idxs[i], slice_idxs[i + 1]) - for i in range(len(size_per_item)) - ] + slices = [(slice(None, None, None), ) * dim + + (slice(slice_idxs[i], slice_idxs[i + 1]), ) + for i in range(len(size_per_item))] - return MultiModalFieldConfig.flat(modality, slices) + return MultiModalFieldConfig.flat(modality, slices, dim=dim) @staticmethod def shared(modality: str, batch_size: int): diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py index 5720aa1eace0a0b61a89209fd73c7a403be14640..b32e129ddd1448f4ed479842dee1a1abc592fc52 100644 --- a/vllm/multimodal/parse.py +++ b/vllm/multimodal/parse.py @@ -3,8 +3,8 @@ from abc import ABC, abstractmethod from collections import UserDict from collections.abc import Callable, Iterator, Mapping, Sequence -from typing import (TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar, - Union) +from typing import (TYPE_CHECKING, Any, Generic, Literal, NamedTuple, Optional, + TypeVar, Union) import numpy as np import torch @@ -14,7 +14,7 @@ from typing_extensions import TypeAlias, TypeGuard, assert_never from vllm.utils import is_list_of -from .audio import resample_audio +from .audio import AudioResampler from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem, ImageItem, ModalityData, MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargs, VideoItem) @@ -308,10 +308,18 @@ class MultiModalDataParser: items to the model's expected sampling rate. """ - def __init__(self, *, target_sr: Optional[float] = None) -> None: + def __init__( + self, + *, + target_sr: Optional[float] = None, + audio_resample_method: Literal["librosa", "scipy"] = "librosa", + ) -> None: super().__init__() - self.target_sr = target_sr + self.audio_resampler = AudioResampler( + target_sr=target_sr, + method=audio_resample_method, + ) def _is_embeddings( self, data: object @@ -374,15 +382,8 @@ class MultiModalDataParser: if orig_sr is None: new_audio = audio else: - target_sr = self.target_sr - if target_sr is None: - raise RuntimeError( - "Audio resampling is not supported when " - "`target_sr` is not provided") - - new_audio = resample_audio(audio, - orig_sr=orig_sr, - target_sr=target_sr) + new_audio = self.audio_resampler.resample(audio, + orig_sr=orig_sr) new_audios.append(new_audio) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 7f289426d349e35a38dd3ac2c4f793c66dde035a..87131122e6f2c03efad55ae5104c2c54a9c805c0 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +import json import re import sys from abc import ABC, abstractmethod @@ -1117,8 +1118,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): if num_items > allowed_limit: raise ValueError( - f"You set or defaulted to {modality}={allowed_limit} " - f"in --limit-mm-per-prompt`, but passed {num_items} " + "You set or defaulted to " + f"'{json.dumps({modality: allowed_limit})}' in " + f"`--limit-mm-per-prompt`, but passed {num_items} " f"{modality} items in the same prompt.") return mm_items @@ -1567,56 +1569,35 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): "model (usually arising from an inconsistency between " "`_call_hf_processor` and `_get_prompt_updates`).") - def apply( + def _hash_mm_items( self, - prompt: Union[str, list[int]], - mm_data: MultiModalDataDict, + mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - return_mm_hashes: bool = False, - ) -> MultiModalInputs: - """ - Process multi-modal inputs to be used in vLLM. + ) -> dict[str, list[str]]: + """Create MM hashes to be returned (only used in V1).""" - The main steps are: - - 1. Apply HF Processor on prompt text and multi-modal data together, - outputting token IDs and processed tensors. - 2. Find and update sequences in the token IDs with placeholder tokens. - The number of placeholder tokens equals the feature size of the - multi-modal data outputted by the multi-modal encoder. - 3. Extract information about the placeholder tokens from the - processed token IDs. - """ - mm_items = self._to_mm_items(mm_data) - - # Create MM hashes to be returned (only used in V1) # TODO: Use these hash keys for caching operations in apply_hf_processor # instead of rehashing. + model_id = self.info.model_id - if return_mm_hashes: - model_id = self.info.model_id - mm_hashes = { - modality: [ - MultiModalHasher.hash_kwargs(model_id=model_id, - **{modality: item}, - **hf_processor_mm_kwargs) - for item in items - ] - for modality, items in mm_items.items() - } - else: - mm_hashes = None - - ( - prompt_ids, - mm_kwargs, - is_update_applied, - ) = self._cached_apply_hf_processor( - prompt, - mm_items, - hf_processor_mm_kwargs, - ) + return { + modality: [ + MultiModalHasher.hash_kwargs(model_id=model_id, + **{modality: item}, + **hf_processor_mm_kwargs) + for item in items + ] + for modality, items in mm_items.items() + } + def _maybe_apply_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + prompt_ids: list[int], + mm_kwargs: MultiModalKwargs, + is_update_applied: bool, + ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: unbound_prompt_updates = self._get_prompt_updates( mm_items, hf_processor_mm_kwargs, @@ -1650,6 +1631,51 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ) self._validate_mm_placeholders(mm_placeholders, mm_item_counts) + return prompt_ids, prompt, mm_placeholders + + def apply( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + return_mm_hashes: bool = False, + ) -> MultiModalInputs: + """ + Process multi-modal inputs to be used in vLLM. + + The main steps are: + + 1. Apply HF Processor on prompt text and multi-modal data together, + outputting token IDs and processed tensors. + 2. Find and update sequences in the token IDs with placeholder tokens. + The number of placeholder tokens equals the feature size of the + multi-modal data outputted by the multi-modal encoder. + 3. Extract information about the placeholder tokens from the + processed token IDs. + """ + mm_items = self._to_mm_items(mm_data) + + mm_hashes = (self._hash_mm_items(mm_items, hf_processor_mm_kwargs) + if return_mm_hashes else None) + + ( + prompt_ids, + mm_kwargs, + is_update_applied, + ) = self._cached_apply_hf_processor( + prompt, + mm_items, + hf_processor_mm_kwargs, + ) + + prompt_ids, prompt, mm_placeholders = self._maybe_apply_prompt_updates( + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + prompt_ids=prompt_ids, + mm_kwargs=mm_kwargs, + is_update_applied=is_update_applied, + ) + mm_placeholder_ranges = { modality: [item.to_range() for item in placeholders] for modality, placeholders in mm_placeholders.items() diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index def0595013b8b21312b60f8b883d320b387405ac..ec4f1568101966a452688f1e0b19d12eaead3951 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -1,12 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 - -import functools -from collections import UserDict -from collections.abc import Mapping, Sequence +from collections.abc import Mapping from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generic, Optional, Protocol, TypeVar +from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar import torch.nn as nn +from typing_extensions import deprecated from vllm.envs import VLLM_MM_INPUT_CACHE_GIB from vllm.inputs import InputProcessingContext @@ -15,15 +13,10 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, cached_tokenizer_from_config) from vllm.utils import ClassRegistry -from .audio import AudioPlugin -from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc -from .image import ImagePlugin -from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors from .processing import (BaseMultiModalProcessor, BaseProcessingInfo, ProcessingCache) from .profiling import (BaseDummyInputsBuilder, DummyDecoderData, DummyEncoderData, MultiModalProfiler) -from .video import VideoPlugin if TYPE_CHECKING: from vllm.config import ModelConfig @@ -84,169 +77,23 @@ class _ProcessorFactories(Generic[_I]): return self.processor(info, dummy_inputs_builder, cache=cache) -class _MultiModalLimits(UserDict["ModelConfig", dict[str, int]]): - """ - Wraps `_limits_by_model` for a more informative error message - when attempting to access a model that does not exist. - """ - - def __getitem__(self, key: "ModelConfig") -> dict[str, int]: - try: - return super().__getitem__(key) - except KeyError as exc: - msg = (f"Cannot find `mm_limits` for model={key.model}. Did you " - "forget to call `init_mm_limits_per_prompt`?") - raise KeyError(msg) from exc - - class MultiModalRegistry: """ A registry that dispatches data processing according to the model. """ - DEFAULT_PLUGINS = (ImagePlugin(), AudioPlugin(), VideoPlugin()) - - def __init__( - self, - *, - plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None: - self._plugins = {p.get_data_key(): p for p in plugins} - + def __init__(self) -> None: self._processor_factories = ClassRegistry[nn.Module, _ProcessorFactories]() - # This is used for non-multimodal models - self._disabled_limits_per_plugin = {k: 0 for k in self._plugins} - - self._limits_by_model = _MultiModalLimits() - self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_GIB) - def register_plugin(self, plugin: MultiModalPlugin) -> None: - """ - Register a multi-modal plugin so it can be recognized by vLLM. - """ - data_type_key = plugin.get_data_key() - - if data_type_key in self._plugins: - logger.warning( - "A plugin is already registered for data type %s, " - "and will be overwritten by the new plugin %s.", data_type_key, - plugin) - - self._plugins[data_type_key] = plugin - - def _get_plugin(self, data_type_key: str): - plugin = self._plugins.get(data_type_key) - if plugin is not None: - return plugin - - msg = f"Unknown multi-modal data type: {data_type_key}" - raise NotImplementedError(msg) - - def register_input_mapper( - self, - data_type_key: str, - mapper: Optional[MultiModalInputMapper] = None, - ): - """ - Register an input mapper for a specific modality to a model class. - - See :meth:`MultiModalPlugin.register_input_mapper` for more details. - """ - return self._get_plugin(data_type_key).register_input_mapper(mapper) - - def register_image_input_mapper( - self, - mapper: Optional[MultiModalInputMapper] = None, - ): - """ - Register an input mapper for image data to a model class. - - See :meth:`MultiModalPlugin.register_input_mapper` for more details. - """ - return self.register_input_mapper("image", mapper) - - def map_input( - self, - model_config: "ModelConfig", - data: MultiModalDataDict, - mm_processor_kwargs: Optional[dict[str, Any]] = None, - ) -> MultiModalKwargs: - """ - Apply an input mapper to the data passed to the model. - - The data belonging to each modality is passed to the corresponding - plugin which in turn converts the data into into keyword arguments - via the input mapper registered for that model. - - See :meth:`MultiModalPlugin.map_input` for more details. - - Note: - This should be called after :meth:`init_mm_limits_per_prompt`. - """ - merged_dict = dict[str, NestedTensors]() - - for data_key, data_value in data.items(): - plugin = self._get_plugin(data_key) - - num_items = len(data_value) if isinstance(data_value, list) else 1 - max_items = self._limits_by_model[model_config][data_key] - if num_items > max_items: - raise ValueError( - f"You set {data_key}={max_items} (or defaulted to 1) in " - f"`--limit-mm-per-prompt`, but found {num_items} items " - "in the same prompt.") - - input_dict = plugin.map_input(model_config, data_value, - mm_processor_kwargs) - for input_key, input_tensor in input_dict.items(): - if input_key in merged_dict: - raise ValueError(f"The input mappers (keys={set(data)}) " - f"resulted in a conflicting keyword " - f"argument to `forward()`: {input_key}") - - merged_dict[input_key] = input_tensor - - return MultiModalKwargs(merged_dict) - + @deprecated("Legacy input processor/mapper pipeline has been removed. " + "Please update your model runner to use " + "`seq_group_metadata.multi_modal_data` directly without " + "further processing.") def create_input_mapper(self, model_config: "ModelConfig"): - """ - Create an input mapper (see :meth:`map_input`) for a specific model. - """ - # NOTE - we currently make the assumption that if a model has multiple - # supported modalities, they take the same kwargs. For the default, - # this could be an issue in the future if it falls back to two HF - # resources and we can't inspect the signature easily since it's - # getting initialized through the autoclass. - # - # If this is a problem in the future, we should revisit it, but since - # it potentially introduces a lot of complexity for a currently - # uncommon case, we do not for simplicity of both use & implementation - return functools.partial(self.map_input, model_config) - - def register_max_multimodal_tokens( - self, - data_type_key: str, - max_mm_tokens: Optional[MultiModalTokensCalc] = None, - ): - """ - Register the maximum number of tokens, corresponding to a single - instance of multimodal data belonging to a specific modality, that are - passed to the language model for a model class. - """ - return self._get_plugin(data_type_key) \ - .register_max_multimodal_tokens(max_mm_tokens) - - def register_max_image_tokens( - self, - max_mm_tokens: Optional[MultiModalTokensCalc] = None, - ): - """ - Register the maximum number of image tokens, corresponding to a single - image, that are passed to the language model for a model class. - """ - return self.register_max_multimodal_tokens("image", max_mm_tokens) + return lambda data, mm_processor_kwargs: data def get_max_tokens_per_item_by_modality( self, @@ -256,25 +103,22 @@ class MultiModalRegistry: Get the maximum number of tokens per data item from each modality based on underlying model configuration. """ - if self.has_processor(model_config): - processor = self.create_processor(model_config, disable_cache=True) - profiler = MultiModalProfiler(processor) - - seq_len = model_config.max_model_len - mm_limits = self.get_mm_limits_per_prompt(model_config) - - return profiler.get_mm_max_tokens( - seq_len, - { - modality: 1 - for modality, limit in mm_limits.items() if limit > 0 - }, - ) + if not model_config.is_multimodal_model: + return {} - return { - key: plugin.get_max_multimodal_tokens(model_config) - for key, plugin in self._plugins.items() - } + processor = self.create_processor(model_config, disable_cache=True) + profiler = MultiModalProfiler(processor) + + seq_len = model_config.max_model_len + mm_limits = self.get_mm_limits_per_prompt(model_config) + + return profiler.get_mm_max_tokens( + seq_len, + { + modality: 1 + for modality, limit in mm_limits.items() if limit > 0 + }, + ) def get_max_tokens_per_item_by_nonzero_modality( self, @@ -307,9 +151,6 @@ class MultiModalRegistry: for profiling the memory usage of a model. See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details. - - Note: - This should be called after :meth:`init_mm_limits_per_prompt`. """ mm_limits = self.get_mm_limits_per_prompt(model_config) @@ -325,47 +166,18 @@ class MultiModalRegistry: for profiling the memory usage of a model. See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details. - - Note: - This should be called after :meth:`init_mm_limits_per_prompt`. """ return sum(self.get_max_tokens_by_modality(model_config).values()) + @deprecated("Legacy input processor/mapper pipeline has been removed. " + "Please update your model runner to use " + "`seq_group_metadata.multi_modal_data` directly without " + "further processing.") def init_mm_limits_per_prompt( self, model_config: "ModelConfig", ) -> None: - """ - Initialize the maximum number of multi-modal input instances for each - modality that are allowed per prompt for a model class. - """ - if model_config in self._limits_by_model: - logger.warning( - "`mm_limits` has already been set for model=%s, and will " - "be overwritten by the new values.", model_config.model) - - multimodal_config = model_config.multimodal_config - if multimodal_config is None: - limits_per_plugin = self._disabled_limits_per_plugin - else: - config_limits_per_plugin = multimodal_config.limit_per_prompt - - extra_keys = config_limits_per_plugin.keys() - self._plugins.keys() - if extra_keys: - logger.warning( - "Detected extra keys in `--limit-mm-per-prompt` which " - "are not registered as multi-modal plugins: %s. " - "They will be ignored.", extra_keys) - - # NOTE: Currently the default is set to 1 for each plugin - # TODO: Automatically determine the limits based on budget - # once more models support multi-image inputs - limits_per_plugin = { - key: multimodal_config.get_limit_per_prompt(key) - for key in self._plugins - } - - self._limits_by_model[model_config] = limits_per_plugin + pass def get_mm_limits_per_prompt( self, @@ -374,16 +186,13 @@ class MultiModalRegistry: """ Get the maximum number of multi-modal input instances for each modality that are allowed per prompt for a model class. - - Note: - This should be called after :meth:`init_mm_limits_per_prompt`. """ - if self.has_processor(model_config): - processor = self.create_processor(model_config, disable_cache=True) - profiler = MultiModalProfiler(processor) - return profiler.get_mm_limits() + if not model_config.is_multimodal_model: + return {} - return self._limits_by_model[model_config] + processor = self.create_processor(model_config, disable_cache=True) + profiler = MultiModalProfiler(processor) + return profiler.get_mm_limits() def register_processor( self, @@ -427,14 +236,12 @@ class MultiModalRegistry: model_cls, _ = get_model_architecture(model_config) return model_cls + @deprecated("Legacy input processor/mapper pipeline has been removed. " + "Please update your model runner to use " + "`seq_group_metadata.multi_modal_data` directly without " + "further processing.") def has_processor(self, model_config: "ModelConfig") -> bool: - """ - Test whether a multi-modal processor is defined for a specific model. - - See also: - :ref:`mm-processing` - """ - return self._get_model_cls(model_config) in self._processor_factories + return True def create_processor( self, @@ -449,6 +256,9 @@ class MultiModalRegistry: See also: :ref:`mm-processing` """ + if not model_config.is_multimodal_model: + raise ValueError(f"{model_config.model} is not a multimodal model") + if tokenizer is None: tokenizer = cached_tokenizer_from_config(model_config) if disable_cache is None: diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index f7c3f105295420d80549161a3a92275a25601645..6d875a1c651e27724d497a1c1104d17e3d840fb1 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -4,80 +4,13 @@ import base64 from functools import partial from io import BytesIO from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional import numpy as np import numpy.typing as npt from PIL import Image -from vllm.inputs.registry import InputContext -from vllm.logger import init_logger -from vllm.transformers_utils.processor import cached_get_video_processor -from vllm.utils import is_list_of - -from .base import MediaIO, ModalityData -from .image import ImageMediaIO, ImagePlugin -from .inputs import MultiModalKwargs, VideoItem - -if TYPE_CHECKING: - from vllm.config import ModelConfig - -logger = init_logger(__name__) - - -class VideoPlugin(ImagePlugin): - """Plugin for video data.""" - - def get_data_key(self) -> str: - return "video" - - def _get_hf_video_processor( - self, - model_config: "ModelConfig", - mm_processor_kwargs: Optional[dict[str, Any]] = None, - ): - if mm_processor_kwargs is None: - mm_processor_kwargs = {} - return cached_get_video_processor( - model_config.model, - trust_remote_code=model_config.trust_remote_code, - **mm_processor_kwargs) - - def _default_input_mapper( - self, - ctx: InputContext, - data: ModalityData[VideoItem], - **mm_processor_kwargs, - ) -> MultiModalKwargs: - model_config = ctx.model_config - - if isinstance(data, list) and len(data) == 1: - data = data[0] # type: ignore - - if isinstance(data, np.ndarray) or is_list_of(data, np.ndarray): - video_processor = self._get_hf_video_processor( - model_config, - mm_processor_kwargs, - ) - if video_processor is None: - raise RuntimeError("No HuggingFace processor is available " - "to process the video object") - try: - # NOTE: Similar to image; it may be a good idea to filter and - # pass mm_processor_kwargs here too, but for now we don't to - # avoid extra complexity if the initializer and preprocess - # signatures of the processor don't align - batch_data = video_processor(data, return_tensors="pt").data - except Exception: - logger.error("Failed to process video (%s)", data) - raise - - return MultiModalKwargs(batch_data) - - raise TypeError(f"Invalid video type: {type(data)}") - - def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: - return 4096 +from .base import MediaIO +from .image import ImageMediaIO def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray: diff --git a/vllm/outputs.py b/vllm/outputs.py index 014e8d5d8823818495a02bde49a6a0148128d34e..65a6ed01451dd90dab5a80fe9d02296e95aa39f6 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -134,26 +134,32 @@ class RequestOutput: self.encoder_prompt_token_ids = encoder_prompt_token_ids self.num_cached_tokens = num_cached_tokens - def add(self, next_output: "RequestOutput") -> None: + def add(self, next_output: "RequestOutput", aggregate: bool) -> None: """Merge subsequent RequestOutput into this one""" self.finished |= next_output.finished for next_completion in next_output.outputs: - for completion in self.outputs: + for i, completion in enumerate(self.outputs): if completion.index == next_completion.index: - # Merge outputs with same index - completion.text += next_completion.text - if not isinstance(completion.token_ids, MutableSequence): - completion.token_ids = list(completion.token_ids) - completion.token_ids.extend(next_completion.token_ids) - if next_completion.logprobs: - assert completion.logprobs is not None - completion.logprobs.extend(next_completion.logprobs) - completion.cumulative_logprob = ( - next_completion.cumulative_logprob) - completion.finish_reason = next_completion.finish_reason - completion.stop_reason = next_completion.stop_reason + if aggregate: + # Merge outputs with same index + completion.text += next_completion.text + if not isinstance(completion.token_ids, + MutableSequence): + completion.token_ids = list(completion.token_ids) + completion.token_ids.extend(next_completion.token_ids) + if next_completion.logprobs: + assert completion.logprobs is not None + completion.logprobs.extend( + next_completion.logprobs) + completion.cumulative_logprob = ( + next_completion.cumulative_logprob) + completion.finish_reason = next_completion.finish_reason + completion.stop_reason = next_completion.stop_reason + else: + # Replace the output with the new one + self.outputs[i] = next_completion break else: self.outputs.append(next_completion) @@ -173,6 +179,13 @@ class RequestOutput: group.finish_seq(seq_group) if assembled_seq_group is None: return None + + # clear finished seq in seq_id_to_seq_group + if len(group.to_be_finished) == 0: + for sub_request_id in list(group.seq_id_to_index.keys()): + if sub_request_id in seq_id_to_seq_group: + del seq_id_to_seq_group[sub_request_id] + return cls.from_seq_group(assembled_seq_group, use_cache, seq_id_to_seq_group) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 0576022be448b3f12c33c27843498369331f9424..f82af426b5a8bd4baca3751714a5b235e0c78c57 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -21,9 +21,6 @@ from .interface import DeviceCapability, Platform, PlatformEnum, _Backend if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig -else: - ModelConfig = None - VllmConfig = None logger = init_logger(__name__) @@ -109,7 +106,7 @@ class CudaPlatformBase(Platform): pass @classmethod - def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: parallel_config = vllm_config.parallel_config scheduler_config = vllm_config.scheduler_config compilation_config = vllm_config.compilation_config @@ -213,6 +210,9 @@ class CudaPlatformBase(Platform): return ("vllm.attention.backends." "flashmla.FlashMLABackend") if use_v1: + if selected_backend == _Backend.FLASHINFER: + logger.info_once("Using FlashInfer backend on V1 engine.") + return "vllm.v1.attention.backends.flashinfer.FlashInferBackend" if selected_backend == _Backend.TRITON_ATTN_VLLM_V1: logger.info_once("Using Triton backend on V1 engine.") return ("vllm.v1.attention.backends." @@ -305,7 +305,7 @@ class CudaPlatformBase(Platform): return cls.has_device_capability(89) @classmethod - def supports_v1(cls, model_config: ModelConfig) -> bool: + def supports_v1(cls, model_config: "ModelConfig") -> bool: return True @classmethod diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 31a7ffbd910d19e8fb68c017ebeafae8e1a9ed1d..c5555aba1a3e3a6d1e9088de6993d276f57aba7e 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union import numpy as np import torch -from vllm.inputs import PromptType +from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger if TYPE_CHECKING: @@ -39,6 +39,7 @@ class _Backend(enum.Enum): TRITON_ATTN_VLLM_V1 = enum.auto() XFORMERS = enum.auto() ROCM_FLASH = enum.auto() + ROCM_AITER_MLA = enum.auto() TORCH_SDPA = enum.auto() FLASHINFER = enum.auto() TRITON_MLA = enum.auto() # Supported by V1 @@ -148,6 +149,9 @@ class Platform: """Stateless version of :func:`torch.cuda.is_available`.""" return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM) + def is_sleep_mode_available(self) -> bool: + return self._enum == PlatformEnum.CUDA + @classmethod def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], @@ -397,9 +401,26 @@ class Platform: cls, prompt: PromptType, params: Union[SamplingParams, PoolingParams], + processed_inputs: ProcessorInputs, ) -> None: """Raises if this request is unsupported on this platform""" + def __getattr__(self, key: str): + device = getattr(torch, self.device_name, None) + if device is not None and hasattr(device, key): + return getattr(device, key) + else: + logger.warning("Current platform %s does not have '%s'" \ + " attribute.", self.device_name, key) + return None + + @classmethod + def get_cu_count(cls, device_id: int = 0) -> int: + """ + Returns the total number of compute units (CU) on single GPU. + """ + raise NotImplementedError + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index c1f426e5b880135bc9f2c57c1ee13e0eb51b92b6..e37a3a578cf20c93d008eaaca27681d7767b2669 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -50,7 +50,7 @@ class NeuronPlatform(Platform): if cache_config: # neuron needs block_size = max_model_len vllm_config.cache_config.block_size = \ - vllm_config.model_config.max_model_len + vllm_config.model_config.max_model_len # type: ignore @classmethod def is_pin_memory_available(cls) -> bool: diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 7cc6365e6b51f0b4aa7f61355884748b822eec7a..5f77310bf65139ad7b25f77761d32b5603490705 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -13,9 +13,6 @@ from .interface import DeviceCapability, Platform, PlatformEnum, _Backend if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig -else: - ModelConfig = None - VllmConfig = None logger = init_logger(__name__) @@ -99,24 +96,29 @@ def device_id_to_physical_device_id(device_id: int) -> int: return device_id +def on_mi250_mi300() -> bool: + GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName + return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"]) + + @cache def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, block_size: int, gqa_ratio: int, max_seq_len: int, sliding_window: int) -> bool: - GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName - ON_NAVI = "gfx1" in GPU_ARCH - ON_MI250_MI300 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"]) - - # rocm custom page attention not support on navi (gfx1*) - return (ON_MI250_MI300 and not ON_NAVI - and (sliding_window == 0 or sliding_window == (-1, -1)) + # rocm custom page attention not support on gfx1* + # custom paged attn always supported on V0. On V1, requires sliding window + # disabled due to observed numerical discrepancy. + return (on_mi250_mi300() and (not envs.VLLM_USE_V1 or sliding_window == 0 + or sliding_window == (-1, -1)) and (qtype == torch.half or qtype == torch.bfloat16) and (head_size == 64 or head_size == 128) and (block_size == 16 or block_size == 32) and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768 - and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) + and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) + and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN + and envs.VLLM_ROCM_USE_AITER)) class RocmPlatform(Platform): @@ -129,8 +131,8 @@ class RocmPlatform(Platform): device_control_env_var: str = "CUDA_VISIBLE_DEVICES" supported_quantization: list[str] = [ - "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors", - "fbgemm_fp8", "gguf", "quark", "ptpc_fp8", "moe_wna16", "blockwise_int8","w8a8_int8" + "awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf", + "quark", "ptpc_fp8", "moe_wna16", "blockwise_int8","w8a8_int8" ] @classmethod @@ -138,6 +140,7 @@ class RocmPlatform(Platform): kv_cache_dtype, block_size, use_v1, use_mla) -> str: if use_mla: +<<<<<<< HEAD if selected_backend == _Backend.TRITON_MLA or block_size != 64: if use_v1: logger.info_once("Using Triton MLA backend on V1 engine.") @@ -173,6 +176,38 @@ class RocmPlatform(Platform): logger.info("Using Triton MLA backend (block size 64).") return "vllm.attention.backends.triton_mla.TritonMLABackend" +======= + from vllm.attention.backends.rocm_aiter_mla import ( + is_aiter_mla_enabled) + + if selected_backend is None: + selected_backend = (_Backend.ROCM_AITER_MLA if + is_aiter_mla_enabled() or block_size == 1 + else _Backend.TRITON_MLA) + + if selected_backend == _Backend.TRITON_MLA: + if block_size != 1: + logger.info("Using Triton MLA backend.") + return "vllm.attention.backends.triton_mla.TritonMLABackend" # noqa: E501 + else: + raise ValueError( + f" The selected backend, {selected_backend.name}," + f"does not support block size {block_size}.") + elif selected_backend == _Backend.ROCM_AITER_MLA: + if block_size == 1: + logger.info("Using AITER MLA backend.") + return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" # noqa: E501 + else: + raise ValueError( + f" The selected backend, {selected_backend.name}," + f"does not support block size {block_size}." + "(currently only supports block size 1)") + else: + raise ValueError( + f" The selected backend, {selected_backend.name}," + f"is not MLA type while requested for MLA backend.") + +>>>>>>> v0.8.5 selected_backend = (_Backend.ROCM_FLASH if selected_backend == _Backend.FLASH_ATTN else selected_backend) if envs.VLLM_USE_V1: @@ -245,7 +280,7 @@ class RocmPlatform(Platform): return True @classmethod - def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: cache_config = vllm_config.cache_config if cache_config and cache_config.block_size is None: cache_config.block_size = 16 @@ -335,7 +370,7 @@ class RocmPlatform(Platform): return torch.float8_e4m3fn @classmethod - def supports_v1(cls, model_config: ModelConfig) -> bool: + def supports_v1(cls, model_config: "ModelConfig") -> bool: # V1 support on AMD gpus is experimental return True @@ -344,4 +379,9 @@ class RocmPlatform(Platform): # We only enable custom allreduce for MI300 series gcn_arch = torch.cuda.get_device_properties(0).gcnArchName supported_archs = ['gfx94'] - return any(gfx in gcn_arch for gfx in supported_archs) \ No newline at end of file + return any(gfx in gcn_arch for gfx in supported_archs) + + @classmethod + def get_cu_count(cls, device_id: int = 0) -> int: + return torch.cuda.get_device_properties( + device_id).multi_processor_count diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index d8807a72ba2f3768c4d6f930017519928e98b606..d5923557a21120a065c15019f3510bb7732758fe 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Optional, Union import torch import vllm.envs as envs -from vllm.inputs import PromptType +from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger from vllm.sampling_params import SamplingParams, SamplingType @@ -30,9 +30,7 @@ class TpuPlatform(Platform): ray_device_key: str = "TPU" device_control_env_var: str = "TPU_VISIBLE_CHIPS" - supported_quantization: list[str] = [ - "tpu_int8", "compressed-tensors", "compressed_tensors" - ] + supported_quantization: list[str] = ["tpu_int8", "compressed-tensors"] additional_env_vars: list[str] = [ "TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS" @@ -97,6 +95,20 @@ class TpuPlatform(Platform): "Using bfloat16 instead.", vllm_config.model_config.dtype) vllm_config.model_config.dtype = torch.bfloat16 + if envs.VLLM_USE_V1: + from vllm.v1.attention.backends.pallas import ( + PallasAttentionBackend) + min_page_size = PallasAttentionBackend.get_min_page_size( + vllm_config) + if min_page_size > vllm_config.cache_config.block_size: + logger.warning( + "Increase the page size from %s to %s to make sure there's" + "no SMEM OOM", + vllm_config.cache_config.block_size, + min_page_size, + ) + vllm_config.cache_config.block_size = min_page_size + parallel_config = vllm_config.parallel_config scheduler_config = vllm_config.scheduler_config if parallel_config.worker_cls == "auto": @@ -150,12 +162,13 @@ class TpuPlatform(Platform): cls, prompt: PromptType, params: Union[SamplingParams, PoolingParams], + processed_inputs: ProcessorInputs, ) -> None: """Raises if this request is unsupported on this platform""" if isinstance(params, SamplingParams): - if params.guided_decoding is not None: + if params.guided_decoding is not None and not envs.VLLM_USE_V1: raise ValueError("Structured output is not supported on " - f"{cls.device_name}.") + f"{cls.device_name} V0.") if params.sampling_type == SamplingType.RANDOM_SEED: raise ValueError( "Torch XLA does not support per-request seed.") diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index f71daf0c19551b8e0d0fa8bb436de8b225aa360c..9a3b254f9b68c3e603c0f09da54ae5249ff732a2 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -35,7 +35,16 @@ class PoolingParams( f'Model "{model_config.served_model_name}" does not ' f'support matryoshka representation, ' f'changing output dimensions will lead to poor results.') - if self.dimensions < 1: + + mds = model_config.matryoshka_dimensions + if mds is not None: + if self.dimensions not in mds: + raise ValueError( + f'Model "{model_config.served_model_name}" ' + f'only supports {str(mds)} matryoshka dimensions, ' + f'use other output dimensions will ' + f'lead to poor results.') + elif self.dimensions < 1: raise ValueError("Dimensions must be greater than 0") def __repr__(self) -> str: diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 68ed9966494718a67a065e339a6a1204c675debc..c430b74a9db9affe063631d1196ba5eb5214cf09 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -38,6 +38,7 @@ class GuidedDecodingParams: """These are other options that can be set""" backend: Optional[str] = None whitespace_pattern: Optional[str] = None + structural_tag: Optional[str] = None @staticmethod def from_optional( @@ -48,9 +49,10 @@ class GuidedDecodingParams: json_object: Optional[bool] = None, backend: Optional[str] = None, whitespace_pattern: Optional[str] = None, + structural_tag: Optional[str] = None, ) -> Optional["GuidedDecodingParams"]: - if all(arg is None - for arg in (json, regex, choice, grammar, json_object)): + if all(arg is None for arg in (json, regex, choice, grammar, + json_object, structural_tag)): return None # Extract json schemas from pydantic models if isinstance(json, (BaseModel, type(BaseModel))): @@ -63,6 +65,7 @@ class GuidedDecodingParams: json_object=json_object, backend=backend, whitespace_pattern=whitespace_pattern, + structural_tag=structural_tag, ) @property @@ -79,6 +82,17 @@ class GuidedDecodingParams: return [] return self.backend.split(":")[1].split(",") + def add_option(self, opt_name: str) -> None: + """Adds an option to the backend options.""" + if not self.backend: + self.backend = f":{opt_name}" + elif ":" not in self.backend: + self.backend += f":{opt_name}" + else: + options = set(self.backend_options()) + options.add(opt_name) + self.backend = f"{self.backend_name}:{','.join(sorted(options))}" + def no_fallback(self) -> bool: """Returns True if the "no-fallback" option is supplied for the guided decoding backend""" @@ -423,6 +437,10 @@ class SamplingParams( and self.truncate_prompt_tokens < 1): raise ValueError(f"truncate_prompt_tokens must be >= 1, " f"got {self.truncate_prompt_tokens}") + assert isinstance(self.stop_token_ids, list) + if not all(isinstance(st_id, int) for st_id in self.stop_token_ids): + raise ValueError(f"stop_token_ids must contain only integers, " + f"got {self.stop_token_ids}.") assert isinstance(self.stop, list) if any(not stop_str for stop_str in self.stop): raise ValueError("stop cannot contain an empty string.") diff --git a/vllm/sequence.py b/vllm/sequence.py index c8bee8969f395b1200473c5e2343d174443049ef..f8fee767b3e9cc1ab89a0cbbfdaac581877f0cb8 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -14,9 +14,9 @@ from typing import Any, Callable, Optional, Union import msgspec import torch -from vllm.inputs import SingletonInputs, SingletonInputsAdapter +from vllm.inputs import SingletonInputs from vllm.lora.request import LoRARequest -from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict +from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import RequestOutputKind, SamplingParams @@ -447,7 +447,7 @@ class Sequence: prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: self.seq_id = seq_id - self.inputs = SingletonInputsAdapter(inputs) + self.inputs = inputs self.block_size = block_size self.eos_token_id = eos_token_id self.lora_request = lora_request @@ -476,31 +476,29 @@ class Sequence: @property def prompt(self) -> Optional[str]: - return self.inputs.prompt + return self.inputs.get("prompt") @property def prompt_token_ids(self) -> list[int]: - return self.inputs.prompt_token_ids - - @property - def prompt_embeds(self) -> Optional[torch.Tensor]: - return self.inputs.prompt_embeds + return self.inputs["prompt_token_ids"] @property def token_type_ids(self) -> list[int]: - return self.inputs.token_type_ids + return self.inputs.get("token_type_ids", []) @property - def multi_modal_data(self) -> "MultiModalDataDict": - return self.inputs.multi_modal_data + def multi_modal_data(self) -> MultiModalKwargs: + if self.inputs["type"] == "multimodal": + return self.inputs["mm_kwargs"] + + return MultiModalKwargs({}) @property def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: - return self.inputs.multi_modal_placeholders + if self.inputs["type"] == "multimodal": + return self.inputs["mm_placeholders"] - @property - def mm_processor_kwargs(self) -> dict[str, Any]: - return self.inputs.mm_processor_kwargs + return {} @property def lora_int_id(self) -> int: @@ -751,12 +749,12 @@ class SequenceGroup: return self.first_seq.token_type_ids @property - def multi_modal_data(self) -> MultiModalDataDict: + def multi_modal_data(self) -> MultiModalKwargs: if self.first_seq.multi_modal_data: return self.first_seq.multi_modal_data elif self.encoder_seq is not None: return self.encoder_seq.multi_modal_data - return {} + return MultiModalKwargs({}) @property def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: @@ -766,14 +764,6 @@ class SequenceGroup: return self.encoder_seq.multi_modal_placeholders return {} - @property - def mm_processor_kwargs(self) -> dict[str, Any]: - if self.first_seq.multi_modal_data: - return self.first_seq.mm_processor_kwargs - elif self.encoder_seq is not None: - return self.encoder_seq.mm_processor_kwargs - return {} - @property def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 @@ -997,12 +987,9 @@ class SequenceGroupMetadata( computed_block_nums: Optional[list[int]] = None state: Optional[SequenceGroupState] = msgspec.field( default_factory=lambda: SequenceGroupState()) - # "MultiModalDataDict" types. We have to use Any due to msgspec - # doesn't allow to have union of 2 different dicts. token_type_ids: Optional[list[int]] = None - multi_modal_data: Optional[Any] = None + multi_modal_data: Optional[MultiModalKwargs] = None multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None - mm_processor_kwargs: Optional[dict[str, Any]] = None encoder_seq_data: Optional[SequenceData] = None cross_block_table: Optional[list[int]] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 3ad9b499332754fcac45e24f99f2e204fd662af5..24095ef2a56750db88de4926e4edcbb58cc5a0d5 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -295,7 +295,7 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase): if not self.is_driver_worker: return [] # Sample the next token. - output = self.model.sample( + output = self.model_runner.sampler( logits=logits, sampling_metadata=model_input.sampling_metadata, ) diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py index e370182a3ae4ce27cadc107eee95fcf29a301306..d7b939a419ba376e5690eb7747bd29f242f04b83 100644 --- a/vllm/spec_decode/metrics.py +++ b/vllm/spec_decode/metrics.py @@ -8,6 +8,7 @@ import torch from vllm.model_executor.layers.spec_decode_base_sampler import ( SpecDecodeBaseSampler) +from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available @@ -91,14 +92,14 @@ class AsyncMetricsCollector: self._rank = rank if isinstance(device_type, torch.device): device_type = device_type.type - if device_type == 'cuda': - self._copy_stream = torch.cuda.Stream() + stream = current_platform.Stream + if stream is not None: + self._copy_stream = stream() def maybe_collect_rejsample_metrics( self, k: int) -> Optional[SpecDecodeWorkerMetrics]: - # currently using cuda.Event, skip for any non_cuda_alike platform - from vllm.platforms import current_platform - if not current_platform.is_cuda_alike(): + # Skip for any platform that doesn't have device Event + if current_platform.Event is None: return None # If a copy was initiated in the previous call, collect and return. diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index d0fdff16ce90d93b5b82ae09df21625520238e23..e542a3c983de6e3f21d933d65211a3e749bd156f 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -50,11 +50,10 @@ class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase): def set_include_gpu_probs_tensor(self) -> None: # Need include_gpu_probs_tensor for MultiStepWorker - self.model_runner.model.sampler.include_gpu_probs_tensor = True + self.model_runner.sampler.include_gpu_probs_tensor = True def set_should_modify_greedy_probs_inplace(self) -> None: - self.model_runner.model.sampler.should_modify_greedy_probs_inplace = ( - True) + self.model_runner.sampler.should_modify_greedy_probs_inplace = True @torch.inference_mode() def sampler_output( diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 45e6d6fe86df58102c8c960572a7e3b37e193262..fbf162998b9dc022d7727bae559ef43b47c9c4da 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -430,7 +430,7 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase): NOTE(cade): This will require a special check if the proposer worker does not have a sampler (e.g. ngram speculation). """ - (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor + (self.scorer_worker.model_runner.sampler.include_gpu_probs_tensor ) = True # tree_style decoding modify probs in _verify_tokens diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index fe0319c9b033ed37c5ae47bc860dd0089979e3a1..e062afd682087654eb5b79480cb898e3919847d5 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -33,10 +33,10 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config, EAGLEConfig, ExaoneConfig, H2OVLChatConfig, InternVLChatConfig, JAISConfig, - MedusaConfig, MllamaConfig, - MLPSpeculatorConfig, MPTConfig, - NemotronConfig, NVLM_D_Config, - Olmo2Config, RWConfig, + KimiVLConfig, MedusaConfig, + MllamaConfig, MLPSpeculatorConfig, + MPTConfig, NemotronConfig, + NVLM_D_Config, RWConfig, SkyworkR1VChatConfig, SolarConfig, Telechat2Config, UltravoxConfig) # yapf: enable @@ -62,6 +62,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { "cohere2": Cohere2Config, "dbrx": DbrxConfig, "deepseek_vl_v2": DeepseekVLV2Config, + "kimi_vl": KimiVLConfig, "mpt": MPTConfig, "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) @@ -74,7 +75,6 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { "internvl_chat": InternVLChatConfig, "nemotron": NemotronConfig, "NVLM_D": NVLM_D_Config, - "olmo2": Olmo2Config, "solar": SolarConfig, "skywork_chat": SkyworkR1VChatConfig, "telechat": Telechat2Config, @@ -220,8 +220,7 @@ def patch_rope_scaling_dict(rope_scaling: Dict[str, Any]) -> None: logger.warning("Replacing legacy rope_type 'mrope' with 'default'") -def uses_mrope(config: PretrainedConfig) -> bool: - """Detect if the model with this config uses M-ROPE.""" +def _uses_mrope(config: PretrainedConfig) -> bool: rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is None: return False @@ -229,6 +228,24 @@ def uses_mrope(config: PretrainedConfig) -> bool: return "mrope_section" in rope_scaling +def uses_mrope(config: PretrainedConfig) -> bool: + """Detect if the model with this config uses M-ROPE.""" + return _uses_mrope(config) or thinker_uses_mrope(config) + + +def thinker_uses_mrope(config: PretrainedConfig) -> bool: + """Detect if the model contains a thinker config and it uses M-ROPE.""" + thinker_config = getattr(config, "thinker_config", None) + if thinker_config is None: + return False + + thinker_text_config = getattr(thinker_config, "text_config", None) + if thinker_text_config is None: + return False + + return uses_mrope(thinker_text_config) + + def is_encoder_decoder(config: PretrainedConfig) -> bool: """Detect if the model with this config is used as an encoder/decoder.""" text_config = getattr(config, "text_config", None) @@ -633,6 +650,11 @@ def load_params_config(model: Union[str, Path], revision: Optional[str], config_file_name = "params.json" config_dict = get_hf_file_to_dict(config_file_name, model, revision) + if config_dict is None: + raise ValueError( + f"Failed to load mistral '{config_file_name}' config for model " + f"{model}. Please check if the model is a mistral-format model " + f"and if the config file exists.") assert isinstance(config_dict, dict) config_mapping = { @@ -671,6 +693,9 @@ def load_params_config(model: Union[str, Path], revision: Optional[str], "quant_method": "fp8", "activation_scheme": "static" } + elif quantization.get("quant_method") == "compressed-tensors": + # Pass through the quantization config to compressed-tensors + quantization_config = quantization else: raise ValueError( f"Found unknown quantization='{quantization}' in config") @@ -688,6 +713,7 @@ def load_params_config(model: Union[str, Path], revision: Optional[str], if config_type == "multimodal": multimodal_config = config_dict.pop("vision_encoder") + quantization_config = config_dict.get("quantization_config", {}) config_dict = { "text_config": config_dict, @@ -695,6 +721,8 @@ def load_params_config(model: Union[str, Path], revision: Optional[str], } config_dict["architectures"] = ["PixtralForConditionalGeneration"] config_dict["model_type"] = "pixtral" + if quantization_config: + config_dict["quantization_config"] = quantization_config config_dict.update(kwargs) @@ -732,14 +760,22 @@ def get_hf_text_config(config: PretrainedConfig): """Get the "sub" config relevant to llm for multi modal models. No op for pure text models. """ - if hasattr(config, "text_config"): + # This block should be unnecessary after https://github.com/huggingface/transformers/pull/37517 + if hasattr(config, "thinker_config"): + # TODO(suyang.fy): Refactor code. + # For Qwen2.5-Omni, change hf_text_config to + # thinker_config.text_config. + return config.thinker_config.text_config + + text_config = config.get_text_config() + + if text_config is not config: # The code operates under the assumption that text_config should have # `num_attention_heads` (among others). Assert here to fail early # if transformers config doesn't align with this assumption. - assert hasattr(config.text_config, "num_attention_heads") - return config.text_config - else: - return config + assert hasattr(text_config, "num_attention_heads") + + return text_config def try_get_generation_config( diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 53699341bfba8f2af6c2dd1fcf469eaadba91948..8812d4c484b17a7b1252d6d01186c73a5e3f08f2 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -13,13 +13,14 @@ from vllm.transformers_utils.configs.falcon import RWConfig from vllm.transformers_utils.configs.h2ovl import H2OVLChatConfig from vllm.transformers_utils.configs.internvl import InternVLChatConfig from vllm.transformers_utils.configs.jais import JAISConfig +from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig from vllm.transformers_utils.configs.medusa import MedusaConfig from vllm.transformers_utils.configs.mllama import MllamaConfig from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig +from vllm.transformers_utils.configs.moonvit import MoonViTConfig from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.nemotron import NemotronConfig from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config -from vllm.transformers_utils.configs.olmo2 import Olmo2Config from vllm.transformers_utils.configs.skyworkr1v import SkyworkR1VChatConfig from vllm.transformers_utils.configs.solar import SolarConfig from vllm.transformers_utils.configs.telechat2 import Telechat2Config @@ -40,9 +41,10 @@ __all__ = [ "ExaoneConfig", "MllamaConfig", "MLPSpeculatorConfig", + "MoonViTConfig", + "KimiVLConfig", "NemotronConfig", "NVLM_D_Config", - "Olmo2Config", "SkyworkR1VChatConfig", "SolarConfig", "Telechat2Config", diff --git a/vllm/transformers_utils/configs/kimi_vl.py b/vllm/transformers_utils/configs/kimi_vl.py new file mode 100644 index 0000000000000000000000000000000000000000..97ff44bb9c1c99f60fa77059936600091958261a --- /dev/null +++ b/vllm/transformers_utils/configs/kimi_vl.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py +from typing import Optional, Union + +from transformers.configuration_utils import PretrainedConfig + +from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config +from vllm.transformers_utils.configs.moonvit import MoonViTConfig + + +class KimiVLConfig(PretrainedConfig): + model_type = "kimi_vl" + + def __init__(self, + vision_config: Optional[Union[dict, MoonViTConfig]] = None, + text_config: Optional[Union[dict, DeepseekV2Config]] = None, + ignore_index: int = -100, + media_placeholder_token_id: int = 163605, + pad_token_id: int = 0, + **kwargs): + if vision_config is None: + vision_config = MoonViTConfig() + elif isinstance(vision_config, dict): + vision_config = MoonViTConfig(**vision_config) + self.vision_config = vision_config + + if text_config is None: + text_config = DeepseekV2Config() + elif isinstance(text_config, dict): + text_config = DeepseekV2Config(**text_config) + self.text_config = text_config + + self.ignore_index = ignore_index + self.media_placeholder_token_id = media_placeholder_token_id + + super().__init__(pad_token_id=pad_token_id, **kwargs) diff --git a/vllm/transformers_utils/configs/moonvit.py b/vllm/transformers_utils/configs/moonvit.py new file mode 100644 index 0000000000000000000000000000000000000000..a2b4059a63efb69f0bb1f88033a57047a724b6d4 --- /dev/null +++ b/vllm/transformers_utils/configs/moonvit.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py +from transformers.configuration_utils import PretrainedConfig + + +class MoonViTConfig(PretrainedConfig): + model_type = "moonvit" + + def __init__( + self, + patch_size: int = 14, + init_pos_emb_height: int = 64, + init_pos_emb_width: int = 64, + num_attention_heads: int = 16, + num_hidden_layers: int = 27, + hidden_size: int = 1152, + intermediate_size: int = 4304, + merge_kernel_size: tuple[int, int] = (2, 2), + **kwargs, + ): + super().__init__(**kwargs) + self.patch_size = patch_size + # Positional embedding config + self.init_pos_emb_height = init_pos_emb_height + self.init_pos_emb_width = init_pos_emb_width + # Transformer config + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + # Patch merger config + self.merge_kernel_size = merge_kernel_size diff --git a/vllm/transformers_utils/configs/olmo2.py b/vllm/transformers_utils/configs/olmo2.py deleted file mode 100644 index c6e446333b43d0aed0dd133bcea5cb463b1e0c15..0000000000000000000000000000000000000000 --- a/vllm/transformers_utils/configs/olmo2.py +++ /dev/null @@ -1,168 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -# yapf: disable -# ruff: noqa: E501 -# coding=utf-8 -# Copied from -# https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmo2/configuration_olmo2.py -"""OLMo 2 configuration.""" - -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging - -logger = logging.get_logger(__name__) - - -class Olmo2Config(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`Olmo2Model`]. It is used to instantiate an OLMo2 - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the [allenai/Olmo2-7B-1124-hf](https://huggingface.co/allenai/Olmo2-7B-1124-hf). - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 50304): - Vocabulary size of the Olmo2 model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`Olmo2Model`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 11008): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer decoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer decoder. - num_key_value_heads (`int`, *optional*): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to - `num_attention_heads`. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 2048): - The maximum sequence length that this model might ever be used with. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - pad_token_id (`int`, *optional*, defaults to 1): - Padding token id. - bos_token_id (`int`, *optional*): - Beginning of stream token id. - eos_token_id (`int`, *optional*, defaults to 50279): - End of stream token id. - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether to tie weight embeddings - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling - strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is - `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update - `max_position_embeddings` to the expected new maximum. See the following thread for more information on how - these scaling strategies behave: - https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an - experimental feature, subject to breaking API changes in future versions. - attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): - Whether to use a bias in the query, key, value and output projection layers during self-attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - rms_norm_eps (`float`, *optional*, defaults to 1e-05): - The epsilon used by the rms normalization layers. - - ```python - >>> from transformers import Olmo2Model, Olmo2Config - - >>> # Initializing a Olmo2 7B style configuration - >>> configuration = Olmo2Config() - - >>> # Initializing a model from the Olmo2 7B style configuration - >>> model = Olmo2Model(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ``` - """ - - model_type = "olmo2" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=50304, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=None, - hidden_act="silu", - max_position_embeddings=2048, - initializer_range=0.02, - use_cache=True, - pad_token_id=1, - bos_token_id=None, - eos_token_id=50279, - tie_word_embeddings=False, - rope_theta=10000.0, - rope_scaling=None, - attention_bias=False, - attention_dropout=0.0, - rms_norm_eps=1e-5, - **kwargs, - ): - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.use_cache = use_cache - self.rope_theta = rope_theta - self.rope_scaling = rope_scaling - self._rope_scaling_validation() - self.attention_bias = attention_bias - self.attention_dropout = attention_dropout - - self.rms_norm_eps = rms_norm_eps - - def _rope_scaling_validation(self): - """ - Validate the `rope_scaling` configuration. - """ - if self.rope_scaling is None: - return - - if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: - raise ValueError( - "`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}" - ) - rope_scaling_type = self.rope_scaling.get("type", None) - rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: - raise ValueError( - f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" - ) - if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: - raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index 9d1d4bb92e4ab5043007f79b69ccd14ef380e51b..991d5631e64e348d439d009afed1d0b181477eb6 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -8,13 +8,13 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams, from .detokenizer_utils import (convert_prompt_ids_to_tokens, detokenize_incrementally) from .tokenizer import AnyTokenizer -from .tokenizer_group import BaseTokenizerGroup +from .tokenizer_group import TokenizerGroup class Detokenizer: """Provides methods to decode the output of a model into text.""" - def __init__(self, tokenizer_group: BaseTokenizerGroup): + def __init__(self, tokenizer_group: TokenizerGroup): self.tokenizer_group = tokenizer_group def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer: diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py index 1d09b99d50c06428c329a63bdde373772845f59a..4f06950c42e292d918f14faa5c668fbd6c455d43 100644 --- a/vllm/transformers_utils/processor.py +++ b/vllm/transformers_utils/processor.py @@ -111,20 +111,20 @@ def cached_processor_from_config( ) -def get_image_processor( +def get_feature_extractor( processor_name: str, *args: Any, trust_remote_code: bool = False, **kwargs: Any, ): - """Load an image processor for the given model name via HuggingFace.""" + """Load an audio feature extractor for the given model name + via HuggingFace.""" # don't put this import at the top level # it will call torch.cuda.device_count() - from transformers import AutoImageProcessor - from transformers.image_processing_utils import BaseImageProcessor - + from transformers import AutoFeatureExtractor + from transformers.feature_extraction_utils import FeatureExtractionMixin try: - processor = AutoImageProcessor.from_pretrained( + feature_extractor = AutoFeatureExtractor.from_pretrained( processor_name, *args, trust_remote_code=trust_remote_code, @@ -135,61 +135,75 @@ def get_image_processor( # Unlike AutoTokenizer, AutoImageProcessor does not separate such errors if not trust_remote_code: err_msg = ( - "Failed to load the image processor. If the image processor is " - "a custom processor not yet available in the HuggingFace " - "transformers library, consider setting " + "Failed to load the feature extractor. If the feature " + "extractor is a custom extractor not yet available in the " + "HuggingFace transformers library, consider setting " "`trust_remote_code=True` in LLM or using the " "`--trust-remote-code` flag in the CLI.") raise RuntimeError(err_msg) from e else: raise e + return cast(FeatureExtractionMixin, feature_extractor) - return cast(BaseImageProcessor, processor) +cached_get_feature_extractor = lru_cache(get_feature_extractor) -cached_get_image_processor = lru_cache(get_image_processor) - -def cached_image_processor_from_config( +def cached_feature_extractor_from_config( model_config: "ModelConfig", **kwargs: Any, ): - return cached_get_image_processor( + return cached_get_feature_extractor( model_config.model, trust_remote_code=model_config.trust_remote_code, **_merge_mm_kwargs(model_config, **kwargs), ) -def get_video_processor( +def get_image_processor( processor_name: str, *args: Any, trust_remote_code: bool = False, **kwargs: Any, ): - """Load a video processor for the given model name via HuggingFace.""" + """Load an image processor for the given model name via HuggingFace.""" # don't put this import at the top level # it will call torch.cuda.device_count() + from transformers import AutoImageProcessor from transformers.image_processing_utils import BaseImageProcessor - processor = get_processor( - processor_name, - *args, - trust_remote_code=trust_remote_code, - **kwargs, - ) + try: + processor = AutoImageProcessor.from_pretrained( + processor_name, + *args, + trust_remote_code=trust_remote_code, + **kwargs) + except ValueError as e: + # If the error pertains to the processor class not existing or not + # currently being imported, suggest using the --trust-remote-code flag. + # Unlike AutoTokenizer, AutoImageProcessor does not separate such errors + if not trust_remote_code: + err_msg = ( + "Failed to load the image processor. If the image processor is " + "a custom processor not yet available in the HuggingFace " + "transformers library, consider setting " + "`trust_remote_code=True` in LLM or using the " + "`--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e - return cast(BaseImageProcessor, processor.video_processor) + return cast(BaseImageProcessor, processor) -cached_get_video_processor = lru_cache(get_video_processor) +cached_get_image_processor = lru_cache(get_image_processor) -def cached_video_processor_from_config( +def cached_image_processor_from_config( model_config: "ModelConfig", **kwargs: Any, ): - return cached_get_video_processor( + return cached_get_image_processor( model_config.model, trust_remote_code=model_config.trust_remote_code, **_merge_mm_kwargs(model_config, **kwargs), diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 2293d062600c9c27bf25f687e063906a5f06e273..ea6afc4db739f397d3085eea6496bacbb2f422dd 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import contextlib +import copy import os import warnings from functools import lru_cache @@ -70,18 +71,17 @@ def encode_tokens( def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: - """Get tokenizer with cached properties. - - This will patch the tokenizer object in place. - + """ By default, transformers will recompute multiple tokenizer properties - each time they are called, leading to a significant slowdown. This - function caches these properties for faster access.""" + each time they are called, leading to a significant slowdown. + This proxy caches these properties for faster access. + """ + cached_tokenizer = copy.copy(tokenizer) - tokenizer_all_special_ids = set(tokenizer.all_special_ids) + tokenizer_all_special_ids = tokenizer.all_special_ids + tokenizer_all_special_tokens = tokenizer.all_special_tokens tokenizer_all_special_tokens_extended = ( tokenizer.all_special_tokens_extended) - tokenizer_all_special_tokens = set(tokenizer.all_special_tokens) tokenizer_vocab = tokenizer.get_vocab() tokenizer_len = len(tokenizer) @@ -97,31 +97,34 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: class CachedTokenizer(tokenizer.__class__): # type: ignore @property - def all_special_ids(self): + def all_special_ids(self) -> list[int]: return tokenizer_all_special_ids @property - def all_special_tokens(self): + def all_special_tokens(self) -> list[str]: return tokenizer_all_special_tokens @property - def all_special_tokens_extended(self): + def all_special_tokens_extended(self) -> list[str]: return tokenizer_all_special_tokens_extended @property - def max_token_id(self): + def max_token_id(self) -> int: return max_token_id - def get_vocab(self): + def get_vocab(self) -> dict[str, int]: return tokenizer_vocab - def __len__(self): + def __len__(self) -> int: return tokenizer_len + def __reduce__(self): + return get_cached_tokenizer, (tokenizer, ) + CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}" - tokenizer.__class__ = CachedTokenizer - return tokenizer + cached_tokenizer.__class__ = CachedTokenizer + return cached_tokenizer def patch_padding_side(tokenizer: PreTrainedTokenizer) -> None: diff --git a/vllm/transformers_utils/tokenizer_base.py b/vllm/transformers_utils/tokenizer_base.py index bb5ddaf88b219949d500b8f370b9ec534c36d8db..b4eb081c9b99d170ab3184038950ca983ff1143f 100644 --- a/vllm/transformers_utils/tokenizer_base.py +++ b/vllm/transformers_utils/tokenizer_base.py @@ -2,7 +2,7 @@ import importlib from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Optional, Union if TYPE_CHECKING: from vllm.entrypoints.chat_utils import ChatCompletionMessageParam @@ -12,17 +12,17 @@ class TokenizerBase(ABC): @property @abstractmethod - def all_special_tokens_extended(self) -> List[str]: + def all_special_tokens_extended(self) -> list[str]: raise NotImplementedError() @property @abstractmethod - def all_special_tokens(self) -> List[str]: + def all_special_tokens(self) -> list[str]: raise NotImplementedError() @property @abstractmethod - def all_special_ids(self) -> List[int]: + def all_special_ids(self) -> list[int]: raise NotImplementedError() @property @@ -66,7 +66,7 @@ class TokenizerBase(ABC): @abstractmethod def __call__( self, - text: Union[str, List[str], List[int]], + text: Union[str, list[str], list[int]], text_pair: Optional[str] = None, add_special_tokens: bool = False, truncation: bool = False, @@ -75,11 +75,11 @@ class TokenizerBase(ABC): raise NotImplementedError() @abstractmethod - def get_vocab(self) -> Dict[str, int]: + def get_vocab(self) -> dict[str, int]: raise NotImplementedError() @abstractmethod - def get_added_vocab(self) -> Dict[str, int]: + def get_added_vocab(self) -> dict[str, int]: raise NotImplementedError() @abstractmethod @@ -88,44 +88,44 @@ class TokenizerBase(ABC): text: str, truncation: bool = False, max_length: Optional[int] = None, - ) -> List[int]: + ) -> list[int]: raise NotImplementedError() @abstractmethod def encode(self, text: str, - add_special_tokens: Optional[bool] = None) -> List[int]: + add_special_tokens: Optional[bool] = None) -> list[int]: raise NotImplementedError() @abstractmethod def apply_chat_template(self, - messages: List["ChatCompletionMessageParam"], - tools: Optional[List[Dict[str, Any]]] = None, - **kwargs) -> List[int]: + messages: list["ChatCompletionMessageParam"], + tools: Optional[list[dict[str, Any]]] = None, + **kwargs) -> list[int]: raise NotImplementedError() @abstractmethod - def convert_tokens_to_string(self, tokens: List[str]) -> str: + def convert_tokens_to_string(self, tokens: list[str]) -> str: raise NotImplementedError() @abstractmethod def decode(self, - ids: Union[List[int], int], + ids: Union[list[int], int], skip_special_tokens: bool = True) -> str: raise NotImplementedError() @abstractmethod def convert_ids_to_tokens( self, - ids: List[int], + ids: list[int], skip_special_tokens: bool = True, - ) -> List[str]: + ) -> list[str]: raise NotImplementedError() class TokenizerRegistry: # Tokenizer name -> (tokenizer module, tokenizer class) - REGISTRY: Dict[str, Tuple[str, str]] = {} + REGISTRY: dict[str, tuple[str, str]] = {} @staticmethod def register(name: str, module: str, class_name: str) -> None: diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group.py similarity index 84% rename from vllm/transformers_utils/tokenizer_group/tokenizer_group.py rename to vllm/transformers_utils/tokenizer_group.py index b6e9005bcd241ef117e22c739ce52d70cd2fb5c1..a829985cb4592e2f736cea0f15901c65f88e610a 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group.py @@ -2,7 +2,7 @@ from typing import List, Optional -from vllm.config import TokenizerPoolConfig +from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig from vllm.lora.request import LoRARequest from vllm.transformers_utils.tokenizer import (AnyTokenizer, encode_tokens, get_lora_tokenizer, @@ -10,10 +10,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, encode_tokens, get_tokenizer) from vllm.utils import LRUCache -from .base_tokenizer_group import BaseTokenizerGroup - -class TokenizerGroup(BaseTokenizerGroup): +class TokenizerGroup: """A group of tokenizers that can be used for LoRA adapters.""" def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, @@ -27,15 +25,6 @@ class TokenizerGroup(BaseTokenizerGroup): self.lora_tokenizers = LRUCache[int, AnyTokenizer]( capacity=max(max_loras, max_num_seqs) if enable_lora else 0) - @classmethod - def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig], - **init_kwargs) -> "TokenizerGroup": - return cls(**init_kwargs) - - def ping(self) -> bool: - """Check if the tokenizer group is alive.""" - return True - def get_max_input_len(self, lora_request: Optional[LoRARequest] = None ) -> Optional[int]: @@ -104,3 +93,18 @@ class TokenizerGroup(BaseTokenizerGroup): return tokenizer else: return self.lora_tokenizers[lora_request.lora_int_id] + + +def init_tokenizer_from_configs(model_config: ModelConfig, + scheduler_config: SchedulerConfig, + lora_config: Optional[LoRAConfig]): + return TokenizerGroup( + tokenizer_id=model_config.tokenizer, + enable_lora=bool(lora_config), + max_num_seqs=scheduler_config.max_num_seqs, + max_loras=lora_config.max_loras if lora_config else 0, + max_input_length=None, + tokenizer_mode=model_config.tokenizer_mode, + trust_remote_code=model_config.trust_remote_code, + revision=model_config.tokenizer_revision, + truncation_side=model_config.truncation_side) diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py deleted file mode 100644 index 9d2209575bd366ff573099fb2c5e7920319c90ed..0000000000000000000000000000000000000000 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ /dev/null @@ -1,56 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from typing import Optional, Type - -from vllm.config import (LoRAConfig, ModelConfig, ParallelConfig, - SchedulerConfig, TokenizerPoolConfig) -from vllm.executor.ray_utils import ray - -from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup -from .tokenizer_group import TokenizerGroup - -if ray: - from .ray_tokenizer_group import RayTokenizerGroupPool -else: - RayTokenizerGroupPool = None # type: ignore - - -def init_tokenizer_from_configs(model_config: ModelConfig, - scheduler_config: SchedulerConfig, - parallel_config: ParallelConfig, - lora_config: Optional[LoRAConfig]): - init_kwargs = dict(tokenizer_id=model_config.tokenizer, - enable_lora=bool(lora_config), - max_num_seqs=scheduler_config.max_num_seqs, - max_loras=lora_config.max_loras if lora_config else 0, - max_input_length=None, - tokenizer_mode=model_config.tokenizer_mode, - trust_remote_code=model_config.trust_remote_code, - revision=model_config.tokenizer_revision, - truncation_side=model_config.truncation_side) - - return get_tokenizer_group(parallel_config.tokenizer_pool_config, - **init_kwargs) - - -def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig], - **init_kwargs) -> BaseTokenizerGroup: - tokenizer_cls: Type[BaseTokenizerGroup] - if tokenizer_pool_config is None: - tokenizer_cls = TokenizerGroup - elif isinstance(tokenizer_pool_config.pool_type, type) and issubclass( - tokenizer_pool_config.pool_type, BaseTokenizerGroup): - tokenizer_cls = tokenizer_pool_config.pool_type - elif tokenizer_pool_config.pool_type == "ray": - if RayTokenizerGroupPool is None: - raise ImportError( - "RayTokenizerGroupPool is not available. Please install " - "the ray package to use the Ray tokenizer group pool.") - tokenizer_cls = RayTokenizerGroupPool - else: - raise ValueError( - f"Unknown pool type: {tokenizer_pool_config.pool_type}") - return tokenizer_cls.from_config(tokenizer_pool_config, **init_kwargs) - - -__all__ = ["AnyTokenizer", "get_tokenizer_group", "BaseTokenizerGroup"] diff --git a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py deleted file mode 100644 index c5108a7fc6ebc99ed1e47712d9ddb3abd52250cf..0000000000000000000000000000000000000000 --- a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +++ /dev/null @@ -1,68 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from abc import ABC, abstractmethod -from typing import List, Optional - -from vllm.config import TokenizerPoolConfig -from vllm.lora.request import LoRARequest -from vllm.transformers_utils.tokenizer import AnyTokenizer - - -class BaseTokenizerGroup(ABC): - """A group of tokenizers that can be used for LoRA adapters.""" - - @classmethod - @abstractmethod - def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig], - **init_kwargs) -> "BaseTokenizerGroup": - pass - - @abstractmethod - def ping(self) -> bool: - """Check if the tokenizer group is alive.""" - pass - - @abstractmethod - def get_max_input_len( - self, - lora_request: Optional[LoRARequest] = None, - ) -> Optional[int]: - """Get the maximum input length for the LoRA request.""" - pass - - @abstractmethod - def encode(self, - prompt: str, - lora_request: Optional[LoRARequest] = None, - add_special_tokens: Optional[bool] = None) -> List[int]: - """Encode a prompt using the tokenizer group.""" - pass - - @abstractmethod - async def encode_async( - self, - prompt: str, - lora_request: Optional[LoRARequest] = None, - add_special_tokens: Optional[bool] = None) -> List[int]: - """Encode a prompt using the tokenizer group.""" - pass - - @abstractmethod - def get_lora_tokenizer( - self, - lora_request: Optional[LoRARequest] = None, - ) -> AnyTokenizer: - """Get a tokenizer for a LoRA request.""" - pass - - @abstractmethod - async def get_lora_tokenizer_async( - self, - lora_request: Optional[LoRARequest] = None, - ) -> AnyTokenizer: - """Get a tokenizer for a LoRA request.""" - pass - - def check_health(self): - """Raise exception if the tokenizer group is unhealthy.""" - return diff --git a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py deleted file mode 100644 index b048b8094174a13cbbbf63f3b533bd2ddc5041a9..0000000000000000000000000000000000000000 --- a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +++ /dev/null @@ -1,244 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import asyncio -import os -from typing import List, Optional - -try: - from ray.exceptions import ActorDiedError # type: ignore -except ImportError: - # For older versions of Ray - from ray.exceptions import RayActorError as ActorDiedError # type: ignore -from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy - -from vllm.config import TokenizerPoolConfig -from vllm.executor.ray_utils import ray -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.transformers_utils.tokenizer import AnyTokenizer - -from .base_tokenizer_group import BaseTokenizerGroup -from .tokenizer_group import TokenizerGroup - -logger = init_logger(__name__) - - -class RayTokenizerGroupPool(BaseTokenizerGroup): - """A Ray-based pool of TokenizerGroups for async tokenization.""" - - # Class to use for workers making up the pool. - _worker_cls = TokenizerGroup - - @classmethod - def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig], - **init_kwargs) -> "RayTokenizerGroupPool": - if not tokenizer_pool_config: - raise ValueError("tokenizer_pool_config must not be None.") - ray_actor_options = (tokenizer_pool_config.extra_config or { - "num_cpus": 0 - }) - ray_actor_options.setdefault( - "scheduling_strategy", - NodeAffinitySchedulingStrategy( - node_id=ray.get_runtime_context().get_node_id(), soft=True)) - - # Carry over the env vars to the actors. - # This is necessary for API keys and such. - ray_actor_options.setdefault("runtime_env", {}) - _carry_over_env_vars_to_runtime_env(ray_actor_options["runtime_env"]) - - init_kwargs["num_actors"] = tokenizer_pool_config.pool_size - init_kwargs["ray_actor_options"] = ray_actor_options - - return cls(**init_kwargs) - - def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, - max_input_length: Optional[int], num_actors: int, - ray_actor_options: dict, **tokenizer_config): - # Store a local copy of the TokenizerGroup for quick access - # to underlying HF tokenizers. - self._tokenizer_config = { - "tokenizer_id": tokenizer_id, - "enable_lora": enable_lora, - "max_num_seqs": max_num_seqs, - "max_input_length": max_input_length, - **tokenizer_config - } - self._local_tokenizer_group = self._worker_cls( - **self._tokenizer_config, ) - - self._ray_tokenizer_group_cls = ray.remote( - self._worker_cls).options(**ray_actor_options) # type: ignore - self.tokenizer_actors = [self._init_actor() for _ in range(num_actors)] - self._idle_actors: Optional[asyncio.Queue] = None - - # If set, actor is unhealthy. Will reraise on the next - # check_health call. - self._exception: Optional[ActorDiedError] = None - - def _init_actor(self) -> ray.ObjectRef: - return self._ray_tokenizer_group_cls.remote(**self._tokenizer_config) - - @property - def pool_size(self) -> int: - return len(self.tokenizer_actors) - - def ping(self): - return ray.get([ - actor.ping.remote() # type: ignore - for actor in self.tokenizer_actors - ]) - - def _ensure_queue_initialized(self): - if self._idle_actors is None: - self._idle_actors = asyncio.Queue() - for actor in self.tokenizer_actors: - self._idle_actors.put_nowait(actor) - - def _finalize_encode(self, actor: ray.ObjectRef, - original_actor: ray.ObjectRef, actor_is_alive: bool): - assert self._idle_actors is not None - # Cleanup the dead actor. - if not actor_is_alive or original_actor is not actor: - self.tokenizer_actors.remove(original_actor) - if actor_is_alive: - # Put the actor back in the queue. - # This is done in a finally block to ensure that the actor is - # always put back in the queue, even if an exception/cancellation - # is raised. - self._idle_actors.put_nowait(actor) - # Add back the new actor. - if original_actor is not actor: - self.tokenizer_actors.append(actor) - - def encode(self, - prompt: str, - lora_request: Optional[LoRARequest] = None, - add_special_tokens: Optional[bool] = None) -> List[int]: - """Encode a prompt using the tokenizer group. - - We pick an idle actor and use it to encode the prompt. - The actor is then put back in the queue for future use. - This is blocking. - """ - self.check_health() - self._ensure_queue_initialized() - assert self._idle_actors is not None - - if self._idle_actors.empty(): - raise RuntimeError("No idle actors available.") - actor = self._idle_actors.get_nowait() - actor_is_alive = True - original_actor = actor - try: - ret = ray.get( - actor.encode.remote(prompt=prompt, - lora_request=lora_request, - add_special_tokens=add_special_tokens)) - except ActorDiedError as e: - # If the actor is dead, we first try to reinitialize it. - logger.warning("%s died with ActorDiedError, reinitializing.", - actor, - exc_info=e) - actor = self._init_actor() - try: - ret = ray.get( - actor.encode.remote(prompt=prompt, - lora_request=lora_request, - add_special_tokens=add_special_tokens)) - except ActorDiedError as e: - logger.error( - "%s died for second time in a row, marking " - "RayTokenizerGroupPool as unhealthy.", actor) - actor_is_alive = False - if not self._exception: - self._exception = e - self.check_health() - finally: - self._finalize_encode(actor, original_actor, actor_is_alive) - return ret - - async def encode_async( - self, - prompt: str, - lora_request: Optional[LoRARequest] = None, - add_special_tokens: Optional[bool] = None) -> List[int]: - """Encode a prompt using the tokenizer group. - - We pick an idle actor and use it to encode the prompt. - If there are no idle actors, we wait until one becomes - available. - The actor is then put back in the queue for future use. - This is non-blocking. - """ - self.check_health() - self._ensure_queue_initialized() - assert self._idle_actors is not None - - actor = await self._idle_actors.get() - actor_is_alive = True - original_actor = actor - try: - ret = await actor.encode.remote( - prompt=prompt, - lora_request=lora_request, - add_special_tokens=add_special_tokens) - except ActorDiedError as e: - # If the actor is dead, we first try to reinitialize it. - logger.warning("%s died with ActorDiedError, reinitializing.", - actor, - exc_info=e) - actor = self._init_actor() - try: - ret = await actor.encode.remote( - prompt=prompt, - lora_request=lora_request, - add_special_tokens=add_special_tokens) - except ActorDiedError as e: - logger.error( - "%s died for second time in a row, marking " - "RayTokenizerGroupPool as unhealthy.", actor) - actor_is_alive = False - if not self._exception: - self._exception = e - self.check_health() - finally: - self._finalize_encode(actor, original_actor, actor_is_alive) - return ret - - def get_max_input_len(self, - lora_request: Optional[LoRARequest] = None - ) -> Optional[int]: - """Get the maximum input length for the LoRA request.""" - return self._local_tokenizer_group.get_max_input_len(lora_request) - - def get_lora_tokenizer( - self, - lora_request: Optional[LoRARequest] = None, - ) -> AnyTokenizer: - return self._local_tokenizer_group.get_lora_tokenizer(lora_request) - - async def get_lora_tokenizer_async( - self, - lora_request: Optional[LoRARequest] = None, - ) -> AnyTokenizer: - return await self._local_tokenizer_group.get_lora_tokenizer_async( - lora_request) - - def check_health(self): - if self._exception: - raise RuntimeError( - "TokenizerGroupPool is unhealthy.") from self._exception - - -def _carry_over_env_vars_to_runtime_env(runtime_env: dict) -> None: - """Copy over all current process environment variables to the runtime_env. - - The variables in runtime_env will take precedence over the current process - environment variables. - - runtime_env will be modified in place.""" - env_vars = os.environ.copy() - runtime_env.setdefault("env_vars", {}) - env_vars.update(runtime_env["env_vars"]) - runtime_env["env_vars"] = env_vars diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 58a114fa3a32feb3482132bed7e462b15f034a7e..296149a4569588036e21178491392cd1a1535d7c 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -257,7 +257,7 @@ class MistralTokenizer(TokenizerBase): # the following attributes are set to fit vLLM's design and are used # by the guided structured output backends. @property - def all_special_tokens_extended(self) -> List[str]: + def all_special_tokens_extended(self) -> list[str]: from mistral_common.tokens.tokenizers.base import SpecialTokens # tekken defines its own extended special tokens list @@ -271,11 +271,11 @@ class MistralTokenizer(TokenizerBase): ] @property - def all_special_tokens(self) -> List[str]: + def all_special_tokens(self) -> list[str]: return self.all_special_tokens_extended @property - def all_special_ids(self) -> List[int]: + def all_special_ids(self) -> list[int]: return [ self.all_special_tokens.index(t) for t in self.all_special_tokens ] @@ -335,12 +335,12 @@ class MistralTokenizer(TokenizerBase): input_ids = self.encode_one(text, truncation, max_length) return Encoding(input_ids=input_ids) - def get_vocab(self) -> Dict[str, int]: + def get_vocab(self) -> dict[str, int]: # NB: the dictionary form of the vocabulary collapses token ids that map # to the same string but have different bytes return self._vocab_dict - def get_added_vocab(self) -> Dict[str, int]: + def get_added_vocab(self) -> dict[str, int]: # Mistral tokenizers have no added vocabulary return {} diff --git a/vllm/triton_utils/__init__.py b/vllm/triton_utils/__init__.py index 43918bcd7c5566b1ebd67d1a675385a7ab0ea35f..bffc56a2e75ca86dc648b622aa35a17bc7a8242e 100644 --- a/vllm/triton_utils/__init__.py +++ b/vllm/triton_utils/__init__.py @@ -2,4 +2,4 @@ from vllm.triton_utils.importing import HAS_TRITON -__all__ = ["HAS_TRITON"] \ No newline at end of file +__all__ = ["HAS_TRITON"] diff --git a/vllm/triton_utils/importing.py b/vllm/triton_utils/importing.py index a20700248c26872202e703b3283289b7e99785a9..fa29efbf6b2d439ff9753f9b58f09c5b1f09c330 100644 --- a/vllm/triton_utils/importing.py +++ b/vllm/triton_utils/importing.py @@ -1,17 +1,53 @@ # SPDX-License-Identifier: Apache-2.0 +import sys +import types from importlib.util import find_spec from vllm.logger import init_logger -from vllm.platforms import current_platform logger = init_logger(__name__) HAS_TRITON = ( find_spec("triton") is not None - and not current_platform.is_xpu() # Not compatible + or find_spec("pytorch-triton-xpu") is not None # Not compatible ) if not HAS_TRITON: logger.info("Triton not installed or not compatible; certain GPU-related" " functions will not be available.") + + class TritonPlaceholder(types.ModuleType): + + def __init__(self): + super().__init__("triton") + self.jit = self._dummy_decorator("jit") + self.autotune = self._dummy_decorator("autotune") + self.heuristics = self._dummy_decorator("heuristics") + self.language = TritonLanguagePlaceholder() + logger.warning_once( + "Triton is not installed. Using dummy decorators. " + "Install it via `pip install triton` to enable kernel" + "compilation.") + + def _dummy_decorator(self, name): + + def decorator(func=None, **kwargs): + if func is None: + return lambda f: f + return func + + return decorator + + class TritonLanguagePlaceholder(types.ModuleType): + + def __init__(self): + super().__init__("triton.language") + self.constexpr = None + self.dtype = None + + sys.modules['triton'] = TritonPlaceholder() + sys.modules['triton.language'] = TritonLanguagePlaceholder() + +if 'triton' in sys.modules: + logger.info("Triton module has been replaced with a placeholder.") diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index 2ee3f9104d19710bb7e90f93c43d2dad28439f77..67b834533b7d60ad127625b7f4fcdd90b24eca8a 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -19,6 +19,7 @@ import torch import vllm.envs as envs from vllm.connections import global_http_connection +from vllm.utils import cuda_device_count_stateless, cuda_get_device_properties from vllm.version import __version__ as VLLM_VERSION _config_home = envs.VLLM_CONFIG_ROOT @@ -168,12 +169,20 @@ class UsageMessage: # Platform information from vllm.platforms import current_platform if current_platform.is_cuda_alike(): - device_property = torch.cuda.get_device_properties(0) - self.gpu_count = torch.cuda.device_count() - self.gpu_type = device_property.name - self.gpu_memory_per_device = device_property.total_memory + self.gpu_count = cuda_device_count_stateless() + self.gpu_type, self.gpu_memory_per_device = ( + cuda_get_device_properties(0, ("name", "total_memory"))) if current_platform.is_cuda(): self.cuda_runtime = torch.version.cuda + if current_platform.is_tpu(): + try: + import torch_xla + self.gpu_count = torch_xla.runtime.world_size() + self.gpu_type = torch_xla.tpu.get_tpu_type() + self.gpu_memory_per_device = ( + torch_xla.core.xla_model.get_memory_info()["bytes_limit"]) + except Exception: + pass self.provider = _detect_cloud_provider() self.architecture = platform.machine() self.platform = platform.platform() diff --git a/vllm/utils.py b/vllm/utils.py index 95f99b3e53cf33a88c5b63ef9567b5cc3313ff3f..aa40d096bee7541f84a671004771e2c9cf2aacd2 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -38,11 +38,13 @@ from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task from collections import UserDict, defaultdict from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable, Iterable, Iterator, KeysView, Mapping) +from concurrent.futures.process import ProcessPoolExecutor from dataclasses import dataclass, field from functools import cache, lru_cache, partial, wraps from types import MappingProxyType from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple, - Optional, Tuple, Type, TypeVar, Union, cast, overload) + Optional, Sequence, Tuple, Type, TypeVar, Union, cast, + overload) from uuid import uuid4 import cachetools @@ -61,6 +63,9 @@ from torch.library import Library from typing_extensions import Never, ParamSpec, TypeIs, assert_never import vllm.envs as envs +# NOTE: import triton_utils to make TritonPlaceholderModule work +# if triton is unavailable +import vllm.triton_utils # noqa: F401 from vllm.logger import enable_trace_function_call, init_logger import json @@ -237,6 +242,12 @@ class CacheInfo(NamedTuple): return self.hits / self.total + def __sub__(self, other: CacheInfo): + return CacheInfo( + hits=self.hits - other.hits, + total=self.total - other.total, + ) + class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): @@ -244,15 +255,26 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): capacity: float, getsizeof: Optional[Callable[[_V], float]] = None): super().__init__(capacity, getsizeof) + self.pinned_items = set[_K]() - self.capacity = capacity self._hits = 0 self._total = 0 + self._last_info = CacheInfo(hits=0, total=0) + + def __getitem__(self, key: _K, *, update_info: bool = True) -> _V: + value = super().__getitem__(key) + + if update_info: + self._hits += 1 + self._total += 1 + + return value def __delitem__(self, key: _K) -> None: run_on_remove = key in self - value = self.__getitem__(key) + value = self.__getitem__(key, + update_info=False) # type: ignore[call-arg] super().__delitem__(key) if key in self.pinned_items: # Todo: add warning to inform that del pinned item @@ -272,11 +294,38 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): """Return the internal order dictionary (read-only).""" return MappingProxyType(self._LRUCache__order) # type: ignore - def stat(self) -> CacheInfo: - return CacheInfo(hits=self._hits, total=self._total) + @property + def capacity(self) -> float: + return self.maxsize + + @property + def usage(self) -> float: + if self.maxsize == 0: + return 0 + + return self.currsize / self.maxsize + + def stat(self, *, delta: bool = False) -> CacheInfo: + """ + Gets the cumulative number of hits and queries against this cache. + + If :code:`delta=True`, instead gets these statistics + since the last call that also passed :code:`delta=True`. + """ + info = CacheInfo(hits=self._hits, total=self._total) + + if delta: + info_delta = info - self._last_info + self._last_info = info + info = info_delta + + return info def touch(self, key: _K) -> None: - self._LRUCache__update(key) # type: ignore + try: + self._LRUCache__order.move_to_end(key) # type: ignore + except KeyError: + self._LRUCache__order[key] = None # type: ignore @overload def get(self, key: _K, /) -> Optional[_V]: @@ -293,7 +342,8 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): _T]] = None) -> Optional[Union[_V, _T]]: value: Optional[Union[_V, _T]] if key in self: - value = self.__getitem__(key) + value = self.__getitem__( + key, update_info=False) # type: ignore[call-arg] self._hits += 1 else: @@ -318,8 +368,9 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): if key not in self: return default - value = self[key] - del self[key] + value = self.__getitem__(key, + update_info=False) # type: ignore[call-arg] + self.__delitem__(key) return value def put(self, key: _K, value: _V) -> None: @@ -354,10 +405,6 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): while self.currsize > self.capacity: self.remove_oldest() - def clear(self) -> None: - while len(self) > 0: - self.remove_oldest(remove_pinned=True) - def popitem(self, remove_pinned: bool = False): """Remove and return the `(key, value)` pair least recently used.""" if not remove_pinned: @@ -373,6 +420,14 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): value = self.pop(cast(_K, lru_key)) return (lru_key, value) + def clear(self) -> None: + while len(self) > 0: + self.remove_oldest(remove_pinned=True) + + self._hits = 0 + self._total = 0 + self._last_info = CacheInfo(hits=0, total=0) + class PyObjectCache: """Used to cache python objects to avoid object allocations @@ -579,12 +634,12 @@ def get_open_port() -> int: process. Currently it uses 2 ports. """ if "VLLM_DP_MASTER_PORT" in os.environ: - dp_port = envs.VLLM_DP_MASTER_PORT + dp_master_port = envs.VLLM_DP_MASTER_PORT + reserved_port_range = range(dp_master_port, dp_master_port + 10) while True: - port = _get_open_port() - if dp_port <= port < dp_port + 10: - continue - return port + candidate_port = _get_open_port() + if candidate_port not in reserved_port_range: + return candidate_port return _get_open_port() @@ -711,21 +766,28 @@ def create_kv_caches_with_random_flash( model_dtype: Optional[Union[str, torch.dtype]] = None, seed: Optional[int] = None, device: Optional[str] = "cuda", + cache_layout: Optional[str] = "NHD", ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: from vllm.platforms import current_platform current_platform.seed_everything(seed) torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) - key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) + generic_kv_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) + assert cache_layout in ("NHD", "HND") + stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2, + 4) + + kv_cache_allocation_shape = tuple(generic_kv_cache_shape[i] + for i in stride_order) scale = head_size**-0.5 key_caches: list[torch.Tensor] = [] value_caches: list[torch.Tensor] = [] for _ in range(num_layers): - key_value_cache = torch.empty(size=key_value_cache_shape, + key_value_cache = torch.empty(size=kv_cache_allocation_shape, dtype=torch_dtype, - device=device) + device=device).permute(*stride_order) if cache_dtype in ["auto", "half", "bfloat16", "float"]: key_value_cache.uniform_(-scale, scale) elif cache_dtype == 'fp8': @@ -1189,6 +1251,22 @@ def cuda_is_initialized() -> bool: return torch.cuda.is_initialized() +def cuda_get_device_properties(device, + names: Sequence[str], + init_cuda=False) -> tuple[Any, ...]: + """Get specified CUDA device property values without initializing CUDA in + the current process.""" + if init_cuda or cuda_is_initialized(): + props = torch.cuda.get_device_properties(device) + return tuple(getattr(props, name) for name in names) + + # Run in subprocess to avoid initializing CUDA as a side effect. + mp_ctx = multiprocessing.get_context("fork") + with ProcessPoolExecutor(max_workers=1, mp_context=mp_ctx) as executor: + return executor.submit(cuda_get_device_properties, device, names, + True).result() + + def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]: """Make an instance method that weakly references its associated instance and no-ops once that diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index b4c7708daab919af36ec0848c1e3014be6374321..41bb9aba2995398d90d9cbfecbbc367176e25c02 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -11,11 +11,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType, is_quantized_kv_cache) from vllm.attention.ops.merge_attn_states import merge_attn_states +from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, + get_flash_attn_version) from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import cdiv -from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8, - get_flash_attn_version) if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -23,7 +23,8 @@ if TYPE_CHECKING: from vllm.v1.worker.gpu_model_runner import GPUModelRunner if current_platform.is_cuda(): - from vllm.vllm_flash_attn import flash_attn_varlen_func + from vllm.vllm_flash_attn import (flash_attn_varlen_func, + get_scheduler_metadata) logger = init_logger(__name__) @@ -63,10 +64,6 @@ class FlashAttentionBackend(AttentionBackend): raise ValueError("Block size must be a multiple of 16.") return (2, num_blocks, block_size, num_kv_heads, head_size) - @staticmethod - def use_cascade_attention(*args, **kwargs) -> bool: - return use_cascade_attention(*args, **kwargs) - @dataclass class FlashAttentionMetadata: @@ -93,6 +90,10 @@ class FlashAttentionMetadata: prefix_kv_lens: Optional[torch.Tensor] suffix_kv_lens: Optional[torch.Tensor] + # Optional aot scheduling + scheduler_metadata: Optional[torch.Tensor] = None + prefix_scheduler_metadata: Optional[torch.Tensor] = None + # For logging. num_input_tokens: int = 0 # Number of tokens including padding. @@ -104,6 +105,7 @@ class FlashAttentionMetadata: local_block_table: torch.Tensor local_max_query_len: int local_max_seq_len: int + local_scheduler_metadata: Optional[torch.Tensor] local_attn_metadata: Optional[LocalAttentionMetadata] = None @@ -277,7 +279,16 @@ def make_local_attention_virtual_batches( class FlashAttentionMetadataBuilder: def __init__(self, runner: "GPUModelRunner"): + model_config = runner.model_config + self.runner = runner + self.aot_schedule = (get_flash_attn_version() == 3) + self.num_heads_q = model_config.get_num_attention_heads( + runner.parallel_config) + self.num_heads_kv = model_config.get_num_kv_heads( + runner.parallel_config) + self.headdim = model_config.get_head_size() + self.page_size = self.runner.block_size def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: @@ -296,6 +307,23 @@ class FlashAttentionMetadataBuilder: slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( self.runner.device, non_blocking=True).long() + def schedule(batch_size, cu_query_lens, max_query_len, seqlens, + max_seq_len, causal): + if self.aot_schedule: + return get_scheduler_metadata( + batch_size=batch_size, + max_seqlen_q=max_query_len, + max_seqlen_k=max_seq_len, + cache_seqlens=seqlens, + num_heads_q=self.num_heads_q, + num_heads_kv=self.num_heads_kv, + headdim=self.headdim, + page_size=self.page_size, + cu_seqlens_q=cu_query_lens, + causal=causal, + ) + return None + # for local attention local_attn_metadata = None if self.runner.attention_chunk_size is not None: @@ -307,18 +335,31 @@ class FlashAttentionMetadataBuilder: block_table, self.runner.block_size, ) + local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to( + self.runner.device, non_blocking=True) + local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to( + self.runner.device, non_blocking=True) + local_max_query_len = seqlens_q_local_np.max() + local_max_seq_len = virt_k_seqlens_np.max() + local_scheduler_metadata = schedule( + batch_size=local_query_start_loc.shape[0] - 1, + cu_query_lens=local_query_start_loc, + max_query_len=local_max_query_len, + seqlens=local_seqused_k, + max_seq_len=local_max_seq_len, + causal=True) + local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata( - local_query_start_loc=torch.from_numpy( - virt_q_cu_seqlens_np).to(self.runner.device, - non_blocking=True), - local_seqused_k=torch.from_numpy(virt_k_seqlens_np).to( - self.runner.device, non_blocking=True), + local_query_start_loc=local_query_start_loc, + local_seqused_k=local_seqused_k, local_block_table=virt_block_table, - local_max_query_len=seqlens_q_local_np.max(), - local_max_seq_len=virt_k_seqlens_np.max(), + local_max_query_len=local_max_query_len, + local_max_seq_len=local_max_seq_len, + local_scheduler_metadata=local_scheduler_metadata, ) use_cascade = common_prefix_len > 0 + if use_cascade: cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], dtype=torch.int32, @@ -330,10 +371,31 @@ class FlashAttentionMetadataBuilder: common_prefix_len) suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to( self.runner.device) + prefix_scheduler_metadata = schedule( + batch_size=1, + cu_query_lens=cu_prefix_query_lens, + max_query_len=num_actual_tokens, + seqlens=prefix_kv_lens, + max_seq_len=common_prefix_len, + causal=False) + scheduler_metadata = schedule(batch_size=num_reqs, + cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=suffix_kv_lens, + max_seq_len=max_seq_len - + common_prefix_len, + causal=True) else: cu_prefix_query_lens = None prefix_kv_lens = None suffix_kv_lens = None + prefix_scheduler_metadata = None + scheduler_metadata = schedule(batch_size=num_reqs, + cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=seq_lens, + max_seq_len=max_seq_len, + causal=True) attn_metadata = FlashAttentionMetadata( num_actual_tokens=num_actual_tokens, @@ -345,13 +407,18 @@ class FlashAttentionMetadataBuilder: slot_mapping=slot_mapping, use_cascade=use_cascade, common_prefix_len=common_prefix_len, + scheduler_metadata=scheduler_metadata, cu_prefix_query_lens=cu_prefix_query_lens, prefix_kv_lens=prefix_kv_lens, suffix_kv_lens=suffix_kv_lens, local_attn_metadata=local_attn_metadata, + prefix_scheduler_metadata=prefix_scheduler_metadata, ) return attn_metadata + def use_cascade_attention(self, *args, **kwargs) -> bool: + return use_cascade_attention(*args, **kwargs) + class FlashAttentionImpl(AttentionImpl): @@ -491,12 +558,14 @@ class FlashAttentionImpl(AttentionImpl): max_seqlen_q = local_metadata.local_max_query_len max_seqlen_k = local_metadata.local_max_seq_len block_table = local_metadata.local_block_table + scheduler_metadata = local_metadata.local_scheduler_metadata else: cu_seqlens_q = attn_metadata.query_start_loc seqused_k = attn_metadata.seq_lens max_seqlen_q = attn_metadata.max_query_len max_seqlen_k = attn_metadata.max_seq_len block_table = attn_metadata.block_table + scheduler_metadata = attn_metadata.scheduler_metadata descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) @@ -515,6 +584,7 @@ class FlashAttentionImpl(AttentionImpl): window_size=self.sliding_window, block_table=block_table, softcap=self.logits_soft_cap, + scheduler_metadata=scheduler_metadata, fa_version=self.vllm_flash_attn_version, q_descale=layer._q_scale.expand(descale_shape), k_descale=layer._k_scale.expand(descale_shape), @@ -543,6 +613,8 @@ class FlashAttentionImpl(AttentionImpl): block_table=attn_metadata.block_table, common_prefix_len=attn_metadata.common_prefix_len, fa_version=self.vllm_flash_attn_version, + prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata, + suffix_scheduler_metadata=attn_metadata.scheduler_metadata, q_descale=layer._q_scale, k_descale=layer._k_scale, v_descale=layer._v_scale, @@ -636,6 +708,8 @@ def cascade_attention( block_table: torch.Tensor, common_prefix_len: int, fa_version: int, + prefix_scheduler_metadata: Optional[torch.Tensor] = None, + suffix_scheduler_metadata: Optional[torch.Tensor] = None, q_descale: Optional[torch.Tensor] = None, k_descale: Optional[torch.Tensor] = None, v_descale: Optional[torch.Tensor] = None, @@ -667,6 +741,7 @@ def cascade_attention( block_table=block_table[:1], softcap=logits_soft_cap, return_softmax_lse=True, + scheduler_metadata=prefix_scheduler_metadata, fa_version=fa_version, q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, @@ -693,6 +768,7 @@ def cascade_attention( block_table=block_table[:, num_common_kv_blocks:], softcap=logits_soft_cap, return_softmax_lse=True, + scheduler_metadata=suffix_scheduler_metadata, fa_version=fa_version, q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py new file mode 100755 index 0000000000000000000000000000000000000000..bce446bd2b827aa027a95c57ace5c8a619df1451 --- /dev/null +++ b/vllm/v1/attention/backends/flashinfer.py @@ -0,0 +1,638 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Attention layer with FlashInfer.""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional + +import torch +from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, + BatchPrefillWithPagedKVCacheWrapper, + MultiLevelCascadeAttentionWrapper) + +import vllm.envs as envs +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionType) +from vllm.attention.layer import Attention +from vllm.config import (VllmConfig, get_current_vllm_config, + get_layers_from_vllm_config) +from vllm.logger import init_logger +from vllm.v1.attention.backends.flash_attn import use_cascade_attention + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch + from vllm.v1.worker.gpu_model_runner import GPUModelRunner + +FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 + +logger = init_logger(__name__) + + +class FlashInferBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> list[int]: + return [64, 128, 256] + + @staticmethod + def get_name() -> str: + return "FLASHINFER_VLLM_V1" + + @staticmethod + def get_impl_cls() -> type[FlashInferImpl]: + return FlashInferImpl + + @staticmethod + def get_metadata_cls() -> type[FlashInferMetadata]: + return FlashInferMetadata + + @staticmethod + def get_builder_cls() -> type[FlashInferMetadataBuilder]: + return FlashInferMetadataBuilder + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> tuple[int, ...]: + return (num_blocks, 2, block_size, num_kv_heads, head_size) + + +@dataclass +class PerLayerParameters: + """ + Currently, FlashInfer backend only support models in which all layers share + the same values for the following hyperparameters. + """ + + window_left: int + logits_soft_cap: Optional[float] + sm_scale: float + + +def get_per_layer_parameters( + vllm_config: VllmConfig) -> dict[str, PerLayerParameters]: + """ + Scan all attention layers and determine some hyperparameters + to use during `plan`. + """ + + layers = get_layers_from_vllm_config(vllm_config, Attention) + per_layer_params: dict[str, PerLayerParameters] = {} + + for key, layer in layers.items(): + impl = layer.impl + assert isinstance(impl, FlashInferImpl) + + # Infer hyperparameters from the attention layer + window_size = impl.sliding_window + window_left = window_size[0] if window_size is not None else -1 + logits_soft_cap = impl.logits_soft_cap + sm_scale = impl.scale + + per_layer_params[key] = PerLayerParameters(window_left, + logits_soft_cap, sm_scale) + + return per_layer_params + + +def infer_global_hyperparameters( + per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters: + """ + Currently, FlashInfer backend only support models in which all layers share + the same values for the following hyperparameters: + - `window_left` + - `logits_soft_cap` + - `sm_scale` + + So this function asserts that all layers share the same values for these + hyperparameters and returns the global values. + """ + + assert len(per_layer_params) > 0, "No attention layers found in the model." + + param_sets = list(per_layer_params.values()) + global_params = param_sets[0] + for params in param_sets: + assert params == global_params, ( + "FlashInfer backend currently only supports models in which all " + "layers share the same values for the following hyperparameters: " + "`window_left`, `logits_soft_cap`, `sm_scale`.") + + return global_params + + +@dataclass +class FlashInferMetadata: + + num_actual_tokens: int # Number of tokens excluding padding. + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + qo_indptr: torch.Tensor + # An example for paged_kv_indices, paged_kv_indptr: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + # The indptr of the paged kv cache, shape: [batch_size + 1] + paged_kv_indptr: torch.Tensor + # The page indices of the paged kv cache + paged_kv_indices: torch.Tensor + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + paged_kv_last_page_len: torch.Tensor + # The number of query/output heads + num_qo_heads: int + # The number of key/value heads + num_kv_heads: int + # The dimension of the attention heads + head_dim: int + # Block size of vllm + page_size: int + # The data type of the paged kv cache + data_type: torch.dtype + # The data type of the query + q_data_type: torch.dtype + + slot_mapping: torch.Tensor + + # For handling prefill decode split + num_decodes: int + num_decode_tokens: int + num_prefills: int + num_prefill_tokens: int + + # For cascade attention. + use_cascade: bool + shared_qo_indptr: Optional[torch.Tensor] = None + shared_kv_page_indptr: Optional[torch.Tensor] = None + shared_kv_page_indices: Optional[torch.Tensor] = None + shared_kv_last_page_len: Optional[torch.Tensor] = None + + prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None + decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None + cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None + + # For logging. + num_input_tokens: int = 0 # Number of tokens including padding. + + @property + def query_start_loc(self): + # The GPUModelRunner expects to be able to access this property. + return self.qo_indptr + + def __post_init__(self): + # Refer to + # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 + supported_head_sizes = FlashInferBackend.get_supported_head_sizes() + if self.head_dim is not None and self.head_dim \ + not in supported_head_sizes: + raise ValueError( + f"Only {supported_head_sizes} are supported for head_dim,", + f" received {self.head_dim}.") + + +class FlashInferMetadataBuilder: + + def __init__(self, runner: GPUModelRunner): + self.runner = runner + self._workspace_buffer = None + self._prefill_wrapper = None # Wrapper for prefill/append + self._decode_wrapper = None # Wrapper for decode + self._cascade_wrapper = None # Wrapper for cascade attention + + # Global hyperparameters shared by all attention layers + self.global_hyperparameters: Optional[PerLayerParameters] = None + + self.vllm_config = get_current_vllm_config() + + def reorder_batch(self, input_batch: InputBatch, + scheduler_output: SchedulerOutput) -> bool: + # We now want to reorder the batch so that the "decode" requests are and + # the front and the "prefill" requests are at the using the least amount + # swaps possible. (NOTE for now we loosely use "decode" to mean requests + # where attention is likely memory-bound and "prefill" to mean requests + # where attention is likely compute-bound, TODO(lucas): figure out a + # better naming here) + decodes = [] + prefills = [] + num_decode_tokens = 0 + num_prefill_tokens = 0 + + for i, req_id in enumerate(input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + # for now treat 1 scheduled token as "decode" even if its not, + # we should update this to something like < 8 in the future but + # currently the decode run only supports num_tokens = 1 + if num_tokens == 1: + decodes.append(i) + num_decode_tokens += num_tokens + else: + prefills.append(i) + num_prefill_tokens += num_tokens + + # We hope that this is fairly minimal since decodes + # should be around for a number of iterations so hopefully they are + # relatively stationary (and new request are generally appended to the + # persistent batch so already should be at the back) + # To achieve this we loop over the decodes in descending order and + # the prefills in ascending order. We swap decodes from the "back" + # i.e. past where the last decode should be in the reodorered with + # prefills from the front of the batch. + # `decodes` and `prefills` are already in ascending order just based on + # the above loop + num_decodes = len(decodes) + num_prefills = len(prefills) + modified_batch = False + + for i in range(1, min(num_decodes, num_prefills) + 1): + # If the decode is at the "back" of the batch, i, we can swap it + # with the prefill closest to the front of the batch + decode_idx = decodes[num_decodes - i] + if decode_idx < num_decodes: + break + + input_batch.swap_states(prefills[i - 1], decode_idx) + modified_batch = True + + # Save for next `build` call + # TODO(lucas): this is a bit of a hack, we should probably have a + # better way of doing this + self._num_decodes = num_decodes + self._num_prefills = num_prefills + self._num_decode_tokens = num_decode_tokens + self._num_prefill_tokens = num_prefill_tokens + + return modified_batch + + def _get_workspace_buffer(self): + if self._workspace_buffer is None: + self._workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.runner.device) + return self._workspace_buffer + + def _get_prefill_wrapper(self): + if self._prefill_wrapper is None: + self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( + self._get_workspace_buffer(), "NHD") + return self._prefill_wrapper + + def _get_decode_wrapper(self): + if self._decode_wrapper is None: + num_qo_heads = (self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config)) + num_kv_heads = self.runner.model_config.get_num_kv_heads( + self.runner.parallel_config) + use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( + num_qo_heads // num_kv_heads > 4) + self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + self._get_workspace_buffer(), + "NHD", + use_tensor_cores=use_tensor_cores) + return self._decode_wrapper + + def _get_cascade_wrapper(self): + if self._cascade_wrapper is None: + self._cascade_wrapper = MultiLevelCascadeAttentionWrapper( + 2, self._get_workspace_buffer(), "NHD") + return self._cascade_wrapper + + def _plan(self, attn_metadata: FlashInferMetadata): + if self.global_hyperparameters is None: + self.global_hyperparameters = infer_global_hyperparameters( + get_per_layer_parameters(self.vllm_config)) + if attn_metadata.use_cascade: + attn_metadata.cascade_wrapper = self._get_cascade_wrapper() + attn_metadata.cascade_wrapper.plan( + [attn_metadata.shared_qo_indptr, attn_metadata.qo_indptr], + [ + attn_metadata.shared_kv_page_indptr, + attn_metadata.paged_kv_indptr + ], + [ + attn_metadata.shared_kv_page_indices, + attn_metadata.paged_kv_indices + ], + [ + attn_metadata.shared_kv_last_page_len, + attn_metadata.paged_kv_last_page_len + ], + attn_metadata.num_qo_heads, + attn_metadata.num_kv_heads, + attn_metadata.head_dim, + attn_metadata.page_size, + causal=True, + sm_scale=self.global_hyperparameters.sm_scale, + window_left=self.global_hyperparameters.window_left, + logits_soft_cap=self.global_hyperparameters.logits_soft_cap, + q_data_type=attn_metadata.q_data_type, + ) + else: + # Regular attention (common case). + # Decodes are at the front and prefills are at the back, + # according to reorder_batch() + if self._num_prefills > 0: + # Decodes are first so prefills start after the last decode + prefill_start = self._num_decodes + attn_metadata.prefill_wrapper = self._get_prefill_wrapper() + assert attn_metadata.qo_indptr[prefill_start:].shape[ + 0] == self._num_prefills + 1 + assert attn_metadata.paged_kv_indptr[prefill_start:].shape[ + 0] == self._num_prefills + 1 + assert attn_metadata.paged_kv_last_page_len[ + prefill_start:].shape[0] == self._num_prefills + # Since prefill_wrapper.run() will be called with + # query[num_decode_tokens:] we need to adjust the qo_indptr + # to be relative to the start of the prefill queries. + qo_indptr = attn_metadata.qo_indptr[ + prefill_start:] - attn_metadata.qo_indptr[prefill_start] + attn_metadata.prefill_wrapper.plan( + qo_indptr, + attn_metadata.paged_kv_indptr[prefill_start:], + attn_metadata.paged_kv_indices, + attn_metadata.paged_kv_last_page_len[prefill_start:], + attn_metadata.num_qo_heads, + attn_metadata.num_kv_heads, + attn_metadata.head_dim, + attn_metadata.page_size, + causal=True, + sm_scale=self.global_hyperparameters.sm_scale, + window_left=self.global_hyperparameters.window_left, + logits_soft_cap=self.global_hyperparameters. + logits_soft_cap, + q_data_type=attn_metadata.q_data_type, + kv_data_type=attn_metadata.data_type, + ) + + if self._num_decodes > 0: + attn_metadata.decode_wrapper = self._get_decode_wrapper() + attn_metadata.decode_wrapper.plan( + attn_metadata.paged_kv_indptr[:self._num_decodes + 1], + attn_metadata.paged_kv_indices, + attn_metadata.paged_kv_last_page_len[:self._num_decodes], + attn_metadata.num_qo_heads, + attn_metadata.num_kv_heads, + attn_metadata.head_dim, + attn_metadata.page_size, + # Disable flashinfer's pos encoding and use vllm's rope. + pos_encoding_mode="NONE", + sm_scale=self.global_hyperparameters.sm_scale, + window_left=self.global_hyperparameters.window_left, + logits_soft_cap=self.global_hyperparameters. + logits_soft_cap, + q_data_type=attn_metadata.q_data_type, + kv_data_type=attn_metadata.data_type, + ) + + def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, + common_prefix_len: int): + assert self._num_decodes + self._num_prefills == num_reqs + assert (self._num_decode_tokens + + self._num_prefill_tokens == num_actual_tokens) + page_size = self.runner.block_size + device = self.runner.device + qo_indptr = self.runner.query_start_loc_cpu[:num_reqs + 1].to( + self.runner.device, non_blocking=True) + seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(self.runner.device, + non_blocking=True) + block_table = ( + self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) + slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( + self.runner.device, non_blocking=True).long() + + block_table_bounds = (seq_lens + page_size - 1) // page_size + + use_cascade = common_prefix_len > 0 + if use_cascade: + # Grab the blocks of the shared prefix from the first request. + assert common_prefix_len % page_size == 0 + num_common_kv_blocks = common_prefix_len // page_size + shared_qo_indptr = torch.tensor([0, num_actual_tokens], + dtype=torch.int32, + device=device) + shared_kv_page_indptr = torch.tensor([0, num_common_kv_blocks], + dtype=torch.int32, + device=device) + shared_kv_page_indices = block_table[0, :num_common_kv_blocks] + shared_kv_last_page_len = torch.tensor([page_size], + dtype=torch.int32, + device=device) + # Remove the blocks of the shared prefix from all requests. + block_table = block_table[:, num_common_kv_blocks:] + block_table_bounds -= num_common_kv_blocks + else: + shared_qo_indptr = None + shared_kv_page_indptr = None + shared_kv_page_indices = None + shared_kv_last_page_len = None + + mask = (torch.arange(block_table.size(1), + dtype=block_table.dtype, + device=block_table.device).unsqueeze(0) + < block_table_bounds.unsqueeze(1)) + paged_kv_indices = block_table[mask] + + paged_kv_indptr = torch.cat([ + torch.zeros(1, + dtype=block_table_bounds.dtype, + device=block_table_bounds.device), + block_table_bounds.cumsum(dim=0, dtype=torch.int32) + ]) + + paged_kv_last_page_len = seq_lens % page_size + paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, + page_size, paged_kv_last_page_len) + + attn_metadata = FlashInferMetadata( + num_actual_tokens=num_actual_tokens, + qo_indptr=qo_indptr, + paged_kv_indptr=paged_kv_indptr, + paged_kv_indices=paged_kv_indices, + paged_kv_last_page_len=paged_kv_last_page_len, + num_qo_heads=self.runner.num_query_heads, + num_kv_heads=self.runner.num_kv_heads, + head_dim=self.runner.head_size, + page_size=page_size, + data_type=self.runner.kv_cache_dtype, + q_data_type=self.runner.dtype, + slot_mapping=slot_mapping, + num_decodes=self._num_decodes, + num_decode_tokens=self._num_decode_tokens, + num_prefills=self._num_prefills, + num_prefill_tokens=self._num_prefill_tokens, + use_cascade=use_cascade, + shared_qo_indptr=shared_qo_indptr, + shared_kv_page_indptr=shared_kv_page_indptr, + shared_kv_page_indices=shared_kv_page_indices, + shared_kv_last_page_len=shared_kv_last_page_len, + ) + + self._plan(attn_metadata) + + return attn_metadata + + def use_cascade_attention(self, *args, **kwargs) -> bool: + if self.runner.kv_cache_dtype != self.runner.model_config.dtype: + # TODO: The cascade wrapper currently does not support setting + # kv cache dtype to something different from query dtype. + return False + return use_cascade_attention(*args, **kwargs) + + +class FlashInferImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: AttentionType = AttentionType.DECODER, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + if sliding_window is None: + self.sliding_window = (-1, -1) + else: + self.sliding_window = (sliding_window - 1, 0) + self.kv_cache_dtype = kv_cache_dtype + self.logits_soft_cap = logits_soft_cap + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashInferImpl") + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashInferMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with FlashInfer. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + kv_cache = [num_blocks, 2, block_size, num_kv_heads, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert output is not None, "Output tensor must be provided." + + if attn_metadata is None: + # Profiling run. + return output + + # IMPORTANT! + # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in + # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead + # in this method. For example, `view` and `slice` (or `[:n]`) operations + # are surprisingly slow even in the case they do not invoke any GPU ops. + # Minimize the PyTorch ops in this method as much as possible. + # Whenever making a change in this method, please benchmark the + # performance to make sure it does not introduce any overhead. + + num_actual_tokens = attn_metadata.num_actual_tokens + # Reshape the input keys and values and store them in the cache. + # NOTE(woosuk): Here, key and value are padded while slot_mapping is + # not padded. However, we don't need to do key[:num_actual_tokens] and + # value[:num_actual_tokens] because the reshape_and_cache_flash op uses + # the slot_mapping's shape to determine the number of actual tokens. + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + kv_cache[:, 0], + kv_cache[:, 1], + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + window_left = (self.sliding_window[0] + if self.sliding_window is not None else -1) + + # Inputs and outputs may be padded for CUDA graphs + query = query[:num_actual_tokens] + output_padded = output + output = output[:num_actual_tokens] + + if attn_metadata.use_cascade: + # Cascade attention (rare case). + assert attn_metadata.cascade_wrapper is not None + output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache)) + return output + + num_decode_tokens = attn_metadata.num_decode_tokens + num_prefill_tokens = attn_metadata.num_prefill_tokens + + # Regular attention (common case). + # Decodes are at the front and prefills are at the back, + # according to reorder_batch() + if prefill_wrapper := attn_metadata.prefill_wrapper: + prefill_query = query[num_decode_tokens:] + assert prefill_query.shape[0] == num_prefill_tokens + assert prefill_wrapper is not None + assert prefill_wrapper._causal + assert prefill_wrapper._window_left == window_left + assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap + or 0.0) + assert prefill_wrapper._sm_scale == self.scale + prefill_wrapper.run( + prefill_query, + kv_cache, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + out=output[num_decode_tokens:], + ) + + if decode_wrapper := attn_metadata.decode_wrapper: + decode_query = query[:num_decode_tokens] + assert decode_query.shape[0] == num_decode_tokens + assert decode_wrapper is not None + assert decode_wrapper._window_left == window_left + assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap + or 0.0) + assert decode_wrapper._sm_scale == self.scale + decode_wrapper.run( + decode_query, + kv_cache, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + out=output[:num_decode_tokens], + ) + + return output_padded diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 8c7179ba0a8af406a0561e9a570077f17233d7d9..e6e483bae2bc8eba72e9de539499ca37dafc6136 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -195,7 +195,9 @@ from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, AttentionMetadata, MLAAttentionImpl) +from vllm.attention.backends.utils import get_mla_dims from vllm.attention.ops.merge_attn_states import merge_attn_states +from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, RowParallelLinear, @@ -203,13 +205,14 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.platforms import current_platform from vllm.utils import cdiv, round_down -from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version try: from vllm.vllm_flash_attn import flash_attn_varlen_func + is_vllm_fa = True except ImportError: # For rocm use upstream flash attention from flash_attn import flash_attn_varlen_func + is_vllm_fa = False if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -248,10 +251,6 @@ class MLACommonBackend(AttentionBackend): def get_supported_head_sizes() -> list[int]: return [576] - @staticmethod - def use_cascade_attention(*args, **kwargs) -> bool: - return False - @dataclass class MLACommonPrefillMetadata: @@ -350,6 +349,14 @@ class MLACommonMetadataBuilder(Generic[M]): model_config = runner.model_config cache_config = runner.cache_config self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled + self.num_heads = model_config.get_num_attention_heads( + runner.parallel_config) + self.mla_dims = get_mla_dims(model_config) + self.aot_schedule = is_vllm_fa and (get_flash_attn_version() == 3) + + # Dont try to access the runner on AMD + if self.aot_schedule: + self.page_size = self.runner.block_size if self.chunked_prefill_enabled: self.chunked_prefill_workspace_size = min( @@ -375,7 +382,6 @@ class MLACommonMetadataBuilder(Generic[M]): dtype=model_config.dtype, device=runner.device, ) - self.page_size = self.runner.block_size def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: @@ -415,20 +421,18 @@ class MLACommonMetadataBuilder(Generic[M]): # the above loop num_decodes = len(decodes) num_prefills = len(prefills) - first_prefill = 0 modified_batch = False for i in range(1, min(num_decodes, num_prefills) + 1): # If the decode is at the "back" of the batch, i, we can swap it # with the prefill closest to the front of the batch - if decodes[num_decodes - i] >= num_decodes: - input_batch.swap_states(prefills[first_prefill], - decodes[num_decodes - i]) - first_prefill += 1 - modified_batch = True - else: + decode_idx = decodes[num_decodes - i] + if decode_idx < num_decodes: break + input_batch.swap_states(prefills[i - 1], decode_idx) + modified_batch = True + # Save for next `build` call # TODO(lucas): this is a bit of a hack, we should probably have a # better way of doing this @@ -466,7 +470,6 @@ class MLACommonMetadataBuilder(Generic[M]): seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] seq_lens = seq_lens_cpu.to(device, non_blocking=True) - max_query_len = seq_lens_cpu.max().item() prefill_metadata = None if self._num_prefills > 0: @@ -477,6 +480,8 @@ class MLACommonMetadataBuilder(Generic[M]): num_computed_tokens_cpu_tensor[reqs_start:num_reqs] max_context_len_cpu = context_lens_cpu.max().item() num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() + prefill_query_start_loc = query_start_loc[ + reqs_start:] - query_start_loc[reqs_start] chunked_context_metadata = None if self.chunked_prefill_enabled and self._num_prefills > 0 \ @@ -539,8 +544,7 @@ class MLACommonMetadataBuilder(Generic[M]): prefill_metadata = MLACommonPrefillMetadata( input_positions=input_positions[tokens_start:], block_table=block_table[reqs_start:, ...], - query_start_loc=query_start_loc[reqs_start:] - - query_start_loc[reqs_start], + query_start_loc=prefill_query_start_loc, max_query_len=max_query_len, chunked_context=chunked_context_metadata, ) @@ -566,6 +570,9 @@ class MLACommonMetadataBuilder(Generic[M]): decode=decode_metadata, ) + def use_cascade_attention(self, *args, **kwargs) -> bool: + return False + class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): """ @@ -630,11 +637,56 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): # and the one from vllm_flash_attn. The former is used on RoCM and the # latter has an additional parameter to control FA2 vs FA3 self.flash_attn_varlen_func = flash_attn_varlen_func + self.vllm_flash_attn_version = get_flash_attn_version() if self.vllm_flash_attn_version is not None: self.flash_attn_varlen_func = \ functools.partial(flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version) + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim for attention backends that do + # not support different headdims + # We don't need to pad V if we are on a hopper system with FA3 + self._pad_v = self.vllm_flash_attn_version is None or not ( + self.vllm_flash_attn_version == 3 + and current_platform.get_device_capability()[0] == 9) + + def _flash_attn_varlen_diff_headdims(self, + q, + k, + v, + return_softmax_lse=False, + softmax_scale=None, + **kwargs): + maybe_padded_v = v + if self._pad_v: + maybe_padded_v = torch.nn.functional.pad( + v, [0, q.shape[-1] - v.shape[-1]], value=0) + + attn_out = self.flash_attn_varlen_func( + q=q, + k=k, + v=maybe_padded_v, + return_softmax_lse=return_softmax_lse, + softmax_scale=softmax_scale, + **kwargs, + ) + + # Unpack the output if there is multiple results + lse = None + if isinstance(attn_out, tuple): + attn_out, lse = attn_out[0], attn_out[1] + + # unpad if necessary + if self._pad_v: + attn_out = attn_out[..., :v.shape[-1]] + + # Remain consistent with old `flash_attn_varlen_func` where there + # is only one output tensor if `return_softmax_lse` is False. + if return_softmax_lse: + return attn_out, lse + return attn_out + def _v_up_proj_and_o_proj(self, x): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) @@ -747,16 +799,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - # For MLA the v head dim is smaller than qk head dim so we pad - # out v with 0s to match the qk head dim - v_padded = torch.nn.functional.pad(v, - [0, q.shape[-1] - v.shape[-1]], - value=0) - - attn_output, attn_softmax_lse = self.flash_attn_varlen_func( + attn_output, attn_softmax_lse = \ + self._flash_attn_varlen_diff_headdims( q=q, k=k, - v=v_padded, + v=v, cu_seqlens_q=prefill_metadata.query_start_loc, cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i], max_seqlen_q=prefill_metadata.max_query_len, @@ -803,15 +850,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - # For MLA the v head dim is smaller than qk head dim so we pad out - # v with 0s to match the qk head dim - v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], - value=0) - - output = self.flash_attn_varlen_func( + output = self._flash_attn_varlen_diff_headdims( q=q, k=k, - v=v_padded, + v=v, cu_seqlens_q=attn_metadata.prefill.query_start_loc, cu_seqlens_k=attn_metadata.prefill.query_start_loc, max_seqlen_q=attn_metadata.prefill.max_query_len, @@ -835,12 +877,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): suffix_lse=suffix_lse, ) - # slice by `:v.shape[-1]` in order to remove v headdim padding - output = output\ - .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ - .reshape(-1, self.num_heads * v.shape[-1]) - - return self.o_proj(output)[0] + return self.o_proj(output.flatten(start_dim=-2))[0] @abstractmethod def _forward_decode( diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 3e8149a24ebf734f22be4c600c2645ca14039962..05b97172bc6c0928ae5ca989c1cb957062090ab6 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -10,7 +10,9 @@ import torch_xla.experimental.custom_kernel # noqa: F401 from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) from vllm.attention.backends.utils import CommonAttentionState +from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.utils import cdiv logger = init_logger(__name__) @@ -50,6 +52,19 @@ class PallasAttentionBackend(AttentionBackend): ) -> None: raise RuntimeError("swap_blocks is not used for the TPU backend.") + # In recent TPU generations, up to v6e, the SMEM size is 1MB. The + # block_tables within the PallasMetadata constitute almost the entire SMEM + # requirement. Its size is max_num_seqs * num_page_per_seq * 4 (Int). Here + # we simply make sure that the size is smaller than half of SMEM capacity. + @staticmethod + def get_min_page_size(vllm_config: VllmConfig) -> int: + max_num_page_per_req = (1024 * 1024 // 2 // + vllm_config.scheduler_config.max_num_seqs // 4) + min_page_size = cdiv(vllm_config.model_config.max_model_len, + max_num_page_per_req) + min_page_size = 1 << (min_page_size - 1).bit_length() + return min_page_size + @dataclass class PallasMetadata: diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 33761cf7f9c01ed3923fa9493a5c7016a78c1b78..0830d8433d89ea186def033618b3a3db14f6f7f3 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -25,7 +25,7 @@ class KVCacheManager: max_model_len: int, enable_caching: bool = True, caching_hash_algo: str = "builtin", - num_preallocate_tokens: int = 64, + use_eagle: bool = False, log_stats: bool = False, ) -> None: assert len(kv_cache_config.kv_cache_groups) == 1, ( @@ -39,24 +39,12 @@ class KVCacheManager: self.enable_caching = enable_caching self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash - # FIXME: make prefix cache stats conditional on log_stats + self.use_eagle = use_eagle self.log_stats = log_stats - # NOTE(woosuk): To avoid frequent block allocation, we preallocate some - # blocks for each request. For example, when a request reaches the end - # of its block table, we preallocate N blocks in advance. This way, we - # reduce the overhead of updating free_block_ids and ref_cnts for each - # request every step (at the cost of some memory waste). - # NOTE(woosuk): This is different from the "lookahead" slots since this - # does not guarantee that the request always has N empty blocks. After - # the request gets N empty blocks, it starts to use the blocks without - # further allocation. When it uses up all the N empty blocks, it gets - # N new empty blocks. - self.num_preallocate_tokens = num_preallocate_tokens - self.num_preallocate_blocks = cdiv(num_preallocate_tokens, - self.block_size) + # FIXME: make prefix cache stats conditional on log_stats + self.prefix_cache_stats = PrefixCacheStats() if log_stats else None self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching) - self.specialized_manager = get_specialized_manager( kv_cache_spec=kv_cache_spec, block_pool=self.block_pool, @@ -79,7 +67,6 @@ class KVCacheManager: # This is only used to track the RUNNING requests, we do not track the # data for reempted ones. self.num_cached_block: dict[str, int] = {} - self.prefix_cache_stats = PrefixCacheStats() @property def usage(self) -> float: @@ -90,12 +77,14 @@ class KVCacheManager: """ return self.block_pool.get_usage() - def make_prefix_cache_stats(self) -> PrefixCacheStats: + def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]: """Get (and reset) the prefix cache stats. Returns: - The current prefix caching stats. + The current prefix caching stats, or None if logging is disabled. """ + if not self.log_stats: + return None stats = self.prefix_cache_stats self.prefix_cache_stats = PrefixCacheStats() return stats @@ -125,7 +114,9 @@ class KVCacheManager: self.block_size, request) self.req_to_block_hashes[request.request_id] = block_hashes - self.prefix_cache_stats.requests += 1 + if self.log_stats: + assert self.prefix_cache_stats is not None + self.prefix_cache_stats.requests += 1 # When the request requires prompt logprobs, we skip prefix caching. if request.sampling_params.prompt_logprobs is not None: return [], 0 @@ -145,8 +136,18 @@ class KVCacheManager: computed_blocks = ( self.specialized_manager.find_longest_cache_hit(block_hashes)) - self.prefix_cache_stats.queries += len(block_hashes) - self.prefix_cache_stats.hits += len(computed_blocks) + + if self.use_eagle and len(computed_blocks) > 0: + # Drop the last matched block if (1) eagle is enabled and + # (2) there is a cache hit. + # This is to recompute the last block to get the required + # hidden states for eagle drafting head. + computed_blocks.pop() + + if self.log_stats: + assert self.prefix_cache_stats is not None + self.prefix_cache_stats.queries += len(block_hashes) + self.prefix_cache_stats.hits += len(computed_blocks) if last_block_hash is not None: # Add back the last block hash if it was removed. @@ -171,8 +172,9 @@ class KVCacheManager: Args: request: The request to allocate slots. - num_tokens: The number of tokens to allocate. Note that this does - not include the tokens that have already been computed. + num_tokens: The number of tokens to allocate, including external + tokens. Note that this does not include tokens that have + already been computed locally (i.e. new_computed_blocks). new_computed_blocks: A list of new computed blocks just hitting the prefix caching. num_lookahead_tokens: The number of speculative tokens to allocate. @@ -249,13 +251,9 @@ class KVCacheManager: # No new block is needed. new_blocks = [] else: - # Get new blocks from the free block pool considering - # preallocated blocks. - num_preallocate_blocks = max( - 0, self.num_preallocate_blocks - - num_lookahead_tokens // self.block_size) + # Get new blocks from the free block pool. num_new_blocks = min( - num_new_blocks + num_preallocate_blocks, + num_new_blocks, self.block_pool.get_num_free_blocks(), # Should not exceed the maximum number of blocks per request. # This is especially because the block table has the shape @@ -316,17 +314,19 @@ class KVCacheManager: def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF - flows to invalid prefix caching after the weights are updated, + flows to invalidate prefix caching after the weights are updated, or used for resetting prefix caching status for benchmarking. Returns: bool: True if the prefix cache is successfully reset, False otherwise. """ - if self.block_pool.reset_prefix_cache(): + if not self.block_pool.reset_prefix_cache(): + return False + if self.log_stats: + assert self.prefix_cache_stats is not None self.prefix_cache_stats.reset = True - return True - return False + return True def get_num_common_prefix_blocks( self, diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index bd0e01d045d17fb8f7b995e358e96a1c6121e065..3026ecc1c968292157569eeb644116ac23268789 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -43,19 +43,19 @@ class BlockHashType(NamedTuple): # This aligns with the behavior of Python's hash() function, which also uses # a random seed if PYTHONHASHSEED is not set. NONE_HASH = int.from_bytes(os.urandom(32), byteorder="big") if os.getenv( - 'PYTHONHASHSEED') is not None else sha256(os.getenv('PYTHONHASHSEED')) + 'PYTHONHASHSEED') is None else sha256(os.getenv('PYTHONHASHSEED')) class PrefixCachingMetrics: - """Metrics for prefix caching with a hit rate of the most recent N requests. + """Metrics for prefix caching with a hit rate of the max recent N requests. Args: - interval: The number of the most recent requests to aggregate. + max_recent_requests: The number of the max recent requests to aggregate. Defaults to 1000. """ - def __init__(self, interval: int = 1000): - self.interval = interval + def __init__(self, max_recent_requests: int = 1000): + self.max_recent_requests = max_recent_requests # The current aggregated values. self.aggregated_requests = 0 self.aggregated_query_total = 0 @@ -70,7 +70,7 @@ class PrefixCachingMetrics: are being scheduled and are looking for computed blocks. When there are more than `interval` requests, the oldest set of - requestsare removed from the metrics. + requests are removed from the metrics. Args: stats: The prefix cache stats. @@ -87,7 +87,7 @@ class PrefixCachingMetrics: self.aggregated_query_hit += stats.hits # Remove the oldest stats if the number of requests exceeds. - if self.aggregated_requests > self.interval: + if self.aggregated_requests > self.max_recent_requests: old_requests, old_queries, old_hits = self.query_queue.popleft() self.aggregated_requests -= old_requests self.aggregated_query_total -= old_queries diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index bfed44f9d58c82601d19436e2307a39c3e5ef068..1de236d42f02540f5241518cbd99c4b15f77b7b5 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -117,11 +117,6 @@ class SchedulerInterface(ABC): not yet returned in SchedulerOutputs.""" return self.has_unfinished_requests() or self.has_finished_requests() - @abstractmethod - def get_num_unscheduled_requests(self) -> int: - """Number of requests that are not being processed by the executor.""" - raise NotImplementedError - @abstractmethod def reset_prefix_cache(self) -> bool: """Reset the prefix cache for KV cache. diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index dc0d2d59fea7f3cd6cbc2c3d90d4e708541a2f69..928fb231a1f2d888b66dc850c9fb51af564ae010 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -9,6 +9,8 @@ if TYPE_CHECKING: import numpy as np import numpy.typing as npt + from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorMetadata) from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams @@ -20,7 +22,6 @@ class NewRequestData: req_id: str prompt_token_ids: list[int] - prompt: Optional[str] mm_inputs: list[MultiModalKwargs] mm_hashes: list[str] mm_positions: list[PlaceholderRange] @@ -38,7 +39,6 @@ class NewRequestData: return cls( req_id=request.request_id, prompt_token_ids=request.prompt_token_ids, - prompt=request.prompt, mm_inputs=request.mm_inputs, mm_hashes=request.mm_hashes, mm_positions=request.mm_positions, @@ -121,3 +121,6 @@ class SchedulerOutput: structured_output_request_ids: dict[str, int] # the bitmask for the whole batch grammar_bitmask: Optional[npt.NDArray[np.int32]] + + # KV Cache Connector metadata. + kv_connector_metadata: Optional[KVConnectorMetadata] = None diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index a81574875a5c1b31e37906353c53f5f761714f1b..21711c9292f9fd7cb5d03941e33bcff02335004e 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -3,12 +3,14 @@ from __future__ import annotations import time -from collections import deque +from collections import defaultdict, deque from collections.abc import Iterable from typing import Optional, Union -from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig, - SpeculativeConfig) +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, @@ -34,20 +36,17 @@ class Scheduler(SchedulerInterface): def __init__( self, - scheduler_config: SchedulerConfig, - model_config: ModelConfig, - cache_config: CacheConfig, - lora_config: Optional[LoRAConfig], + vllm_config: VllmConfig, kv_cache_config: KVCacheConfig, structured_output_manager: StructuredOutputManager, - speculative_config: SpeculativeConfig = None, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, include_finished_set: bool = False, log_stats: bool = False, ) -> None: - self.scheduler_config = scheduler_config - self.cache_config = cache_config - self.lora_config = lora_config + self.vllm_config = vllm_config + self.scheduler_config = vllm_config.scheduler_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config self.kv_cache_config = kv_cache_config self.log_stats = log_stats self.structured_output_manager = structured_output_manager @@ -64,13 +63,17 @@ class Scheduler(SchedulerInterface): self.scheduler_config.max_num_batched_tokens self.max_model_len = self.scheduler_config.max_model_len - # Create the KV cache manager. - self.kv_cache_manager = KVCacheManager( - kv_cache_config=kv_cache_config, - max_model_len=self.max_model_len, - enable_caching=cache_config.enable_prefix_caching, - caching_hash_algo=self.cache_config.prefix_caching_hash_algo, - log_stats=self.log_stats) + # Create KVConnector for the Scheduler. Note that each Worker + # will have a corresponding KVConnector with Role=WORKER. + # KV Connector pushes/pull of remote KVs for P/D and offloading. + self.connector = None + if self.vllm_config.kv_transfer_config is not None: + self.connector = KVConnectorFactory.create_connector_v1( + config=self.vllm_config, role=KVConnectorRole.SCHEDULER) + + num_gpu_blocks = self.cache_config.num_gpu_blocks + assert num_gpu_blocks is not None and num_gpu_blocks > 0 + self.block_size = self.cache_config.block_size # req_id -> Request @@ -78,9 +81,6 @@ class Scheduler(SchedulerInterface): # Priority queues for requests. self.waiting: deque[Request] = deque() self.running: list[Request] = [] - # The requests that have been scheduled and are being executed - # by the executor. - self.scheduled_req_ids: set[str] = set() # The request IDs that are finished in between the previous and the # current steps. This is used to notify the workers about the finished @@ -90,8 +90,9 @@ class Scheduler(SchedulerInterface): # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # them at each scheduling step. - # Request id -> CachedRequestData - self._cached_reqs_data: dict[str, CachedRequestData] = {} + # Request id -> deque of CachedRequestData + self._cached_reqs_data: dict[ + str, deque[CachedRequestData]] = defaultdict(deque) # Encoder-related. # Calculate encoder cache size if applicable @@ -99,8 +100,8 @@ class Scheduler(SchedulerInterface): # This can be changed when we make encoder cache for embedding caching # across requests. encoder_compute_budget, encoder_cache_size = compute_encoder_budget( - model_config=model_config, - scheduler_config=scheduler_config, + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, mm_registry=mm_registry, ) @@ -114,10 +115,24 @@ class Scheduler(SchedulerInterface): self.encoder_cache_manager = EncoderCacheManager( cache_size=encoder_cache_size) - self.num_lookahead_tokens = 0 - if speculative_config and speculative_config.method == "eagle": - self.num_lookahead_tokens = \ - speculative_config.num_speculative_tokens + speculative_config = vllm_config.speculative_config + + self.use_eagle = False + self.num_spec_tokens = self.num_lookahead_tokens = 0 + if speculative_config: + self.num_spec_tokens = speculative_config.num_speculative_tokens + if speculative_config.use_eagle(): + self.use_eagle = True + self.num_lookahead_tokens = self.num_spec_tokens + + # Create the KV cache manager. + self.kv_cache_manager = KVCacheManager( + kv_cache_config=kv_cache_config, + max_model_len=self.max_model_len, + enable_caching=self.cache_config.enable_prefix_caching, + caching_hash_algo=self.cache_config.prefix_caching_hash_algo, + use_eagle=self.use_eagle, + log_stats=self.log_stats) def schedule(self) -> SchedulerOutput: # NOTE(woosuk) on the scheduling algorithm: @@ -160,10 +175,6 @@ class Scheduler(SchedulerInterface): req_index = 0 while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] - if request.request_id in self.scheduled_req_ids: - # This request has already been scheduled. - req_index += 1 - continue num_new_tokens = (request.num_tokens_with_spec - request.num_computed_tokens) @@ -172,26 +183,35 @@ class Scheduler(SchedulerInterface): num_new_tokens = ( self.scheduler_config.long_prefill_token_threshold) num_new_tokens = min(num_new_tokens, token_budget) - assert num_new_tokens > 0 + + # Make sure the input position does not exceed the max model len. + # This is necessary when using spec decoding. + num_new_tokens = min( + num_new_tokens, + self.max_model_len - request.num_computed_tokens) # Schedule encoder inputs. + encoder_inputs_to_schedule = None + new_encoder_budget = encoder_budget if request.has_encoder_inputs: (encoder_inputs_to_schedule, num_new_tokens, new_encoder_budget) = self._try_schedule_encoder_inputs( request, request.num_computed_tokens, num_new_tokens, encoder_budget) - if num_new_tokens == 0: - # The request cannot be scheduled because the encoder budget - # or the encoder cache is exhausted. - # NOTE(woosuk): By using `continue` instead of `break` here, - # we intentionally relax the strict FCFS scheduling policy - # to allow lower-priority requests to be scheduled when a - # higher-priority request is blocked by encoder constraints. - req_index += 1 - continue - else: - encoder_inputs_to_schedule = None - new_encoder_budget = encoder_budget + + if num_new_tokens == 0: + # The request cannot be scheduled because one of the following + # reasons: + # 1. No new tokens to schedule. This may happen when PP>1 and + # we have already scheduled all prompt tokens but they are + # not finished yet. + # 2. The encoder budget is exhausted. + # 3. The encoder cache is exhausted. + # NOTE(woosuk): Here, by doing `continue` instead of `break`, + # we do not strictly follow the FCFS scheduling policy and + # allow the lower-priority requests to be scheduled. + req_index += 1 + continue while True: new_blocks = self.kv_cache_manager.allocate_slots( @@ -225,7 +245,6 @@ class Scheduler(SchedulerInterface): # Schedule the request. scheduled_running_reqs.append(request) - self.scheduled_req_ids.add(request.request_id) if request.use_structured_output: # PERF: in case of chunked prefill, # request might not include any new tokens. @@ -303,7 +322,18 @@ class Scheduler(SchedulerInterface): # Get already-cached tokens. computed_blocks, num_computed_tokens = \ - self.kv_cache_manager.get_computed_blocks(request) + self.kv_cache_manager.get_computed_blocks( + request) + + # Get externally-cached tokens if using a KVConnector. + num_external_tokens = ( + 0 if self.connector is None else + self.connector.get_num_new_matched_tokens( + request, num_computed_tokens)) + + # Total computed tokens (local + external). + num_computed_tokens += num_external_tokens + # Number of tokens to be scheduled. # We use `request.num_tokens` instead of # `request.num_prompt_tokens` to consider the resumed requests, @@ -330,18 +360,30 @@ class Scheduler(SchedulerInterface): new_encoder_budget = encoder_budget new_blocks = self.kv_cache_manager.allocate_slots( - request, num_new_tokens, computed_blocks) + request, + num_new_tokens + num_external_tokens, + computed_blocks, + num_lookahead_tokens=self.num_lookahead_tokens, + ) if new_blocks is None: # The request cannot be scheduled. break + # KVConnector: update internal state after allocation. + # This information is used to determine if a load is + # needed for this request. + if self.connector is not None: + self.connector.update_state_after_alloc( + request, + num_external_tokens, + ) + self.waiting.popleft() if request.use_structured_output: structured_output_request_ids[ request.request_id] = req_index req_index += 1 self.running.append(request) - self.scheduled_req_ids.add(request.request_id) if self.log_stats: request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp) @@ -443,6 +485,14 @@ class Scheduler(SchedulerInterface): grammar_bitmask=grammar_bitmask, ) + # NOTE(Kuntai): this function is designed for multiple purposes: + # 1. Plan the KV cache store + # 2. Wrap up all the KV cache load / save ops into an opaque object + # 3. Clear the internal states of the connector + if self.connector is not None: + meta = self.connector.build_connector_meta(scheduler_output) + scheduler_output.kv_connector_metadata = meta + # Advance the number of computed tokens for the request AFTER # the request is scheduled. # 1. The scheduler_output of the current step has to include the @@ -472,18 +522,21 @@ class Scheduler(SchedulerInterface): num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens new_token_ids = request.all_token_ids[ num_computed_tokens:num_computed_tokens + num_regular_tokens] - req_data = self._cached_reqs_data.get(request.request_id) - if req_data is not None: + + req_data_queue = self._cached_reqs_data.get(request.request_id) + if req_data_queue: + req_data = req_data_queue.popleft() req_data.resumed_from_preemption = resumed_from_preemption req_data.new_token_ids = new_token_ids req_data.new_block_ids = new_block_ids req_data.num_computed_tokens = num_computed_tokens else: + # No cached request data, or all cached request data has been + # used by the scheduled requests. req_data = CachedRequestData.from_request(request, resumed_from_preemption, new_token_ids, new_block_ids) - self._cached_reqs_data[request.request_id] = req_data return req_data def _try_schedule_encoder_inputs( @@ -508,7 +561,12 @@ class Scheduler(SchedulerInterface): If an encoder input cannot be scheduled due to cache or budget limitations, the method adjusts `num_new_tokens` to schedule only the decoder tokens up to just before the unschedulable encoder input. + + Note that num_computed_tokens includes both locally cached + blocks and externally cached blocks (via KVConnector). """ + if num_new_tokens == 0 or not request.has_encoder_inputs: + return [], num_new_tokens, encoder_budget encoder_inputs_to_schedule: list[int] = [] mm_positions = request.mm_positions assert mm_positions is not None @@ -676,10 +734,13 @@ class Scheduler(SchedulerInterface): # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors - self.scheduled_req_ids.remove(req_id) if not stopped: new_running.append(request) + # Return the cached request data to the queue so they can be reused. + for req_data in scheduler_output.scheduled_cached_reqs: + self._cached_reqs_data[req_data.req_id].append(req_data) + self.running = new_running engine_core_outputs = EngineCoreOutputs( outputs=outputs, @@ -722,7 +783,6 @@ class Scheduler(SchedulerInterface): if request.status == RequestStatus.RUNNING: self.running.remove(request) - self.scheduled_req_ids.discard(request.request_id) else: self.waiting.remove(request) request.status = finished_status @@ -743,10 +803,6 @@ class Scheduler(SchedulerInterface): def has_finished_requests(self) -> bool: return len(self.finished_req_ids) > 0 - def get_num_unscheduled_requests(self) -> int: - """Number of requests that are not being processed by the executor.""" - return self.get_num_unfinished_requests() - len(self.scheduled_req_ids) - def reset_prefix_cache(self) -> bool: return self.kv_cache_manager.reset_prefix_cache() @@ -756,11 +812,13 @@ class Scheduler(SchedulerInterface): ) -> Optional[SchedulerStats]: if not self.log_stats: return None + prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats() + assert prefix_cache_stats is not None return SchedulerStats( num_running_reqs=len(self.running), num_waiting_reqs=len(self.waiting), gpu_cache_usage=self.kv_cache_manager.usage, - prefix_cache_stats=self.kv_cache_manager.make_prefix_cache_stats(), + prefix_cache_stats=prefix_cache_stats, spec_decoding_stats=spec_decoding_stats, ) @@ -773,7 +831,8 @@ class Scheduler(SchedulerInterface): if not self.log_stats: return None if spec_decoding_stats is None: - spec_decoding_stats = SpecDecodingStats() - spec_decoding_stats.observe(num_draft_tokens=num_draft_tokens, - num_accepted_tokens=num_accepted_tokens) + spec_decoding_stats = SpecDecodingStats.new(self.num_spec_tokens) + spec_decoding_stats.observe_draft( + num_draft_tokens=num_draft_tokens, + num_accepted_tokens=num_accepted_tokens) return spec_decoding_stats diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 1264e43c79d9ee667abeccd483b026b9d6d622a8..0474669610cdcd2b456d62de26d4211121587a5c 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -49,9 +49,6 @@ class EngineCoreRequest( # due to circular imports and typing we have in data.py request_id: str - # NOTE(ywang96): original text prompt is needed when a request is added to - # Detokenizer, but set to None when it is added to EngineCoreClient. - prompt: Optional[str] prompt_token_ids: list[int] mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] mm_hashes: Optional[list[str]] @@ -61,6 +58,11 @@ class EngineCoreRequest( arrival_time: float lora_request: Optional[LoRARequest] + # Used in DP case to indicate which wave of requests this is expected to + # belong to, to cover a race condition where the request is sent before + # a wave finished notification is received. + current_wave: int = 0 + class EngineCoreEventType(enum.IntEnum): """The type of engine core request event.""" @@ -139,8 +141,12 @@ class EngineCoreOutputs( utility_output: Optional[UtilityOutput] = None finished_requests: Optional[set[str]] = None - # In DP case, used to signal that the engine is paused. - engine_paused: bool = False + # In DP case, used to signal that the current wave of requests + # has finished and the engines are paused. + wave_complete: Optional[int] = None + # In DP case, used to signal that a request was received for an + # "old" wave, so the next wave needs to be started in other engines. + start_wave: Optional[int] = None def __post_init__(self): if self.timestamp == 0.0: @@ -154,5 +160,7 @@ class EngineCoreRequestType(enum.Enum): """ ADD = b'\x00' ABORT = b'\x01' - START_DP = b'\x02' + START_DP_WAVE = b'\x02' UTILITY = b'\x03' + # Sentinel used within EngineCoreProc. + EXECUTOR_FAILED = b'\x04' diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index b77a6824cddbdcafba4a2abbe4366a137ae335aa..1334fb789aa4c9781a0d52642d454795af76317e 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -1,8 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 - import asyncio -import logging -import os from collections.abc import AsyncGenerator, Mapping from copy import copy from typing import Optional, Union @@ -26,16 +23,17 @@ from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext -from vllm.utils import Device, cdiv, kill_process_tree +from vllm.utils import Device, cdiv from vllm.v1.engine import EngineCoreRequest -from vllm.v1.engine.core_client import EngineCoreClient +from vllm.v1.engine.core_client import AsyncMPClient, DPAsyncMPClient +from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError from vllm.v1.engine.output_processor import (OutputProcessor, RequestOutputCollector) from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor -from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger, - StatLoggerBase) +from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory, + setup_default_loggers) from vllm.v1.metrics.stats import IterationStats, SchedulerStats logger = init_logger(__name__) @@ -53,7 +51,28 @@ class AsyncLLM(EngineClient): use_cached_outputs: bool = False, log_requests: bool = True, start_engine_loop: bool = True, + stat_loggers: Optional[list[StatLoggerFactory]] = None, ) -> None: + """ + Create an AsyncLLM. + + Args: + vllm_config: global configuration. + executor_class: an Executor impl, e.g. MultiprocExecutor. + log_stats: Whether to log stats. + usage_context: Usage context of the LLM. + mm_registry: Multi-modal registry. + use_cached_outputs: Whether to use cached outputs. + log_requests: Whether to log requests. + start_engine_loop: Whether to start the engine loop. + stat_loggers: customized stat loggers for the engine. + If not provided, default stat loggers will be used. + PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE + IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE. + + Returns: + None + """ if not envs.VLLM_USE_V1: raise ValueError( "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. " @@ -61,31 +80,24 @@ class AsyncLLM(EngineClient): "AsyncLLMEngine.from_vllm_config(...) or explicitly set " "VLLM_USE_V1=0 or 1 and report this issue on Github.") - assert start_engine_loop - self.model_config = vllm_config.model_config - + self.vllm_config = vllm_config self.log_requests = log_requests self.log_stats = log_stats # Set up stat loggers; independent set for each DP rank. - self.stat_loggers: list[list[StatLoggerBase]] = [] - if self.log_stats: - for i in range(vllm_config.parallel_config.data_parallel_size): - loggers: list[StatLoggerBase] = [] - if logger.isEnabledFor(logging.INFO): - loggers.append(LoggingStatLogger(engine_index=i)) - loggers.append( - PrometheusStatLogger(vllm_config, engine_index=i)) - self.stat_loggers.append(loggers) + self.stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers( + vllm_config=vllm_config, + log_stats=self.log_stats, + engine_num=vllm_config.parallel_config.data_parallel_size, + custom_stat_loggers=stat_loggers, + ) # Tokenizer (+ ensure liveness if running in another process). self.tokenizer = init_tokenizer_from_configs( model_config=vllm_config.model_config, scheduler_config=vllm_config.scheduler_config, - parallel_config=vllm_config.parallel_config, lora_config=vllm_config.lora_config) - self.tokenizer.ping() # Processor (converts Inputs --> EngineCoreRequests). self.processor = Processor( @@ -99,15 +111,23 @@ class AsyncLLM(EngineClient): log_stats=self.log_stats) # EngineCore (starts the engine in background process). - self.engine_core = EngineCoreClient.make_client( - multiprocess_mode=True, - asyncio_mode=True, + core_client_class = AsyncMPClient if ( + vllm_config.parallel_config.data_parallel_size + == 1) else DPAsyncMPClient + + self.engine_core = core_client_class( vllm_config=vllm_config, executor_class=executor_class, log_stats=self.log_stats, ) self.output_handler: Optional[asyncio.Task] = None + try: + # Start output handler eagerly if we are in the asyncio eventloop. + asyncio.get_running_loop() + self._run_output_handler() + except RuntimeError: + pass @classmethod def from_vllm_config( @@ -115,7 +135,7 @@ class AsyncLLM(EngineClient): vllm_config: VllmConfig, start_engine_loop: bool = True, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[dict[str, StatLoggerBase]] = None, + stat_loggers: Optional[list[StatLoggerFactory]] = None, disable_log_requests: bool = False, disable_log_stats: bool = False, ) -> "AsyncLLM": @@ -126,17 +146,12 @@ class AsyncLLM(EngineClient): "AsyncLLMEngine.from_vllm_config(...) or explicitly set " "VLLM_USE_V1=0 or 1 and report this issue on Github.") - # FIXME(rob): refactor VllmConfig to include the StatLoggers - # include StatLogger in the Oracle decision. - if stat_loggers is not None: - raise ValueError("Custom StatLoggers are not yet supported on V1. " - "Explicitly set VLLM_USE_V1=0 to disable V1.") - # Create the LLMEngine. return cls( vllm_config=vllm_config, executor_class=Executor.get_class(vllm_config), start_engine_loop=start_engine_loop, + stat_loggers=stat_loggers, log_requests=not disable_log_requests, log_stats=not disable_log_stats, usage_context=usage_context, @@ -148,6 +163,7 @@ class AsyncLLM(EngineClient): engine_args: AsyncEngineArgs, start_engine_loop: bool = True, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[list[StatLoggerFactory]] = None, ) -> "AsyncLLM": """Create an AsyncLLM from the EngineArgs.""" @@ -163,8 +179,12 @@ class AsyncLLM(EngineClient): log_stats=not engine_args.disable_log_stats, start_engine_loop=start_engine_loop, usage_context=usage_context, + stat_loggers=stat_loggers, ) + def __del__(self): + self.shutdown() + def shutdown(self): """Shutdown, cleaning up the background proc and IPC.""" @@ -187,6 +207,9 @@ class AsyncLLM(EngineClient): ) -> RequestOutputCollector: """Add new request to the AsyncLLM.""" + if self.errored: + raise EngineDeadError() + assert isinstance(params, SamplingParams), \ "Pooling is not supported in V1" @@ -194,14 +217,12 @@ class AsyncLLM(EngineClient): queue = RequestOutputCollector(output_kind=params.output_kind) # Convert Input --> Request. - request = self.processor.process_inputs(request_id, prompt, params, - arrival_time, lora_request, - trace_headers, - prompt_adapter_request, - priority) + prompt_str, request = self.processor.process_inputs( + request_id, prompt, params, arrival_time, lora_request, + trace_headers, prompt_adapter_request, priority) if params.n == 1: - await self._add_request(request, None, 0, queue) + await self._add_request(request, prompt_str, None, 0, queue) return queue # Fan out child requests (for n>1). @@ -211,15 +232,18 @@ class AsyncLLM(EngineClient): child_request = request if idx == params.n - 1 else copy(request) child_request.request_id = request_id child_request.sampling_params = params - await self._add_request(child_request, parent_request, idx, queue) + await self._add_request(child_request, prompt_str, parent_request, + idx, queue) return queue async def _add_request(self, request: EngineCoreRequest, + prompt: Optional[str], parent_req: Optional[ParentRequest], index: int, queue: RequestOutputCollector): # Add the request to OutputProcessor (this process). - self.output_processor.add_request(request, parent_req, index, queue) + self.output_processor.add_request(request, prompt, parent_req, index, + queue) # Add the EngineCoreRequest to EngineCore (separate process). await self.engine_core.add_request_async(request) @@ -261,9 +285,7 @@ class AsyncLLM(EngineClient): # We start the output_handler on the first call to generate() so # we can call __init__ before the event loop, which enables us # to handle startup failure gracefully in the OpenAI server. - if self.output_handler is None: - self.output_handler = asyncio.create_task( - self._run_output_handler()) + self._run_output_handler() q = await self.add_request( request_id, @@ -288,62 +310,96 @@ class AsyncLLM(EngineClient): finished = out.finished yield out - # If the request is disconnected by the client, the - # generate() task will be canceled. So, we abort the - # request if we end up here. + # If the request is disconnected by the client, generate() + # is cancelled. So, we abort the request if we end up here. except asyncio.CancelledError: await self.abort(request_id) + if self.log_requests: + logger.info("Request %s aborted.", request_id) raise - async def _run_output_handler(self): - """Background loop: pulls from EngineCore and pushes to AsyncStreams.""" + # Engine is dead. Do not abort since we shut down. + except EngineDeadError: + if self.log_requests: + logger.info("Request %s failed (engine dead).", request_id) + raise - try: - while True: - # 1) Pull EngineCoreOutputs from the EngineCore. - outputs = await self.engine_core.get_output_async() - num_outputs = len(outputs.outputs) - - iteration_stats = IterationStats() if ( - self.log_stats and num_outputs) else None - - # Split outputs into chunks of at most - # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the - # event loop for too long. - if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: - slices = (outputs.outputs, ) - else: - slices = np.array_split( - outputs.outputs, - cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE)) - - for i, outputs_slice in enumerate(slices): - # 2) Process EngineCoreOutputs. - processed_outputs = self.output_processor.process_outputs( - outputs_slice, outputs.timestamp, iteration_stats) - # NOTE: RequestOutputs are pushed to their queues. - assert not processed_outputs.request_outputs - - # Allow other asyncio tasks to run between chunks - if i + 1 < len(slices): - await asyncio.sleep(0) - - # 3) Abort any reqs that finished due to stop strings. - await self.engine_core.abort_requests_async( - processed_outputs.reqs_to_abort) - - # 4) Logging. - # TODO(rob): make into a coroutine and launch it in - # background thread once Prometheus overhead is non-trivial. - self._record_stats( - engine_index=outputs.engine_index, - scheduler_stats=outputs.scheduler_stats, - iteration_stats=iteration_stats, - ) + # Request validation error. + except ValueError: + if self.log_requests: + logger.info("Request %s failed (bad request).", request_id) + raise + # Unexpected error in the generate() task (possibly recoverable). except Exception as e: - logger.exception("EngineCore output handler hit an error: %s", e) - kill_process_tree(os.getpid()) + await self.abort(request_id) + if self.log_requests: + logger.info("Request %s failed.", request_id) + raise EngineGenerateError() from e + + def _run_output_handler(self): + """Background loop: pulls from EngineCore and pushes to AsyncStreams.""" + + if self.output_handler is not None: + return + + # Ensure that the task doesn't have a circular ref back to the AsyncLLM + # object, or else it won't be garbage collected and cleaned up properly. + engine_core = self.engine_core + output_processor = self.output_processor + log_stats = self.log_stats + stat_loggers = self.stat_loggers if log_stats else None + + async def output_handler(): + try: + while True: + # 1) Pull EngineCoreOutputs from the EngineCore. + outputs = await engine_core.get_output_async() + num_outputs = len(outputs.outputs) + + iteration_stats = IterationStats() if ( + log_stats and num_outputs) else None + + # Split outputs into chunks of at most + # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the + # event loop for too long. + if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: + slices = (outputs.outputs, ) + else: + slices = np.array_split( + outputs.outputs, + cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE)) + + for i, outputs_slice in enumerate(slices): + # 2) Process EngineCoreOutputs. + processed_outputs = output_processor.process_outputs( + outputs_slice, outputs.timestamp, iteration_stats) + # NOTE: RequestOutputs are pushed to their queues. + assert not processed_outputs.request_outputs + + # Allow other asyncio tasks to run between chunks + if i + 1 < len(slices): + await asyncio.sleep(0) + + # 3) Abort any reqs that finished due to stop strings. + await engine_core.abort_requests_async( + processed_outputs.reqs_to_abort) + + # 4) Logging. + # TODO(rob): make into a coroutine and launch it in + # background thread once Prometheus overhead is non-trivial. + if stat_loggers: + assert outputs.scheduler_stats is not None + AsyncLLM._record_stats( + stat_loggers[outputs.engine_index], + scheduler_stats=outputs.scheduler_stats, + iteration_stats=iteration_stats, + ) + except Exception as e: + logger.exception("AsyncLLM output_handler failed.") + output_processor.propagate_error(e) + + self.output_handler = asyncio.create_task(output_handler()) async def abort(self, request_id: str) -> None: """Abort RequestId in OutputProcessor and EngineCore.""" @@ -354,17 +410,15 @@ class AsyncLLM(EngineClient): if self.log_requests: logger.info("Aborted request %s.", request_id) + @staticmethod def _record_stats( - self, - scheduler_stats: Optional[SchedulerStats], + stat_loggers: list[StatLoggerBase], + scheduler_stats: SchedulerStats, iteration_stats: Optional[IterationStats], - engine_index: int = 0, ): - if not self.log_stats: - return - - assert scheduler_stats is not None - for stat_logger in self.stat_loggers[engine_index]: + """static so that it can be used from the output_handler task + without a circular ref to AsyncLLM.""" + for stat_logger in stat_loggers: stat_logger.record(scheduler_stats=scheduler_stats, iteration_stats=iteration_stats) @@ -379,6 +433,9 @@ class AsyncLLM(EngineClient): ): raise ValueError("Not Supported on V1 yet.") + async def get_vllm_config(self) -> VllmConfig: + return self.vllm_config + async def get_model_config(self) -> ModelConfig: return self.model_config @@ -446,18 +503,30 @@ class AsyncLLM(EngineClient): """Prevent an adapter from being evicted.""" return await self.engine_core.pin_lora_async(lora_id) + async def collective_rpc(self, + method: str, + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None): + """ + Perform a collective RPC call to the given path. + """ + return await self.engine_core.collective_rpc_async( + method, timeout, args, kwargs) + @property def is_running(self) -> bool: - return True + # Is None before the loop is started. + return self.output_handler is None or not self.output_handler.done() @property def is_stopped(self) -> bool: - return False + return self.errored @property def errored(self) -> bool: - return False + return self.engine_core.resources.engine_dead or not self.is_running @property def dead_error(self) -> BaseException: - return Exception() # TODO: implement + return EngineDeadError() diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 702dc19d7d632a256ccaa8395991bbc2da450de5..5b5532d37048ad92b403e26831693268e80b99e2 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -5,15 +5,14 @@ import signal import sys import threading import time +from collections import deque from concurrent.futures import Future from inspect import isclass, signature from logging import DEBUG from typing import Any, Callable, Optional, TypeVar, Union import msgspec -import psutil import zmq -import zmq.asyncio from vllm.config import ParallelConfig, VllmConfig from vllm.distributed import stateless_destroy_torch_distributed_process_group @@ -22,8 +21,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) -from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname, - zmq_socket_ctx) +from vllm.utils import resolve_obj_by_qualname, zmq_socket_ctx from vllm.v1.core.kv_cache_utils import (get_kv_cache_config, unify_kv_cache_configs) from vllm.v1.core.sched.interface import SchedulerInterface @@ -50,12 +48,11 @@ _R = TypeVar('_R') # Return type for collective_rpc class EngineCore: """Inner loop of vLLM's Engine.""" - def __init__( - self, - vllm_config: VllmConfig, - executor_class: type[Executor], - log_stats: bool, - ): + def __init__(self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + executor_fail_callback: Optional[Callable] = None): assert vllm_config.model_config.runner_type != "pooling" logger.info("Initializing a V1 LLM engine (v%s) with config: %s", @@ -65,6 +62,9 @@ class EngineCore: # Setup Model. self.model_executor = executor_class(vllm_config) + if executor_fail_callback is not None: + self.model_executor.register_failure_callback( + executor_fail_callback) # Setup KV Caches and update CacheConfig after profiling. num_gpu_blocks, num_cpu_blocks, kv_cache_config = \ @@ -93,12 +93,8 @@ class EngineCore: vllm_config.scheduler_config.scheduler_cls) self.scheduler: SchedulerInterface = Scheduler( - scheduler_config=vllm_config.scheduler_config, - model_config=vllm_config.model_config, - cache_config=vllm_config.cache_config, - lora_config=vllm_config.lora_config, + vllm_config=vllm_config, kv_cache_config=kv_cache_config, - speculative_config=vllm_config.speculative_config, structured_output_manager=self.structured_output_manager, include_finished_set=vllm_config.parallel_config.data_parallel_size > 1, @@ -215,10 +211,10 @@ class EngineCore: Note that if nothing to output in this step, None is returned. The execution flow is as follows: - 1. Try to schedule a new batch if there are unscheduled requests - and the job queue is not full. If a new batch is scheduled, directly - return an empty engine core output. In other words, we won't check - and return model outputs before the batch queue is full. + 1. Try to schedule a new batch if the batch queue is not full. + If a new batch is scheduled, directly return an empty engine core + output. In other words, fulfilling the batch queue has a higher priority + than getting model outputs. 2. If there is no new scheduled batch, meaning that the batch queue is full or no other requests can be scheduled, we block until the first batch in the job queue is finished. @@ -228,10 +224,10 @@ class EngineCore: engine_core_outputs = None scheduler_output = None - # If there are unscheduled requests and the job queue - # is not full, schedule a new batch. Note that this is not blocking. - if (self.scheduler.get_num_unscheduled_requests() > 0 - and not self.batch_queue.full()): + # Try to schedule a new batch if the batch queue is not full, but + # the scheduler may return an empty batch if all requests are scheduled. + # Note that this is not blocking. + if not self.batch_queue.full(): scheduler_output = self.scheduler.schedule() if scheduler_output.total_num_scheduled_tokens > 0: future = self.model_executor.execute_model(scheduler_output) @@ -243,6 +239,10 @@ class EngineCore: # If no more requests can be scheduled and the job queue is not empty, # block until the first batch in the job queue is finished. + # TODO(comaniac): Ideally we should peek the first batch in the + # job queue to check if it's finished before scheduling a new batch, + # but peeking the first element in a queue is not thread-safe, + # so we need more work. if not scheduled_batch and not self.batch_queue.empty(): future, scheduler_output = self.batch_queue.get_nowait() # Blocking until the first result is available. @@ -254,7 +254,9 @@ class EngineCore: return engine_core_outputs def shutdown(self): - self.model_executor.shutdown() + self.structured_output_manager.clear_backend() + if self.model_executor: + self.model_executor.shutdown() def profile(self, is_start: bool = True): self.model_executor.profile(is_start) @@ -308,6 +310,8 @@ class EngineCore: class EngineCoreProc(EngineCore): """ZMQ-wrapper for running EngineCore in background process.""" + ENGINE_CORE_DEAD = b'ENGINE_CORE_DEAD' + def __init__( self, input_path: str, @@ -317,27 +321,33 @@ class EngineCoreProc(EngineCore): log_stats: bool, engine_index: int = 0, ): - super().__init__(vllm_config, executor_class, log_stats) + input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]() + + executor_fail_callback = lambda: input_queue.put_nowait( + (EngineCoreRequestType.EXECUTOR_FAILED, b'')) + + super().__init__(vllm_config, executor_class, log_stats, + executor_fail_callback) self.step_fn = (self.step if self.batch_queue is None else self.step_with_batch_queue) - - self.global_unfinished_reqs = False + self.engines_running = False # Background Threads and Queues for IO. These enable us to # overlap ZMQ socket IO with GPU since they release the GIL, # and to overlap some serialization/deserialization with the # model forward pass. # Threads handle Socket <-> Queues and core_busy_loop uses Queue. - self.input_queue: queue.Queue[tuple[EngineCoreRequestType, - Any]] = queue.Queue() - self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue() + self.input_queue = input_queue + self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]() threading.Thread(target=self.process_input_socket, args=(input_path, engine_index), daemon=True).start() - threading.Thread(target=self.process_output_socket, - args=(output_path, engine_index), - daemon=True).start() + self.output_thread = threading.Thread( + target=self.process_output_socket, + args=(output_path, engine_index), + daemon=True) + self.output_thread.start() @staticmethod def run_engine_core(*args, @@ -364,7 +374,6 @@ class EngineCoreProc(EngineCore): signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) - parent_process = psutil.Process().parent() engine_core: Optional[EngineCoreProc] = None try: parallel_config: ParallelConfig = kwargs[ @@ -380,13 +389,15 @@ class EngineCoreProc(EngineCore): engine_core.run_busy_loop() except SystemExit: - logger.debug("EngineCore interrupted.") - - except Exception: - traceback = get_exception_traceback() - logger.error("EngineCore hit an exception: %s", traceback) - parent_process.send_signal(signal.SIGUSR1) - + logger.debug("EngineCore exiting.") + raise + except Exception as e: + if engine_core is None: + logger.exception("EngineCore failed to start.") + else: + logger.exception("EngineCore encountered a fatal error.") + engine_core._send_engine_dead() + raise e finally: if engine_core is not None: engine_core.shutdown() @@ -405,8 +416,7 @@ class EngineCoreProc(EngineCore): """Exits when an engine step needs to be performed.""" waited = False - while not self.global_unfinished_reqs and not ( - self.scheduler.has_requests()): + while not self.engines_running and not (self.scheduler.has_requests()): if logger.isEnabledFor(DEBUG) and self.input_queue.empty(): logger.debug("EngineCore waiting for work.") waited = True @@ -414,10 +424,7 @@ class EngineCoreProc(EngineCore): self._handle_client_request(*req) if waited: - logger.debug( - "EngineCore loop active - local unfinished: %s, finished: %s.", - self.scheduler.has_unfinished_requests(), - self.scheduler.has_finished_requests()) + logger.debug("EngineCore loop active.") # Handle any more client requests. while not self.input_queue.empty(): @@ -441,10 +448,6 @@ class EngineCoreProc(EngineCore): self.add_request(request) elif request_type == EngineCoreRequestType.ABORT: self.abort_requests(request) - elif request_type == EngineCoreRequestType.START_DP: - if not self.global_unfinished_reqs: - logger.debug("EngineCore starting idle loop.") - self.global_unfinished_reqs = True elif request_type == EngineCoreRequestType.UTILITY: call_id, method_name, args = request output = UtilityOutput(call_id) @@ -458,6 +461,11 @@ class EngineCoreProc(EngineCore): f" failed: {str(e)}") self.output_queue.put_nowait( EngineCoreOutputs(utility_output=output)) + elif request_type == EngineCoreRequestType.EXECUTOR_FAILED: + raise RuntimeError("Executor failed.") + else: + logger.error("Unrecognized input request type encountered: %s", + request_type) @staticmethod def _convert_msgspec_args(method, args): @@ -473,6 +481,18 @@ class EngineCoreProc(EngineCore): and not isinstance(v, p.annotation) else v for v, p in zip(args, arg_types)) + def _send_engine_dead(self): + """Send EngineDead status to the EngineCoreClient.""" + + # Put ENGINE_CORE_DEAD in the queue. + self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD) + + # Wait until msg sent by the daemon before shutdown. + self.output_thread.join(timeout=5.0) + if self.output_thread.is_alive(): + logger.fatal("vLLM shutdown signal from EngineCore failed " + "to send. Please report this issue.") + def process_input_socket(self, input_path: str, engine_index: int): """Input socket IO thread.""" @@ -508,18 +528,40 @@ class EngineCoreProc(EngineCore): # Msgpack serialization encoding. encoder = MsgpackEncoder() - # Reuse send buffer. - buffer = bytearray() - - with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket: + # Send buffers to reuse. + reuse_buffers: list[bytearray] = [] + # Keep references to outputs and buffers until zmq is finished + # with them (outputs may contain tensors/np arrays whose + # backing buffers were extracted for zero-copy send). + pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]() + + # We must set linger to ensure the ENGINE_CORE_DEAD + # message is sent prior to closing the socket. + with zmq_socket_ctx(output_path, zmq.constants.PUSH, + linger=4000) as socket: while True: outputs = self.output_queue.get() + if outputs == EngineCoreProc.ENGINE_CORE_DEAD: + socket.send(outputs, copy=False) + break + assert not isinstance(outputs, bytes) outputs.engine_index = engine_index - buffers = encoder.encode_into(outputs, buffer) - socket.send_multipart(buffers, copy=False) + # Reclaim buffers that zmq is finished with. + while pending and pending[-1][0].done: + reuse_buffers.append(pending.pop()[2]) -ENGINE_PAUSED_OUTPUTS = EngineCoreOutputs(engine_paused=True) + buffer = reuse_buffers.pop() if reuse_buffers else bytearray() + buffers = encoder.encode_into(outputs, buffer) + tracker = socket.send_multipart(buffers, + copy=False, + track=True) + if not tracker.done: + ref = outputs if len(buffers) > 1 else None + pending.appendleft((tracker, ref, buffer)) + elif len(reuse_buffers) < 2: + # Keep at most 2 buffers to reuse. + reuse_buffers.append(buffer) class DPEngineCoreProc(EngineCoreProc): @@ -558,7 +600,9 @@ class DPEngineCoreProc(EngineCoreProc): for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) * tp_size)) + self.local_dp_rank = local_dp_rank self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() + self.current_wave = 0 # Initialize the engine after setting up environment. super().__init__(input_path, output_path, vllm_config, executor_class, @@ -573,6 +617,31 @@ class DPEngineCoreProc(EngineCoreProc): if dp_group := getattr(self, "dp_group", None): stateless_destroy_torch_distributed_process_group(dp_group) + def add_request(self, request: EngineCoreRequest): + if request.current_wave != self.current_wave: + if request.current_wave > self.current_wave: + self.current_wave = request.current_wave + elif not self.engines_running: + # Request received for an already-completed wave, notify + # front-end that we need to start the next one. + self.output_queue.put_nowait( + EngineCoreOutputs(start_wave=self.current_wave)) + + super().add_request(request) + + def _handle_client_request(self, request_type: EngineCoreRequestType, + request: Any) -> None: + if request_type == EngineCoreRequestType.START_DP_WAVE: + new_wave: int = request + if new_wave >= self.current_wave: + self.current_wave = new_wave + if not self.engines_running: + logger.debug("EngineCore starting idle loop for wave %d.", + new_wave) + self.engines_running = True + else: + super()._handle_client_request(request_type, request) + def run_busy_loop(self): """Core busy loop of the EngineCore for data parallel case.""" @@ -599,7 +668,7 @@ class DPEngineCoreProc(EngineCoreProc): # up-to-date state is returned in the engine outputs. self._process_engine_step() - if not self.global_unfinished_reqs: + if not self.engines_running: # All engines are idle. continue @@ -608,18 +677,23 @@ class DPEngineCoreProc(EngineCoreProc): self.execute_dummy_batch() # 3) All-reduce operation to determine global unfinished reqs. - self.global_unfinished_reqs = self._has_global_unfinished_reqs( + self.engines_running = self._has_global_unfinished_reqs( local_unfinished_reqs) - if not self.global_unfinished_reqs: - # Notify client that we are pausing the loop. - self.output_queue.put_nowait(ENGINE_PAUSED_OUTPUTS) + if not self.engines_running: + if self.local_dp_rank == 0: + # Notify client that we are pausing the loop. + logger.debug("Wave %d finished, pausing engine loop.", + self.current_wave) + self.output_queue.put_nowait( + EngineCoreOutputs(wave_complete=self.current_wave)) + self.current_wave += 1 def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: - # Optimization - only perform finish-sync all-reduce every 16 steps. + # Optimization - only perform finish-sync all-reduce every 24 steps. self.counter += 1 - if self.counter != 16: + if self.counter != 24: return True self.counter = 0 diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index b43c9e5b0f03d4a7aca2cd67d00c116704a297df..dd51909961964f685605045716811b5d6785b3e7 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -1,14 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 - import asyncio -import os +import contextlib import queue -import signal -import threading import uuid import weakref from abc import ABC, abstractmethod -from collections.abc import Awaitable +from collections import deque +from collections.abc import Awaitable, Sequence from concurrent.futures import Future from dataclasses import dataclass, field from threading import Thread @@ -21,10 +19,11 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.utils import (get_open_zmq_inproc_path, get_open_zmq_ipc_path, - kill_process_tree, make_zmq_socket) + make_zmq_socket) from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, UtilityOutput) from vllm.v1.engine.core import EngineCore, EngineCoreProc +from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.executor.abstract import Executor from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr from vllm.v1.utils import BackgroundProcHandle @@ -305,14 +304,23 @@ class BackgroundResources: core_engines: list[CoreEngine] = field(default_factory=list) output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None + output_queue_task: Optional[asyncio.Task] = None shutdown_path: Optional[str] = None + # Set if any of the engines are dead. Here so that the output + # processing threads can access it without holding a ref to the client. + engine_dead: bool = False + def __call__(self): """Clean up background resources.""" + self.engine_dead = True for core_engine in self.core_engines: core_engine.close() + if self.output_queue_task is not None: + self.output_queue_task.cancel() + # ZMQ context termination can hang if the sockets # aren't explicitly closed first. if self.output_socket is not None: @@ -327,6 +335,12 @@ class BackgroundResources: # Send shutdown signal. shutdown_sender.send(b'') + def validate_alive(self, frames: Sequence[zmq.Frame]): + if len(frames) == 1 and (frames[0].buffer + == EngineCoreProc.ENGINE_CORE_DEAD): + self.engine_dead = True + raise EngineDeadError() + class MPClient(EngineCoreClient): """ @@ -348,27 +362,6 @@ class MPClient(EngineCoreClient): executor_class: type[Executor], log_stats: bool, ): - # The child processes will send SIGUSR1 when unrecoverable - # errors happen. We kill the process tree here so that the - # stack trace is very evident. - # TODO(rob): rather than killing the main process, we should - # figure out how to raise an AsyncEngineDeadError and - # handle at the API server level so we can return a better - # error code to the clients calling vLLM. - def sigusr1_handler(signum, frame): - logger.fatal("Got fatal signal from worker processes, shutting " - "down. See stack trace above for root cause issue.") - kill_process_tree(os.getpid()) - - if threading.current_thread() == threading.main_thread(): - signal.signal(signal.SIGUSR1, sigusr1_handler) - else: - logger.warning("SIGUSR1 handler not installed because we are not " - "running in the main thread. In this case the " - "forked engine process may not be killed when " - "an exception is raised, and you need to handle " - "the engine process shutdown manually.") - # Serialization setup. self.encoder = MsgpackEncoder() self.decoder = MsgpackDecoder(EngineCoreOutputs) @@ -378,32 +371,43 @@ class MPClient(EngineCoreClient): self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx # This will ensure resources created so far are closed - # when the client is garbage collected, even if an + # when the client is garbage collected, even if an # exception is raised mid-construction. self.resources = BackgroundResources(ctx=sync_ctx) self._finalizer = weakref.finalize(self, self.resources) - - # Paths and sockets for IPC. - self.output_path = get_open_zmq_ipc_path() - input_path = get_open_zmq_ipc_path() - self.input_socket = make_zmq_socket(self.ctx, - input_path, - zmq.ROUTER, - bind=True) - self.resources.input_socket = self.input_socket - - new_core_engine = lambda index, local_dp_rank=None: CoreEngine( - vllm_config, executor_class, log_stats, input_path, self. - output_path, index, local_dp_rank) - - # Start engine core process(es). - self._init_core_engines(vllm_config, new_core_engine, - self.resources.core_engines) - - # Wait for engine core process(es) to start. - self._wait_for_engine_startup() - - self.utility_results: dict[int, AnyFuture] = {} + success = False + try: + # Paths and sockets for IPC. + self.output_path = get_open_zmq_ipc_path() + input_path = get_open_zmq_ipc_path() + self.input_socket = make_zmq_socket(self.ctx, + input_path, + zmq.ROUTER, + bind=True) + self.resources.input_socket = self.input_socket + + new_core_engine = lambda index, local_dp_rank=None: CoreEngine( + vllm_config, executor_class, log_stats, input_path, self. + output_path, index, local_dp_rank) + + # Start engine core process(es). + self._init_core_engines(vllm_config, new_core_engine, + self.resources.core_engines) + + # Wait for engine core process(es) to start. + self._wait_for_engine_startup() + + self.utility_results: dict[int, AnyFuture] = {} + + # Request objects which may contain pytorch-allocated tensors + # that we need to keep references to until zmq is done with the + # underlying data. + self.pending_messages = deque[tuple[zmq.MessageTracker, Any]]() + + success = True + finally: + if not success: + self._finalizer() def _wait_for_engine_startup(self): # Get a sync handle to the socket which can be sync or async. @@ -443,16 +447,34 @@ class MPClient(EngineCoreClient): ) -> None: # Default case - single core engine. - dp_rank = vllm_config.parallel_config.data_parallel_rank - local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local core_engine = new_core_engine( - dp_rank, local_dp_rank if local_dp_rank is not None else dp_rank) + vllm_config.parallel_config.data_parallel_rank, + vllm_config.parallel_config.data_parallel_rank_local, + ) core_engines.append(core_engine) self.core_engine = core_engine def shutdown(self): + # Terminate background resources. self._finalizer() + def _format_exception(self, e: Exception) -> Exception: + """If errored, use EngineDeadError so root cause is clear.""" + return EngineDeadError( + suppress_context=True) if self.resources.engine_dead else e + + def ensure_alive(self): + if self.resources.engine_dead: + raise EngineDeadError() + + def add_pending_message(self, tracker: zmq.MessageTracker, msg: Any): + if not tracker.done: + self.pending_messages.appendleft((tracker, msg)) + + def free_pending_messages(self): + while self.pending_messages and self.pending_messages[-1][0].done: + self.pending_messages.pop() + def _process_utility_output(output: UtilityOutput, utility_results: dict[int, AnyFuture]): @@ -476,7 +498,7 @@ class SyncMPClient(MPClient): log_stats=log_stats, ) - self.outputs_queue: queue.Queue[EngineCoreOutputs] = queue.Queue() + self.outputs_queue = queue.Queue[Union[EngineCoreOutputs, Exception]]() # Ensure that the outputs socket processing thread does not have # a ref to the client which prevents gc. @@ -487,7 +509,8 @@ class SyncMPClient(MPClient): outputs_queue = self.outputs_queue shutdown_path = get_open_zmq_inproc_path() - self.resources.shutdown_path = shutdown_path + resources = self.resources + resources.shutdown_path = shutdown_path def process_outputs_socket(): shutdown_socket = ctx.socket(zmq.PAIR) @@ -506,12 +529,15 @@ class SyncMPClient(MPClient): break frames = out_socket.recv_multipart(copy=False) + resources.validate_alive(frames) outputs = decoder.decode(frames) if outputs.utility_output: _process_utility_output(outputs.utility_output, utility_results) else: outputs_queue.put_nowait(outputs) + except Exception as e: + outputs_queue.put_nowait(e) finally: # Close sockets. shutdown_socket.close(linger=0) @@ -524,13 +550,28 @@ class SyncMPClient(MPClient): self.output_queue_thread.start() def get_output(self) -> EngineCoreOutputs: - return self.outputs_queue.get() + # If an exception arises in process_outputs_socket task, + # it is forwarded to the outputs_queue so we can raise it + # from this (run_output_handler) task to shut down the server. + outputs = self.outputs_queue.get() + if isinstance(outputs, Exception): + raise self._format_exception(outputs) from None + return outputs def _send_input(self, request_type: EngineCoreRequestType, request: Any): + self.ensure_alive() + self.free_pending_messages() # (Identity, RequestType, SerializedRequest) msg = (self.core_engine.identity, request_type.value, *self.encoder.encode(request)) - self.input_socket.send_multipart(msg, copy=False) + + if len(msg) <= 3: + # No auxiliary buffers => no tensor backing buffers in request. + self.input_socket.send_multipart(msg, copy=False) + return + + tracker = self.input_socket.send_multipart(msg, copy=False, track=True) + self.add_pending_message(tracker, request) def call_utility(self, method: str, *args) -> Any: call_id = uuid.uuid1().int >> 64 @@ -542,13 +583,10 @@ class SyncMPClient(MPClient): return future.result() def add_request(self, request: EngineCoreRequest) -> None: - # NOTE: text prompt is not needed in the core engine as it has been - # tokenized. - request.prompt = None self._send_input(EngineCoreRequestType.ADD, request) def abort_requests(self, request_ids: list[str]) -> None: - if len(request_ids) > 0: + if request_ids and not self.resources.engine_dead: self._send_input(EngineCoreRequestType.ABORT, request_ids) def profile(self, is_start: bool = True) -> None: @@ -608,71 +646,111 @@ class AsyncMPClient(MPClient): log_stats=log_stats, ) - self.outputs_queue: Optional[asyncio.Queue[EngineCoreOutputs]] = None - self.queue_task: Optional[asyncio.Task] = None - - self.outputs_handler: Optional[Callable[ - [AsyncMPClient, EngineCoreOutputs], Awaitable[None]]] = None + self.outputs_queue = asyncio.Queue[Union[EngineCoreOutputs, + Exception]]() + try: + # If we are running in an asyncio event loop, start the queue task. + # Otherwise, it will be started lazily. If it is not started here, + # we could miss EXECUTOR_FAILED messages from engine core if they + # occur prior to any requests being sent. + asyncio.get_running_loop() + self._ensure_output_queue_task() + except RuntimeError: + pass def _ensure_output_queue_task(self): - if self.outputs_queue is not None: + resources = self.resources + if resources.output_queue_task is not None: return # Perform IO in separate task to parallelize as much as possible. # Avoid task having direct reference back to the client. - self.outputs_queue = asyncio.Queue() decoder = self.decoder utility_results = self.utility_results outputs_queue = self.outputs_queue - output_handler = self.outputs_handler + output_handler: Optional[Callable[[AsyncMPClient, EngineCoreOutputs], + Awaitable[None]]] = getattr( + self.__class__, + "process_engine_outputs", None) _self_ref = weakref.ref(self) if output_handler else None output_path = self.output_path output_socket = make_zmq_socket(self.ctx, output_path, zmq.constants.PULL) - self.resources.output_socket = output_socket + resources.output_socket = output_socket async def process_outputs_socket(): - while True: - frames = await output_socket.recv_multipart(copy=False) - outputs: EngineCoreOutputs = decoder.decode(frames) - if outputs.utility_output: - _process_utility_output(outputs.utility_output, - utility_results) - continue - - if output_handler is not None: - assert _self_ref is not None - _self = _self_ref() - if not _self: - # Client has been garbage collected, abort. - return - await output_handler(_self, outputs) - - if outputs.outputs or outputs.scheduler_stats: - outputs_queue.put_nowait(outputs) - - self.queue_task = asyncio.create_task(process_outputs_socket(), - name="EngineCoreOutputQueueTask") + try: + while True: + frames = await output_socket.recv_multipart(copy=False) + resources.validate_alive(frames) + outputs: EngineCoreOutputs = decoder.decode(frames) + if outputs.utility_output: + _process_utility_output(outputs.utility_output, + utility_results) + continue + + if output_handler is not None: + assert _self_ref is not None + _self = _self_ref() + if not _self: + # Client has been garbage collected, abort. + return + await output_handler(_self, outputs) + + if outputs.outputs or outputs.scheduler_stats: + outputs_queue.put_nowait(outputs) + except Exception as e: + outputs_queue.put_nowait(e) + + resources.output_queue_task = asyncio.create_task( + process_outputs_socket(), name="EngineCoreOutputQueueTask") async def get_output_async(self) -> EngineCoreOutputs: self._ensure_output_queue_task() + # If an exception arises in process_outputs_socket task, + # it is forwarded to the outputs_queue so we can raise it + # from this (run_output_handler) task to shut down the server. assert self.outputs_queue is not None - return await self.outputs_queue.get() + outputs = await self.outputs_queue.get() + if isinstance(outputs, Exception): + raise self._format_exception(outputs) from None + return outputs def _send_input(self, request_type: EngineCoreRequestType, request: Any, - engine: Optional[CoreEngine] = None) -> Awaitable[None]: + engine: Optional[CoreEngine] = None) -> Awaitable[Any]: + self.ensure_alive() if engine is None: engine = self.core_engine message = (request_type.value, *self.encoder.encode(request)) - return self._send_input_message(message, engine) + return self._send_input_message(message, engine, request) + + def _send_input_message(self, message: tuple[bytestr, + ...], engine: CoreEngine, + objects: Any) -> Awaitable[Any]: + """ + objects is a reference to retain until zmq is finished with the + buffers, in case they were extracted from tensors in the request. + """ + self.ensure_alive() + self.free_pending_messages() + + msg = (engine.identity, ) + message + if not objects or len(msg) <= 3: + # No auxiliary buffers => no tensor backing buffers in request. + return self.input_socket.send_multipart(msg, copy=False) + + future: asyncio.Future[zmq.MessageTracker] + future = self.input_socket.send_multipart(msg, copy=False, track=True) - def _send_input_message(self, message: tuple[bytestr, ...], - engine: CoreEngine) -> Awaitable[None]: - message = (engine.identity, ) + message - return self.input_socket.send_multipart(message, copy=False) + def add_pending(f: asyncio.Future[zmq.MessageTracker]): + with contextlib.suppress(BaseException): + self.add_pending_message(f.result(), objects) + + future.add_done_callback(add_pending) + return future async def call_utility_async(self, method: str, *args) -> Any: return await self._call_utility_async(method, @@ -686,19 +764,16 @@ class AsyncMPClient(MPClient): self.utility_results[call_id] = future message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode( (call_id, method, args))) - await self._send_input_message(message, engine) + await self._send_input_message(message, engine, args) self._ensure_output_queue_task() return await future async def add_request_async(self, request: EngineCoreRequest) -> None: - # NOTE: text prompt is not needed in the core engine as it has been - # tokenized. - request.prompt = None await self._send_input(EngineCoreRequestType.ADD, request) self._ensure_output_queue_task() async def abort_requests_async(self, request_ids: list[str]) -> None: - if len(request_ids) > 0: + if request_ids and not self.resources.engine_dead: await self._send_input(EngineCoreRequestType.ABORT, request_ids) async def profile_async(self, is_start: bool = True) -> None: @@ -754,18 +829,14 @@ class DPAsyncMPClient(AsyncMPClient): def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool): - super().__init__(vllm_config, executor_class, log_stats) - - assert len(self.core_engines) > 1 - - # Control message used for triggering dp idle mode loop. - self.start_dp_msg = (EngineCoreRequestType.START_DP.value, - *self.encoder.encode(None)) - self.num_engines_running = 0 + self.current_wave = 0 + self.engines_running = False self.reqs_in_flight: dict[str, CoreEngine] = {} - self.outputs_handler = DPAsyncMPClient.process_engine_outputs # type: ignore[assignment] + super().__init__(vllm_config, executor_class, log_stats) + + assert len(self.core_engines) > 1 def _init_core_engines( self, @@ -790,26 +861,23 @@ class DPAsyncMPClient(AsyncMPClient): ]))[0] async def add_request_async(self, request: EngineCoreRequest) -> None: - # NOTE: text prompt is not needed in the core engine as it has been - # tokenized. - request.prompt = None - - msg = (EngineCoreRequestType.ADD.value, *self.encoder.encode(request)) + request.current_wave = self.current_wave chosen_engine = self.get_core_engine_for_request() self.reqs_in_flight[request.request_id] = chosen_engine chosen_engine.num_reqs_in_flight += 1 - if self.num_engines_running >= len(self.core_engines): - await self._send_input_message(msg, chosen_engine) - else: + + to_await = self._send_input(EngineCoreRequestType.ADD, request, + chosen_engine) + if not self.engines_running: # Send request to chosen engine and dp start loop # control message to all other engines. - self.num_engines_running += len(self.core_engines) - await asyncio.gather(*[ - self._send_input_message( - msg if engine is chosen_engine else self.start_dp_msg, - engine) for engine in self.core_engines - ]) + self.engines_running = True + to_await = asyncio.gather( + to_await, # type: ignore[assignment] + *self._start_wave_coros(exclude_index=chosen_engine.index)) + + await to_await self._ensure_output_queue_task() @@ -824,21 +892,31 @@ class DPAsyncMPClient(AsyncMPClient): if engine := self.reqs_in_flight.pop(req_id, None): engine.num_reqs_in_flight -= 1 - if outputs.engine_paused: - assert self.num_engines_running >= 1 - self.num_engines_running -= 1 - if not self.num_engines_running and self.reqs_in_flight: - # If there are requests in flight here, they must have - # been sent after the engines paused. We must make - # sure to start the other engines: - self.num_engines_running = len(self.core_engines) - coros = [ - self._send_input_message(self.start_dp_msg, engine) - for engine in self.core_engines - if not engine.num_reqs_in_flight - ] - if coros: - await asyncio.gather(*coros) + if outputs.wave_complete is not None: + # Current wave is complete, move to next wave number + # and mark engines as paused. + if self.current_wave <= outputs.wave_complete: + self.current_wave = outputs.wave_complete + 1 + self.engines_running = False + + elif outputs.start_wave is not None and ( + outputs.start_wave > self.current_wave or + (outputs.start_wave == self.current_wave + and not self.engines_running)): + # Engine received request for a non-current wave so we must ensure + # that other engines progress to the next wave. + self.current_wave = outputs.start_wave + self.engines_running = True + await asyncio.gather(*self._start_wave_coros( + exclude_index=outputs.engine_index)) + + def _start_wave_coros(self, exclude_index: int) -> list[Awaitable[None]]: + logger.debug("Sending start DP wave %d.", self.current_wave) + return [ + self._send_input(EngineCoreRequestType.START_DP_WAVE, + self.current_wave, engine) + for engine in self.core_engines if engine.index != exclude_index + ] async def abort_requests_async(self, request_ids: list[str]) -> None: if not request_ids: @@ -859,5 +937,6 @@ class DPAsyncMPClient(AsyncMPClient): async def _abort_requests(self, request_ids: list[str], engine: CoreEngine) -> None: - await self._send_input(EngineCoreRequestType.ABORT, request_ids, - engine) \ No newline at end of file + if not self.resources.engine_dead: + await self._send_input(EngineCoreRequestType.ABORT, request_ids, + engine) diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index bf06a17507b216743c5203f411bda80c96c7c3ed..dca327cc5d07bde3fc490249aabff55c362888e5 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -1,8 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 - -from dataclasses import dataclass, field +from abc import ABC, abstractmethod from typing import Optional +import tokenizers +from packaging import version +from tokenizers import Tokenizer +from tokenizers.decoders import DecodeStream +from transformers import PreTrainedTokenizerFast + from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger from vllm.transformers_utils.detokenizer_utils import ( @@ -12,39 +17,22 @@ from vllm.v1.engine import EngineCoreRequest logger = init_logger(__name__) -@dataclass class IncrementalDetokenizer: - # Generation data - token_ids: list[int] - output_text: str = "" - tokens: list[str] = field(default_factory=list) - prompt_len: int = 0 - - # Stop strings - stop: list[str] = field(default_factory=list) - include_stop_str_in_output: bool = False - - # Metadata for incremental detokenization - prefix_offset: int = 0 - read_offset: int = 0 - - # Parameters for detokenization - skip_special_tokens: bool = True - spaces_between_special_tokens: bool = True - - # Tokenizer for this request, - # None if detokenization is disabled. - tokenizer: Optional[AnyTokenizer] = None - - # Accounting for stop string buffering - stop_buffer_length: int = 0 - _last_output_text_offset: int = 0 + def __init__(self): + self.token_ids: list[int] = [] @property def output_token_ids(self) -> list[int]: - return self.token_ids if not self.prompt_len else ( - self.token_ids[self.prompt_len:]) + return self.token_ids + + def update(self, new_token_ids: list[int], + stop_terminated: bool) -> Optional[str]: + self.token_ids.extend(new_token_ids) + return None + + def get_next_output_text(self, finished: bool, delta: bool) -> str: + return "" @classmethod def from_new_request( @@ -54,39 +42,39 @@ class IncrementalDetokenizer: ) -> "IncrementalDetokenizer": if tokenizer is None: - return cls(token_ids=[]) + # No tokenizer => skipping detokenization. + return IncrementalDetokenizer() + + if (isinstance(tokenizer, PreTrainedTokenizerFast) and version.parse( + tokenizers.__version__) >= version.parse("0.21.1")): + # Fast tokenizer => use tokenizers library DecodeStream. + # And only tokenizers >= 0.21.1 supports Fast Detokenizer. + return FastIncrementalDetokenizer(tokenizer, request) + + # Fall back to slow python-based incremental detokenization. + return SlowIncrementalDetokenizer(tokenizer, request) + + +class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC): + + def __init__(self, request: EngineCoreRequest): + super().__init__() - tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens( - tokenizer=tokenizer, - prompt_ids=request.prompt_token_ids, - skip_special_tokens=request.sampling_params.skip_special_tokens, - ) + # Stop strings + params = request.sampling_params + self.stop = stop = params.stop + self.include_stop_str_in_output = params.include_stop_str_in_output - stops = request.sampling_params.stop # Number of chars to hold back when stop strings are to be excluded # from streamed output. - if stops and not request.sampling_params.include_stop_str_in_output: - stop_buffer_length = max(len(s) for s in stops) - 1 + if stop and not self.include_stop_str_in_output: + self.stop_buffer_length = max(len(s) for s in stop) - 1 else: - stop_buffer_length = 0 - - return cls( - tokens=tokens, - # Detokenizer mutates this list, so need a unique copy. - # NOTE(Nick): could we take ownership of it though? - token_ids=request.prompt_token_ids.copy(), - stop=stops, - include_stop_str_in_output=request.sampling_params. - include_stop_str_in_output, - prefix_offset=prefix_offset, - read_offset=read_offset, - skip_special_tokens=request.sampling_params.skip_special_tokens, - spaces_between_special_tokens=request.sampling_params. - spaces_between_special_tokens, - prompt_len=len(request.prompt_token_ids), - tokenizer=tokenizer, - stop_buffer_length=stop_buffer_length, - ) + self.stop_buffer_length = 0 + self._last_output_text_offset: int = 0 + + # Generation data + self.output_text = "" def update(self, new_token_ids: list[int], stop_terminated: bool) -> Optional[str]: @@ -98,11 +86,7 @@ class IncrementalDetokenizer: Return matched stop string or None. """ if not new_token_ids: - # Skip detokenization if no new token ids - return None - if self.tokenizer is None: - # Skip detokenization if no tokenizer - self.token_ids.extend(new_token_ids) + # Skip detokenization if no new token ids. return None if stop_terminated and not self.include_stop_str_in_output: @@ -116,34 +100,16 @@ class IncrementalDetokenizer: # 1) Detokenize the new token ids incrementally. # TODO(woosuk): This method becomes very inefficient when the number of # new_token_ids is more than 1. We need to optimize this. - decoded_text = "" + offset_before = len(self.output_text) for new_token_id in new_token_ids: self.token_ids.append(new_token_id) - (new_tokens, new_decoded_token_text, prefix_offset, - read_offset) = detokenize_incrementally( - tokenizer=self.tokenizer, - all_input_ids=self.token_ids, - prev_tokens=self.tokens, - prefix_offset=self.prefix_offset, - read_offset=self.read_offset, - skip_special_tokens=self.skip_special_tokens, - spaces_between_special_tokens=self. - spaces_between_special_tokens, - ) - - self.tokens.extend(new_tokens) - self.prefix_offset = prefix_offset - self.read_offset = read_offset - - decoded_text += new_decoded_token_text - - self.output_text += decoded_text + self.output_text += self.decode_next(new_token_id) if stop_terminated: if skipped_stop_token_id is not None: - # Cleanup after skipping detokenization + # Cleanup after skipping detokenization. self.token_ids.append(skipped_stop_token_id) - # Stop token triggered; skip stop string check + # Stop token triggered; skip stop string check. return None # 2) Evaluate stop strings. @@ -151,7 +117,7 @@ class IncrementalDetokenizer: if self.stop: stop = StopChecker.check_stop_strings( output_text=self.output_text, - new_char_count=len(decoded_text), + new_char_count=len(self.output_text) - offset_before, stop=self.stop, include_in_output=self.include_stop_str_in_output, ) @@ -162,6 +128,10 @@ class IncrementalDetokenizer: return stop_string + @abstractmethod + def decode_next(self, next_token_id: int) -> str: + raise NotImplementedError + def get_next_output_text(self, finished: bool, delta: bool) -> str: """If delta is True, only new text since the last call to this method is returned""" @@ -177,3 +147,114 @@ class IncrementalDetokenizer: self._last_output_text_offset = length return self.output_text[last_offset:length] return "" + + +class FastIncrementalDetokenizer(BaseIncrementalDetokenizer): + + def __init__(self, tokenizer: PreTrainedTokenizerFast, + request: EngineCoreRequest): + super().__init__(request) + + sampling_params = request.sampling_params + self.stream = DecodeStream( + skip_special_tokens=sampling_params.skip_special_tokens) + + self.tokenizer: Tokenizer = tokenizer._tokenizer + + # Find a safe place to start. + prompt_suffix = request.prompt_token_ids + prompt_len = len(prompt_suffix) + if prompt_len > 4: + for i in range(4, min(prompt_len + 1, 24)): + suffix = request.prompt_token_ids[-i:] + if '�' not in self.tokenizer.decode(suffix): + prompt_suffix = suffix + break + + # Prime the stream. + for tid in prompt_suffix: + self.stream.step(self.tokenizer, tid) + + self.spaces_between_special_tokens = ( + sampling_params.skip_special_tokens + or sampling_params.spaces_between_special_tokens) + + if not self.spaces_between_special_tokens: + # Store dict of added token ids so that we can suppress + # the spaces between them. + if (added_token_ids := getattr(self.tokenizer, "added_token_ids", + None)) is None: + self.tokenizer.added_token_ids = added_token_ids = { + tid: tok.content + for tid, tok in + self.tokenizer.get_added_tokens_decoder().items() + } + + if added_token_ids: + self.last_special = False + self.added_token_ids = added_token_ids + else: + # No added tokens. + self.spaces_between_special_tokens = True + + def decode_next(self, next_token_id: int) -> str: + token = self.stream.step(self.tokenizer, next_token_id) + + if not self.spaces_between_special_tokens: + special_token = self.added_token_ids.get(next_token_id) + is_special = special_token is not None + if is_special and self.last_special: + # Return raw token string without any prefixed spaces. + token = special_token + self.last_special = is_special + + return token or "" + + +class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer): + + def __init__(self, tokenizer: AnyTokenizer, request: EngineCoreRequest): + super().__init__(request) + + self.tokenizer = tokenizer + + # Metadata for incremental detokenization. + self.tokens, self.prefix_offset, self.read_offset = ( + convert_prompt_ids_to_tokens( + tokenizer=tokenizer, + prompt_ids=request.prompt_token_ids, + skip_special_tokens=request.sampling_params. + skip_special_tokens, + )) + + self.token_ids.extend(request.prompt_token_ids) + self.prompt_len = len(request.prompt_token_ids) + + params = request.sampling_params + self.skip_special_tokens = params.skip_special_tokens + self.spaces_between_special_tokens = ( + params.spaces_between_special_tokens) + + @property + def output_token_ids(self) -> list[int]: + return self.token_ids if not self.prompt_len else ( + self.token_ids[self.prompt_len:]) + + def decode_next(self, next_token_id: int) -> str: + new_tokens, decoded_text, prefix_offset, read_offset = ( + detokenize_incrementally( + tokenizer=self.tokenizer, + all_input_ids=self.token_ids, + prev_tokens=self.tokens, + prefix_offset=self.prefix_offset, + read_offset=self.read_offset, + skip_special_tokens=self.skip_special_tokens, + spaces_between_special_tokens=self. + spaces_between_special_tokens, + )) + + self.tokens.extend(new_tokens) + self.prefix_offset = prefix_offset + self.read_offset = read_offset + + return decoded_text diff --git a/vllm/v1/engine/exceptions.py b/vllm/v1/engine/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..97dd31d5e5218f7353d59cc1d17f7dbb2f826114 --- /dev/null +++ b/vllm/v1/engine/exceptions.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 +class EngineGenerateError(Exception): + """Raised when a AsyncLLM.generate() fails. Recoverable.""" + pass + + +class EngineDeadError(Exception): + """Raised when the EngineCore dies. Unrecoverable.""" + + def __init__(self, *args, suppress_context: bool = False, **kwargs): + ENGINE_DEAD_MESSAGE = "EngineCore encountered an issue. See stack trace (above) for the root cause." # noqa: E501 + + super().__init__(ENGINE_DEAD_MESSAGE, *args, **kwargs) + # Make stack trace clearer when using with LLMEngine by + # silencing irrelevant ZMQError. + self.__suppress_context__ = suppress_context diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 4c67186f70401c2f1ebcba3dc1cede2e17563e35..85da58451c78773505f9e6ff667ae0d3d0523664 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -10,7 +10,6 @@ import vllm.envs as envs from vllm.config import ParallelConfig, VllmConfig from vllm.distributed import stateless_destroy_torch_distributed_process_group from vllm.engine.arg_utils import EngineArgs -from vllm.engine.metrics_types import StatLoggerBase from vllm.inputs import PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -20,7 +19,7 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import ( - BaseTokenizerGroup, init_tokenizer_from_configs) + TokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import UsageContext from vllm.utils import Device from vllm.v1.engine.core_client import EngineCoreClient @@ -28,10 +27,10 @@ from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor +from vllm.v1.metrics.loggers import StatLoggerFactory logger = init_logger(__name__) -_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) _R = TypeVar("_R", default=Any) @@ -44,7 +43,7 @@ class LLMEngine: executor_class: type[Executor], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[dict[str, StatLoggerBase]] = None, + stat_loggers: Optional[list[StatLoggerFactory]] = None, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, use_cached_outputs: bool = False, multiprocess_mode: bool = False, @@ -56,6 +55,11 @@ class LLMEngine: "LLMEngine.from_vllm_config(...) or explicitly set " "VLLM_USE_V1=0 or 1 and report this issue on Github.") + if stat_loggers is not None: + raise NotImplementedError( + "Passing StatLoggers to LLMEngine in V1 is not yet supported. " + "Set VLLM_USE_V1=0 and file and issue on Github.") + self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config @@ -73,9 +77,7 @@ class LLMEngine: self.tokenizer = init_tokenizer_from_configs( model_config=vllm_config.model_config, scheduler_config=vllm_config.scheduler_config, - parallel_config=vllm_config.parallel_config, lora_config=vllm_config.lora_config) - self.tokenizer.ping() # Processor (convert Inputs --> EngineCoreRequests) self.processor = Processor(vllm_config=vllm_config, @@ -104,14 +106,9 @@ class LLMEngine: cls, vllm_config: VllmConfig, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[dict[str, StatLoggerBase]] = None, + stat_loggers: Optional[list[StatLoggerFactory]] = None, disable_log_stats: bool = False, ) -> "LLMEngine": - if stat_loggers is not None: - raise NotImplementedError( - "Passing StatLoggers to V1 is not yet supported. " - "Set VLLM_USE_V1=0 and file and issue on Github.") - return cls(vllm_config=vllm_config, executor_class=Executor.get_class(vllm_config), log_stats=(not disable_log_stats), @@ -124,7 +121,7 @@ class LLMEngine: cls, engine_args: EngineArgs, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[dict[str, StatLoggerBase]] = None, + stat_loggers: Optional[list[StatLoggerFactory]] = None, enable_multiprocessing: bool = False, ) -> "LLMEngine": """Creates an LLM engine from the engine arguments.""" @@ -183,17 +180,15 @@ class LLMEngine: priority: int = 0, ) -> None: # Process raw inputs into the request. - request = self.processor.process_inputs(request_id, prompt, params, - arrival_time, lora_request, - trace_headers, - prompt_adapter_request, - priority) + prompt_str, request = self.processor.process_inputs( + request_id, prompt, params, arrival_time, lora_request, + trace_headers, prompt_adapter_request, priority) n = params.n if isinstance(params, SamplingParams) else 1 if n == 1: # Make a new RequestState and queue. - self.output_processor.add_request(request, None, 0) + self.output_processor.add_request(request, prompt_str, None, 0) # Add the request to EngineCore. self.engine_core.add_request(request) return @@ -207,7 +202,8 @@ class LLMEngine: child_request.sampling_params = params # Make a new RequestState and queue. - self.output_processor.add_request(child_request, parent_req, idx) + self.output_processor.add_request(child_request, prompt_str, + parent_req, idx) # Add the request to EngineCore. self.engine_core.add_request(child_request) @@ -230,6 +226,9 @@ class LLMEngine: return processed_outputs.request_outputs + def get_vllm_config(self): + return self.vllm_config + def get_model_config(self): return self.model_config @@ -251,21 +250,12 @@ class LLMEngine: def is_sleeping(self) -> bool: return self.engine_core.is_sleeping() - def get_tokenizer_group( - self, - group_type: type[_G] = BaseTokenizerGroup, - ) -> _G: - tokenizer_group = self.tokenizer - - if tokenizer_group is None: + def get_tokenizer_group(self) -> TokenizerGroup: + if self.tokenizer is None: raise ValueError("Unable to get tokenizer because " "skip_tokenizer_init is True") - if not isinstance(tokenizer_group, group_type): - raise TypeError("Invalid type of tokenizer group. " - f"Expected type: {group_type}, but " - f"found type: {type(tokenizer_group)}") - return tokenizer_group + return self.tokenizer def add_lora(self, lora_request: LoRARequest) -> bool: """Load a new LoRA adapter into the engine for future requests.""" diff --git a/vllm/v1/engine/mm_input_cache.py b/vllm/v1/engine/mm_input_cache.py index ef5a2e5acb152b4c0251b858d04186227d8447d8..c765c1bbffcf31d02fe680be43bb54a613ff5a4d 100644 --- a/vllm/v1/engine/mm_input_cache.py +++ b/vllm/v1/engine/mm_input_cache.py @@ -50,7 +50,7 @@ class MirroredProcessingCache: full_mm_inputs = list[Optional[MultiModalKwargs]]() for mm_input, mm_hash in zip(mm_inputs, mm_hashes): - if mm_hash in self.mm_cache: + if self.mm_cache.get(mm_hash) is not None: mm_input = None else: self.mm_cache[mm_hash] = mm_input diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 70f072d3c9399f83c29efca11553aaaef6b1fe40..f76c44cb8bca7264e7d846db083675bb3d11a124 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -8,7 +8,7 @@ from typing import Optional, Union from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import RequestOutputKind from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup +from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason from vllm.v1.engine.detokenizer import IncrementalDetokenizer from vllm.v1.engine.logprobs import LogprobsProcessor @@ -28,32 +28,37 @@ class RequestOutputCollector: def __init__(self, output_kind: RequestOutputKind): self.aggregate = output_kind == RequestOutputKind.DELTA - self.output: Optional[RequestOutput] = None + self.output: Optional[Union[RequestOutput, Exception]] = None self.ready = asyncio.Event() - def put(self, output: RequestOutput) -> None: - if self.output is None: + def put(self, output: Union[RequestOutput, Exception]) -> None: + """Non-blocking put operation.""" + if self.output is None or isinstance(output, Exception): self.output = output self.ready.set() - elif self.aggregate: - # Coalesce the outputs in delta case. - self.output.add(output) - else: - # Just replace latest in non-delta case. - self.output = output + elif isinstance(self.output, RequestOutput): + # This ensures that request outputs with different request indexes + # (if n > 1) do not override each other. + self.output.add(output, aggregate=self.aggregate) async def get(self) -> RequestOutput: + """Get operation blocks on put event.""" while (output := self.output) is None: await self.ready.wait() self.output = None self.ready.clear() + if isinstance(output, Exception): + raise output return output def get_nowait(self) -> Optional[RequestOutput]: + """Non-blocking get operation.""" output = self.output if output is not None: self.output = None self.ready.clear() + if isinstance(output, Exception): + raise output return output @@ -104,6 +109,7 @@ class RequestState: cls, tokenizer: AnyTokenizer, request: EngineCoreRequest, + prompt: Optional[str], parent_req: Optional[ParentRequest], request_index: int, queue: Optional[RequestOutputCollector], @@ -118,7 +124,7 @@ class RequestState: lora_name=(request.lora_request.name if request.lora_request is not None else None), output_kind=request.sampling_params.output_kind, - prompt=request.prompt, + prompt=prompt, prompt_token_ids=request.prompt_token_ids, logprobs_processor=LogprobsProcessor.from_new_request( tokenizer=tokenizer, @@ -220,7 +226,7 @@ class OutputProcessor: def __init__( self, - tokenizer: BaseTokenizerGroup, + tokenizer: TokenizerGroup, log_stats: bool, ): self.log_stats = log_stats @@ -235,6 +241,13 @@ class OutputProcessor: def has_unfinished_requests(self) -> bool: return len(self.request_states) > 0 + def propagate_error(self, e: Exception): + """Propagate error to all generate() tasks.""" + + for _, state in self.request_states.items(): + assert state.queue is not None + state.queue.put(e) + def abort_requests( self, request_ids: Iterable[str], @@ -255,6 +268,7 @@ class OutputProcessor: def add_request( self, request: EngineCoreRequest, + prompt: Optional[str], parent_req: Optional[ParentRequest] = None, request_index: int = 0, queue: Optional[RequestOutputCollector] = None, @@ -266,6 +280,7 @@ class OutputProcessor: req_state = RequestState.from_new_request( tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request), request=request, + prompt=prompt, parent_req=parent_req, request_index=request_index, queue=queue, diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 6d3290f16565347ed7667a3df2893fe60517af7d..fa334302e781d59c159120cd9b4324142992bdd4 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -17,13 +17,13 @@ from vllm.multimodal.utils import merge_and_sort_multimodal_metadata from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup +from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.mm_input_cache import MirroredProcessingCache from vllm.v1.structured_output.backend_guidance import ( validate_guidance_grammar) -from vllm.v1.structured_output.utils import ( - validate_structured_output_request_xgrammar) +from vllm.v1.structured_output.backend_xgrammar import ( + validate_xgrammar_grammar) class Processor: @@ -31,7 +31,7 @@ class Processor: def __init__( self, vllm_config: VllmConfig, - tokenizer: BaseTokenizerGroup, + tokenizer: TokenizerGroup, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ): @@ -145,48 +145,52 @@ class Processor: if not params.guided_decoding or not self.decoding_config: return - supported_backends = [ - "xgrammar", "xgrammar:disable-any-whitespace", "guidance", - "guidance:disable-any-whitespace", "auto" - ] engine_level_backend = self.decoding_config.guided_decoding_backend - if engine_level_backend not in supported_backends: - raise ValueError(f"Only {supported_backends} structured output is " - "supported in V1.") if params.guided_decoding.backend: - if params.guided_decoding.backend != engine_level_backend: - raise ValueError("Request-level structured output backend " - "must match engine-level backend. " - f"{params.guided_decoding.backend}" - f" != {engine_level_backend}") + # Request-level backend selection is not supported in V1. + # The values may differ if `params` is reused and was set + # to a specific backend based on `auto` behavior in a previous + # request. We remember that it was set as a result of `auto` + # using the `_auto` option set on the backend in the params. + if (params.guided_decoding.backend != engine_level_backend + and not (engine_level_backend == "auto" and "_auto" + in params.guided_decoding.backend_options())): + raise ValueError( + "Request-level structured output backend selection is no " + "longer supported. The request specified " + f"'{params.guided_decoding.backend}', but vLLM was " + f"initialised with '{engine_level_backend}'. This error " + "can be resolved by removing backend selection from the " + "request.") else: params.guided_decoding.backend = engine_level_backend # Request content validation if engine_level_backend.startswith("xgrammar"): # xgrammar with no fallback - validate_structured_output_request_xgrammar(params) - params.guided_decoding.backend = engine_level_backend - elif engine_level_backend == "auto": + validate_xgrammar_grammar(params) + elif engine_level_backend.startswith("guidance"): + # TODO: ideally we would have the LLTokenizer here as Lark syntax + # allows <|special_token|> and similar, see + # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens + # Without tokenizer these are disallowed in grammars. + validate_guidance_grammar(params, tokenizer=None) + else: + # NOTE: engine_level_backend must be "auto" here, because we have + # checked supported_backends above. # "auto" is an opt-in to opinionated behavior where we try to # choose a backend based on request contents. This is not the # default as it is less predictable and subject to change # between releases as feature support changes. try: - validate_structured_output_request_xgrammar(params) + validate_xgrammar_grammar(params) params.guided_decoding.backend = "xgrammar" except ValueError: # The request includes some jsonschema feature(s) that # are not supported in xgrammar. Fall back to guidance. params.guided_decoding.backend = "guidance" - - if engine_level_backend.startswith("guidance"): - # TODO ideally we would have the LLTokenizer here as Lark syntax - # allows <|special_token|> and similar, see - # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens - # Without tokenizer these are disallowed in grammars. - validate_guidance_grammar(params, tokenizer=None) - params.guided_decoding.backend = engine_level_backend + # Remember that this backend was set automatically + params.guided_decoding.add_option("_auto") def process_inputs( self, @@ -198,16 +202,10 @@ class Processor: trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, - ) -> EngineCoreRequest: + ) -> tuple[Optional[str], EngineCoreRequest]: # TODO(woosuk): Support pooling models. # TODO(woosuk): Support encoder-decoder models. - - from vllm.platforms import current_platform - current_platform.validate_request( - prompt=prompt, - params=params, - ) self._validate_lora(lora_request) self._validate_params(params) if priority != 0: @@ -231,6 +229,12 @@ class Processor: prompt_adapter_request=prompt_adapter_request, return_mm_hashes=self.use_hash, ) + from vllm.platforms import current_platform + current_platform.validate_request( + prompt=prompt, + params=params, + processed_inputs=processed_inputs, + ) eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) self._validate_model_inputs(processed_inputs, lora_request) @@ -302,9 +306,8 @@ class Processor: else: sorted_mm_inputs = orig_sorted_mm_inputs - return EngineCoreRequest( + return decoder_inputs.get("prompt"), EngineCoreRequest( request_id=request_id, - prompt=decoder_inputs.get("prompt"), prompt_token_ids=decoder_inputs["prompt_token_ids"], mm_inputs=sorted_mm_inputs, mm_hashes=sorted_mm_hashes, @@ -351,7 +354,7 @@ class Processor: raise ValueError(f"Token id {max_input_id} is out of vocabulary") max_prompt_len = self.model_config.max_model_len - if len(prompt_ids) >= max_prompt_len: + if len(prompt_ids) > max_prompt_len: if prompt_type == "encoder" and model_config.is_multimodal_model: mm_registry = self.input_preprocessor.mm_registry mm_processor = mm_registry.create_processor( diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index e3a4cd98c1f81df7bc82b171ab47d53034105a30..3b9feb0d32980c2946595c48b8d3ecd4ee520852 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from concurrent.futures import Future -from typing import Union +from typing import Callable, Union import torch import torch.distributed as dist @@ -15,6 +15,8 @@ from vllm.executor.uniproc_executor import ( # noqa from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput +FailureCallback = Callable[[], None] + class Executor(ExecutorBase): """ @@ -62,6 +64,13 @@ class Executor(ExecutorBase): args=(kv_cache_configs, )) self.collective_rpc("compile_or_warm_up_model") + def register_failure_callback(self, callback: FailureCallback): + """ + Register a function to be called if the executor enters a permanent + failed state. + """ + pass + def determine_available_memory(self) -> list[int]: # in bytes output = self.collective_rpc("determine_available_memory") return output diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index e854c2a44ff94d5ae782aea8e3eaaec577ff6040..cb125bf4bf1739170f6d0550d476b62f19bbfec0 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -1,21 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 - +import multiprocessing import os import pickle import signal import sys +import threading import time import traceback import weakref +from concurrent.futures import Future from dataclasses import dataclass from enum import Enum, auto from functools import partial +from multiprocessing.connection import Connection from multiprocessing.process import BaseProcess -from typing import Any, Callable, Optional, Union +from threading import Thread +from typing import Any, Callable, Optional, Union, cast import cloudpickle -import psutil -import zmq from vllm.config import VllmConfig from vllm.distributed import (destroy_distributed_environment, @@ -26,8 +28,9 @@ from vllm.executor.multiproc_worker_utils import ( _add_prefix, set_multiprocessing_worker_envs) from vllm.logger import init_logger from vllm.utils import (get_distributed_init_method, get_mp_context, - get_open_port, get_open_zmq_ipc_path, zmq_socket_ctx) -from vllm.v1.executor.abstract import Executor + get_open_port) +from vllm.v1.executor.abstract import Executor, FailureCallback +from vllm.v1.outputs import ModelRunnerOutput from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -35,6 +38,8 @@ logger = init_logger(__name__) POLLING_TIMEOUT_MS = 5000 POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000 +EXECUTE_MODEL_TIMEOUT_S = 40 + class MultiprocExecutor(Executor): @@ -42,19 +47,9 @@ class MultiprocExecutor(Executor): # Call self.shutdown at exit to clean up # and ensure workers will be terminated. self._finalizer = weakref.finalize(self, self.shutdown) - - # The child processes will send SIGUSR1 when unrecoverable - # errors happen. - def sigusr1_handler(signum, frame): - logger.fatal( - "MulitprocExecutor got fatal signal from worker processes, " - "shutting down. See stack trace above for root cause issue.") - # Propagate error up to parent process. - parent_process = psutil.Process().parent() - parent_process.send_signal(signal.SIGUSR1) - self.shutdown() - - signal.signal(signal.SIGUSR1, sigusr1_handler) + self.is_failed = False + self.shutdown_event = threading.Event() + self.failure_callback: Optional[FailureCallback] = None self.world_size = self.parallel_config.world_size tensor_parallel_size = self.parallel_config.tensor_parallel_size @@ -78,26 +73,92 @@ class MultiprocExecutor(Executor): scheduler_output_handle = self.rpc_broadcast_mq.export_handle() # Create workers - self.workers: list[WorkerProcHandle] = [] - for rank in range(self.world_size): - worker = WorkerProc.make_worker_process(self.vllm_config, rank, - rank, - distributed_init_method, - scheduler_output_handle) - self.workers.append(worker) - - # Ensure message queues are ready. Will deadlock if re-ordered - # Must be kept consistent with the WorkerProc - self.rpc_broadcast_mq.wait_until_ready() - for w in self.workers: - w.worker_response_mq.wait_until_ready() + unready_workers: list[UnreadyWorkerProcHandle] = [] + success = False + try: + for rank in range(self.world_size): + unready_workers.append( + WorkerProc.make_worker_process( + vllm_config=self.vllm_config, + local_rank=rank, + rank=rank, + distributed_init_method=distributed_init_method, + input_shm_handle=scheduler_output_handle, + )) + + # Workers must be created before wait_for_ready to avoid + # deadlock, since worker.init_device() does a device sync. + self.workers = WorkerProc.wait_for_ready(unready_workers) + + # Ensure message queues are ready. Will deadlock if re-ordered + # Must be kept consistent with the WorkerProc. + self.rpc_broadcast_mq.wait_until_ready() + for w in self.workers: + w.worker_response_mq.wait_until_ready() + + self.start_worker_monitor() + success = True + finally: + if not success: + # Clean up the worker procs if there was a failure. + self._ensure_worker_termination( + [w.proc for w in unready_workers]) + + def start_worker_monitor(self): + workers = self.workers + self_ref = weakref.ref(self) + + # Monitors worker process liveness. If any die unexpectedly, + # logs an error, shuts down the executor and invokes the failure + # callback to inform the engine. + def monitor_workers(): + sentinels = [h.proc.sentinel for h in workers] + died = multiprocessing.connection.wait(sentinels) + _self = self_ref() + if not _self or getattr(_self, 'shutting_down', False): + return + _self.is_failed = True + proc_name = next(h.proc.name for h in workers + if h.proc.sentinel == died[0]) + logger.error( + "Worker proc %s died unexpectedly, " + "shutting down executor.", proc_name) + _self.shutdown() + callback = _self.failure_callback + if callback is not None: + _self.failure_callback = None + callback() + + Thread(target=monitor_workers, + daemon=True, + name="MultiprocWorkerMonitor").start() + + def register_failure_callback(self, callback: FailureCallback): + if self.is_failed: + callback() + else: + self.failure_callback = callback + + def execute_model( + self, + scheduler_output, + ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: + (output, ) = self.collective_rpc("execute_model", + args=(scheduler_output, ), + rank0_reply_only=True, + timeout=EXECUTE_MODEL_TIMEOUT_S) + return output def collective_rpc(self, method: Union[str, Callable], timeout: Optional[float] = None, args: tuple = (), - kwargs: Optional[dict] = None) -> list[Any]: - start_time = time.monotonic() + kwargs: Optional[dict] = None, + rank0_reply_only: bool = False) -> list[Any]: + if self.is_failed: + raise RuntimeError("Executor failed.") + + deadline = None if timeout is None else time.monotonic() + timeout kwargs = kwargs or {} # NOTE: If the args are heterogeneous, then we pack them into a list, @@ -109,30 +170,30 @@ class MultiprocExecutor(Executor): else: send_method = cloudpickle.dumps( method, protocol=pickle.HIGHEST_PROTOCOL) - self.rpc_broadcast_mq.enqueue((send_method, args, kwargs)) - - responses = [None] * self.world_size - for w in self.workers: - dequeue_timeout = timeout - (time.monotonic() - start_time - ) if timeout is not None else None + self.rpc_broadcast_mq.enqueue( + (send_method, args, kwargs, rank0_reply_only)) + + workers = (self.workers[0], ) if rank0_reply_only else self.workers + responses = [None] * len(workers) + for w in workers: + dequeue_timeout = None if deadline is None else ( + deadline - time.monotonic()) status, result = w.worker_response_mq.dequeue( - timeout=dequeue_timeout) + timeout=dequeue_timeout, cancel=self.shutdown_event) if status != WorkerProc.ResponseStatus.SUCCESS: raise RuntimeError( - "Worker failed with error %s, please check the" - " stack trace above for the root cause", result) + f"Worker failed with error '{result}', please check the" + " stack trace above for the root cause") responses[w.rank] = result return responses except TimeoutError as e: raise TimeoutError(f"RPC call to {method} timed out.") from e - except Exception as e: - # Re-raise any other exceptions - raise e - def _ensure_worker_termination(self): + @staticmethod + def _ensure_worker_termination(worker_procs: list[BaseProcess]): """Ensure that all worker processes are terminated. Assumes workers have received termination requests. Waits for processing, then sends termination and kill signals if needed.""" @@ -150,7 +211,7 @@ class MultiprocExecutor(Executor): return False # Send SIGTERM if still running - active_procs = [w.proc for w in self.workers if w.proc.is_alive()] + active_procs = [proc for proc in worker_procs if proc.is_alive()] for p in active_procs: p.terminate() if not wait_for_termination(active_procs, 4): @@ -159,22 +220,14 @@ class MultiprocExecutor(Executor): for p in active_procs: p.kill() - self._cleanup_sockets() - - def _cleanup_sockets(self): - for w in self.workers: - # Remove the zmq ipc socket file - socket_path = w.ready_path.replace("ipc://", "") - if os and os.path.exists(socket_path): - os.remove(socket_path) - def shutdown(self): """Properly shut down the executor and its workers""" if not getattr(self, 'shutting_down', False): self.shutting_down = True + self.shutdown_event.set() for w in self.workers: w.worker_response_mq = None - self._ensure_worker_termination() + self._ensure_worker_termination([w.proc for w in self.workers]) self.rpc_broadcast_mq = None @@ -183,13 +236,30 @@ class MultiprocExecutor(Executor): return +@dataclass +class UnreadyWorkerProcHandle: + """WorkerProcess handle before READY.""" + proc: BaseProcess + rank: int + ready_pipe: Connection + + @dataclass class WorkerProcHandle: proc: BaseProcess rank: int - ready_path: str worker_response_mq: MessageQueue # The worker process writes to this MQ + @classmethod + def from_unready_handle( + cls, unready_handle: UnreadyWorkerProcHandle, + worker_response_mq: MessageQueue) -> "WorkerProcHandle": + return cls( + proc=unready_handle.proc, + rank=unready_handle.rank, + worker_response_mq=worker_response_mq, + ) + class WorkerProc: """Wrapper that runs one Worker in a separate process.""" @@ -203,7 +273,6 @@ class WorkerProc: rank: int, distributed_init_method: str, input_shm_handle: Handle, - ready_path: str, ): self.rank = rank wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank) @@ -231,18 +300,8 @@ class WorkerProc: # Initializes a message queue for sending the model output self.worker_response_mq = MessageQueue(1, 1) - worker_response_mq_handle = self.worker_response_mq.export_handle() - - # Send Readiness signal to EngineCore process. - # Set linger here because we want to ensure the message has - # been sent before the context is closed. - with zmq_socket_ctx(ready_path, zmq.constants.PUSH, - linger=10000) as ready_socket: - payload = pickle.dumps(worker_response_mq_handle, - protocol=pickle.HIGHEST_PROTOCOL) - ready_socket.send_string(WorkerProc.READY_STR) - ready_socket.send(payload) + # Initialize device and loads weights self.worker.init_device() self.worker.load_model() @@ -253,12 +312,10 @@ class WorkerProc: rank: int, distributed_init_method: str, input_shm_handle, # Receive SchedulerOutput - ) -> WorkerProcHandle: + ) -> UnreadyWorkerProcHandle: context = get_mp_context() - - # ZMQ path for worker to send ready message and shm_broadcast handle - # back to core process. - ready_path = get_open_zmq_ipc_path() + # (reader, writer) + reader, writer = context.Pipe(duplex=False) process_kwargs = { "vllm_config": vllm_config, @@ -266,24 +323,57 @@ class WorkerProc: "rank": rank, "distributed_init_method": distributed_init_method, "input_shm_handle": input_shm_handle, - "ready_path": ready_path, + "ready_pipe": (reader, writer), } # Run EngineCore busy loop in background process. proc = context.Process(target=WorkerProc.worker_main, kwargs=process_kwargs, + name=f"VllmWorker-{rank}", daemon=True) - with zmq_socket_ctx(ready_path, zmq.constants.PULL) as ready_socket: - proc.start() - - # Wait for startup - worker_response_mq_handle = WorkerProc.wait_for_startup( - proc, ready_socket) - - worker_response_mq = MessageQueue.create_from_handle( - worker_response_mq_handle, 0) + proc.start() + writer.close() + return UnreadyWorkerProcHandle(proc, rank, reader) - return WorkerProcHandle(proc, rank, ready_path, worker_response_mq) + @staticmethod + def wait_for_ready( + unready_proc_handles: list[UnreadyWorkerProcHandle] + ) -> list[WorkerProcHandle]: + + e = Exception("WorkerProc initialization failed due to " + "an exception in a background process. " + "See stack trace for root cause.") + + pipes = {handle.ready_pipe: handle for handle in unready_proc_handles} + ready_proc_handles: list[Optional[WorkerProcHandle]] = ( + [None] * len(unready_proc_handles)) + while pipes: + ready = multiprocessing.connection.wait(pipes.keys()) + for pipe in ready: + assert isinstance(pipe, Connection) + try: + # Wait until the WorkerProc is ready. + unready_proc_handle = pipes.pop(pipe) + response: dict[str, Any] = pipe.recv() + if response["status"] != "READY": + raise e + + # Extract the message queue handle. + worker_response_mq = MessageQueue.create_from_handle( + response["handle"], 0) + ready_proc_handles[unready_proc_handle.rank] = ( + WorkerProcHandle.from_unready_handle( + unready_proc_handle, worker_response_mq)) + + except EOFError: + e.__suppress_context__ = True + raise e from None + + finally: + # Close connection. + pipe.close() + + return cast(list[WorkerProcHandle], ready_proc_handles) def shutdown(self): self.rpc_broadcast_mq = None @@ -312,51 +402,51 @@ class WorkerProc: signal.signal(signal.SIGINT, signal_handler) worker = None + # tuple[Connection, Connection] + reader, ready_writer = kwargs.pop("ready_pipe") try: + reader.close() worker = WorkerProc(*args, **kwargs) + # Send READY once we know everything is loaded + ready_writer.send({ + "status": + WorkerProc.READY_STR, + "handle": + worker.worker_response_mq.export_handle(), + }) + # Ensure message queues are ready. Will deadlock if re-ordered. # Must be kept consistent with the Executor worker.rpc_broadcast_mq.wait_until_ready() worker.worker_response_mq.wait_until_ready() + ready_writer.close() + ready_writer = None worker.worker_busy_loop() - except SystemExit: - logger.debug("Worker interrupted.") - except Exception: - # worker_busy_loop sends exceptions to Executor - # for shutdown, but if there is an error in startup or an - # error with IPC itself, we need to alert the parent. - psutil.Process().parent().send_signal(signal.SIGUSR1) - raise + # NOTE: if an Exception arises in busy_loop, we send + # a FAILURE message over the MQ RPC to notify the Executor, + # which triggers system shutdown. + # TODO(rob): handle case where the MQ itself breaks. + + if ready_writer is not None: + logger.exception("WorkerProc failed to start.") + else: + logger.exception("WorkerProc failed.") + + # The parent sends a SIGTERM to all worker processes if + # any worker dies. Set this value so we don't re-throw + # SystemExit() to avoid zmq exceptions in __del__. + shutdown_requested = True finally: + if ready_writer is not None: + ready_writer.close() # Clean up once worker exits busy loop if worker is not None: worker.shutdown() - worker = None - - @staticmethod - def wait_for_startup( - proc: BaseProcess, - ready_socket: zmq.Socket, - ) -> Optional[Handle]: - """Wait until the Worker is ready.""" - - # Wait for Worker to send READY. - while ready_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: - logger.debug("Waiting for WorkerProc to startup.") - - if not proc.is_alive(): - raise RuntimeError("WorkerProc failed to start.") - - message = ready_socket.recv_string() - assert message == WorkerProc.READY_STR - handle_frame = ready_socket.recv(copy=False) - handle = pickle.loads(handle_frame.buffer) - return handle class ResponseStatus(Enum): SUCCESS = auto() @@ -365,7 +455,7 @@ class WorkerProc: def worker_busy_loop(self): """Main busy loop for Multiprocessing Workers""" while True: - method, args, kwargs = self.rpc_broadcast_mq.dequeue() + method, args, kwargs, rank0_only = self.rpc_broadcast_mq.dequeue() try: if isinstance(method, str): @@ -377,12 +467,14 @@ class WorkerProc: # Notes have been introduced in python 3.11 if hasattr(e, "add_note"): e.add_note(traceback.format_exc()) - logger.exception("WorkerProc hit an exception: %s", exc_info=e) + logger.exception("WorkerProc hit an exception.") # exception might not be serializable, so we convert it to # string, only for logging purpose. - self.worker_response_mq.enqueue( - (WorkerProc.ResponseStatus.FAILURE, str(e))) + if not rank0_only or self.rank == 0: + self.worker_response_mq.enqueue( + (WorkerProc.ResponseStatus.FAILURE, str(e))) continue - self.worker_response_mq.enqueue( - (WorkerProc.ResponseStatus.SUCCESS, output)) + if not rank0_only or self.rank == 0: + self.worker_response_mq.enqueue( + (WorkerProc.ResponseStatus.SUCCESS, output)) diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 3959be40b7253cd6013100742f8de3f0efc94b78..7051c681b1a01d1a997cac449b131eada700c940 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 +import logging import time from abc import ABC, abstractmethod -from typing import Optional +from typing import Callable, Optional import numpy as np import prometheus_client @@ -12,14 +13,26 @@ from vllm.logger import init_logger from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics from vllm.v1.engine import FinishReason from vllm.v1.metrics.stats import IterationStats, SchedulerStats -from vllm.v1.spec_decode.metrics import SpecDecodingMetrics +from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5.0 +StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"] + class StatLoggerBase(ABC): + """Interface for logging metrics. + + API users may define custom loggers that implement this interface. + However, note that the `SchedulerStats` and `IterationStats` classes + are not considered stable interfaces and may change in future versions. + """ + + @abstractmethod + def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): + ... @abstractmethod def record(self, scheduler_stats: SchedulerStats, @@ -32,14 +45,16 @@ class StatLoggerBase(ABC): class LoggingStatLogger(StatLoggerBase): - def __init__(self, engine_index: int = 0): + def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): self.engine_index = engine_index self._reset(time.monotonic()) self.last_scheduler_stats = SchedulerStats() # Prefix cache metrics. This cannot be reset. # TODO: Make the interval configurable. self.prefix_caching_metrics = PrefixCachingMetrics() - self.spec_decoding_metrics = SpecDecodingMetrics() + self.spec_decoding_logging = SpecDecodingLogging() + self.last_prompt_throughput: float = 0.0 + self.last_generation_throughput: float = 0.0 def _reset(self, now): self.last_log_time = now @@ -68,7 +83,7 @@ class LoggingStatLogger(StatLoggerBase): self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats) if scheduler_stats.spec_decoding_stats is not None: - self.spec_decoding_metrics.observe( + self.spec_decoding_logging.observe( scheduler_stats.spec_decoding_stats) self.last_scheduler_stats = scheduler_stats @@ -83,8 +98,17 @@ class LoggingStatLogger(StatLoggerBase): scheduler_stats = self.last_scheduler_stats + log_fn = logger.info + if not any( + (prompt_throughput, generation_throughput, + self.last_prompt_throughput, self.last_generation_throughput)): + # Avoid log noise on an idle production system + log_fn = logger.debug + self.last_generation_throughput = generation_throughput + self.last_prompt_throughput = prompt_throughput + # Format and print output. - logger.info( + log_fn( "Engine %03d: " "Avg prompt throughput: %.1f tokens/s, " "Avg generation throughput: %.1f tokens/s, " @@ -101,7 +125,7 @@ class LoggingStatLogger(StatLoggerBase): ) if scheduler_stats.spec_decoding_stats is not None: - self.spec_decoding_metrics.log() + self.spec_decoding_logging.log(log_fn=log_fn) class PrometheusStatLogger(StatLoggerBase): @@ -122,6 +146,9 @@ class PrometheusStatLogger(StatLoggerBase): max_model_len = vllm_config.model_config.max_model_len + self.spec_decoding_prom = SpecDecodingProm( + vllm_config.speculative_config, labelnames, labelvalues) + # # Scheduler state # @@ -205,7 +232,10 @@ class PrometheusStatLogger(StatLoggerBase): prometheus_client.Histogram( name="vllm:iteration_tokens_total", documentation="Histogram of number of tokens per engine_step.", - buckets=build_cudagraph_buckets(vllm_config), + buckets=[ + 1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, + 16384 + ], labelnames=labelnames).labels(*labelvalues) self.histogram_max_num_generation_tokens_request = \ @@ -312,24 +342,6 @@ class PrometheusStatLogger(StatLoggerBase): self.labelname_running_lora_adapters, ]) - # - # Speculative Decoding metrics - # The acceptance rate can be calculated using a PromQL query: - # - # rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) / - # rate(vllm:spec_decode_num_draft_tokens_total[$interval]) - # - self.counter_spec_decode_num_draft_tokens = \ - prometheus_client.Counter( - name="vllm:spec_decode_num_draft_tokens_total", - documentation="Number of draft tokens.", - labelnames=labelnames).labels(*labelvalues) - self.counter_spec_decode_num_accepted_tokens = \ - prometheus_client.Counter( - name="vllm:spec_decode_num_accepted_tokens_total", - documentation="Number of accepted tokens.", - labelnames=labelnames).labels(*labelvalues) - # # Cache config info metric # @@ -367,10 +379,8 @@ class PrometheusStatLogger(StatLoggerBase): scheduler_stats.prefix_cache_stats.hits) if scheduler_stats.spec_decoding_stats is not None: - self.counter_spec_decode_num_draft_tokens.inc( - scheduler_stats.spec_decoding_stats.num_draft_tokens) - self.counter_spec_decode_num_accepted_tokens.inc( - scheduler_stats.spec_decoding_stats.num_accepted_tokens) + self.spec_decoding_prom.observe( + scheduler_stats.spec_decoding_stats) if iteration_stats is None: return @@ -460,11 +470,29 @@ def build_1_2_5_buckets(max_value: int) -> list[int]: return build_buckets([1, 2, 5], max_value) -def build_cudagraph_buckets(vllm_config: VllmConfig) -> list[int]: - if not vllm_config.model_config.enforce_eager: - buckets = vllm_config.compilation_config.\ - cudagraph_capture_sizes.copy() - buckets.sort() - return buckets +def setup_default_loggers( + vllm_config: VllmConfig, + log_stats: bool, + engine_num: int, + custom_stat_loggers: Optional[list[StatLoggerFactory]] = None, +) -> list[list[StatLoggerBase]]: + """Setup logging and prometheus metrics.""" + if not log_stats: + return [] + + factories: list[StatLoggerFactory] + if custom_stat_loggers is not None: + factories = custom_stat_loggers else: - return [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096] + factories = [PrometheusStatLogger] + if logger.isEnabledFor(logging.INFO): + factories.append(LoggingStatLogger) + + stat_loggers: list[list[StatLoggerBase]] = [] + for i in range(engine_num): + per_engine_stat_loggers: list[StatLoggerBase] = [] + for logger_factory in factories: + per_engine_stat_loggers.append(logger_factory(vllm_config, i)) + stat_loggers.append(per_engine_stat_loggers) + + return stat_loggers diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 6be72431dde52a0f650bb90fcc01092e83f7f02a..3b9b666f936a118e37e9b8161c223f219289c79b 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -20,7 +20,6 @@ class Request: def __init__( self, request_id: str, - prompt: Optional[str], prompt_token_ids: list[int], multi_modal_inputs: Optional[list[MultiModalKwargs]], multi_modal_hashes: Optional[list[str]], @@ -46,7 +45,6 @@ class Request: assert sampling_params.max_tokens is not None self.max_tokens = sampling_params.max_tokens - self.prompt = prompt self.prompt_token_ids = prompt_token_ids self.num_prompt_tokens = len(self.prompt_token_ids) self._output_token_ids: list[int] = [] @@ -81,7 +79,6 @@ class Request: return cls( request_id=request.request_id, - prompt=request.prompt, prompt_token_ids=request.prompt_token_ids, multi_modal_inputs=request.mm_inputs, multi_modal_hashes=request.mm_hashes, diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index f69623edd6321506e2ed3f636b5f0da9de4051d6..745b81ded3f119e88ba39da7fcf4f8e6db9f858b 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -72,14 +72,7 @@ class TopKTopPSampler(nn.Module): "best performance, please install FlashInfer.") self.forward = self.forward_native elif current_platform.is_tpu(): - if envs.VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: - logger.warning( - "TPU-specific optimization for top-k & top-p sampling are " - "disabled, falling back to PyTorch-native implementation " - "which could be very slow.") - self.forward = self.forward_native - else: - self.forward = self.forward_tpu + self.forward = self.forward_tpu else: self.forward = self.forward_native @@ -146,12 +139,22 @@ def apply_top_k_top_p_tpu( chance of being chosen during final sampling, so we can consider the tie being broken then. """ + probs = logits.softmax(dim=-1) + probs_sort, _ = probs.sort(dim=-1, descending=False) + if k is not None: - logits = apply_top_k_only(logits, k) + top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, ) + top_k_count = top_k_count.unsqueeze(dim=1) + top_k_cutoff = probs_sort.gather(-1, top_k_count) + + # Make sure the no top-k rows are no-op. + no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1) + top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf")) + + elements_to_discard = probs < top_k_cutoff + logits.masked_fill_(elements_to_discard, -float("inf")) if p is not None: - probs = logits.softmax(dim=-1) - probs_sort, _ = probs.sort(dim=-1, descending=False) cumprob = torch.cumsum(probs_sort, dim=-1) top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1) top_p_mask[:, -1] = False # at least one @@ -224,7 +227,7 @@ def apply_top_k_only( max_top_k = k.max() # topk.values tensor has shape [batch_size, max_top_k]. # Convert top k to 0-based index in range [0, max_top_k). - k_index = k.sub_(1).unsqueeze(1).expand(logits.shape[0], 1) + k_index = k.sub_(1).unsqueeze(1) top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long()) # Handle non-topk rows. top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf")) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 3cf7fde5cd0ecc5e415c11965de919b6bc043e34..9061a64db57c961cf2d173c9d836fd963617c1f7 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -226,7 +226,7 @@ def rejection_sample( is_greedy, max_spec_len, vocab_size, - IS_NGRAM=draft_probs is None, + NO_DRAFT_PROBS=draft_probs is None, num_warps=1, ) return output_token_ids @@ -423,7 +423,7 @@ def sample_recovered_tokens( q, vocab_size, triton.next_power_of_2(vocab_size), - IS_NGRAM=draft_probs is None, + NO_DRAFT_PROBS=draft_probs is None, ) return recovered_token_ids @@ -490,7 +490,7 @@ def rejection_random_sample_kernel( is_greedy_ptr, # [batch_size] max_spec_len, vocab_size, - IS_NGRAM: tl.constexpr, + NO_DRAFT_PROBS: tl.constexpr, ): req_idx = tl.program_id(0) is_greedy = tl.load(is_greedy_ptr + req_idx) @@ -509,7 +509,7 @@ def rejection_random_sample_kernel( for pos in range(num_draft_tokens): if not rejected: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) - if IS_NGRAM: + if NO_DRAFT_PROBS: draft_prob = 1 else: draft_prob = tl.load(draft_probs_ptr + @@ -575,7 +575,7 @@ def sample_recovered_tokens_kernel( q_ptr, # [batch_size, vocab_size] vocab_size, PADDED_VOCAB_SIZE: tl.constexpr, - IS_NGRAM: tl.constexpr, + NO_DRAFT_PROBS: tl.constexpr, ): req_idx = tl.program_id(0) if req_idx == 0: @@ -591,7 +591,7 @@ def sample_recovered_tokens_kernel( return vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE) - if IS_NGRAM: + if NO_DRAFT_PROBS: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) @@ -624,7 +624,7 @@ def sample_recovered_tokens_kernel( recovered_id = tl.argmax(prob / q, axis=-1) tl.store(output_token_ids_ptr + start_idx + pos, recovered_id) - if IS_NGRAM: + if NO_DRAFT_PROBS: # Restore the original probability. tl.store( target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, diff --git a/vllm/v1/sample/tpu/metadata.py b/vllm/v1/sample/tpu/metadata.py index 3950fda3e5eae33af9a88267c7845241890c39f9..d4ea8c2dee071a63ad26581ee2b5488a721ae494 100644 --- a/vllm/v1/sample/tpu/metadata.py +++ b/vllm/v1/sample/tpu/metadata.py @@ -10,8 +10,8 @@ DEFAULT_SAMPLING_PARAMS = dict( temperature=-1.0, min_p=0.0, # strictly disabled for now - # top_k=-1, - # top_p=0.0, + top_k=0, + top_p=1.0, # frequency_penalties=0.0, # presence_penalties=0.0, # repetition_penalties=0.0, @@ -26,11 +26,9 @@ class TPUSupportedSamplingMetadata: temperature: torch.Tensor = None min_p: torch.Tensor = None - # Still too slow on forward_native! top_k: torch.Tensor = None top_p: torch.Tensor = None - # Greedy sampling flag for compiling single xla graph. all_greedy: bool = True # unsupported, you need to return an extra tensor of static size BxV @@ -99,11 +97,12 @@ class TPUSupportedSamplingMetadata: fill_slice(input_batch.temperature_cpu_tensor, DEFAULT_SAMPLING_PARAMS["temperature"]) - # TODO Temporarily disabled until sampling options are enabled - # fill_slice(input_batch.top_p_cpu_tensor) - # fill_slice(input_batch.top_k_cpu_tensor) fill_slice(input_batch.min_p_cpu_tensor, DEFAULT_SAMPLING_PARAMS["min_p"]) + fill_slice(input_batch.top_k_cpu_tensor, + DEFAULT_SAMPLING_PARAMS["top_k"]) + fill_slice(input_batch.top_p_cpu_tensor, + DEFAULT_SAMPLING_PARAMS["top_p"]) # Slice persistent device tensors to a fixed pre-compiled padded shape. return cls( @@ -111,7 +110,9 @@ class TPUSupportedSamplingMetadata: to(xla_device), all_greedy=input_batch.all_greedy, # TODO enable more and avoid returning None values - top_p=None, # input_batch.top_p[:padded_num_reqs], - top_k=None, # input_batch.top_k[:padded_num_reqs], + top_p=input_batch.top_p_cpu_tensor[:padded_num_reqs].to( + xla_device), + top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to( + xla_device), min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to( xla_device)) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 3af6793fde74c3d194fe4e4c3aad49a7866083c6..a3ad8cb920962fc6f0ecedaab0da8e4aafeb1901 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import dataclasses import pickle from collections.abc import Sequence from inspect import isclass @@ -12,12 +13,26 @@ import torch import zmq from msgspec import msgpack +from vllm import envs +from vllm.multimodal.inputs import (BaseMultiModalField, + MultiModalBatchedField, + MultiModalFieldConfig, MultiModalFieldElem, + MultiModalFlatField, MultiModalKwargs, + MultiModalKwargsItem, + MultiModalSharedField, NestedTensors) + CUSTOM_TYPE_PICKLE = 1 CUSTOM_TYPE_CLOUDPICKLE = 2 CUSTOM_TYPE_RAW_VIEW = 3 -# TODO calibrate this size -MIN_NOCOPY_BUF_SIZE = 512 +# MultiModalField class serialization type map. +# These need to list all possible field types and match them +# to factory methods in `MultiModalFieldConfig`. +MMF_CLASS_TO_FACTORY: dict[type[BaseMultiModalField], str] = { + MultiModalFlatField: "flat", + MultiModalSharedField: "shared", + MultiModalBatchedField: "batched", +} bytestr = Union[bytes, bytearray, memoryview, zmq.Frame] @@ -27,14 +42,20 @@ class MsgpackEncoder: Note that unlike vanilla `msgspec` Encoders, this interface is generally not thread-safe when encoding tensors / numpy arrays. + + By default, arrays below 256B are serialized inline Larger will get sent + via dedicated messages. Note that this is a per-tensor limit. """ - def __init__(self): + def __init__(self, size_threshold: Optional[int] = None): + if size_threshold is None: + size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD self.encoder = msgpack.Encoder(enc_hook=self.enc_hook) # This is used as a local stash of buffers that we can then access from # our custom `msgspec` hook, `enc_hook`. We don't have a way to # pass custom data to the hook otherwise. self.aux_buffers: Optional[list[bytestr]] = None + self.size_threshold = size_threshold def encode(self, obj: Any) -> Sequence[bytestr]: try: @@ -59,12 +80,31 @@ class MsgpackEncoder: def enc_hook(self, obj: Any) -> Any: if isinstance(obj, torch.Tensor): - return self._encode_ndarray(obj.numpy()) + return self._encode_tensor(obj) # Fall back to pickle for object or void kind ndarrays. if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'): return self._encode_ndarray(obj) + if isinstance(obj, MultiModalKwargs): + mm: MultiModalKwargs = obj + if not mm.modalities: + # just return the main dict if there are no modalities. + return dict(mm) + + # ignore the main dict, it will be re-indexed. + # Encode a list of MultiModalKwargsItems as plain dicts + # + special handling for .field. + # Any tensors *not* indexed by modality will be ignored. + return [[{ + "modality": elem.modality, + "key": elem.key, + "data": self._encode_nested_tensors(elem.data), + "field": self._encode_mm_field(elem.field), + } for elem in item.values()] + for itemlist in mm._items_by_modality.values() + for item in itemlist] + if isinstance(obj, FunctionType): # `pickle` is generally faster than cloudpickle, but can have # problems serializing methods. @@ -77,8 +117,9 @@ class MsgpackEncoder: self, obj: np.ndarray ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: assert self.aux_buffers is not None + # If the array is non-contiguous, we need to copy it first arr_data = obj.data if obj.data.c_contiguous else obj.tobytes() - if not obj.shape or obj.nbytes < MIN_NOCOPY_BUF_SIZE: + if not obj.shape or obj.nbytes < self.size_threshold: # Encode small arrays and scalars inline. Using this extension type # ensures we can avoid copying when decoding. data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data) @@ -92,6 +133,44 @@ class MsgpackEncoder: # backing buffers that we've stashed in `aux_buffers`. return obj.dtype.str, obj.shape, data + def _encode_tensor( + self, obj: torch.Tensor + ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: + assert self.aux_buffers is not None + # this creates a copy of the tensor if it's not already contiguous + obj = obj.contiguous() + # view the tensor as a 1D array of bytes + arr = obj.view((obj.numel(), )).view(torch.uint8).numpy() + if obj.nbytes < self.size_threshold: + # Smaller tensors are encoded inline, just like ndarrays. + data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data) + else: + # Otherwise encode index of backing buffer to avoid copy. + data = len(self.aux_buffers) + self.aux_buffers.append(arr.data) + dtype = str(obj.dtype)[6:] # remove 'torch.' prefix + return dtype, obj.shape, data + + def _encode_nested_tensors(self, nt: NestedTensors) -> Any: + if isinstance(nt, torch.Tensor): + return self._encode_tensor(nt) + if isinstance(nt, (int, float)): + # Although it violates NestedTensors type, MultiModalKwargs + # values are sometimes floats. + return nt + return [self._encode_nested_tensors(x) for x in nt] + + def _encode_mm_field(self, field: BaseMultiModalField): + # Figure out the factory name for the field type. + name = MMF_CLASS_TO_FACTORY.get(field.__class__) + if not name: + raise TypeError(f"Unsupported field type: {field.__class__}") + # We just need to copy all of the field values in order + # which will be then used to reconstruct the field. + field_values = (getattr(field, f.name) + for f in dataclasses.fields(field)) + return name, *field_values + class MsgpackDecoder: """Decoder with custom torch tensor and numpy array serialization. @@ -125,14 +204,64 @@ class MsgpackDecoder: if issubclass(t, np.ndarray): return self._decode_ndarray(obj) if issubclass(t, torch.Tensor): - return torch.from_numpy(self._decode_ndarray(obj)) + return self._decode_tensor(obj) + if issubclass(t, MultiModalKwargs): + if isinstance(obj, list): + return MultiModalKwargs.from_items( + self._decode_mm_items(obj)) + return MultiModalKwargs({ + k: self._decode_nested_tensors(v) + for k, v in obj.items() + }) return obj def _decode_ndarray(self, arr: Any) -> np.ndarray: dtype, shape, data = arr + # zero-copy decode. We assume the ndarray will not be kept around, + # as it now locks the whole received message buffer in memory. buffer = self.aux_buffers[data] if isinstance(data, int) else data return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape) + def _decode_tensor(self, arr: Any) -> torch.Tensor: + dtype, shape, data = arr + # Copy from inline representation, to decouple the memory storage + # of the message from the original buffer. And also make Torch + # not complain about a readonly memoryview. + buffer = self.aux_buffers[data] if isinstance(data, int) \ + else bytearray(data) + # Create numpy wrapper around the bytes + arr = np.ndarray(buffer=buffer, dtype=np.uint8, shape=(len(buffer), )) + torch_dtype = getattr(torch, dtype) + assert isinstance(torch_dtype, torch.dtype) + # Convert back to proper shape & type + return torch.from_numpy(arr).view(torch_dtype).view(shape) + + def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]: + decoded_items = [] + for item in obj: + elems = [] + for v in item: + v["data"] = self._decode_nested_tensors(v["data"]) + # Reconstruct the field processor using MultiModalFieldConfig + factory_meth_name, *field_args = v["field"] + factory_meth = getattr(MultiModalFieldConfig, + factory_meth_name) + v["field"] = factory_meth(None, *field_args).field + elems.append(MultiModalFieldElem(**v)) + decoded_items.append(MultiModalKwargsItem.from_elems(elems)) + return decoded_items + + def _decode_nested_tensors(self, obj: Any) -> NestedTensors: + if isinstance(obj, (int, float)): + # Although it violates NestedTensors type, MultiModalKwargs + # values are sometimes floats. + return obj + if not isinstance(obj, list): + raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}") + if obj and isinstance(obj[0], str): + return self._decode_tensor(obj) + return [self._decode_nested_tensors(x) for x in obj] + def ext_hook(self, code: int, data: memoryview) -> Any: if code == CUSTOM_TYPE_RAW_VIEW: return data diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 2322463c0713d2744a3e6a067ae9e035019a6096..1de14584d3968406a77fcb9a268d233cdf718f10 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -6,12 +6,18 @@ import triton.language as tl from vllm.config import VllmConfig, set_current_vllm_config from vllm.forward_context import set_forward_context +from vllm.logger import init_logger from vllm.model_executor.model_loader.loader import get_model_loader from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models.llama_eagle import EagleLlamaForCausalLM +from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.sample.metadata import SamplingMetadata +logger = init_logger(__name__) + +PADDING_SLOT_ID = -1 + class EagleProposer: @@ -23,6 +29,7 @@ class EagleProposer: self.vllm_config = vllm_config self.num_speculative_tokens = ( vllm_config.speculative_config.num_speculative_tokens) + self.max_model_len = vllm_config.model_config.max_model_len self.block_size = vllm_config.cache_config.block_size # We need +1 here because the arange is used to set query_start_loc, # which has one more element than batch_size. @@ -48,7 +55,7 @@ class EagleProposer: # [batch_size, max_num_blocks_per_req] block_table: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] last_token_indices = cu_num_tokens[1:] - 1 @@ -84,27 +91,25 @@ class EagleProposer: ) with set_forward_context(attn_metadata, self.vllm_config): - hidden_states = self.model( + hidden_states_logits, hidden_states_fwd = self.model( input_ids=input_ids, hidden_states=target_hidden_states, positions=target_positions, ) - sample_hidden_states = hidden_states[last_token_indices] + sample_hidden_states = hidden_states_logits[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) - draft_token_ids, draft_probs = compute_probs_and_sample_next_token( - logits, sampling_metadata) + draft_token_ids = logits.argmax(dim=-1) # Early exit if there is only one draft token to be generated. if self.num_speculative_tokens == 1: - # [batch_size, 1] and [batch_size, 1, vocab_size] - return draft_token_ids.view(-1, 1), draft_probs.unsqueeze(dim=1) + # [batch_size, 1] + return draft_token_ids.view(-1, 1) # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] - draft_probs_list = [draft_probs] positions = target_positions[last_token_indices] - hidden_states = sample_hidden_states + hidden_states = hidden_states_fwd[last_token_indices] attn_metadata.num_actual_tokens = batch_size attn_metadata.max_query_len = 1 attn_metadata.query_start_loc = self.arange[:batch_size + 1] @@ -112,34 +117,56 @@ class EagleProposer: # Update the inputs. input_ids = draft_token_ids_list[-1] positions += 1 + + # NOTE(woosuk): We should handle the case where the draft model + # generates tokens beyond the max model length. Since it is complex + # to remove such requests from the batch, we keep them in the batch + # but adjust the position ids and slot mappings to avoid the + # out-of-range access during the model execution. The draft tokens + # generated with this adjustment should be ignored. + exceeds_max_model_len = positions >= self.max_model_len + # Mask out the position ids that exceed the max model length. + # Otherwise, we may get out-of-range error in RoPE. + clamped_positions = torch.where(exceeds_max_model_len, 0, + positions) + + # Increment the sequence lengths. attn_metadata.max_seq_len += 1 attn_metadata.seq_lens += 1 + # Consider max model length. + attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, + self.max_model_len) + # For the requests that exceed the max model length, we set the + # sequence length to 1 to minimize their overheads in attention. + attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) + # Compute the slot mapping. - block_numbers = positions // self.block_size + block_numbers = clamped_positions // self.block_size block_ids = block_table.gather(dim=1, index=block_numbers.view(-1, 1)) block_ids = block_ids.view(-1) attn_metadata.slot_mapping = (block_ids * self.block_size + - positions % self.block_size) + clamped_positions % self.block_size) + # Mask out the slot mappings that exceed the max model length. + # Otherwise, the KV cache will be inadvertently updated with the + # padding tokens. + attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len, + PADDING_SLOT_ID) # Run the model. with set_forward_context(attn_metadata, self.vllm_config): - hidden_states = self.model( + hidden_states_logits, hidden_states = self.model( input_ids=input_ids, hidden_states=hidden_states, - positions=positions, + positions=clamped_positions, ) - logits = self.model.compute_logits(hidden_states, None) - draft_token_ids, probs = compute_probs_and_sample_next_token( - logits, sampling_metadata) + logits = self.model.compute_logits(hidden_states_logits, None) + draft_token_ids = logits.argmax(dim=-1) draft_token_ids_list.append(draft_token_ids) - draft_probs_list.append(probs) # [batch_size, num_speculative_tokens] draft_token_ids = torch.stack(draft_token_ids_list, dim=1) - # [batch_size, num_speculative_tokens, vocab_size] - draft_probs = torch.stack(draft_probs_list, dim=1) - return draft_token_ids, draft_probs + return draft_token_ids @staticmethod def prepare_inputs( @@ -198,17 +225,34 @@ class EagleProposer: with set_default_torch_dtype( draft_model_config.dtype), set_current_vllm_config( self.vllm_config): - self.model = EagleLlamaForCausalLM( - model_config=draft_model_config, - start_layer_id=target_layer_num).to(target_device) - - self.model.load_weights( + if self.vllm_config.speculative_config.method == "eagle": + self.model = EagleLlamaForCausalLM( + model_config=draft_model_config, + start_layer_id=target_layer_num).to(target_device) + else: + assert self.vllm_config.speculative_config.method == "eagle3" + self.model = Eagle3LlamaForCausalLM( + model_config=draft_model_config, + start_layer_id=target_layer_num).to(target_device) + + loaded_weights = self.model.load_weights( loader.get_all_weights( self.vllm_config.speculative_config.draft_model_config, self.model)) - self.model.lm_head = target_model.lm_head - - + if self.vllm_config.speculative_config.method == "eagle3": + if "model.embed_tokens.weight" not in loaded_weights: + logger.info( + "Loading EAGLE embedding weights from the target model.") + self.model.model.embed_tokens = target_model.model.embed_tokens + else: + logger.info("Loading EAGLE LM head weights from the target model.") + self.model.lm_head = target_model.lm_head + + +# NOTE(woosuk): Currently, the below code is not used and we always use argmax +# to sample the draft tokens. We will use this after we find a way to manage +# the draft prob tensor. +# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details. # FIXME(woosuk): The logic here is duplicated with the main sampling code. # We should refactor this to reuse the same sampling implementation. def compute_probs_and_sample_next_token( @@ -235,7 +279,9 @@ def compute_probs_and_sample_next_token( # TODO(woosuk): Consider seeds. q = torch.empty_like(probs) q.exponential_() - next_token_ids = probs.div_(q).argmax(dim=-1).view(-1) + # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs + # will be used later for rejection sampling. + next_token_ids = probs.div(q).argmax(dim=-1).view(-1) if not sampling_metadata.all_random: greedy_token_ids = probs.argmax(dim=-1) next_token_ids = torch.where( diff --git a/vllm/v1/spec_decode/metrics.py b/vllm/v1/spec_decode/metrics.py index 7bb3c209d1dcb331ff830aac8208813217e591f9..33ce98284e20df1c1f904d23fd9ede3cce2abd39 100644 --- a/vllm/v1/spec_decode/metrics.py +++ b/vllm/v1/spec_decode/metrics.py @@ -1,9 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Optional import numpy as np +import prometheus_client +from vllm.config import SpeculativeConfig from vllm.logger import init_logger logger = init_logger(__name__) @@ -11,52 +14,151 @@ logger = init_logger(__name__) @dataclass class SpecDecodingStats: + """Per-step iteration decoding stats from scheduler. + + Each scheduler step, statistics on spec decoding performance are + aggregated across requests by the scheduler and returned to the + frontend in EngineCoreOutputs->SchedulerStats. + """ + + num_spec_tokens: int + num_drafts: int = 0 num_draft_tokens: int = 0 num_accepted_tokens: int = 0 + num_accepted_tokens_per_pos: list[int] = field(default_factory=list) - def take(self): - copied = SpecDecodingStats(self.num_draft_tokens, - self.num_accepted_tokens) - self.reset() - return copied - - def reset(self): - self.num_draft_tokens = 0 - self.num_accepted_tokens = 0 + @classmethod + def new(cls, num_spec_tokens: int) -> "SpecDecodingStats": + return cls(num_spec_tokens=num_spec_tokens, + num_accepted_tokens_per_pos=[0] * num_spec_tokens) - def observe(self, num_draft_tokens: int, num_accepted_tokens: int): + def observe_draft(self, num_draft_tokens: int, num_accepted_tokens: int): + self.num_drafts += 1 self.num_draft_tokens += num_draft_tokens self.num_accepted_tokens += num_accepted_tokens + assert num_accepted_tokens <= self.num_spec_tokens + for i in range(num_accepted_tokens): + self.num_accepted_tokens_per_pos[i] += 1 + +class SpecDecodingLogging: + """Aggregate and log spec decoding metrics. -class SpecDecodingMetrics: + LoggingStatLogger aggregates per-iteration metrics over a set + time interval using observe() and then logs them using log() + before resetting to zero. + """ def __init__(self): self.reset() def reset(self): + self.num_drafts: list[int] = [] self.num_draft_tokens: list[int] = [] self.num_accepted_tokens: list[int] = [] + self.accepted_tokens_per_pos_lists: list[list[int]] = [] def observe(self, spec_decoding_stats: SpecDecodingStats): + self.num_drafts.append(spec_decoding_stats.num_drafts) self.num_draft_tokens.append(spec_decoding_stats.num_draft_tokens) self.num_accepted_tokens.append( spec_decoding_stats.num_accepted_tokens) + self.accepted_tokens_per_pos_lists.append( + spec_decoding_stats.num_accepted_tokens_per_pos) - def log(self): + def log(self, log_fn=logger.info): + num_drafts = np.sum(self.num_drafts) num_draft_tokens = np.sum(self.num_draft_tokens) num_accepted_tokens = np.sum(self.num_accepted_tokens) draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens * 100 if num_draft_tokens > 0 else float("nan")) + mean_acceptance_length = (num_accepted_tokens / num_drafts) - logger.info( + pos_matrix = np.array(self.accepted_tokens_per_pos_lists) + acceptance_rates = np.sum(pos_matrix, axis=0) / num_drafts + rates_str = ", ".join(f"{p:.3f}" for p in acceptance_rates) + + log_fn( "SpecDecoding metrics: " "Draft acceptance rate: %.1f%%, " + "Mean acceptance length: %.2f, " "Accepted: %d tokens, " - "Drafted: %d tokens", + "Drafted: %d tokens, " + "Per-position acceptance rate: %s", draft_acceptance_rate, + mean_acceptance_length, num_accepted_tokens, num_draft_tokens, + rates_str, ) self.reset() + + +class SpecDecodingProm: + """Record spec decoding metrics in Prometheus. + + The acceptance rate can be calculated using a PromQL query: + + rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) / + rate(vllm:spec_decode_num_draft_tokens_total[$interval]) + + The mean acceptance length can be calculated using: + + rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) / + rate(vllm:spec_decode_num_drafts[$interval]) + + A per-position acceptance rate vector can be computed using + + vllm:spec_decode_num_accepted_tokens_per_pos[$interval] / + vllm:spec_decode_num_drafts[$interval] + """ + + def __init__(self, speculative_config: Optional[SpeculativeConfig], + labelnames: list[str], labelvalues: list[str]): + self.spec_decoding_enabled = speculative_config is not None + if not self.spec_decoding_enabled: + return + + self.counter_spec_decode_num_drafts = \ + prometheus_client.Counter( + name="vllm:spec_decode_num_drafts_total", + documentation="Number of spec decoding drafts.", + labelnames=labelnames).labels(*labelvalues) + self.counter_spec_decode_num_draft_tokens = \ + prometheus_client.Counter( + name="vllm:spec_decode_num_draft_tokens_total", + documentation="Number of draft tokens.", + labelnames=labelnames).labels(*labelvalues) + self.counter_spec_decode_num_accepted_tokens = \ + prometheus_client.Counter( + name="vllm:spec_decode_num_accepted_tokens_total", + documentation="Number of accepted tokens.", + labelnames=labelnames).labels(*labelvalues) + + assert speculative_config is not None + num_spec_tokens = (speculative_config.num_speculative_tokens + if self.spec_decoding_enabled else 0) + pos_labelnames = labelnames + ["position"] + base_counter = prometheus_client.Counter( + name="vllm:spec_decode_num_accepted_tokens_per_pos", + documentation="Accepted tokens per draft position.", + labelnames=pos_labelnames) + self.counter_spec_decode_num_accepted_tokens_per_pos: \ + list[prometheus_client.Counter] = [] + for pos in range(num_spec_tokens): + pos_labelvalues = labelvalues + [str(pos)] + self.counter_spec_decode_num_accepted_tokens_per_pos.append( + base_counter.labels(*pos_labelvalues)) + + def observe(self, spec_decoding_stats: SpecDecodingStats): + if not self.spec_decoding_enabled: + return + self.counter_spec_decode_num_drafts.inc(spec_decoding_stats.num_drafts) + self.counter_spec_decode_num_draft_tokens.inc( + spec_decoding_stats.num_draft_tokens) + self.counter_spec_decode_num_accepted_tokens.inc( + spec_decoding_stats.num_accepted_tokens) + for pos, counter in enumerate( + self.counter_spec_decode_num_accepted_tokens_per_pos): + counter.inc(spec_decoding_stats.num_accepted_tokens_per_pos[pos]) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 7e548bb48b57ce9607a92a184e02585ec295acbd..704153d43a2b48033d29eeefe9b82c86f472f9a7 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -18,6 +18,9 @@ class NgramProposer: # tokens follow the match, we will return the maximum amount of # tokens until the end. self.k = vllm_config.speculative_config.num_speculative_tokens + # Maximum length of the model. + self.max_model_len = vllm_config.model_config.max_model_len + # Trigger Numba JIT compilation for N-gram proposer. # This usually takes less than 1 second. self.propose(np.zeros(1024, dtype=np.int32)) @@ -50,9 +53,14 @@ class NgramProposer: followed that pattern. Here we will return [4,2,3] because we only have three tokens after the match. """ + # Do not generate draft tokens beyond the max model length. + k = min(self.k, self.max_model_len - context_token_ids.shape[0]) + if k <= 0: + return None + # TODO(woosuk): Optimize this. for n in range(self.max_n, self.min_n - 1, -1): - result = _find_subarray_kmp(context_token_ids, n, self.k) + result = _find_subarray_kmp(context_token_ids, n, k) if result is not None: return result return None diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 218af43deb6777e146149a519a9d0dd851402bd2..0fd66c0729602bb7f8c98c02949cc59a54700ff0 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -107,3 +107,7 @@ class StructuredOutputManager: # np.ndarray, because that is much more efficient for serialization # and deserialization when sending this to the GPU workers. return bitmask_tensor.numpy() + + def clear_backend(self) -> None: + if self.backend is not None: + self.backend.destroy() diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py index 9150a28570bdd430c3c460080ac8737f7a389524..1453e284b0132c999ef59f484f8b92f4972d72a9 100644 --- a/vllm/v1/structured_output/backend_guidance.py +++ b/vllm/v1/structured_output/backend_guidance.py @@ -1,14 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 +import copy +import json import os from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Optional, Union import torch from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.utils import LazyLoader from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, @@ -29,6 +31,29 @@ else: logger = init_logger(__name__) +def _walk_json_for_additional_properties(data: object): + if isinstance(data, dict): + for value in data.values(): + _walk_json_for_additional_properties(value) + if 'additionalProperties' not in data and \ + ('properties' in data or 'patternProperties' in data): + data['additionalProperties'] = False + elif isinstance(data, list): + for item in data: + _walk_json_for_additional_properties(item) + + +def process_for_additional_properties( + guide_json: Union[str, dict[str, Any]]) -> dict[str, Any]: + if isinstance(guide_json, str): + guide_json_obj = json.loads(guide_json) + else: + # copy for modifications + guide_json_obj = copy.deepcopy(guide_json) + _walk_json_for_additional_properties(guide_json_obj) + return guide_json_obj + + class GuidanceBackend(StructuredOutputBackend): def __init__(self, vllm_config: VllmConfig): @@ -36,14 +61,23 @@ class GuidanceBackend(StructuredOutputBackend): tokenizer_group = init_tokenizer_from_configs( model_config=vllm_config.model_config, scheduler_config=vllm_config.scheduler_config, - parallel_config=vllm_config.parallel_config, lora_config=vllm_config.lora_config) # type: ignore[arg-type] - tokenizer_group.ping() self.vllm_config = vllm_config self.vocab_size = vllm_config.model_config.get_vocab_size() - self.disable_any_whitespace = ( - "disable-any-whitespace" - in vllm_config.decoding_config.guided_decoding_backend) + + self.disable_any_whitespace = False + self.no_additional_properties = False + backend_options = GuidedDecodingParams( + backend=vllm_config.decoding_config.guided_decoding_backend + ).backend_options() + for option in backend_options: + if option == "disable-any-whitespace": + self.disable_any_whitespace = True + elif option == "no-additional-properties": + self.no_additional_properties = True + else: + raise ValueError( + f"Unsupported option for the guidance backend: {option}") tokenizer = tokenizer_group.get_lora_tokenizer(None) self.ll_tokenizer = llguidance_hf.from_tokenizer( @@ -52,7 +86,8 @@ class GuidanceBackend(StructuredOutputBackend): def compile_grammar(self, request_type: StructuredOutputOptions, grammar_spec: str) -> StructuredOutputGrammar: self.serialized_grammar = serialize_guidance_grammar( - request_type, grammar_spec, self.disable_any_whitespace) + request_type, grammar_spec, self.disable_any_whitespace, + self.no_additional_properties) ll_matcher = llguidance.LLMatcher( self.ll_tokenizer, @@ -73,6 +108,9 @@ class GuidanceBackend(StructuredOutputBackend): return llguidance_torch.allocate_token_bitmask( max_num_seqs, self.ll_tokenizer.vocab_size) + def destroy(self): + pass + @dataclass class GuidanceGrammar(StructuredOutputGrammar): @@ -129,10 +167,15 @@ class GuidanceGrammar(StructuredOutputGrammar): self.ll_matcher.reset() -def serialize_guidance_grammar(request_type: StructuredOutputOptions, - grammar_spec: str, - disable_any_whitespace: bool = False) -> str: +def serialize_guidance_grammar( + request_type: StructuredOutputOptions, + grammar_spec: Union[str, dict[str, Any]], + disable_any_whitespace: bool = False, + no_additional_properties: bool = False, +) -> str: if request_type == StructuredOutputOptions.JSON: + if no_additional_properties: + grammar_spec = process_for_additional_properties(grammar_spec) return llguidance.LLMatcher.grammar_from_json_schema( grammar_spec, defaults={ @@ -151,6 +194,9 @@ def serialize_guidance_grammar(request_type: StructuredOutputOptions, tp = "grammar" elif request_type == StructuredOutputOptions.CHOICE: tp = "choice" + elif request_type == StructuredOutputOptions.STRUCTURAL_TAG: + raise ValueError("Structural tag is not supported " + "for guidance backend yet") else: logger.error("Validation should have already occurred. " "Please file an issue.") diff --git a/vllm/v1/structured_output/backend_types.py b/vllm/v1/structured_output/backend_types.py index 6dc2a92411de0552be11b9f5941a9e158a09b174..6330bcbf20c35836febcdbac78ca7b5e960e4221 100644 --- a/vllm/v1/structured_output/backend_types.py +++ b/vllm/v1/structured_output/backend_types.py @@ -12,6 +12,7 @@ class StructuredOutputOptions(enum.Enum): REGEX = enum.auto() GRAMMAR = enum.auto() CHOICE = enum.auto() + STRUCTURAL_TAG = enum.auto() StructuredOutputKey = tuple[StructuredOutputOptions, str] @@ -87,3 +88,9 @@ class StructuredOutputBackend(ABC): max_num_seqs (int): The maximum number of sequences for which to allocate the bitmask. """ + + @abstractmethod + def destroy(self): + """ + Backend-specific cleanup. + """ diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index 83f2c6436ed2cf25c17885ee34671ea66e8807a4..ecaeb6e4ee8064a00c334feccb0a1ac8e0d7fed1 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -1,19 +1,24 @@ # SPDX-License-Identifier: Apache-2.0 +import json from dataclasses import dataclass, field -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import torch import vllm.envs from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.utils import LazyLoader from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, StructuredOutputGrammar, StructuredOutputOptions) +from vllm.v1.structured_output.utils import (choice_as_grammar, + convert_lark_to_ebnf, + grammar_is_likely_lark) if TYPE_CHECKING: import xgrammar as xgr @@ -27,15 +32,21 @@ class XgrammarBackend(StructuredOutputBackend): def __init__(self, vllm_config: VllmConfig): self.vllm_config = vllm_config - self.disable_any_whitespace = ( - "disable-any-whitespace" - in vllm_config.decoding_config.guided_decoding_backend) tokenizer_group = init_tokenizer_from_configs( model_config=vllm_config.model_config, scheduler_config=vllm_config.scheduler_config, - parallel_config=vllm_config.parallel_config, lora_config=vllm_config.lora_config) # type: ignore[arg-type] - tokenizer_group.ping() + + self.disable_any_whitespace = False + backend_options = GuidedDecodingParams( + backend=vllm_config.decoding_config.guided_decoding_backend + ).backend_options() + for option in backend_options: + if option == "disable-any-whitespace": + self.disable_any_whitespace = True + else: + raise ValueError( + f"Unsupported option for the xgrammar backend: {option}") tokenizer = tokenizer_group.get_lora_tokenizer(None) self.vocab_size = vllm_config.model_config.get_vocab_size() @@ -97,6 +108,16 @@ class XgrammarBackend(StructuredOutputBackend): ctx = self.compiler.compile_grammar(grammar_spec) elif request_type == StructuredOutputOptions.REGEX: ctx = self.compiler.compile_regex(grammar_spec) + elif request_type == StructuredOutputOptions.STRUCTURAL_TAG: + s_tag = json.loads(grammar_spec) + tags = [ + xgr.StructuralTagItem( + begin=s["begin"], + schema=json.dumps(s["schema"]), + end=s["end"], + ) for s in s_tag["structures"] + ] + ctx = self.compiler.compile_structural_tag(tags, s_tag["triggers"]) else: logger.error( "Validation should have already occurred. Please file an issue." @@ -113,6 +134,9 @@ class XgrammarBackend(StructuredOutputBackend): def allocate_token_bitmask(self, max_num_seqs: int): return xgr.allocate_token_bitmask(max_num_seqs, self.vocab_size) + def destroy(self): + del self.compiler + @dataclass class XgrammarGrammar(StructuredOutputGrammar): @@ -156,3 +180,120 @@ class XgrammarGrammar(StructuredOutputGrammar): def reset(self): self.num_processed_tokens = 0 self.matcher.reset() + + +def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool: + """Check if JSON schema contains features unsupported by xgrammar.""" + + def check_object(obj: dict[str, Any]) -> bool: + if not isinstance(obj, dict): + return False + + # Check for numeric ranges + if obj.get("type") in ("integer", "number") and ("multipleOf" in obj): + return True + + # Check for array unsupported keywords + if obj.get("type") == "array" and any( + key in obj + for key in ("uniqueItems", "contains", "minContains", + "maxContains", "minItems", "maxItems")): + return True + + # Unsupported keywords for strings + if obj.get("type") == "string" and "format" in obj: + return True + + # Unsupported keywords for objects + if obj.get("type") == "object" and any( + key in obj for key in ("minProperties", "maxProperties", + "propertyNames", "patternProperties")): + return True + + # Recursively check all nested objects and arrays + for value in obj.values(): + if isinstance(value, dict): + if check_object(value): + return True + elif isinstance(value, list): + for item in value: + if isinstance(item, dict) and check_object(item): + return True + + return False + + return check_object(schema) + + +def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None: + """Validate that the request is supported by structured output. + + Raises ValueError if the request is not supported. + """ + if sampling_params.guided_decoding is None: + return + + gd_params = sampling_params.guided_decoding + + if gd_params.regex: + try: + xgr.Grammar.from_regex(gd_params.regex) + except Exception as err: + raise ValueError("Failed to transform regex into a grammar: " + f"{err}") from err + + if gd_params.choice: + choice_grammar = choice_as_grammar(gd_params.choice) + try: + xgr.Grammar.from_ebnf(choice_grammar) + except Exception as err: + raise ValueError("Failed to transform choices into a grammar: " + "{err}") from err + gd_params.choice = None + gd_params.grammar = choice_grammar + return + + if gd_params.json: + if isinstance(gd_params.json, str): + try: + schema = json.loads(gd_params.json) + except json.JSONDecodeError as e: + raise ValueError("Invalid JSON grammar specification.") from e + else: + schema = gd_params.json + + if has_xgrammar_unsupported_json_features(schema): + raise ValueError("The provided JSON schema contains features not " + "supported by xgrammar.") + return + + if gd_params.grammar: + if grammar_is_likely_lark(gd_params.grammar): + # xgrammar supports EBNF grammars only + try: + gd_params.grammar = convert_lark_to_ebnf(gd_params.grammar) + except ValueError as e: + raise ValueError( + "Failed to convert the grammar from Lark to EBNF. ") from e + + # Test parsing EBNF grammar, possibly already converted from Lark + try: + # parse the grammar, but we aren't compiling it. + xgr.Grammar.from_ebnf(gd_params.grammar) + except Exception as e: + raise ValueError("Invalid grammar specification.") from e + return + + if gd_params.structural_tag: + try: + s_tag = json.loads(gd_params.structural_tag) + tags = [ + xgr.StructuralTagItem( + begin=s["begin"], + schema=json.dumps(s["schema"]), + end=s["end"], + ) for s in s_tag["structures"] + ] + xgr.Grammar.from_structural_tag(tags, s_tag["triggers"]) + except Exception as e: + raise ValueError("Invalid structural tag specification.") from e diff --git a/vllm/v1/structured_output/request.py b/vllm/v1/structured_output/request.py index 9e54b8bf028db8140e932b2c9566e172a5b7a0d9..6ef472eb896c612c73b0ce30d81d28e71eea204c 100644 --- a/vllm/v1/structured_output/request.py +++ b/vllm/v1/structured_output/request.py @@ -78,5 +78,7 @@ def get_structured_output_key( return (StructuredOutputOptions.CHOICE, json_str) elif params.grammar is not None: return (StructuredOutputOptions.GRAMMAR, params.grammar) + elif params.structural_tag is not None: + return (StructuredOutputOptions.STRUCTURAL_TAG, params.structural_tag) else: raise ValueError("No valid structured output parameter found") diff --git a/vllm/v1/structured_output/utils.py b/vllm/v1/structured_output/utils.py index 56eed95944e2f6d162ee591452988b536f5c30be..f33f4972e1032297f49ca71a37259c8eb60e7e53 100644 --- a/vllm/v1/structured_output/utils.py +++ b/vllm/v1/structured_output/utils.py @@ -2,67 +2,7 @@ from __future__ import annotations -import json import re -from typing import TYPE_CHECKING, Any - -from vllm.sampling_params import SamplingParams -from vllm.utils import LazyLoader - -if TYPE_CHECKING: - import xgrammar as xgr -else: - xgr = LazyLoader("xgr", globals(), "xgrammar") - - -def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool: - """Check if JSON schema contains features unsupported by xgrammar.""" - - def check_object(obj: dict[str, Any]) -> bool: - if not isinstance(obj, dict): - return False - - # Check for pattern restrictions - if "pattern" in obj: - return True - - # Check for numeric ranges - if obj.get("type") in ("integer", "number") and any( - key in obj - for key in ("minimum", "maximum", "exclusiveMinimum", - "exclusiveMaximum", "multipleOf")): - return True - - # Check for array unsupported keywords - if obj.get("type") == "array" and any( - key in obj - for key in ("uniqueItems", "contains", "minContains", - "maxContains", "minItems", "maxItems")): - return True - - # Unsupported keywords for strings - if obj.get("type") == "string" and "format" in obj: - return True - - # Unsupported keywords for objects - if obj.get("type") == "object" and any( - key in obj for key in ("minProperties", "maxProperties", - "propertyNames", "patternProperties")): - return True - - # Recursively check all nested objects and arrays - for value in obj.values(): - if isinstance(value, dict): - if check_object(value): - return True - elif isinstance(value, list): - for item in value: - if isinstance(item, dict) and check_object(item): - return True - - return False - - return check_object(schema) def grammar_is_likely_lark(grammar_str: str) -> bool: @@ -232,63 +172,3 @@ def choice_as_grammar(choice: list[str]) -> str: escaped_choices = (escape_ebnf_string(c) for c in choice) grammar = ('root ::= ' + ' | '.join(f'"{c}"' for c in escaped_choices)) return grammar - - -def validate_structured_output_request_xgrammar( - sampling_params: SamplingParams) -> None: - """Validate that the request is supported by structured output. - - Raises ValueError if the request is not supported. - """ - if sampling_params.guided_decoding is None: - return - - gd_params = sampling_params.guided_decoding - - if gd_params.regex: - try: - xgr.Grammar.from_regex(gd_params.regex) - except Exception as err: - raise ValueError("Failed to transform regex into a grammar: " - f"{err}") from err - - if gd_params.choice: - choice_grammar = choice_as_grammar(gd_params.choice) - try: - xgr.Grammar.from_ebnf(choice_grammar) - except Exception as err: - raise ValueError("Failed to transform choices into a grammar: " - "{err}") from err - gd_params.choice = None - gd_params.grammar = choice_grammar - return - - if gd_params.json: - if isinstance(gd_params.json, str): - try: - schema = json.loads(gd_params.json) - except json.JSONDecodeError as e: - raise ValueError("Invalid JSON grammar specification.") from e - else: - schema = gd_params.json - - if has_xgrammar_unsupported_json_features(schema): - raise ValueError("The provided JSON schema contains features not " - "supported by xgrammar.") - return - - if gd_params.grammar: - if grammar_is_likely_lark(gd_params.grammar): - # xgrammar supports EBNF grammars only - try: - gd_params.grammar = convert_lark_to_ebnf(gd_params.grammar) - except ValueError as e: - raise ValueError( - "Failed to convert the grammar from Lark to EBNF. ") from e - - # Test parsing EBNF grammar, possibly already converted from Lark - try: - # parse the grammar, but we aren't compiling it. - xgr.Grammar.from_ebnf(gd_params.grammar) - except Exception as e: - raise ValueError("Invalid grammar specification.") from e diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index d00ff815d7dc410cb86e58a0d306ce0be78be9e0..9c238c3aad8e59b57dc3ba89ce6feef2342b3c80 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -12,6 +12,8 @@ import torch from vllm.logger import init_logger from vllm.model_executor.models.utils import extract_layer_index +from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, + usage_message) from vllm.utils import get_mp_context, kill_process_tree if TYPE_CHECKING: @@ -134,8 +136,8 @@ def shutdown(proc: Process, input_path: str, output_path: str): proc.terminate() proc.join(5) - if proc.is_alive(): - kill_process_tree(proc.pid) + if proc.is_alive() and (pid := proc.pid) is not None: + kill_process_tree(pid) # Remove zmq ipc socket files. ipc_sockets = [output_path, input_path] @@ -200,4 +202,48 @@ def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor, Returns the sliced target tensor. """ - return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True) \ No newline at end of file + return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True) + + +def report_usage_stats( + vllm_config, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT) -> None: + """Report usage statistics if enabled.""" + + if not is_usage_stats_enabled(): + return + + from vllm.model_executor.model_loader import get_architecture_class_name + + usage_message.report_usage( + get_architecture_class_name(vllm_config.model_config), + usage_context, + extra_kvs={ + # Common configuration + "dtype": + str(vllm_config.model_config.dtype), + "tensor_parallel_size": + vllm_config.parallel_config.tensor_parallel_size, + "block_size": + vllm_config.cache_config.block_size, + "gpu_memory_utilization": + vllm_config.cache_config.gpu_memory_utilization, + + # Quantization + "quantization": + vllm_config.model_config.quantization, + "kv_cache_dtype": + str(vllm_config.cache_config.cache_dtype), + + # Feature flags + "enable_lora": + bool(vllm_config.lora_config), + "enable_prompt_adapter": + bool(vllm_config.prompt_adapter_config), + "enable_prefix_caching": + vllm_config.cache_config.enable_prefix_caching, + "enforce_eager": + vllm_config.model_config.enforce_eager, + "disable_custom_all_reduce": + vllm_config.parallel_config.disable_custom_all_reduce, + }) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index a64cb97e0123f860875c5149f85aa188a05f080d..c00424dfea73b46d0775d885779d510ec6c8e440 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -24,7 +24,6 @@ class CachedRequestState: req_id: str prompt_token_ids: list[int] - prompt: Optional[str] mm_inputs: list[MultiModalKwargs] mm_positions: list[PlaceholderRange] sampling_params: SamplingParams diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 70e8bd75ec94e981e084804ee09494acccf710ba..e3d8b94fe9d7eccb5abfc11d271082d28d3efda9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -12,11 +12,13 @@ import torch.nn as nn from vllm.attention import AttentionType, get_attn_backend from vllm.attention.layer import Attention -from vllm.config import CompilationLevel, VllmConfig +from vllm.config import (CompilationLevel, VllmConfig, + get_layers_from_vllm_config) +from vllm.distributed.kv_transfer import (get_kv_transfer_group, + has_kv_transfer_group) from vllm.distributed.parallel_state import get_pp_group, graph_capture from vllm.forward_context import set_forward_context from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY @@ -36,6 +38,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler +from vllm.v1.sample.sampler import Sampler from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer @@ -151,6 +154,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.max_num_encoder_input_tokens = encoder_compute_budget self.encoder_cache_size = encoder_cache_size + # Sampler + self.sampler = Sampler() + # Lazy initialization # self.model: nn.Module # Set after load_model self.kv_caches: list[torch.Tensor] = [] @@ -159,14 +165,17 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Set up speculative decoding. self.use_spec_decode = False + self.use_aux_hidden_state_outputs = False if self.speculative_config: self.use_spec_decode = True if get_pp_group().is_last_rank: if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) - elif self.speculative_config.method == "eagle": + elif self.speculative_config.use_eagle(): self.drafter = EagleProposer(self.vllm_config, self.device) # type: ignore + if self.speculative_config.method == "eagle3": + self.use_aux_hidden_state_outputs = True else: raise ValueError("Unknown speculative decoding method: " f"{self.speculative_config.method}") @@ -239,10 +248,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): device=self.device) # OPTIMIZATION: Cache the tensors rather than creating them every step. + # Keep in int64 to avoid overflow with long context self.arange_np = np.arange(max(self.max_num_reqs + 1, self.max_model_len, self.max_num_tokens), - dtype=np.int32) + dtype=np.int64) # NOTE(woosuk): These tensors are "stateless", i.e., they are literally # a faster version of creating a new tensor every time. Thus, we should # not make any assumptions about the values in these tensors. @@ -337,7 +347,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, - prompt=new_req_data.prompt, mm_inputs=new_req_data.mm_inputs, mm_positions=new_req_data.mm_positions, sampling_params=sampling_params, @@ -353,6 +362,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): image_grid_thw = [] video_grid_thw = [] second_per_grid_ts = [] + audio_feature_lengths = [] + use_audio_in_video = False for mm_input in self.requests[req_id].mm_inputs: if mm_input.get("image_grid_thw") is not None: image_grid_thw.extend( @@ -363,6 +374,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): if mm_input.get("second_per_grid_ts") is not None: second_per_grid_ts.extend( mm_input["second_per_grid_ts"]) + if mm_input.get("audio_feature_lengths") is not None: + audio_feature_lengths.extend( + mm_input["audio_feature_lengths"]) + if mm_input.get("use_audio_in_video") is True: + use_audio_in_video = True hf_config = self.model_config.hf_config @@ -374,6 +390,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, ) req_ids_to_add.append(req_id) @@ -443,7 +461,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. - removed_req_indices = sorted(removed_req_indices, reverse=True) + removed_req_indices.sort(reverse=True) for req_id in req_ids_to_add: req_state = self.requests[req_id] if removed_req_indices: @@ -458,7 +476,13 @@ class GPUModelRunner(LoRAModelRunnerMixin): if removed_req_indices: self.input_batch.condense(removed_req_indices) - if batch_changed: + # Some attention backends (namely MLA) may want to separate requests + # based on if the attention computation will be compute-bound or + # memory-bound. This gives them a hook to do that. + batch_reordered = self.attn_metadata_builder.reorder_batch( + self.input_batch, scheduler_output) + + if batch_changed or batch_reordered: self.input_batch.refresh_sampling_metadata() def _prepare_inputs( @@ -471,14 +495,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_reqs = self.input_batch.num_reqs assert num_reqs > 0 - # Some attention backends (namely MLA) may want to separate requests - # based on if the attention computation will be compute-bound or - # memory-bound. This gives them a hook to do that. - modified_batch = self.attn_metadata_builder.reorder_batch( - self.input_batch, scheduler_output) - if modified_batch: - self.input_batch.refresh_sampling_metadata() - # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations. self.input_batch.block_table.commit(num_reqs) @@ -540,9 +556,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): # because M (max_model_len) is not necessarily divisible by block_size. block_table_indices = (req_indices * self.max_num_blocks_per_req + positions_np // self.block_size) - # NOTE(woosuk): We use torch.index_select instead of np.take here - # because torch.index_select is much faster than np.take for large - # tensors. block_table_cpu = self.input_batch.block_table.get_cpu_tensor() block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() block_offsets = positions_np % self.block_size @@ -690,7 +703,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # common_prefix_len should be a multiple of the block size. common_prefix_len = (common_prefix_len // self.block_size * self.block_size) - use_cascade = self.attn_backend.use_cascade_attention( + use_cascade = self.attn_metadata_builder.use_cascade_attention( common_prefix_len=common_prefix_len, query_lens=num_scheduled_tokens, num_query_heads=self.num_query_heads, @@ -992,18 +1005,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, torch.Tensor]: + # Update KVConnector with the KVConnector metadata forward(). + if has_kv_transfer_group(): + get_kv_transfer_group().bind_connector_metadata( + scheduler_output.kv_connector_metadata) + self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: # Return empty ModelRunnerOutput if there's no work to do. return EMPTY_MODEL_RUNNER_OUTPUT - if self.is_multimodal_model: - # Run the multimodal encoder if any. - self._execute_mm_encoder(scheduler_output) - mm_embeds = self._gather_mm_embeddings(scheduler_output) - else: - mm_embeds = [] - # Prepare the decoder inputs. attn_metadata, logits_indices, spec_decode_metadata = ( self._prepare_inputs(scheduler_output)) @@ -1016,9 +1027,26 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_scheduled_tokens) else: # Eager mode. - num_input_tokens = num_scheduled_tokens + # Pad tokens to multiple of tensor_parallel_size when + # enabled collective fusion for SP + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + if self.vllm_config.compilation_config.pass_config. \ + enable_sequence_parallelism and tp_size > 1: + from vllm.utils import round_up + num_input_tokens = round_up(num_scheduled_tokens, tp_size) + else: + num_input_tokens = num_scheduled_tokens attn_metadata.num_input_tokens = num_input_tokens + # _prepare_inputs may reorder the batch, so we must gather multi + # modal outputs after that to ensure the correct order + if self.is_multimodal_model: + # Run the multimodal encoder if any. + self._execute_mm_encoder(scheduler_output) + mm_embeds = self._gather_mm_embeddings(scheduler_output) + else: + mm_embeds = [] + if self.is_multimodal_model: # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) @@ -1061,12 +1089,18 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Run the decoder. # Use persistent buffers for CUDA graphs. with set_forward_context(attn_metadata, self.vllm_config): - hidden_states = self.model( + output = self.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) + + if self.use_aux_hidden_state_outputs: + hidden_states, aux_hidden_states = output + else: + hidden_states = output + if not get_pp_group().is_last_rank: # For mid-pipeline stages, return the hidden states. return hidden_states @@ -1082,7 +1116,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata if spec_decode_metadata is None: - sampler_output = self.model.sample( + sampler_output = self.sampler( logits=logits, sampling_metadata=sampling_metadata, ) @@ -1092,7 +1126,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # logits tensor. This means any in-place operations on bonus_logits # won't affect the original logits tensor. bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] - sampler_output = self.model.sample( + sampler_output = self.sampler( logits=bonus_logits, sampling_metadata=sampling_metadata, ) @@ -1164,7 +1198,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): assert isinstance(self.drafter, NgramProposer) spec_token_ids = self.generate_draft_token_ids( valid_sampled_token_ids, sampling_metadata) - elif self.speculative_config.method == "eagle": + elif self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) # TODO(woosuk): Refactor the loop. next_token_ids: list[int] = [] @@ -1192,7 +1226,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): # not include padding. target_token_ids = self.input_ids[:num_scheduled_tokens] target_positions = positions[:num_scheduled_tokens] - target_hidden_states = hidden_states[:num_scheduled_tokens] + if self.use_aux_hidden_state_outputs: + target_hidden_states = [ + h[:num_scheduled_tokens] for h in aux_hidden_states + ] + else: + target_hidden_states = hidden_states[:num_scheduled_tokens] target_slot_mapping = attn_metadata.slot_mapping cu_num_tokens = attn_metadata.query_start_loc else: @@ -1213,10 +1252,17 @@ class GPUModelRunner(LoRAModelRunnerMixin): ) target_token_ids = self.input_ids[token_indices] target_positions = positions[token_indices] - target_hidden_states = hidden_states[token_indices] + if self.use_aux_hidden_state_outputs: + target_hidden_states = [ + h[token_indices] for h in aux_hidden_states + ] + else: + target_hidden_states = hidden_states[token_indices] target_slot_mapping = attn_metadata.slot_mapping[token_indices] - draft_token_ids, draft_probs = self.drafter.propose( + if self.use_aux_hidden_state_outputs: + target_hidden_states = torch.cat(target_hidden_states, dim=-1) + draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, @@ -1227,9 +1273,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): sampling_metadata=sampling_metadata, ) spec_token_ids = draft_token_ids.tolist() - # TODO(woosuk): Cache draft_probs and use it for rejection sampling - # in the next step. - del draft_probs + + # Clear KVConnector state after all KVs are generated. + if has_kv_transfer_group(): + get_kv_transfer_group().clear_connector_metadata() return ModelRunnerOutput( req_ids=self.input_batch.req_ids, @@ -1254,7 +1301,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): draft_token_ids.append([]) continue - # Skip requests that require top-p, top-k, etc. + # Skip requests that require sampling parameters that are not + # supported with speculative decoding. req_id = self.input_batch.req_ids[i] if not is_spec_decode_supported(req_id, self.input_batch): draft_token_ids.append([]) @@ -1263,6 +1311,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Add sampled_token_ids to token_ids_cpu. start_idx = self.input_batch.num_tokens_no_spec[i] end_idx = start_idx + num_sampled_ids + if end_idx >= self.max_model_len: + # Skip requests that have already reached the max model length. + draft_token_ids.append([]) + continue + self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids drafter_output = self.drafter.propose( self.input_batch.token_ids_cpu[i, :end_idx]) @@ -1286,6 +1339,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): if hasattr(self, "drafter"): logger.info("Loading drafter model...") self.drafter.load_model(self.model) + if self.use_aux_hidden_state_outputs: + self.model.set_aux_hidden_state_layers( + self.model.get_eagle3_aux_hidden_state_layers()) time_after_load = time.perf_counter() self.model_memory_usage = m.consumed_memory logger.info("Model loading took %.4f GiB and %.6f seconds", @@ -1362,8 +1418,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits] # Compute prompt logprobs. - logprobs = self.model.sampler.compute_logprobs(logits) - token_ids, logprobs, ranks = self.model.sampler.gather_logprobs( + logprobs = self.sampler.compute_logprobs(logits) + token_ids, logprobs, ranks = self.sampler.gather_logprobs( logprobs, num_prompt_logprobs, tgt_token_ids) # Transfer GPU->CPU async. @@ -1438,12 +1494,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): - hidden_states = model( + outputs = model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) + if self.use_aux_hidden_state_outputs: + hidden_states, _ = outputs + else: + hidden_states = outputs logit_indices = np.cumsum(num_scheduled_tokens) - 1 return hidden_states[logit_indices] @@ -1481,8 +1541,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): bad_words_token_ids={}, ) try: - sampler_output = self.model.sample( - logits=logits, sampling_metadata=dummy_metadata) + sampler_output = self.sampler(logits=logits, + sampling_metadata=dummy_metadata) except RuntimeError as e: if 'out of memory' in str(e): raise RuntimeError( @@ -1681,17 +1741,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): format. Layers that do not need KV cache are not included. """ - forward_ctx = self.vllm_config.compilation_config.static_forward_context + layers = get_layers_from_vllm_config(self.vllm_config, Attention) block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla kv_cache_spec: dict[str, KVCacheSpec] = {} - for layer_name, attn_module in forward_ctx.items(): - if isinstance(attn_module, FusedMoE): - continue - - # TODO: Support other attention modules, e.g., sliding window, - # cross-attention - assert isinstance(attn_module, Attention) + for layer_name, attn_module in layers.items(): + # TODO: Support other attention modules, e.g., cross-attention if attn_module.attn_type == AttentionType.DECODER: if attn_module.sliding_window is not None: kv_cache_spec[layer_name] = SlidingWindowSpec( diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 2972e0ffb3baae4046e834312f291fcfa278affc..68c4e94fcd73e3f18807963e2b2299429a57f2c6 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -9,11 +9,12 @@ import torch.distributed import torch.nn as nn import vllm.envs as envs -from vllm.config import ParallelConfig, VllmConfig +from vllm.config import VllmConfig from vllm.device_allocator.cumem import CuMemAllocator from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) +from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -22,6 +23,7 @@ from vllm.platforms import current_platform from vllm.utils import GiB_bytes from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.utils import report_usage_stats from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.worker_base import WorkerBase @@ -53,6 +55,9 @@ class Worker(WorkerBase): from vllm.utils import init_cached_hf_modules init_cached_hf_modules() + # Buffers saved before sleep + self._sleep_saved_buffers: dict[str, torch.Tensor] = {} + # Torch profiler. Enabled and configured through env vars: # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace if envs.VLLM_TORCH_PROFILER_DIR: @@ -72,6 +77,15 @@ class Worker(WorkerBase): def sleep(self, level: int = 1) -> None: free_bytes_before_sleep = torch.cuda.mem_get_info()[0] + + # Save the buffers before level 2 sleep + if level == 2: + model = self.model_runner.model + self._sleep_saved_buffers = { + name: buffer.cpu().clone() + for name, buffer in model.named_buffers() + } + allocator = CuMemAllocator.get_instance() allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple()) free_bytes_after_sleep, total = torch.cuda.mem_get_info() @@ -87,6 +101,14 @@ class Worker(WorkerBase): allocator = CuMemAllocator.get_instance() allocator.wake_up(tags) + # Restore the buffers after level 2 sleep + if len(self._sleep_saved_buffers): + model = self.model_runner.model + for name, buffer in model.named_buffers(): + if name in self._sleep_saved_buffers: + buffer.data.copy_(self._sleep_saved_buffers[name].data) + self._sleep_saved_buffers = {} + def init_device(self): if self.device_config.device.type == "cuda": # torch.distributed.all_reduce does not free the input tensor until @@ -110,7 +132,7 @@ class Worker(WorkerBase): raise RuntimeError( f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. - init_worker_distributed_environment(self.parallel_config, self.rank, + init_worker_distributed_environment(self.vllm_config, self.rank, self.distributed_init_method, self.local_rank) # Set random seed. @@ -120,6 +142,10 @@ class Worker(WorkerBase): self.model_runner: GPUModelRunner = GPUModelRunner( self.vllm_config, self.device) + if self.rank == 0: + # If usage stat is enabled, collect relevant info. + report_usage_stats(self.vllm_config) + # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool # to hijack tensor allocation. def load_model(self) -> None: @@ -285,12 +311,13 @@ class Worker(WorkerBase): def init_worker_distributed_environment( - parallel_config: ParallelConfig, + vllm_config: VllmConfig, rank: int, distributed_init_method: Optional[str] = None, local_rank: int = -1, ) -> None: """Initialize the distributed environment.""" + parallel_config = vllm_config.parallel_config set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) init_distributed_environment(parallel_config.world_size, rank, @@ -299,6 +326,8 @@ def init_worker_distributed_environment( ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) + ensure_kv_transfer_initialized(vllm_config) + def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): # Check if the GPU supports the dtype. diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index a8a19e0e6206c9f212df0b9a936d1f0507579239..3cbab840e9693784cad601a58d6cc17974d6bfc0 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -28,20 +28,16 @@ class LoRAModelRunnerMixin: scheduler_config: SchedulerConfig, lora_config: LoRAConfig, device: str) -> nn.Module: - assert supports_lora( - model), f"{model.__class__.__name__} does not support LoRA yet." + if not supports_lora(model): + raise ValueError( + f"{model.__class__.__name__} does not support LoRA yet.") if supports_multimodal(model): logger.warning("Regarding multimodal models, vLLM currently " "only supports adding LoRA to language model.") - # It's necessary to distinguish between the max_position_embeddings - # of VLMs and LLMs. - if hasattr(model.config, "max_position_embeddings"): - max_pos_embeddings = model.config.max_position_embeddings - else: - max_pos_embeddings = ( - model.config.text_config.max_position_embeddings) + # Use get_text_config() in case of multimodal models + text_config = model_config.hf_config.get_text_config() # Add LoRA Manager to the Model Runner self.lora_manager = LRUCacheWorkerLoRAManager( @@ -52,7 +48,7 @@ class LoRAModelRunnerMixin: device, model.embedding_modules, model.embedding_padding_modules, - max_position_embeddings=max_pos_embeddings, + max_position_embeddings=text_config.max_position_embeddings, ) return self.lora_manager.create_lora_manager(model) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 69251d8bbb31f8f2557e0aace39d5e57554bcf8d..67f8af29db0eb725e872b39d58aa9689f531c74a 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import bisect +import gc import time from typing import TYPE_CHECKING, Optional, cast from unittest.mock import patch @@ -16,20 +17,22 @@ import vllm.envs as envs from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange +from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs, + PlaceholderRange) from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, PallasMetadata) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheSpec, SlidingWindowSpec) +from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, + KVCacheConfig, KVCacheSpec, + SlidingWindowSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput) from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata @@ -37,8 +40,7 @@ from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch -from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs, - scatter_mm_placeholders) +from .utils import sanity_check_mm_encoder_outputs if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -53,6 +55,41 @@ INVALID_TOKEN_ID = -1 MIN_NUM_SEQS = 8 +######################################################### +# Ways to avoid recompilation +######################################################### +# +# The model executor has two primary components: +# 1. preparing the model and sampler inputs +# 2. executing the model and sampler. +# The core idea is to avoid any TPU computation during input preparation. For +# better compilation tracking and increased flexibility, the model execution and +# sampler are divided into several distinct components. +# +# Below are the detailed steps: +# +# Step 1 +# It is recommended to avoid TPU operations when preparing the model and sampler +# inputs. CPU tensors can be prepared and transferred to the XLA device using +# cpu_tensor.to(xla_device), which only triggers CPU to TPU transfers and avoids +# compilation. +# +# Step 2 +# The TPU execution should be decomposed into subgraphs (4 at the moment): +# 1. the main model +# 2. selecting hidden states for each request +# 3. sampler +# 4. encoder. +# Each subgraph should be decorated in a torch.compile. This is used to make +# sure that we have the same subgraph topology in both dummy_run and +# xecute_model. The results from these subgraphs should either be passed to +# other subgraphs, or transferred from TPU to CPU using xla_tensor.cpu() for +# subsequent processing on the CPU. +# +# Step 3 +# The dummy_run should be comprehensive, ensuring all potential input shapes and +# branch predictions are included as subgraph inputs to facilitate +# pre-compilation. class TPUModelRunner: def __init__( @@ -93,10 +130,16 @@ class TPUModelRunner: self.block_size = cache_config.block_size self.max_model_len = model_config.max_model_len self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) - self.max_num_tokens = scheduler_config.max_num_batched_tokens # InputBatch needs to work with sampling tensors greater than padding # to avoid dynamic shapes. Also, avoid suboptimal alignment. self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS) + self.num_tokens_paddings = _get_token_paddings( + min_token_size=16, + max_token_size=scheduler_config.max_num_batched_tokens, + padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP) + # In case `max_num_tokens < max(num_tokens_paddings)` use the actual + # padded max value to pre-allocate data structures and pre-compile. + self.max_num_tokens = self.num_tokens_paddings[-1] # Model-related. self.num_attn_layers = model_config.get_num_layers_by_block_type( @@ -106,6 +149,7 @@ class TPUModelRunner: self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.head_size = model_config.get_head_size() self.hidden_size = model_config.get_hidden_size() + self.vocab_size = model_config.get_vocab_size() # Multi-modal data support self.mm_registry = MULTIMODAL_REGISTRY @@ -136,7 +180,7 @@ class TPUModelRunner: max_num_blocks_per_req=self.max_num_blocks_per_req, device=self.device, pin_memory=self.pin_memory, - vocab_size=model_config.get_vocab_size(), + vocab_size=self.vocab_size, ) # Cached torch/numpy tensor @@ -157,7 +201,7 @@ class TPUModelRunner: device="cpu") self.slot_mapping_np = self.slot_mapping_cpu.numpy() self.block_table_cpu = torch.zeros( - (self.max_num_tokens, self.max_num_blocks_per_req), + (self.max_num_reqs, self.max_num_blocks_per_req), dtype=self.input_batch.block_table.get_cpu_tensor().dtype, device="cpu") @@ -175,14 +219,56 @@ class TPUModelRunner: # Range tensor with values [0 .. self.max_num_tokens - 1]. # Used to initialize positions / context_lens / seq_lens - self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32) - self.num_tokens_paddings = _get_token_paddings( - min_token_size=16, - max_token_size=self.max_num_tokens, - padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP) + # Keep in int64 to avoid overflow with long context + self.arange_np = np.arange(self.max_num_tokens, dtype=np.int64) self.num_reqs_paddings = _get_req_paddings( min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs) + # tensors for structured decoding + self.grammar_bitmask_cpu = torch.zeros( + (self.max_num_reqs, cdiv(self.vocab_size, 32)), + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.require_structured_out_cpu = torch.zeros( + (self.max_num_reqs, 1), + dtype=torch.bool, + device="cpu", + pin_memory=self.pin_memory) + self.structured_decode_arange = torch.arange( + 0, 32, device="cpu", pin_memory=self.pin_memory) + + # Get maximum number of mm items per modality (batch size). + self.max_num_mm_items_by_modality = dict() + if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0 + and self.encoder_cache_size > 0): + max_tokens_by_modality_dict = ( + MULTIMODAL_REGISTRY. + get_max_tokens_per_item_by_nonzero_modality(self.model_config)) + for modality, max_tokens in max_tokens_by_modality_dict.items(): + # Check how many items of this modality can be supported by + # the encoder budget. + encoder_budget = min(self.max_num_encoder_input_tokens, + self.encoder_cache_size) + + max_num_mm_items_encoder_budget = cdiv(encoder_budget, + max_tokens) + + # Check how many items of this modality can be supported by + # the decoder budget. + max_mm_items_per_req = self.mm_registry.\ + get_mm_limits_per_prompt(self.model_config)[modality] + + # NOTE: We do not consider max_num_batched_tokens on purpose + # because the multimodal embeddings can be generated in advance + # and chunked prefilled. + max_num_mm_items_decoder_budget = self.max_num_reqs * \ + max_mm_items_per_req + + max_num_mm_items = min(max_num_mm_items_encoder_budget, + max_num_mm_items_decoder_budget) + self.max_num_mm_items_by_modality[modality] = max_num_mm_items + def _update_num_xla_graphs(self, case_str): check_comp = self.check_recompilation and not self.enforce_eager if not check_comp: @@ -270,7 +356,6 @@ class TPUModelRunner: self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, - prompt=new_req_data.prompt, mm_inputs=new_req_data.mm_inputs, mm_positions=new_req_data.mm_positions, sampling_params=sampling_params, @@ -344,11 +429,10 @@ class TPUModelRunner: format. Layers that do not need KV cache are not included. """ - forward_ctx = self.vllm_config.compilation_config.static_forward_context + layers = get_layers_from_vllm_config(self.vllm_config, Attention) block_size = self.vllm_config.cache_config.block_size kv_cache_spec: dict[str, KVCacheSpec] = {} - for layer_name, attn_module in forward_ctx.items(): - assert isinstance(attn_module, Attention) + for layer_name, attn_module in layers.items(): if attn_module.attn_type == AttentionType.DECODER: if attn_module.sliding_window is not None: kv_cache_spec[layer_name] = SlidingWindowSpec( @@ -569,29 +653,36 @@ class TPUModelRunner: # 2. A list or tuple (length: num_items) of tensors, each of shape # (feature_size, hidden_size) in case the feature size is dynamic # depending on the input multimodal items. + xm.mark_step() curr_group_outputs = self.model.get_multimodal_embeddings( **batched_mm_inputs) + xm.mark_step() sanity_check_mm_encoder_outputs( curr_group_outputs, expected_num_items=len(grouped_mm_inputs), ) - for output in curr_group_outputs: - encoder_outputs.append(output) + if isinstance(curr_group_outputs, torch.Tensor): + encoder_outputs.append(curr_group_outputs) + else: + assert isinstance(curr_group_outputs, (list, tuple)) + for output in curr_group_outputs: + encoder_outputs.append(output) # Cache the encoder outputs. + # NOTE (NickLucche) here we diverge from logic in other runners, as we + # assume to only have whole mm items to process. Hence we avoid the + # intrinsic dynamism that `scatter_mm_placeholders` introduces. for (req_id, input_id, pos_info), output in zip( req_ids_pos, encoder_outputs, ): if req_id not in self.encoder_cache: self.encoder_cache[req_id] = {} - - self.encoder_cache[req_id][input_id] = scatter_mm_placeholders( - output, - is_embed=pos_info.is_embed, - ) + assert pos_info.is_embed is None, "Expected all positions to be"\ + " contiguous and embeddings." + self.encoder_cache[req_id][input_id] = output def _gather_mm_embeddings( self, @@ -604,6 +695,10 @@ class TPUModelRunner: req_state = self.requests[req_id] num_computed_tokens = req_state.num_computed_tokens mm_positions = req_state.mm_positions + # TODO unroll loop and assume/enforce --disable_chunked_mm_input + # NOTE (NickLucche) here we diverge from logic in other runners, as + # we assume to only have whole mm items to process. Hence we avoid + # the intrinsic dynamism that `gather_mm_placeholders` introduces. for i, pos_info in enumerate(mm_positions): start_pos = pos_info.offset num_encoder_tokens = pos_info.length @@ -620,25 +715,33 @@ class TPUModelRunner: # in the decoder's KV cache. continue - start_idx = max(num_computed_tokens - start_pos, 0) - end_idx = min( - num_computed_tokens - start_pos + num_scheduled_tokens, - num_encoder_tokens) - assert start_idx < end_idx assert req_id in self.encoder_cache assert i in self.encoder_cache[req_id] + assert pos_info.is_embed is None, "Expected all positions to"\ + " be contiguous and embeddings." encoder_output = self.encoder_cache[req_id][i] - - if (is_embed := pos_info.is_embed) is not None: - is_embed = is_embed[start_idx:end_idx] - - mm_embeds_item = gather_mm_placeholders( - encoder_output[start_idx:end_idx], - is_embed=is_embed, - ) - mm_embeds.append(mm_embeds_item) + mm_embeds.append(encoder_output) return mm_embeds + def _get_model_inputs(self, input_ids: torch.Tensor, + mm_embeds: list[torch.Tensor]): + if self.is_multimodal_model: + # NOTE(woosuk): To unify token ids and soft tokens (vision + # embeddings), we always use embeddings (rather than token ids) + # as input to the multimodal model, even when the input is text. + if mm_embeds: + inputs_embeds = self.model.get_input_embeddings( + input_ids, mm_embeds) + else: + inputs_embeds = self.model.get_input_embeddings(input_ids) + return None, inputs_embeds + else: + # For text-only models, we use token ids as input. + # While it is possible to use embeddings as input just like the + # multimodal models, it is not desirable for performance since + # then the embedding layer is not included in the CUDA graph. + return input_ids, None + @torch.no_grad() def execute_model( self, @@ -657,27 +760,13 @@ class TPUModelRunner: mm_embeds = self._gather_mm_embeddings(scheduler_output) else: mm_embeds = [] - + xm.mark_step() # Prepare inputs attn_metadata, logits_indices, padded_num_reqs = self._prepare_inputs( scheduler_output) - if self.is_multimodal_model: - # NOTE(woosuk): To unify token ids and soft tokens (vision - # embeddings), we always use embeddings (rather than token ids) - # as input to the multimodal model, even when the input is text. - if mm_embeds: - inputs_embeds = self.model.get_input_embeddings( - self.input_ids, mm_embeds) - else: - inputs_embeds = self.model.get_input_embeddings(self.input_ids) - input_ids = None - else: - # For text-only models, we use token ids as input. - # While it is possible to use embeddings as input just like the - # multimodal models, it is not desirable for performance since - # then the embedding layer is not included in the CUDA graph. - input_ids = self.input_ids - inputs_embeds = None + input_ids, inputs_embeds = self._get_model_inputs( + self.input_ids, mm_embeds) + xm.mark_step() num_reqs = self.input_batch.num_reqs # Run the decoder with set_forward_context(attn_metadata, self.vllm_config): @@ -688,9 +777,16 @@ class TPUModelRunner: ) hidden_states = self.select_hidden_states(hidden_states, logits_indices) + logits = self.compute_logits(hidden_states) tpu_sampling_metadata = TPUSupportedSamplingMetadata.\ from_input_batch(self.input_batch, padded_num_reqs, self.device) - selected_token_ids = self.sample_from_hidden(hidden_states, + if scheduler_output.grammar_bitmask is not None: + require_struct_decoding, grammar_bitmask_padded, arange = \ + self.prepare_structured_decoding_input(logits, scheduler_output) + logits = self.structured_decode(require_struct_decoding, + grammar_bitmask_padded, logits, + arange) + selected_token_ids = self.sample_from_logits(logits, tpu_sampling_metadata) # Remove padding on cpu and keep dynamic op outside of xla graph. selected_token_ids = selected_token_ids.cpu()[:num_reqs] @@ -853,16 +949,77 @@ class TPUModelRunner: inputs_embeds=inputs_embeds) self._hidden_states_dtype = out.dtype + def _precompile_mm_encoder(self) -> None: + # Pre-compile MM encoder for all supported data modalities. + hf_config = self.vllm_config.model_config.hf_config + for mode, max_items_by_mode in \ + self.max_num_mm_items_by_modality.items(): + logger.info( + "Compiling Multimodal %s Encoder with different input" + " shapes.", mode) + start = time.perf_counter() + # No padding for MM encoder just yet. + for num_items in range(1, max_items_by_mode + 1): + logger.info(" -- mode: %s items: %d", mode, num_items) + batched_dummy_mm_inputs = self._get_mm_dummy_batch( + mode, num_items) + # Run multimodal encoder. + xm.mark_step() + mm_embeds = self.model.\ + get_multimodal_embeddings(**batched_dummy_mm_inputs) + xm.mark_step() + num_patches = mm_embeds[0].shape[0] + items_size = num_patches * num_items + + # NOTE (NickLucche) pre-compile `get_input_embeddings` when mm + # embeddings are present. We assume `--disable-mm-chunked`, + # hence only whole items can be scheduled. This implies we just + # need to compile when `num_items` fit the (padded) `input_ids` + for num_tokens in self.num_tokens_paddings: + if num_tokens >= items_size: + # XLA Workaround: if torch.zeros(..device) is used, XLA + # compiles a scalar+expansion op, which won't match + # the graph generated at runtime. CPU->TPU must be used + placeholders_ids = torch.zeros(num_tokens, + dtype=torch.int32, + device="cpu") + # Align placeholders and actual num mm_embeddings. + placeholders_ids[:items_size] = \ + hf_config.image_token_index + + placeholders_ids = placeholders_ids.to(self.device) + # Assign outputs or the graph will be cut short. + a, b = self._get_model_inputs(placeholders_ids, + [mm_embeds]) + assert a is None + xm.mark_step() + + # Pre-compile `get_input_embeddings` when mm_embeddings are not + # present. Chunk is only made of text, no mm_placeholders. + for num_tokens in self.num_tokens_paddings: + placeholders_ids = torch.zeros(num_tokens, + dtype=torch.int32, + device="cpu") + placeholders_ids = placeholders_ids.to(self.device) + a, b = self._get_model_inputs(placeholders_ids, []) + assert a is None + xm.mark_step() + + xm.wait_device_ops() + end = time.perf_counter() + logger.info( + "Multimodal %s Encoder compilation finished in in %.2f " + "[secs].", mode, end - start) + def _precompile_backbone(self) -> None: logger.info("Compiling the model with different input shapes.") - start = time.perf_counter() for num_tokens in self.num_tokens_paddings: logger.info(" -- num_tokens: %d", num_tokens) self._dummy_run(num_tokens) xm.wait_device_ops() end = time.perf_counter() - logger.info("Compilation finished in in %.2f [secs].", end - start) + logger.info("Compilation finished in %.2f [secs].", end - start) self._update_num_xla_graphs("model backbone") def _precompile_select_hidden_states(self) -> None: @@ -883,22 +1040,67 @@ class TPUModelRunner: device=self.device) torch._dynamo.mark_dynamic(indices, 0) self.select_hidden_states(dummy_hidden, indices) - logger.info(" -- num_tokens: %d", num_tokens) + logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens, + num_reqs) + # Requests can't be more than tokens. But do compile for the + # next bigger value in case num_tokens uses bucketed padding. + if num_reqs >= min(num_tokens, self.max_num_reqs): + break xm.wait_device_ops() end = time.perf_counter() - logger.info("Compilation finished in in %.2f [secs].", end - start) + logger.info("Compilation finished in %.2f [secs].", end - start) self._update_num_xla_graphs("select_hidden_states") - def _precompile_sample_from_hidden(self) -> None: - logger.info("Compiling sampling with different input shapes.") + def _precompile_compute_logits(self) -> None: + logger.info("Compiling compute_logits with different input shapes.") start = time.perf_counter() hsize = self.model_config.get_hidden_size() for num_reqs in self.num_reqs_paddings: dummy_hidden = torch.zeros((num_reqs, hsize), device=self.device, dtype=self._hidden_states_dtype) - # The first dimension of dummy_hidden cannot be mark_dynamic because - # some operations in the sampler require it to be static. + torch._dynamo.mark_dynamic(dummy_hidden, 0) + self.compute_logits(dummy_hidden) + logger.info(" -- num_seqs: %d", num_reqs) + xm.wait_device_ops() + end = time.perf_counter() + logger.info("Compilation finished in %.2f [secs].", end - start) + self._update_num_xla_graphs("compute_logits") + + def _precompile_structured_decoding(self) -> None: + logger.info( + "Compiling structured_decoding with different input shapes.") + start = time.perf_counter() + for num_reqs in self.num_reqs_paddings: + dummy_logits = torch.zeros((num_reqs, self.vocab_size), + device=self.device, + dtype=self._hidden_states_dtype) + dummy_require_struct_decoding = \ + self.require_structured_out_cpu[:num_reqs].to(self.device) + dummy_grammar_bitmask = \ + self.grammar_bitmask_cpu[:num_reqs].to(self.device) + # The first dimension of the above 3 dummy tensors cannot be + # mark_dynamic because some operations in structured_decode require + # them to be static. + arange = self.structured_decode_arange.to(self.device) + self.structured_decode(dummy_require_struct_decoding, + dummy_grammar_bitmask, dummy_logits, arange) + logger.info(" -- num_seqs: %d", num_reqs) + xm.wait_device_ops() + end = time.perf_counter() + logger.info("Compilation finished in %.2f [secs].", end - start) + self._update_num_xla_graphs("structured_decoding") + + def _precompile_sample_from_logits(self) -> None: + logger.info( + "Compiling sample_from_logits with different input shapes.") + start = time.perf_counter() + for num_reqs in self.num_reqs_paddings: + dummy_logits = torch.zeros((num_reqs, self.vocab_size), + device=self.device, + dtype=self._hidden_states_dtype) + # The first dimension of dummy_logits cannot be mark_dynamic + # because some operations in the sampler require it to be static. for all_greedy in [False, True]: generate_params_if_all_greedy = not all_greedy sampling_metadata = ( @@ -909,21 +1111,82 @@ class TPUModelRunner: generate_params_if_all_greedy, )) sampling_metadata.all_greedy = all_greedy - self.sample_from_hidden(dummy_hidden, sampling_metadata) + self.sample_from_logits(dummy_logits, sampling_metadata) logger.info(" -- num_seqs: %d", num_reqs) xm.wait_device_ops() end = time.perf_counter() - logger.info("Compilation finished in in %.2f [secs].", end - start) - self._update_num_xla_graphs("sampling") + logger.info("Compilation finished in %.2f [secs].", end - start) + self._update_num_xla_graphs("sample_from_logits") def capture_model(self) -> None: """ Precompile all the subgraphs with possible input shapes. """ - # TODO: precompile encoder + self._precompile_mm_encoder() self._precompile_backbone() self._precompile_select_hidden_states() - self._precompile_sample_from_hidden() + self._precompile_compute_logits() + self._precompile_structured_decoding() + self._precompile_sample_from_logits() + + def profile_run( + self, + num_tokens: int, + ) -> None: + # Profile with multimodal encoder & encoder cache. + # TODO: handle encoder-decoder models once we support them. + if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0 + and self.encoder_cache_size > 0): + + # NOTE: Currently model is profiled with a single non-text + # modality with the max possible input tokens even when + # it supports multiple. + dummy_data_modality, max_num_mm_items = max( + self.max_num_mm_items_by_modality.items(), key=lambda t: t[1]) + + encoder_budget = min(self.max_num_encoder_input_tokens, + self.encoder_cache_size) + + logger.info( + "Encoder cache will be initialized with a budget of %d tokens," + " and profiled with %s %s items of the maximum feature size.", + encoder_budget, max_num_mm_items, dummy_data_modality) + + # Create dummy batch of multimodal inputs. + batched_dummy_mm_inputs = self._get_mm_dummy_batch( + dummy_data_modality, max_num_mm_items) + + # Run multimodal encoder. + # Isolate encoder graph from post-processing to minimize + # impact of recompilation until it's fixed. + start = time.perf_counter() + xm.mark_step() + dummy_encoder_outputs = self.model.get_multimodal_embeddings( + **batched_dummy_mm_inputs) + xm.mark_step() + xm.wait_device_ops() + end = time.perf_counter() + logger.info( + "Multimodal Encoder profiling finished in in %.2f [secs].", + end - start) + + assert len(dummy_encoder_outputs) == max_num_mm_items, ( + "Expected dimension 0 of encoder outputs to match the number " + f"of multimodal data items: {max_num_mm_items}, got " + f"{len(dummy_encoder_outputs)=} instead. This is most likely " + "due to the 'get_multimodal_embeddings' method of the model " + "not implemented correctly.") + + # Cache the dummy encoder outputs. + self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) + + # Trigger compilation for general shape. + self._dummy_run(num_tokens) + + xm.mark_step() + xm.wait_device_ops() + self.encoder_cache.clear() + gc.collect() def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ @@ -945,7 +1208,7 @@ class TPUModelRunner: tensor_config = kv_cache_config.tensors[layer_name] assert tensor_config.size % kv_cache_spec.page_size_bytes == 0 num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes - if isinstance(kv_cache_spec, FullAttentionSpec): + if isinstance(kv_cache_spec, AttentionSpec): kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) @@ -980,16 +1243,14 @@ class TPUModelRunner: return hidden_states[indices_do_sample] @torch.compile(backend="openxla", fullgraph=True, dynamic=False) - def sample_from_hidden( - self, - sample_hidden_states: torch.Tensor, - sampling_metadata: TPUSupportedSamplingMetadata, - ) -> torch.Tensor: - """ - Sample with xla-friendly function. This function is to be traced - separately from `forward` for lighter compilation overhead. - """ - logits = self.model.compute_logits(sample_hidden_states, None) + def compute_logits(self, + sample_hidden_states: torch.Tensor) -> torch.Tensor: + return self.model.compute_logits(sample_hidden_states, None) + + @torch.compile(backend="openxla", fullgraph=True, dynamic=False) + def sample_from_logits( + self, logits: torch.Tensor, + sampling_metadata: TPUSupportedSamplingMetadata) -> torch.Tensor: if sampling_metadata.all_greedy: out_tokens = torch.argmax(logits, dim=-1, keepdim=True) else: @@ -997,12 +1258,101 @@ class TPUModelRunner: sampling_metadata).sampled_token_ids return out_tokens + @torch.compile(backend="openxla", fullgraph=True, dynamic=False) + def structured_decode(self, require_struct_decoding: torch.Tensor, + grammar_bitmask: torch.Tensor, logits: torch.Tensor, + arange: torch.Tensor) -> torch.Tensor: + return torch.where( + require_struct_decoding, + self.apply_grammar_bitmask(logits, grammar_bitmask, arange), + logits) + + def apply_grammar_bitmask(self, logits: torch.Tensor, + grammar_bitmask: torch.Tensor, + arange: torch.Tensor): + assert (logits.shape[0] == grammar_bitmask.shape[0]) + logits_cloned = logits.clone() + for i in range(logits.shape[0]): + unpacked_bitmask = (torch.bitwise_right_shift( + grammar_bitmask[i][:, None], arange[None, :]) & 1) == 0 + unpacked_bitmask = unpacked_bitmask.reshape(-1)[:self.vocab_size] + logits_cloned[i] = logits_cloned[i].masked_fill( + unpacked_bitmask, -float("inf")) + return logits_cloned + def get_multimodal_embeddings(self, *args, **kwargs): return self.model.get_multimodal_embeddings(*args, **kwargs) def get_input_embeddings(self, *args, **kwargs): return self.model.get_input_embeddings(*args, **kwargs) + def prepare_structured_decoding_input( + self, logits: torch.Tensor, scheduler_output: "SchedulerOutput" + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + grammar_bitmask = scheduler_output.grammar_bitmask + assert grammar_bitmask is not None + num_reqs, _ = logits.shape + + # Reset pre-allocated tensors + self.grammar_bitmask_cpu.zero_() + self.require_structured_out_cpu.zero_() + + # We receive the structured output bitmask from the scheduler, but the + # indices of the requests in the batch may not match the indices of + # the bitmask since the scheduler doesn't know how the tpu runner is + # ordering the requests in the batch. We need to match the order of + # bitmask with the order of requests + struct_out_indices: list[int] = [] + mask_indices: list[int] = [] + for req_id in self.input_batch.req_ids: + mask_index = scheduler_output.structured_output_request_ids.get( + req_id) + if mask_index is None: + continue + batch_index = self.input_batch.req_id_to_index[req_id] + struct_out_indices.append(batch_index) + mask_indices.append(mask_index) + self.grammar_bitmask_cpu[struct_out_indices] = torch.from_numpy( + grammar_bitmask[mask_indices]) + # It's not guaranteed that all requests in this batch require + # structured output, so create a bool tensor to represent + # the requests that need structured output. + struct_out_indices = torch.tensor(struct_out_indices, dtype=torch.long) + self.require_structured_out_cpu[struct_out_indices] = True + return self.require_structured_out_cpu[:num_reqs].to(logits.device), \ + self.grammar_bitmask_cpu[:num_reqs].to(logits.device), \ + self.structured_decode_arange.to(logits.device) + + def _get_mm_dummy_batch(self, modality: str, + batch_size: int) -> BatchedTensorInputs: + # Dummy data for pre-compiling multimodal models. + dummy_request_data = self.mm_registry.get_decoder_dummy_data( + model_config=self.model_config, + seq_len=self.max_num_tokens, + ) + dummy_mm_data = dummy_request_data.multi_modal_data + + # Dummy data definition in V0 may contain multiple multimodal items + # (e.g, multiple images) for a single request, therefore here we + # always replicate first item by max_num_mm_items times since in V1 + # they are scheduled to be processed separately. + assert isinstance(dummy_mm_data, MultiModalKwargs), ( + "Expected dummy multimodal data to be of type " + f"MultiModalKwargs, got {type(dummy_mm_data)=} instead. " + "This is most likely due to the model not having a merged " + "processor.") + + # When models have a merged processor, their dummy data is + # already batched `MultiModalKwargs`, therefore we take the first + # `MultiModalKwargsItem` from the desired modality to profile on. + dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0) + dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item]) + + batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] * + batch_size) + return MultiModalKwargs.as_kwargs(batched_dummy_mm_inputs, + device=self.device) + def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]: logger.info("Preparing request paddings:") @@ -1040,11 +1390,12 @@ def _get_token_paddings(min_token_size: int, max_token_size: int, if padding_gap == 0: logger.info("Using exponential token paddings:") - while num <= max_token_size: + while True: logger.info(" %d", num) paddings.append(num) + if num >= max_token_size: + break num *= 2 - else: logger.info("Using incremental token paddings:") while num <= padding_gap: diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 73c43969b87b5c9fe45c1b191e01dc13aefbe7e6..de676541effa5ead8b5173bdb0e2d9ac3a39f0de 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -21,7 +21,7 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import ModelRunnerOutput -from vllm.v1.utils import bind_kv_cache +from vllm.v1.utils import bind_kv_cache, report_usage_stats from vllm.v1.worker.tpu_model_runner import TPUModelRunner logger = init_logger(__name__) @@ -133,6 +133,10 @@ class TPUWorker: # Init ModelRunner here, so that we have access to self.device. self.model_runner = TPUModelRunner(self.vllm_config, self.device) + if rank == 0: + # If usage stat is enabled, collect relevant info. + report_usage_stats(self.vllm_config) + def determine_available_memory(self) -> int: kv_caches: dict[str, torch.Tensor] = {} kv_cache_spec = self.model_runner.get_kv_cache_spec() @@ -156,8 +160,8 @@ class TPUWorker: self.vllm_config.compilation_config.static_forward_context, runner_kv_caches) - self.model_runner._dummy_run( - self.scheduler_config.max_num_batched_tokens) + # `max_num_tokens >= max_num_batched_tokens` due to padding. + self.model_runner.profile_run(self.model_runner.max_num_tokens) # Synchronize before measuring the memory usage. xm.wait_device_ops() diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index b83826920f7873e515f8f4c840e2a0620b4df937..0b687f45810793dacf7349d3c0d334711f759d53 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -72,19 +72,32 @@ class CacheEngine: device: str, ) -> List[torch.Tensor]: """Allocates KV cache on the specified device.""" - kv_cache_shape = self.attn_backend.get_kv_cache_shape( + kv_cache_generic_shape = self.attn_backend.get_kv_cache_shape( num_blocks, self.block_size, self.num_kv_heads, self.head_size) pin_memory = is_pin_memory_available() if device == "cpu" else False kv_cache: List[torch.Tensor] = [] + try: + kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order( + ) + except (AttributeError, NotImplementedError): + kv_cache_stride_order = tuple(range(len(kv_cache_generic_shape))) + + # The allocation respects the backend-defined stride order to ensure + # the semantic remains consistent for each backend. We first obtain the + # generic kv cache shape and then permute it according to the stride + # order which could result in a non-contiguous tensor. + kv_cache_allocation_shape = tuple(kv_cache_generic_shape[i] + for i in kv_cache_stride_order) for _ in range(self.num_attention_layers): # null block in CpuGpuBlockAllocator requires at least that # block to be zeroed-out. # We zero-out everything for simplicity. - layer_kv_cache = torch.zeros(kv_cache_shape, - dtype=self.dtype, - pin_memory=pin_memory, - device=device) + layer_kv_cache = torch.zeros( + kv_cache_allocation_shape, + dtype=self.dtype, + pin_memory=pin_memory, + device=device).permute(*kv_cache_stride_order) # view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases # when entry_shape is higher than 1D diff --git a/vllm/worker/cpu_enc_dec_model_runner.py b/vllm/worker/cpu_enc_dec_model_runner.py index ac7c93e48395df11d81345650e981e82ed8e98ac..c2120c035175a1f169ffca04244843198bd0ea0c 100644 --- a/vllm/worker/cpu_enc_dec_model_runner.py +++ b/vllm/worker/cpu_enc_dec_model_runner.py @@ -316,7 +316,7 @@ class CPUEncoderDecoderModelRunner( return [] # Sample the next token. - output = self.model.sample( + output = self.sampler( logits=logits, sampling_metadata=model_input.sampling_metadata, ) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 9f4b18869bdfa21a92f732aa68605ecc86160265..710ca1a13b0c5be16e15c972429f6ec73a7f750f 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -19,11 +19,11 @@ from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding -from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_lora, supports_multimodal -from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalKwargs, MultiModalPlaceholderMap) +from vllm.multimodal import (BatchedTensorInputs, MultiModalKwargs, + MultiModalPlaceholderMap) from vllm.sequence import (IntermediateTensors, SequenceData, SequenceGroupMetadata) from vllm.worker.model_runner_base import ( @@ -154,7 +154,6 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): self.sliding_window = self.runner.sliding_window self.block_size = self.runner.block_size self.device = self.runner.device - self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper self.enable_lora = self.runner.lora_config is not None if self.runner.attn_backend is not None: # spec decode (e.g. Medusa) does not have atten backend @@ -359,22 +358,14 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): computed_len = seq_data.get_num_computed_tokens() seq_len = self.input_data.seq_lens[-1] - # NOTE: mm_data only includes the subset of multi-modal items that + # NOTE: mm_kwargs only includes the subset of multi-modal items that # intersect with the current prefill positions. - mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( + mm_kwargs, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( seq_group_metadata, range(computed_len, seq_len)) - if not mm_data: + if not mm_kwargs: return - if self.runner.mm_registry.has_processor(self.runner.model_config): - mm_kwargs = mm_data - else: - mm_kwargs = self.multi_modal_input_mapper( - mm_data, - seq_group_metadata.mm_processor_kwargs, - ) - # special processing for mrope position deltas. if self.runner.model_config.uses_mrope: assert not self.chunked_prefill, \ @@ -382,11 +373,17 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): image_grid_thw = mm_kwargs.get("image_grid_thw", None) video_grid_thw = mm_kwargs.get("video_grid_thw", None) - assert image_grid_thw is not None or video_grid_thw is not None, ( - "mrope embedding type requires multi-modal input mapper " - "returns 'image_grid_thw' or 'video_grid_thw'.") + audio_feature_lengths = mm_kwargs.get("audio_feature_lengths", + None) + assert ( + image_grid_thw is not None or video_grid_thw is not None + or audio_feature_lengths is not None), ( + "mrope embedding type requires multi-modal input mapper " + "returns 'image_grid_thw' or 'video_grid_thw' or " + "'audio_feature_lengths'.") second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None) + use_audio_in_video = mm_kwargs.get("use_audio_in_video", False) hf_config = self.runner.model_config.hf_config token_ids = seq_data.get_token_ids() @@ -398,6 +395,8 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): video_grid_thw=video_grid_thw, second_per_grid_ts=second_per_grid_ts, context_len=computed_len, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, ) seq_data.mrope_position_delta = mrope_position_delta @@ -472,16 +471,11 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]): use_mla=self.model_config.use_mla, ) if needs_attn_backend else None - # Multi-modal data support - self.mm_registry = MULTIMODAL_REGISTRY - self.multi_modal_input_mapper = self.mm_registry \ - .create_input_mapper(self.model_config) - self.mm_registry.init_mm_limits_per_prompt(self.model_config) - # Lazy initialization. self.model: nn.Module # Set after init_Model # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None + self.sampler = get_sampler() if hasattr(self, "_builder_cls"): # multi-step model runner does not have `_builder_cls` @@ -499,13 +493,8 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]): logger.warning("Regarding multimodal models, vLLM currently " "only supports adding LoRA to language model.") - # It's necessary to distinguish between the max_position_embeddings - # of VLMs and LLMs. - if hasattr(self.model.config, "max_position_embeddings"): - max_pos_embeddings = self.model.config.max_position_embeddings - else: - max_pos_embeddings = ( - self.model.config.text_config.max_position_embeddings) + # Use get_text_config() in case of multimodal models + text_config = self.model_config.hf_config.get_text_config() self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, @@ -515,7 +504,7 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]): self.device, self.model.embedding_modules, self.model.embedding_padding_modules, - max_position_embeddings=max_pos_embeddings, + max_position_embeddings=text_config.max_position_embeddings, ) self.model = self.lora_manager.create_lora_manager(self.model) @@ -537,11 +526,6 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]): return self.builder.build() # type: ignore - # sampler property will be used by spec_decode_worker - @property - def sampler(self): - return self.model.sampler - @property def vocab_size(self) -> int: return self.model_config.get_vocab_size() @@ -669,7 +653,7 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]): return [] # Sample the next token. - output = self.model.sample( + output = self.sampler( logits=logits, sampling_metadata=model_input.sampling_metadata, ) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 72ff9d66a689852c260fbe83a555db60db9d12ea..4df192a8727c30f9d37f2a2f3afd5ef3722d12af 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -100,6 +100,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): vllm_config=vllm_config, kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker, + input_registry=input_registry, + mm_registry=mm_registry, ) # Crash for unsupported encoder/scenarios @@ -205,7 +207,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): model_input.async_callback() # Sample the next token. - output: SamplerOutput = self.model.sample( + output: SamplerOutput = self.sampler( logits=logits, sampling_metadata=model_input.sampling_metadata, ) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 7a346b34cef59b01ddfa5016a1e49cddd1636243..e25864349e2804eec90fde3e27cccd204e6b11a6 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -11,11 +11,9 @@ import functools import gc import itertools import math -import operator import os import time from array import array -from dataclasses import dataclass, field from enum import IntEnum from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union) @@ -24,8 +22,9 @@ import habana_frameworks.torch as htorch import habana_frameworks.torch.internal.bridge_config as bc import torch import torch.nn as nn +import vllm_hpu_extension.environment as environment +from vllm_hpu_extension.bucketing.common import get_bucketing_context from vllm_hpu_extension.ops import LoraMask as LoraMask -from vllm_hpu_extension.ops import batch2block, block2batch from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler, HabanaMemoryProfiler, format_bytes) @@ -41,13 +40,12 @@ from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader import get_model from vllm.model_executor.sampling_metadata import SequenceGroupToSample -from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalKwargs) +from vllm.multimodal import BatchedTensorInputs, MultiModalKwargs from vllm.sampling_params import SamplingParams from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, Logprob, SequenceData, SequenceGroupMetadata, @@ -74,24 +72,7 @@ _PAD_BLOCK_ID = 0 LORA_WARMUP_RANK = 8 - -class Singleton(type): - _instances: Dict[type, object] = {} - - def __call__(cls, *args, **kwargs): - if cls not in cls._instances: - cls._instances[cls] = super().__call__(*args, **kwargs) - return cls._instances[cls] - - -@dataclass -class HPUBucketingGlobalState(metaclass=Singleton): - prompt_bs_bucket_cfg: Tuple[int, int, int] = field(init=False) - decode_bs_bucket_cfg: Tuple[int, int, int] = field(init=False) - prompt_seq_bucket_cfg: Tuple[int, int, int] = field(init=False) - decode_block_bucket_cfg: Tuple[int, int, int] = field(init=False) - prompt_buckets: List[Tuple[int, int]] = field(init=False) - decode_buckets: List[Tuple[int, int]] = field(init=False) +DUMMY_TOKEN_ID = -1 def subtuple(obj: object, @@ -113,134 +94,10 @@ def subtuple(obj: object, return _TYPE_CACHE[typename](**values) -def read_bucket_settings(phase: str, dim: str, **defaults): - """Read bucketing configuration from env variables. - - phase is either 'prompt' or 'decode' - dim is either 'bs', 'seq' or 'block' - param is either 'min', 'step' or 'max' - example env variable: VLLM_DECODE_BS_BUCKET_STEP=128 - """ - params = ['min', 'step', 'max'] - env_vars = [f'VLLM_{phase}_{dim}_BUCKET_{p}'.upper() for p in params] - default_values = [defaults[p] for p in params] - values = [ - int(os.environ.get(e, d)) for e, d in zip(env_vars, default_values) - ] - for e, v, d in zip(env_vars, values, default_values): - logger.info('%s=%s (default:%s)', e, v, d) - return values - - -def warmup_range(config: Tuple[int, int, int]): - """Generate a warmup range. - - Start from bmin and multiply by 2 until you reach bstep. - Then, increase the values in the range by the value of bstep until you - reach bmax. - - Example: - bmin = 2, bstep = 32, bmax = 64 - => ramp_up = (2, 4, 8, 16) - => stable = (32, 64) - => return ramp_up + stable => (2, 4, 8, 16, 32, 64) - """ - bmin, bstep, bmax = config - assert bmin <= bmax, ("Min. batch size cannot be greater than max. " - "batch size. If you want to skip warmup, " - "set VLLM_SKIP_WARMUP=true") - base = itertools.repeat(2) - ramp_up_acc = itertools.accumulate(base, func=operator.mul, initial=bmin) - ramp_up_tw = itertools.takewhile(lambda x: x < bstep and x <= bmax, \ - ramp_up_acc) - stable = range(bstep, bmax + 1, bstep) - buckets = list(ramp_up_tw) + list(stable) - return list(filter(lambda bucket: bucket >= bmin, buckets)) - - -def generate_prompt_buckets(bs_bucket_config, - seq_bucket_config, - max_num_batched_tokens=None): - buckets = list( - itertools.product(warmup_range(bs_bucket_config), - warmup_range(seq_bucket_config))) - if len(buckets) == 0: - msg = ("No buckets could be captured with following config " - f"(min, step, max_warmup): " - f"bs:{bs_bucket_config}, " - f"seq:{seq_bucket_config}") - raise ValueError(msg) - - filtered_buckets = buckets - if max_num_batched_tokens is not None: - # Remove buckets exceeding batch token budget - filtered_buckets = list( - filter( - lambda bucket: bucket[0] * bucket[1] <= max_num_batched_tokens, - buckets)) - - if len(filtered_buckets) == 0: - # we can handle this if we ignore max_num_batched_tokens - min_bucket_bs, min_bucket_seq = min(buckets, - key=lambda b: (b[0] * b[1])) - min_reqd_budget = min_bucket_bs * min_bucket_seq - msg = ( - "The current bucketing configuration " - f"(min, step, max_warmup): " - f"bs:{bs_bucket_config}, " - f"seq:{seq_bucket_config} cannot be used with specified " - f"max_num_batched_tokens ({max_num_batched_tokens}), as the " - f"smallest bucket ({min_reqd_budget}) would exceed token " - "budget. Please increase max_num_batched_tokens or decrease " - "bucket minimum Ignoring max_num_batched_tokens at risk of " - "out-of-memory errors.") - logger.error(msg) - return list( - sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))), [] - - captured_buckets = list( - sorted(filtered_buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) - omitted_buckets = list( - sorted([x for x in buckets if x not in filtered_buckets])) - return captured_buckets, omitted_buckets - - -def generate_decode_buckets(bs_bucket_config, blocks_bucket_config, - max_blocks): - buckets = [] - bs_buckets = warmup_range(bs_bucket_config) - block_buckets = warmup_range(blocks_bucket_config) - bmin, bstep, bmax = blocks_bucket_config - last_bucket = round_up(max_blocks, bstep) - for bs in bs_buckets: - for blocks in block_buckets: - if blocks < bs: - continue - if blocks > last_bucket: - break - buckets.append((bs, blocks)) - return list(sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) - - -def next_pow2(value: int, base: int): - res = base - while value > 1: - value = (value + 1) // 2 - res *= 2 - return res - - def round_up(value: int, k: int): return (value + k - 1) // k * k -def find_bucket(value: int, config: Tuple[int, int, int]): - bmin, bstep, _ = config - next_step = round_up(value, bstep) - next_pow = next_pow2(value, bmin) - return max(bmin, min(next_step, next_pow)) - - def align_workers(value, op): group = get_world_group().cpu_group world_size = torch.distributed.get_world_size() @@ -314,6 +171,7 @@ class HpuModelAdapter: def __init__(self, model, vllm_config): self.model = model + self.sampler = get_sampler() self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', '0').lower() in ['1', 'true'] self.vllm_config = vllm_config @@ -403,16 +261,6 @@ class HpuModelAdapter: attn_bias=attn_bias) return metadata - def _set_block_scales(self, metadata, device): - block_mapping = metadata.block_mapping - ones = torch.ones((block_mapping.size(0), ), - device=device, - dtype=block_mapping.dtype) - sums = batch2block(block2batch(ones, block_mapping), block_mapping) - block_scales = torch.reciprocal(torch.maximum(ones, sums)) - metadata = metadata._replace(block_scales=block_scales) - return metadata - def _update_metadata(self, attn_metadata, batch_size, seq_len, device, dtype): if attn_metadata.is_prompt: @@ -423,7 +271,6 @@ class HpuModelAdapter: meta = attn_metadata attn_metadata = self._set_block_mapping(meta, batch_size, device, dtype) - attn_metadata = self._set_block_scales(attn_metadata, device) return attn_metadata def forward(self, *args, **kwargs): @@ -452,7 +299,7 @@ class HpuModelAdapter: return self.model.compute_logits(*args, **kwargs) def sample(self, *args, **kwargs): - return self.model.sample(*args, **kwargs) + return self.sampler(*args, **kwargs) class PreparePromptMetadata(NamedTuple): @@ -622,6 +469,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): return_hidden_states: bool = False, ): ModelRunnerBase.__init__(self, vllm_config=vllm_config) + environment.set_model_config(self.model_config) self.is_driver_worker = is_driver_worker self.return_hidden_states = return_hidden_states @@ -661,13 +509,21 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): self.profiler_counter_helper = HabanaProfilerCounterHelper() self.seen_configs: set = set() self._mem_margin: Optional[int] = None - self.bucketing_global_state = HPUBucketingGlobalState() - self._setup_buckets() + HPUBucketingContext = get_bucketing_context() + self.bucketing_ctx = HPUBucketingContext(self.max_num_seqs, + self.max_num_prefill_seqs, + self.block_size, + self.max_num_batched_tokens, + False, self.max_model_len) + self.graphed_buckets: Set[Any] = set() self._set_gc_threshold() self.use_contiguous_pa = envs.VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH # For multi-step scheduling self.cached_step_outputs: List[torch.Tensor] = [] + # For delayed sampling + self.cached_step_inputs: List[ + ModelInputForHPUWithSamplingMetadata] = [] def _set_gc_threshold(self) -> None: # Read https://docs.python.org/3/library/gc.html#gc.set_threshold @@ -688,10 +544,6 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ] gc.set_threshold(*requested_gc_thrs) - # Multi-modal data support - self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ - .create_input_mapper(self.model_config) - self.skip_warmup = os.environ.get('VLLM_SKIP_WARMUP', 'false').lower() == 'true' @@ -718,14 +570,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): "Bias support in LoRA is not enabled in HPU yet." assert not self.lora_config.fully_sharded_loras, \ "Fully sharded LoRAs is not enabled in HPU yet." - # It's necessary to distinguish between the - # max_position_embeddings of VLMs and LLMs. - if hasattr(self.model.config, "max_position_embeddings"): - max_pos_embeddings = ( - self.model.config.max_position_embeddings) - else: - max_pos_embeddings = ( - self.model.config.text_config.max_position_embeddings) + + # Use get_text_config() in case of multimodal models + text_config = self.model_config.hf_config.get_text_config() self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, @@ -735,7 +582,8 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): self.device, self.model.embedding_modules, self.model.embedding_padding_modules, - max_position_embeddings=max_pos_embeddings, + max_position_embeddings=text_config. + max_position_embeddings, ) self.model = self.lora_manager.create_lora_manager(self.model) @@ -771,6 +619,27 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): msg = f"Loading model weights took in total {m.get_summary_string()}" logger.info(msg) + def _add_dummy_seq(self, seq_group_metadata_list, is_prompt): + real_batch_size = len(seq_group_metadata_list) + batch_size_padded = self.bucketing_ctx.get_padded_batch_size( + real_batch_size, is_prompt) + batch_size_padding = batch_size_padded - real_batch_size + + seq_group_metadata_list = seq_group_metadata_list.copy() + + if batch_size_padding > 0: + dummy_seq_group_metadata = self.create_dummy_seq_group_metadata( + 0, 0, is_prompt) + seq_group_metadata_list.extend(dummy_seq_group_metadata + for _ in range(batch_size_padding)) + return seq_group_metadata_list, real_batch_size, batch_size_padded + + def _maybe_wrap_in_hpu_graph(self, *args, **kwargs): + return htorch.hpu.wrap_in_hpu_graph( + HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True + ) if htorch.utils.internal.is_lazy() else HpuModelAdapter( + *args, **kwargs) + def get_model(self) -> nn.Module: return self.model @@ -784,46 +653,6 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): def _is_valid_bucket(self, bucket): return bucket[0] * bucket[1] <= self.max_num_batched_tokens - def _setup_buckets(self) -> None: - align_bs = lambda x: min(self.max_num_seqs, x) - #FIXME: The default values should be max_model_len - max_prompt_seq = 1024 - max_decode_seq = 2048 - self.bucketing_global_state.prompt_bs_bucket_cfg = read_bucket_settings( - 'prompt', - 'bs', - min=1, - step=align_bs(32), - max=self.max_num_prefill_seqs) - self.bucketing_global_state.decode_bs_bucket_cfg = read_bucket_settings( - 'decode', 'bs', min=1, step=align_bs(32), max=self.max_num_seqs) - self.bucketing_global_state.prompt_seq_bucket_cfg = \ - read_bucket_settings( - 'prompt', - 'seq', - min=self.block_size, - step=self.block_size, - max=max_prompt_seq) - self.bucketing_global_state.decode_block_bucket_cfg = \ - read_bucket_settings( - 'decode', - 'block', - min=self.block_size, - step=self.block_size, - max=max(self.block_size, - self.max_num_seqs * max_decode_seq // self.block_size)) - self.graphed_buckets: Set[Any] = set() - - msg = ("Prompt bucket config (min, step, max_warmup) " - f"bs:{self.bucketing_global_state.prompt_bs_bucket_cfg}, " - f"seq:{self.bucketing_global_state.prompt_seq_bucket_cfg}") - logger.info(msg) - - msg = ("Decode bucket config (min, step, max_warmup) " - f"bs:{self.bucketing_global_state.decode_bs_bucket_cfg}, " - f"block:{self.bucketing_global_state.decode_block_bucket_cfg}") - logger.info(msg) - def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -897,9 +726,8 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): # is always the first token in the sequence. input_positions.append(list(range(context_len, seq_len))) - mm_data = seq_group_metadata.multi_modal_data - if mm_data: - mm_kwargs = self.multi_modal_input_mapper(mm_data) + mm_kwargs = seq_group_metadata.multi_modal_data + if mm_kwargs: multi_modal_kwargs_list.append(mm_kwargs) if seq_group_metadata.block_tables is None: @@ -939,8 +767,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): assert max_query_len > 0 max_prompt_len = max( - find_bucket(max(seq_lens), - self.bucketing_global_state.prompt_seq_bucket_cfg), + self.bucketing_ctx.get_padded_prompt_seq_len(max_query_len), self.block_size) lora_ids: List[int] = [] @@ -989,7 +816,6 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): block_usage=None, block_indices=block_indices, block_offsets=block_offsets, - block_scales=None, block_groups=None, attn_bias=None, seq_lens_tensor=seq_lens_tensor, @@ -1116,9 +942,8 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): padding_fn = None if self.use_contiguous_pa: block_bucket_size = max(max(block_list) + 1, len(block_list)) - block_bucket_size = find_bucket( - block_bucket_size, - self.bucketing_global_state.decode_block_bucket_cfg) + block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks( + block_bucket_size) indices: List[Any] indices = [None] * block_bucket_size for i, bid in enumerate(block_list): @@ -1126,9 +951,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): padding_fn = lambda tensor, pad_value: gather_list( tensor, indices, pad_value) else: - block_bucket_size = find_bucket( - len(block_list), - self.bucketing_global_state.decode_block_bucket_cfg) + block_bucket_size = \ + self.bucketing_ctx.get_padded_decode_num_blocks( + len(block_list)) padding_fn = lambda tensor, pad_value: pad_list( tensor, block_bucket_size, pad_value) @@ -1159,7 +984,6 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): block_usage=block_usage, block_indices=block_indices, block_offsets=block_offsets, - block_scales=None, block_groups=block_groups, attn_bias=None, seq_lens_tensor=None, @@ -1202,17 +1026,8 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): base_event_name = 'prompt' if is_prompt else 'decode' self.profiler.start('internal', base_event_name) - real_batch_size = len(seq_group_metadata_list) - bucket_cfg = self.bucketing_global_state.prompt_bs_bucket_cfg \ - if is_prompt else self.bucketing_global_state.decode_bs_bucket_cfg - batch_size_padded = find_bucket(real_batch_size, bucket_cfg) - batch_size_padding = batch_size_padded - real_batch_size - seq_group_metadata_list = seq_group_metadata_list.copy() - if batch_size_padding > 0: - dummy_seq_group_metadata = self.create_dummy_seq_group_metadata( - 0, 0, is_prompt) - seq_group_metadata_list.extend(dummy_seq_group_metadata - for _ in range(batch_size_padding)) + seq_group_metadata_list, real_batch_size, batch_size_padded = ( + self._add_dummy_seq(seq_group_metadata_list, is_prompt)) prefill_reqs = [] decode_reqs = [] @@ -1374,7 +1189,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [ 'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping', 'block_usage', 'slot_mapping', 'is_prompt', 'block_indices', - 'block_offsets', 'block_scales', 'block_groups' + 'block_offsets', 'block_groups' ]) return attention_metadata @@ -1412,16 +1227,18 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): bind_kv_cache( self.vllm_config.compilation_config.static_forward_context, [kv_caches]) - max_seq_len = self.bucketing_global_state.prompt_seq_bucket_cfg[-1] - max_batch_size = min(self.max_num_batched_tokens // max_seq_len, - self.scheduler_config.max_num_seqs) - self.warmup_scenario(max_batch_size, max_seq_len, True, False, True) + _, max_seq_len = self.bucketing_ctx.get_max_prompt_shape() + max_batch_size = min(self.max_num_seqs, + self.max_num_batched_tokens // max_seq_len) + self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches, + False, True) return def warmup_scenario(self, batch_size, seq_len, is_prompt, + kv_caches, is_pt_profiler_run=False, is_lora_profile_run=False) -> None: use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) @@ -1557,16 +1374,17 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): f"free_mem:{free_mem}") logger.info(msg) - def warmup_all_buckets(self, buckets, is_prompt): + def warmup_all_buckets(self, buckets, is_prompt, kv_caches): for i, (batch_size, seq_len) in enumerate(reversed(buckets)): self.log_warmup('Prompt' if is_prompt else 'Decode', i, len(buckets), batch_size, seq_len) - self.warmup_scenario(batch_size, seq_len, is_prompt) + self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) def warmup_graphs(self, strategy, buckets, is_prompt, + kv_caches, available_mem, starting_mem=0, total_batch_seq=0.001): @@ -1598,7 +1416,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): self.graphed_buckets.add(graphed_bucket) self.log_warmup(phase, idx, num_candidates, batch_size, seq_len) with HabanaMemoryProfiler() as mem_prof: - self.warmup_scenario(batch_size, seq_len, is_prompt) + self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) used_mem = align_workers(mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX) available_mem -= used_mem @@ -1622,50 +1440,21 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): @torch.inference_mode() def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: + max_blocks = kv_caches[0][0].size(0) + self.bucketing_ctx.generate_decode_buckets(max_blocks) if profile := os.environ.get('VLLM_PT_PROFILE', None): phase, bs, seq_len, graph = profile.split('_') is_prompt = phase == 'prompt' graphs = graph == 't' if graphs: self.graphed_buckets.add((int(bs), int(seq_len), is_prompt)) - self.warmup_scenario(int(bs), int(seq_len), is_prompt, True) + self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches, + True) raise AssertionError("Finished profiling") - if self.skip_warmup: - logger.info("Skipping warmup...") - return - self.profiler.start('internal', 'warmup') - max_blocks = kv_caches[0][0].size(0) - - self.bucketing_global_state.prompt_buckets, prompt_omitted_buckets = \ - generate_prompt_buckets( - self.bucketing_global_state.prompt_bs_bucket_cfg, - self.bucketing_global_state.prompt_seq_bucket_cfg, - self.max_num_batched_tokens) - - msg = (f"Generated {len(self.bucketing_global_state.prompt_buckets)} " - f"prompt buckets [bs, seq]: \ - {list(sorted(self.bucketing_global_state.prompt_buckets))}") - logger.info(msg) - - msg = (f"Omitted {len(prompt_omitted_buckets)} " - "prompt buckets due to exceeded token budget " - f"(max_num_batched_tokens={self.max_num_batched_tokens})") - logger.info(msg) - - msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}" - logger.debug(msg) - - self.bucketing_global_state.decode_buckets = generate_decode_buckets( - self.bucketing_global_state.decode_bs_bucket_cfg, - self.bucketing_global_state.decode_block_bucket_cfg, max_blocks) - logger.info("Generated %d decode buckets [bs, total_blocks]: %s", - len(self.bucketing_global_state.decode_buckets), - list(sorted(self.bucketing_global_state.decode_buckets))) - if not htorch.utils.internal.is_lazy() and not self.enforce_eager: cache_size_limit = 1 + 3 * ( - len(self.bucketing_global_state.prompt_buckets) + - len(self.bucketing_global_state.decode_buckets)) + len(self.bucketing_ctx.prompt_buckets) + + len(self.bucketing_ctx.decode_buckets)) torch._dynamo.config.cache_size_limit = max( cache_size_limit, torch._dynamo.config.cache_size_limit) # Multiply by 8 to follow the original default ratio between @@ -1673,7 +1462,10 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): torch._dynamo.config.accumulated_cache_size_limit = max( cache_size_limit * 8, torch._dynamo.config.accumulated_cache_size_limit) - + if self.skip_warmup: + logger.info("Skipping warmup...") + return + self.profiler.start('internal', 'warmup') start_mem = HabanaMemoryProfiler.current_device_memory_usage() start_time = time.perf_counter() @@ -1692,10 +1484,12 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): 'Please update Gaudi Software Suite.') with compile_only_mode_context( ) if can_use_compile_only_mode else contextlib.nullcontext(): - self.warmup_all_buckets(self.bucketing_global_state.prompt_buckets, - True) - self.warmup_all_buckets(self.bucketing_global_state.decode_buckets, - False) + print("aa") + self.warmup_all_buckets(self.bucketing_ctx.prompt_buckets, True, + kv_caches) + print("bb") + self.warmup_all_buckets(self.bucketing_ctx.decode_buckets, False, + kv_caches) if not self.enforce_eager and htorch.utils.internal.is_lazy(): assert self.mem_margin is not None, \ @@ -1725,12 +1519,12 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): 'max_bs') mem_post_prompt, prompt_batch_seq, prompt_captured_all = \ self.warmup_graphs( - prompt_strategy, self.bucketing_global_state.prompt_buckets, - True, prompt_available_memory) + prompt_strategy, self.bucketing_ctx.prompt_buckets, + True, kv_caches, prompt_available_memory) mem_post_decode, decode_batch_seq, decode_captured_all = \ self.warmup_graphs( - decode_strategy, self.bucketing_global_state.decode_buckets, - False, decode_available_memory) + decode_strategy, self.bucketing_ctx.decode_buckets, + False, kv_caches, decode_available_memory) # Not all prompt buckets were captured, but all decode buckets # were captured and we have some free graph-allocated space @@ -1739,8 +1533,8 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): and not prompt_captured_all and decode_captured_all): mem_post_prompt, _, prompt_captured_all = ( self.warmup_graphs( - prompt_strategy, - self.bucketing_global_state.prompt_buckets, True, + prompt_strategy, self.bucketing_ctx.prompt_buckets, + True, kv_caches, graph_free_mem - mem_post_prompt - mem_post_decode, mem_post_prompt, prompt_batch_seq)) @@ -1751,17 +1545,15 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): and not decode_captured_all \ and prompt_captured_all: mem_post_decode, _, _ = self.warmup_graphs( - decode_strategy, - self.bucketing_global_state.decode_buckets, False, + decode_strategy, self.bucketing_ctx.decode_buckets, + False, kv_caches, graph_free_mem - mem_post_prompt - mem_post_decode, mem_post_decode, decode_batch_seq) self.log_graph_warmup_summary( - self.bucketing_global_state.prompt_buckets, True, - mem_post_prompt) + self.bucketing_ctx.prompt_buckets, True, mem_post_prompt) self.log_graph_warmup_summary( - self.bucketing_global_state.decode_buckets, False, - mem_post_decode) + self.bucketing_ctx.decode_buckets, False, mem_post_decode) end_time = time.perf_counter() end_mem = HabanaMemoryProfiler.current_device_memory_usage() @@ -2020,6 +1812,21 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): return lora_mask, lora_logits_mask + def _get_seq_ids(self, model_input): + return ([ + sg.seq_ids[0] for sg in model_input.sampling_metadata.seq_groups + ]) + + def _pad_to_max_num_seqs(self, tensor, value): + padding_needed = self.max_num_seqs - tensor.size(0) + if padding_needed: + padding = torch.full((padding_needed, *tensor.shape[1:]), + value, + device=tensor.device, + dtype=tensor.dtype) + tensor = torch.cat([tensor, padding]) + return tensor + @torch.inference_mode() def execute_model( self, @@ -2030,6 +1837,37 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): warmup_mode=False, seqs=None, ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: + VLLM_DELAYED_SAMPLING = envs.VLLM_HPU_USE_DELAYED_SAMPLING + use_delayed_sampling = VLLM_DELAYED_SAMPLING and not warmup_mode + assert not (use_delayed_sampling and num_steps != 1), \ + 'Delayed sampling is not compatible with MSS!' + assert model_input.input_tokens is not None + if use_delayed_sampling and not model_input.is_prompt and \ + self.is_driver_worker: + num_cached = len(self.cached_step_outputs) + assert num_cached > 0 + cur_seq_ids = self._get_seq_ids(model_input) + cur_seq_id_pos = { + sid: idx + for idx, sid in enumerate(cur_seq_ids) if sid >= 0 + } + htorch.core.mark_step() + for i in range(num_cached): + prev_seq_ids = self._get_seq_ids(self.cached_step_inputs[i]) + target_indices = [ + cur_seq_id_pos.get(psi, -1) for psi in prev_seq_ids + ] + padding = self.cached_step_outputs[i].size(0) - len( + target_indices) + target_indices.extend([-1] * padding) + target_indices = torch.tensor( + target_indices, + device=model_input.input_tokens.device, + dtype=model_input.input_tokens.dtype) + model_input.input_tokens.index_copy_( + 0, target_indices, self.cached_step_outputs[i]) + htorch.core.mark_step() + if not model_input.is_first_multi_step: if not model_input.is_last_step: # not first or last multi-step @@ -2045,7 +1883,21 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): assert model_input.lora_mapping is not None self.set_active_loras(model_input.lora_requests, model_input.lora_mapping) - input_tokens = model_input.input_tokens + # Rank!=0 workers has is_prompt==None + if use_delayed_sampling and not model_input.is_prompt and \ + model_input.input_tokens.size(1) == 1: + if self.is_driver_worker: + model_kwargs_broadcast_data = { + "input_tokens": model_input.input_tokens + } + broadcast_tensor_dict(model_kwargs_broadcast_data, src=0) + input_tokens = model_input.input_tokens + + else: + model_kwargs_broadcast_data = broadcast_tensor_dict(src=0) + input_tokens = model_kwargs_broadcast_data["input_tokens"] + else: + input_tokens = model_input.input_tokens input_positions = model_input.input_positions attn_metadata = model_input.attn_metadata sampling_metadata = model_input.sampling_metadata @@ -2092,11 +1944,11 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): f"graphs{'T' if use_graphs else 'F'}") else: model_event_name = 'model_executable' - if num_steps > 1: + if num_steps > 1 or use_delayed_sampling: # in case of multi-step scheduling # we only want to pythonize in the last step sampling_metadata.skip_sampler_cpu_output = True - self.model.model.sampler.include_gpu_probs_tensor = True + self.model.sampler.include_gpu_probs_tensor = True cache_orig_output_tokens_len: List[Dict] = [] def try_revert_dummy_output_tokens(): @@ -2152,9 +2004,9 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): if not self.is_driver_worker: continue - if model_input.async_callback is not None: - model_input.async_callback() - # Sample the next token. + if use_delayed_sampling: + fake_output = self._delayed_sampler_outputs(model_input) + with self.profiler.record_event( 'internal', ('sample_' f'{"prompt" if is_prompt else "decode"}_' @@ -2166,9 +2018,16 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): ) if num_steps > 1: output = output.sampled_token_ids - self.cached_step_outputs.append( - output.detach().clone()) + self.cached_step_outputs.append(output) + if use_delayed_sampling and self.is_driver_worker: + self._patch_prev_output() + output = self._pad_to_max_num_seqs( + output.sampled_token_ids, DUMMY_TOKEN_ID) + self.cached_step_outputs.append(output) + self.cached_step_inputs.append(model_input) htorch.core.mark_step() + if model_input.async_callback is not None: + model_input.async_callback() if i < num_steps - 1: if i == 0: if model_input.async_callback is not None: @@ -2241,11 +2100,30 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): is_prompt=is_prompt) self.profiler.record_counter(self.event_start, counters) if num_steps == 1: + if self.return_hidden_states: + # we only need to pass hidden states of most recent token + assert model_input.sampling_metadata is not None + if model_input.is_prompt: + output.prefill_hidden_states = hidden_states + output.hidden_states = hidden_states + if use_delayed_sampling: + if self.is_driver_worker: + return [fake_output] + else: + return [] + return [output] if self.is_driver_worker else [] else: return [] return output if type(output) is list else [output] + def _delayed_sampler_outputs(self, model_input): + next_token_ids = [[DUMMY_TOKEN_ID]] * len( + model_input.sampling_metadata.seq_groups) + sampler_output = self._make_decode_output( + next_token_ids, model_input.sampling_metadata.seq_groups) + return sampler_output + def _decode_sampler_outputs(self, model_input): use_async_out_proc = model_input.async_callback is not None sampler_outputs = [] @@ -2312,3 +2190,32 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): def __del__(self): self.shutdown_inc() + + def _patch_prev_output(self): + assert len(self.cached_step_inputs) == len(self.cached_step_outputs), \ + f'''Inputs and outputs are out of sync! + {len(self.cached_step_inputs)} vs {len(self.cached_step_outputs)}''' + if len(self.cached_step_inputs) == 0: + return + model_input = self.cached_step_inputs.pop(0) + delayed_output = self.cached_step_outputs.pop(0).cpu().squeeze( + -1).tolist() + ctx = model_input.async_callback.keywords["ctx"] # type: ignore + # If there's no output to patch with, which is usually the case when + # we're starting a new request after all requests are completed. + if len(ctx.output_queue) == 0: + return + assert len( + ctx.output_queue) == 1, 'There should be exactly 1 output waiting!' + output_data = ctx.output_queue[0] + assert len(output_data.outputs) == 1 + for fake_out, real_out in zip(output_data.outputs[0], delayed_output): + fake_out.samples[0].output_token = real_out + for sg, real_out in zip(output_data.seq_group_metadata_list, + delayed_output): + assert len(sg.seq_data) == 1 + seq_data = list(sg.seq_data.values())[0] + # This is a hack. Assigning output_token_ids triggers + # a cache recomputation and we only need to update the last token + seq_data.output_token_ids_array[-1] = real_out + seq_data._cached_all_token_ids[-1] = real_out diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py index ccb175d88fd3c38dafe2bde64e485ce47a098958..8d7d5d7adc1058b86fb62f8ce3f41bcdd5f37a14 100644 --- a/vllm/worker/hpu_worker.py +++ b/vllm/worker/hpu_worker.py @@ -245,6 +245,7 @@ class HPUWorker(LocalOrDistributedWorkerBase): cache_block_size) num_hpu_blocks = max(num_hpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) + self.model_runner.bucketing_ctx.num_hpu_blocks = num_hpu_blocks if self.model_runner.lora_manager: self.model_runner.remove_all_loras() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 4c8ebe32bf24266f3bbc492d8924e965c64a306a..de124b4eca5463447253d2d5585bae23e9eae1dc 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -24,7 +24,8 @@ from vllm.attention.backends.abstract import AttentionState from vllm.attention.backends.utils import CommonAttentionState from vllm.config import CompilationLevel, VllmConfig from vllm.core.scheduler import SchedulerOutputs -from vllm.distributed import get_kv_transfer_group, get_pp_group +from vllm.distributed import get_pp_group +from vllm.distributed.kv_transfer import get_kv_transfer_group from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, graph_capture) from vllm.forward_context import get_forward_context, set_forward_context @@ -35,7 +36,7 @@ from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata, SamplingMetadataCache from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding -from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.models import supports_lora, supports_multimodal @@ -457,7 +458,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): self.enable_lora = self.runner.lora_config is not None self.enable_prompt_adapter = (self.runner.prompt_adapter_config is not None) - self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper # Attention metadata inputs. if self.attn_backend is not None: @@ -678,23 +678,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, seq_group_metadata: SequenceGroupMetadata): """If multi-modal data is given, add it to the input.""" - # NOTE: mm_data only includes the subset of multi-modal items that + # NOTE: mm_kwargs only includes the subset of multi-modal items that # intersect with the current prefill positions. positions = inter_data.input_positions[0] - mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( + mm_kwargs, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( seq_group_metadata, range(positions[0], positions[0] + len(positions))) - if not mm_data: + if not mm_kwargs: return - if self.runner.mm_registry.has_processor(self.runner.model_config): - mm_kwargs = mm_data - else: - mm_kwargs = self.multi_modal_input_mapper( - mm_data, - seq_group_metadata.mm_processor_kwargs, - ) - inter_data.multi_modal_kwargs = mm_kwargs inter_data.multi_modal_placeholder_maps = placeholder_maps @@ -702,11 +694,17 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): if self.runner.model_config.uses_mrope: image_grid_thw = mm_kwargs.get("image_grid_thw", None) video_grid_thw = mm_kwargs.get("video_grid_thw", None) - assert image_grid_thw is not None or video_grid_thw is not None, ( - "mrope embedding type requires multi-modal input mapper " - "returns 'image_grid_thw' or 'video_grid_thw'.") + audio_feature_lengths = mm_kwargs.get("audio_feature_lengths", + None) + assert ( + image_grid_thw is not None or video_grid_thw is not None + or audio_feature_lengths is not None), ( + "mrope embedding type requires multi-modal input mapper " + "returns 'image_grid_thw' or 'video_grid_thw' or " + "'audio_feature_lengths'.") second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None) + use_audio_in_video = mm_kwargs.get("use_audio_in_video", False) hf_config = self.runner.model_config.hf_config inter_data.mrope_input_positions = [None] * inter_data.n_seqs @@ -724,6 +722,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): second_per_grid_ts=second_per_grid_ts, context_len=inter_data.context_lens[seq_idx], seq_len=inter_data.seq_lens[seq_idx], + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, ) seq_data.mrope_position_delta = mrope_position_delta @@ -1080,15 +1080,13 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): # Multi-modal data support self.input_registry = input_registry self.mm_registry = mm_registry - self.multi_modal_input_mapper = mm_registry \ - .create_input_mapper(model_config) - self.mm_registry.init_mm_limits_per_prompt(self.model_config) # Lazy initialization self.model: nn.Module # Set after load_model # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None + self.sampler = get_sampler() set_cpu_offload_max_bytes( int(self.cache_config.cpu_offload_gb * 1024**3)) @@ -1128,14 +1126,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): logger.warning( "Regarding multimodal models, vLLM currently " "only supports adding LoRA to language model.") - # It's necessary to distinguish between the - # max_position_embeddings of VLMs and LLMs. - if hasattr(self.model.config, "max_position_embeddings"): - max_pos_embeddings = ( - self.model.config.max_position_embeddings) - else: - max_pos_embeddings = ( - self.model.config.text_config.max_position_embeddings) + + # Use get_text_config() in case of multimodal models + text_config = self.model_config.hf_config.get_text_config() self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, @@ -1145,7 +1138,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): self.device, self.model.embedding_modules, self.model.embedding_padding_modules, - max_position_embeddings=max_pos_embeddings, + max_position_embeddings=text_config. + max_position_embeddings, ) self.model = self.lora_manager.create_lora_manager(self.model) time_after_load = time.perf_counter() @@ -1329,8 +1323,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): dummy_data = self.input_registry \ .dummy_data_for_profiling(self.model_config, - seq_len, - self.mm_registry) + seq_len, + self.mm_registry) seq = SequenceGroupMetadata( request_id=str(group_id), @@ -1832,7 +1826,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): model_input.async_callback() # Sample the next token. - output: SamplerOutput = self.model.sample( + output: SamplerOutput = self.sampler( logits=logits, sampling_metadata=model_input.sampling_metadata, ) diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 7ddf382079c624c2b3642e6d87880790a15c5343..a6f5ec825635bbe025f5e70f4452f1c2db20d591 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -488,8 +488,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): device="cpu", pin_memory=True) - self._base_model_runner.model.sampler.include_gpu_probs_tensor = ( - True) + self._base_model_runner.sampler.include_gpu_probs_tensor = True if frozen_model_input.sampling_metadata: frozen_model_input.sampling_metadata.skip_sampler_cpu_output = ( True) diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index f2093fc42ad16d26b3a194a6b39bc463efc9637a..e046ebc449deeb8dc4afee06140ccde428f92acc 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -15,8 +15,7 @@ from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.neuron import get_neuron_model -from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalKwargs) +from vllm.multimodal import BatchedTensorInputs, MultiModalKwargs from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase @@ -69,11 +68,6 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): self.device = self.device_config.device self.pin_memory = is_pin_memory_available() - # Multi-modal data support - self.mm_registry = MULTIMODAL_REGISTRY - self.multi_modal_input_mapper = self.mm_registry \ - .create_input_mapper(self.model_config) - # Lazy initialization. self.model: nn.Module # initialize after load_model. @@ -149,16 +143,8 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): assert len(block_table) == 1 input_block_ids.append(block_table[0]) - mm_data = seq_group_metadata.multi_modal_data - if mm_data: - if self.mm_registry.has_processor(self.model_config): - mm_kwargs = mm_data - else: - mm_kwargs = self.multi_modal_input_mapper( - mm_data, - seq_group_metadata.mm_processor_kwargs, - ) - + mm_kwargs = seq_group_metadata.multi_modal_data + if mm_kwargs: multi_modal_kwargs_list.append(mm_kwargs) max_seq_len = max(seq_lens) diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index aacd239f1807739ca4ae65eb829f43cb572f4605..f4499e3894d0d8654e52b38c5f39636675d4596e 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -163,8 +163,8 @@ class TPUWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase): usable_memory_size = int(total_memory_size * self.cache_config.gpu_memory_utilization) tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0) - dtype_btyes = get_dtype_size(self.cache_dtype) - block_size_bytes = (dtype_btyes * self.cache_config.block_size * + dtype_bytes = get_dtype_size(self.cache_dtype) + block_size_bytes = (dtype_bytes * self.cache_config.block_size * num_layers * 2 * head_size * num_kv_heads) num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 76af09e85a2dcaac73ad49a45d961aef68e2e60f..bcb70237ce88d5eb522250da756cc340ec3fa4d4 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -10,10 +10,10 @@ import torch.distributed import vllm.envs as envs from vllm.config import VllmConfig from vllm.device_allocator.cumem import CuMemAllocator -from vllm.distributed import (ensure_kv_transfer_initialized, - ensure_model_parallel_initialized, +from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) +from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed @@ -95,6 +95,9 @@ class Worker(LocalOrDistributedWorkerBase): self.gpu_cache: Optional[List[List[torch.Tensor]]] = None self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {} + # Buffers saved before sleep + self._sleep_saved_buffers: Dict[str, torch.Tensor] = {} + # Torch profiler. Enabled and configured through env vars: # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace if envs.VLLM_TORCH_PROFILER_DIR: @@ -124,6 +127,15 @@ class Worker(LocalOrDistributedWorkerBase): def sleep(self, level: int = 1) -> None: free_bytes_before_sleep = torch.cuda.mem_get_info()[0] + + # Save the buffers before level 2 sleep + if level == 2: + model = self.model_runner.model + self._sleep_saved_buffers = { + name: buffer.cpu().clone() + for name, buffer in model.named_buffers() + } + allocator = CuMemAllocator.get_instance() allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple()) free_bytes_after_sleep, total = torch.cuda.mem_get_info() @@ -139,6 +151,14 @@ class Worker(LocalOrDistributedWorkerBase): allocator = CuMemAllocator.get_instance() allocator.wake_up(tags=tags) + # Restore the buffers after level 2 sleep + if len(self._sleep_saved_buffers): + model = self.model_runner.model + for name, buffer in model.named_buffers(): + if name in self._sleep_saved_buffers: + buffer.data.copy_(self._sleep_saved_buffers[name].data) + self._sleep_saved_buffers = {} + def init_device(self) -> None: if self.device_config.device.type == "cuda": # torch.distributed.all_reduce does not free the input tensor until diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 9d49b4385dcaa98124e48b7f3b9db395aab5b78e..7042b575aa78750354e3f5681da59cee3980ecba 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -18,7 +18,7 @@ from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger from vllm.model_executor import SamplingMetadataCache -from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalKwargs, MultiModalPlaceholderMap, @@ -188,20 +188,11 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): input_positions.extend(list(positions_range)) if seq_group_metadata.multi_modal_data: - # NOTE: mm_data only includes the subset of multi-modal items + # NOTE: mm_kwargs only includes the subset of multi-modal items # that intersect with the current prefill positions. - mm_data, placeholder_maps = MultiModalPlaceholderMap \ + mm_kwargs, placeholder_maps = MultiModalPlaceholderMap \ .from_seq_group(seq_group_metadata, positions_range) - if self.runner.mm_registry.has_processor( - self.runner.model_config): - mm_kwargs = mm_data - else: - mm_kwargs = self.runner.multi_modal_input_mapper( - mm_data, - seq_group_metadata.mm_processor_kwargs, - ) - multi_modal_kwargs_list.append(mm_kwargs) for modality, placeholder_map in placeholder_maps.items(): @@ -404,12 +395,10 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): # Multi-modal data support self.input_registry = input_registry self.mm_registry = mm_registry - self.multi_modal_input_mapper = mm_registry \ - .create_input_mapper(model_config) - self.mm_registry.init_mm_limits_per_prompt(self.model_config) # Lazy initialization. self.model: nn.Module # Set after init_Model + self.sampler = get_sampler() self.sampling_metadata_cache: SamplingMetadataCache = \ SamplingMetadataCache() \ @@ -596,7 +585,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): model_input.async_callback() # Sample the next token. - output: SamplerOutput = self.model.sample( + output: SamplerOutput = self.sampler( logits=logits, sampling_metadata=model_input.sampling_metadata, )