From e6ae4b1be1c3dca1c25d7a12058dbb1fd900caa2 Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Mon, 16 Mar 2026 17:05:51 -0400 Subject: [PATCH 001/223] [compile] Enable mega aot artifact for torch 2.12+. (#37198) Signed-off-by: zhxchen17 --- vllm/compilation/caching.py | 12 ++++-------- vllm/envs.py | 15 +++++++++++---- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/vllm/compilation/caching.py b/vllm/compilation/caching.py index 00fb95921..2b667344f 100644 --- a/vllm/compilation/caching.py +++ b/vllm/compilation/caching.py @@ -307,13 +307,6 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] num_submods = len(submod_names) num_artifacts = standalone_compile_artifacts.num_artifacts() - logger.info( - "reconstructing serializable fn from standalone compile " - "artifacts. num_artifacts=%d num_submods=%d", - num_artifacts, - num_submods, - ) - with functorch_ctx: fn = reconstruct_serializable_fn_from_mega_artifact( state=state, @@ -324,7 +317,10 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] ) logger.info( - "reconstructed serializable fn from standalone compile artifacts" + "reconstructed serializable fn from standalone compile " + "artifacts. num_artifacts=%d num_submods=%d", + num_artifacts, + num_submods, ) return fn diff --git a/vllm/envs.py b/vllm/envs.py index caa2fb38a..d6240df36 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -296,6 +296,16 @@ def use_aot_compile() -> bool: ) +def use_mega_aot_artifact(): + from vllm.utils.torch_utils import is_torch_equal_or_newer + + default_value = ( + "1" if is_torch_equal_or_newer("2.12.0.dev") and use_aot_compile() else "0" + ) + + return os.environ.get("VLLM_USE_MEGA_AOT_ARTIFACT", default_value) == "1" + + def env_with_choices( env_name: str, default: str | None, @@ -616,10 +626,7 @@ environment_variables: dict[str, Callable[[], Any]] = { # Enable loading compiled models directly from cached standalone compile artifacts # without re-splitting graph modules. This reduces overhead during model # loading by using reconstruct_serializable_fn_from_mega_artifact. - "VLLM_USE_MEGA_AOT_ARTIFACT": lambda: os.environ.get( - "VLLM_USE_MEGA_AOT_ARTIFACT", "0" - ) - == "1", + "VLLM_USE_MEGA_AOT_ARTIFACT": use_mega_aot_artifact, # local rank of the process in the distributed setting, used to determine # the GPU device id "LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")), -- GitLab From c0f011918da543f1323833c8ee2bfcac99e0452a Mon Sep 17 00:00:00 2001 From: Krish Gupta Date: Tue, 17 Mar 2026 02:41:33 +0530 Subject: [PATCH 002/223] [Bugfix] opcheck false mutation error in rms_norm_per_block_quant (#36688) (#36779) Signed-off-by: Krish Gupta --- ...fused_layernorm_dynamic_per_token_quant.cu | 9 +++++++++ .../core/test_fused_quant_layernorm.py | 19 ++++++++++--------- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index e178f2526..723ca8142 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -286,6 +286,15 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, "Outer scale stride must be 1 when scales are not transposed"); } + int64_t hidden_size = input.size(-1); + TORCH_CHECK(hidden_size > 0 && hidden_size % group_size == 0, + "hidden_size must be a positive multiple of group_size"); + int64_t num_tokens = input.numel() / hidden_size; + int64_t num_groups = hidden_size / group_size; + TORCH_CHECK(scales.numel() >= num_tokens * num_groups, + "scales buffer too small: need ", num_tokens * num_groups, + " elements, got ", scales.numel()); + rms_norm_per_block_quant_dispatch(out, input, weight, scales, group_size, var_epsilon, scale_ub, residual, is_scale_transposed); diff --git a/tests/kernels/core/test_fused_quant_layernorm.py b/tests/kernels/core/test_fused_quant_layernorm.py index fe06605af..f9c01f4f1 100644 --- a/tests/kernels/core/test_fused_quant_layernorm.py +++ b/tests/kernels/core/test_fused_quant_layernorm.py @@ -280,21 +280,22 @@ def test_rms_norm( assert torch.allclose(ref_residual, ops_residual) output = torch.empty(x.shape, dtype=quant_dtype, device=x.device) - scales = torch.empty( - (x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32 - ) - if group_size is None: + scales = torch.empty( + (x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32 + ) opcheck( torch.ops._C.rms_norm_dynamic_per_token_quant, (output, x, layer.weight, scales, 1e-5, scale_ub, residual), ) else: - # TODO(luka/eliza) opcheck is broken? - # Somehow the cloned args are getting mutated in-place, - # which causes the opcheck to fail. - # https://github.com/vllm-project/vllm/issues/36688 - return + assert hidden_size % group_size[1] == 0 + num_groups = hidden_size // group_size[1] + scales = torch.empty( + (num_groups, num_tokens), + device=x.device, + dtype=torch.float32, + ).transpose(0, 1) opcheck( torch.ops._C.rms_norm_per_block_quant, ( -- GitLab From fd4d96302a2999a8d773b1b331951d232e3f5e05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elvir=20Crn=C4=8Devi=C4=87?= Date: Mon, 16 Mar 2026 23:03:54 +0100 Subject: [PATCH 003/223] Fix eplb nvfp4 experts hook (#37217) Signed-off-by: Elvir Crncevic Signed-off-by: Elvir Crncevic Co-authored-by: Tyler Michael Smith Co-authored-by: Claude Opus 4.6 --- .../layers/fused_moe/cutlass_moe.py | 7 ++++++ .../fused_moe/experts/trtllm_nvfp4_moe.py | 23 +++++++++++++++---- .../fused_moe/flashinfer_cutedsl_moe.py | 4 ++++ .../fused_moe/flashinfer_cutlass_moe.py | 5 ++++ vllm/model_executor/layers/fused_moe/layer.py | 18 +++++++++------ .../layers/fused_moe/modular_kernel.py | 3 +++ .../layers/fused_moe/oracle/nvfp4.py | 10 ++++---- .../compressed_tensors_moe.py | 1 + .../layers/quantization/modelopt.py | 1 + .../quantization/utils/flashinfer_fp4_moe.py | 10 -------- 10 files changed, 57 insertions(+), 25 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 51a97e0a2..534cab1b8 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -659,6 +659,13 @@ def run_cutlass_moe_fp4( class CutlassExpertsFp4(mk.FusedMoEExpertsModular): """CUTLASS FP4 fused MoE expert implementation.""" + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Fuse activation scales into w_scale_2 in-place so that + # g1/g2_alphas (which reference the same tensor) stay in sync + # when EPLB rearranges the parameter. + layer.w13_weight_scale_2.data.mul_(layer.w13_input_scale) + layer.w2_weight_scale_2.data.mul_(layer.w2_input_scale) + @property def expects_unquantized_inputs(self) -> bool: return True diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py index 174c581b3..87b1eb9fd 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py @@ -56,10 +56,25 @@ class TrtLlmNvFp4ExpertsBase: # g1_scale_c = a13_scale * w13_scale_2 / a2_scale self.g1_scale_c = self.quant_config.g1_alphas * self.quant_config.a2_gscale else: - self.g1_scale_c = ( - torch.ones_like(self.quant_config.a1_gscale) - * self.quant_config.a2_gscale - ) + self.g1_scale_c = self.quant_config.a2_gscale.clone() + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.w13_weight_scale_2.data.mul_(layer.w13_input_scale) + layer.w2_weight_scale_2.data.mul_(layer.w2_input_scale) + # Recompute g1_scale_c since g1_alphas was just fused in-place. + # Register as a layer parameter so EPLB rearranges it alongside + # other expert weights. + assert self.quant_config.g1_alphas is not None + assert self.quant_config.a2_gscale is not None + if self.moe_config.is_act_and_mul: + g1_scale_c = self.quant_config.g1_alphas * self.quant_config.a2_gscale + else: + g1_scale_c = self.quant_config.a2_gscale.clone() + layer.register_parameter( + "g1_scale_c", + torch.nn.Parameter(g1_scale_c, requires_grad=False), + ) + self.g1_scale_c = layer.g1_scale_c @staticmethod def _supports_current_device() -> bool: diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py index fb8a18ef3..5805a4dd5 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py @@ -49,6 +49,10 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular): ) self.out_dtype = moe_config.in_dtype + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.w13_weight_scale_2.data.mul_(layer.w13_input_scale) + layer.w2_weight_scale_2.data.mul_(layer.w2_input_scale) + @staticmethod def activation_format() -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.BatchedExperts diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index e58d52eee..91f7a83f6 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -61,6 +61,11 @@ def is_valid_flashinfer_cutlass_fused_moe( class FlashInferExperts(mk.FusedMoEExpertsModular): + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if self.quant_config.use_nvfp4_w4a4: + layer.w13_weight_scale_2.data.mul_(layer.w13_input_scale) + layer.w2_weight_scale_2.data.mul_(layer.w2_input_scale) + def __init__( self, moe_config: mk.FusedMoEConfig, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 7135cbbd2..75283b9bb 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1421,19 +1421,23 @@ class FusedMoE(CustomOp): weights = list(self.named_parameters()) weights = [(name, _maybe_make_contiguous(name, p)) for name, p in weights] + # `w13_input_scale` and `w2_input_scale` are global per-tensor + # activation scales shared across all experts (e.g. NVFP4). + # They are broadcast views (stride 0) from .expand() and are + # not actual expert weights, so exclude them from EPLB. + NON_EXPERT_WEIGHTS = { + "e_score_correction_bias", + "w13_input_scale", + "w2_input_scale", + } + assert all( weight.is_contiguous() for name, weight in weights if not (name.startswith("_shared_experts.") or name.startswith("_gate.")) + and name not in NON_EXPERT_WEIGHTS ) - # Filter out the non-expert weights. - # `e_score_correction_bias` is a bias for each logical expert, - # with shape (num_logical_experts,), not an expert weight. - NON_EXPERT_WEIGHTS = { - "e_score_correction_bias", - } - return [ weight.view(self.local_num_experts, -1) for name, weight in weights diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 7100c87c9..a6b498834 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -489,6 +489,9 @@ class FusedMoEExperts(ABC): self.max_num_tokens = max_num_tokens self.num_dispatchers = num_dispatchers + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # noqa: B027 + pass + @staticmethod def is_monolithic() -> bool: raise NotImplementedError("Implemented by subclasses.") diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py index b06cf49cf..8a224cb39 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py @@ -374,11 +374,13 @@ def make_nvfp4_moe_quant_config( w2_scale=w2_scale, ) - g1_alphas = a13_scale * w13_scale_2 - g2_alphas = a2_scale * w2_scale_2 + # Pass w13_scale_2 / w2_scale_2 directly as g1/g2_alphas. + # The expert's process_weights_after_loading will fuse activation + # scales in-place. Since the quant config references the same tensor + # as the registered parameter, EPLB rearrangement stays in sync. return nvfp4_moe_quant_config( - g1_alphas=g1_alphas, - g2_alphas=g2_alphas, + g1_alphas=w13_scale_2, + g2_alphas=w2_scale_2, a1_gscale=(1.0 / a13_scale), a2_gscale=(1.0 / a2_scale), w1_scale=w13_scale, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index f35a4c0b9..29115fbbc 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -570,6 +570,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): shared_experts=layer.shared_experts, routing_tables=layer._maybe_init_expert_routing_tables(), ) + self.moe_kernel.fused_experts.process_weights_after_loading(layer) def maybe_make_prepare_finalize( self, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 977612313..640580da6 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1394,6 +1394,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): shared_experts=layer.shared_experts, routing_tables=layer._maybe_init_expert_routing_tables(), ) + self.moe_kernel.fused_experts.process_weights_after_loading(layer) def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig: return make_nvfp4_moe_quant_config( diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index 42677a592..66300ceae 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -267,16 +267,6 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass( num_experts=w13.size(0), is_gated_activation=is_gated, ) - - # We do not need to make this a parameter, because - # it is not used during the weight (re)-loading process. - if is_gated: - layer.g1_scale_c = a13_scale * w13_scale_2 / a2_scale - else: - layer.g1_scale_c = torch.ones_like(a13_scale) / a2_scale - layer.a1_gscale = 1.0 / a13_scale - layer.g1_alphas = a13_scale * w13_scale_2 - layer.g2_alphas = a2_scale * w2_scale_2 else: # Swizzle the block scales for other FI NVFP4 MoE kernels. w13_scale = swizzle_blockscale(w13_scale) -- GitLab From e5b807607c8493155e6eccd665772d4c19b2114e Mon Sep 17 00:00:00 2001 From: EdalatiAli Date: Mon, 16 Mar 2026 18:07:39 -0400 Subject: [PATCH 004/223] [Quant][Feature] Support online MXFP8 quantization for MoE and dense models (#35448) Signed-off-by: EdalatiAli --- tests/models/quantization/test_mxfp8.py | 104 +++++ .../fused_moe/experts/trtllm_fp8_moe.py | 111 ++++-- .../layers/fused_moe/oracle/fp8.py | 17 +- .../layers/fused_moe/oracle/mxfp8.py | 89 +++-- vllm/model_executor/layers/fused_moe/utils.py | 2 +- .../layers/quantization/__init__.py | 3 + .../layers/quantization/modelopt.py | 9 +- .../layers/quantization/mxfp8.py | 354 ++++++++++++++++++ .../quantization/utils/flashinfer_utils.py | 104 ++++- .../layers/quantization/utils/quant_utils.py | 6 + 10 files changed, 745 insertions(+), 54 deletions(-) create mode 100644 tests/models/quantization/test_mxfp8.py create mode 100644 vllm/model_executor/layers/quantization/mxfp8.py diff --git a/tests/models/quantization/test_mxfp8.py b/tests/models/quantization/test_mxfp8.py new file mode 100644 index 000000000..2cb0f2008 --- /dev/null +++ b/tests/models/quantization/test_mxfp8.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""E2E tests for online MXFP8 quantization. + +Loads a BF16 model with ``--quantization mxfp8`` (online quantization) and +compares log-probabilities against the same model served in BF16 without +quantization. This exercises the full pipeline: config parsing, +``Mxfp8OnlineLinearMethod``, ``Mxfp8OnlineMoEMethod``, weight loading, +online quantization / shuffling, and inference through ``apply_monolithic``. + +Layer skipping (``modules_to_not_convert``) is configured in the model's +``config.json`` under ``quantization_config`` and is not tested here. + +``example_prompts`` is a pytest fixture (from conftest.py) that loads 8 +diverse prompts from ``tests/prompts/example.txt``. +""" + +import pytest + +from tests.quantization.utils import is_quant_method_supported + +from ..utils import check_logprobs_close + +# A small MoE model that fits on a single GPU and has both linear + MoE layers. +MOE_MODEL = "Qwen/Qwen3-30B-A3B" +# A small dense model (no MoE) to validate the linear-only path. +DENSE_MODEL = "Qwen/Qwen3-0.6B" + +MAX_MODEL_LEN = 1024 +MAX_TOKENS = 4 +NUM_LOG_PROBS = 8 + + +@pytest.mark.skipif( + not is_quant_method_supported("mxfp8"), + reason="mxfp8 is not supported on this GPU type (requires sm_100+).", +) +@pytest.mark.quant_model +@pytest.mark.parametrize("model", [DENSE_MODEL, MOE_MODEL], ids=["dense", "moe"]) +def test_mxfp8_logprobs( + vllm_runner, + example_prompts, + model: str, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Compare BF16 baseline logprobs against online MXFP8-quantized model. + + Runs the same model twice -- once in BF16 (baseline) and once with + online MXFP8 quantization -- then checks that the top log-probabilities + are close. Only 4 tokens are generated to keep the test fast while + still catching numerical divergence. + """ + with monkeypatch.context() as m: + m.setenv("TOKENIZERS_PARALLELISM", "true") + + with vllm_runner( + model, + max_model_len=MAX_MODEL_LEN, + enforce_eager=True, + ) as vllm_model: + baseline_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, MAX_TOKENS, NUM_LOG_PROBS + ) + + with vllm_runner( + model, + max_model_len=MAX_MODEL_LEN, + enforce_eager=True, + quantization="mxfp8", + ) as vllm_model: + test_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, MAX_TOKENS, NUM_LOG_PROBS + ) + + check_logprobs_close( + outputs_0_lst=baseline_outputs, + outputs_1_lst=test_outputs, + name_0="bf16", + name_1="mxfp8", + ) + + +@pytest.mark.skipif( + not is_quant_method_supported("mxfp8"), + reason="mxfp8 is not supported on this GPU type (requires sm_100+).", +) +@pytest.mark.quant_model +@pytest.mark.parametrize("model", [DENSE_MODEL, MOE_MODEL], ids=["dense", "moe"]) +def test_mxfp8_generation(vllm_runner, model: str) -> None: + """Smoke test: verify online MXFP8 model generates coherent text.""" + prompt = "1 2 3 4 5" + with vllm_runner( + model, + enforce_eager=True, + quantization="mxfp8", + max_model_len=MAX_MODEL_LEN, + ) as vllm_model: + output = vllm_model.generate_greedy([prompt], max_tokens=5) + + generated = output[0][1] + assert len(generated) > len(prompt), ( + f"MXFP8 model produced no new tokens. Output: {generated!r}" + ) diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py index 1c86702e9..74096ef6e 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py @@ -23,6 +23,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8Dynamic128Sym, kFp8Static128BlockSym, kFp8StaticTensorSym, + kMxfp8Dynamic, + kMxfp8Static, ) from vllm.platforms import current_platform @@ -67,11 +69,54 @@ class TrtLlmFp8ExpertsBase: """Does not support non-gated MoE (i.e. Nanotron-3-Nano).""" return True + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + """Supports Fp8 per-tensor, Fp8 block, and MXFP8.""" + SUPPORTED_W_A = [ + (kFp8Static128BlockSym, kFp8Dynamic128Sym), + (kFp8StaticTensorSym, kFp8StaticTensorSym), + (kMxfp8Static, kMxfp8Dynamic), + ] + return (weight_key, activation_key) in SUPPORTED_W_A + @staticmethod def _supports_activation(activation: MoEActivation) -> bool: """Supports only SiLU and RELU^2 non-gated activation.""" return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] + @staticmethod + def _supports_routing_method( + routing_method: RoutingMethodType, + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + """Monolithic kernels need to express router support.""" + # NOTE(dbari): TopK routing could also be enabled, but need to validate models + # NOTE(dbari): Default is not implemented and should not be enabled until it is + if (weight_key, activation_key) in [ + (kFp8Static128BlockSym, kFp8Dynamic128Sym), + (kMxfp8Static, kMxfp8Dynamic), + ]: + # NOTE(rob): potentially allow others here. This is a conservative list. + return routing_method in [ + RoutingMethodType.DeepSeekV3, + RoutingMethodType.Renormalize, + RoutingMethodType.RenormalizeNaive, + ] + elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym): + # NOTE(dbari): as above, potentially allow others here. + return routing_method in [ + RoutingMethodType.DeepSeekV3, + RoutingMethodType.Llama4, + RoutingMethodType.Renormalize, + RoutingMethodType.RenormalizeNaive, + ] + else: + raise ValueError("Unsupported quantization scheme.") + @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: """Monolithic kernel so only use with naive DP/EP and TP.""" @@ -113,9 +158,10 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular): weight_key: QuantKey | None, activation_key: QuantKey | None, ) -> bool: - """Supports Fp8 block.""" + """Supports Fp8 block and MXFP8.""" SUPPORTED_W_A = [ (kFp8Static128BlockSym, kFp8Dynamic128Sym), + (kMxfp8Static, kMxfp8Dynamic), ] return (weight_key, activation_key) in SUPPORTED_W_A @@ -159,6 +205,7 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular): apply_router_weight_on_input: bool, ): import flashinfer + from flashinfer.fused_moe import Fp8QuantizationType # Pack topk_ids and topk_weights into single tensor # Format: (expert_id << 16) | (weight_bf16.view(int16)) @@ -175,6 +222,16 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular): assert a1q_scale is not None + is_mxfp8 = self.quant_config.block_shape == [1, 32] + if is_mxfp8: + fp8_quant_type = Fp8QuantizationType.MxFp8 + use_shuffled_weight = True + hidden_states_scale = a1q_scale + else: + fp8_quant_type = Fp8QuantizationType.DeepSeekFp8 + use_shuffled_weight = False + hidden_states_scale = a1q_scale.t().contiguous() + # `trtllm_fp8_block_scale_routed_moe` has a bug and does not write to the # output tensor in-place so we need to manually copy the result to the # output tensor @@ -183,7 +240,7 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular): topk_ids=packed_topk_ids, routing_bias=None, hidden_states=hidden_states, - hidden_states_scale=a1q_scale.t().contiguous(), # type: ignore[union-attr] + hidden_states_scale=hidden_states_scale, gemm1_weights=w1, gemm1_weights_scale=self.quant_config.w1_scale, gemm2_weights=w2, @@ -197,8 +254,9 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular): local_num_experts=self.local_num_experts, routed_scaling_factor=None, routing_method_type=1, - use_shuffled_weight=False, + use_shuffled_weight=use_shuffled_weight, weight_layout=0, + fp8_quantization_type=fp8_quant_type, # output=output, ) output.copy_(result) @@ -240,10 +298,11 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit weight_key: QuantKey | None, activation_key: QuantKey | None, ) -> bool: - """Supports Fp8 per-tensor and Fp8 block.""" + """Supports Fp8 per-tensor, Fp8 block, and MXFP8.""" SUPPORTED_W_A = [ (kFp8Static128BlockSym, kFp8Dynamic128Sym), (kFp8StaticTensorSym, kFp8StaticTensorSym), + (kMxfp8Static, kMxfp8Dynamic), ] return (weight_key, activation_key) in SUPPORTED_W_A @@ -256,7 +315,10 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit """Monolithic kernels need to express router support.""" # NOTE(dbari): TopK routing could also be enabled, but need to validate models # NOTE(dbari): Default is not implemented and should not be enabled until it is - if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym): + if (weight_key, activation_key) in [ + (kFp8Static128BlockSym, kFp8Dynamic128Sym), + (kMxfp8Static, kMxfp8Dynamic), + ]: # NOTE(rob): potentially allow others here. This is a conservative list. return routing_method in [ RoutingMethodType.DeepSeekV3, @@ -274,7 +336,7 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit else: raise ValueError("Unsupported quantization scheme.") - def _apply_per_block( + def _apply_block_scale( self, hidden_states: torch.Tensor, w1: torch.Tensor, @@ -291,32 +353,38 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit routed_scaling_factor: float | None = None, topk_group: int | None = None, ) -> torch.Tensor: - # Delay import for non-CUDA. import flashinfer + from flashinfer.fused_moe import Fp8QuantizationType assert not apply_router_weight_on_input assert activation == MoEActivation.SILU - - if self.routing_method_type == RoutingMethodType.DeepSeekV3: - router_logits = router_logits.to(torch.float32) - assert self.topk <= global_num_experts assert self.topk <= 10 assert global_num_experts % 4 == 0 - assert self.quant_config.block_shape == [128, 128] - # Routing kernel expects #experts <= #threads 512 + assert self.quant_config.block_shape in [[128, 128], [1, 32]] + # Kernel expects #experts <= #threads 512 assert global_num_experts <= 512 - - # Kernel requires transposed hidden state scales # TODO: fuse into the quant kernel. assert a1q_scale is not None - a1q_scale_t = a1q_scale.t().contiguous() + + if self.routing_method_type == RoutingMethodType.DeepSeekV3: + router_logits = router_logits.to(torch.float32) + + is_mxfp8 = self.quant_config.block_shape == [1, 32] + if is_mxfp8: + fp8_quant_type = Fp8QuantizationType.MxFp8 + use_shuffled_weight = True + hidden_states_scale = a1q_scale + else: + fp8_quant_type = Fp8QuantizationType.DeepSeekFp8 + use_shuffled_weight = False + hidden_states_scale = a1q_scale.t().contiguous() return flashinfer.fused_moe.trtllm_fp8_block_scale_moe( routing_logits=router_logits, routing_bias=e_score_correction_bias, hidden_states=hidden_states, - hidden_states_scale=a1q_scale_t, + hidden_states_scale=hidden_states_scale, gemm1_weights=w1, gemm1_weights_scale=self.quant_config.w1_scale, gemm2_weights=w2, @@ -330,7 +398,8 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit local_num_experts=self.local_num_experts, routed_scaling_factor=routed_scaling_factor, routing_method_type=self.routing_method_type, - use_shuffled_weight=False, + use_shuffled_weight=use_shuffled_weight, + fp8_quantization_type=fp8_quant_type, ) def _apply_per_tensor( @@ -409,7 +478,7 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit topk_group: int | None = None, ) -> torch.Tensor: if self.quant_config.block_shape is not None: - return self._apply_per_block( + return self._apply_block_scale( hidden_states, w1, w2, @@ -441,6 +510,6 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit ) else: raise NotImplementedError( - "Only per-block and per-tensor quantization are supported in " - f"{self.__class__.__name__}." + "Only per-block, per-tensor, and MXFP8 quantization are " + f"supported in {self.__class__.__name__}." ) diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index 48ca03f66..a63c02663 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -444,7 +444,7 @@ def convert_to_fp8_moe_kernel_format( Fp8MoeBackend.FLASHINFER_CUTLASS, Fp8MoeBackend.FLASHINFER_TRTLLM, ]: - w13, w2, w13_scale = prepare_fp8_moe_layer_for_fi( + w13, w2, w13_scale, w2_scale = prepare_fp8_moe_layer_for_fi( layer=layer, w13=w13, w2=w2, @@ -512,6 +512,21 @@ def make_fp8_moe_quant_config( g1_alphas=(w1_scale * a1_scale).squeeze(), g2_alphas=(w2_scale * a2_scale).squeeze(), ) + # MXFP8 uses "mxfp8" quant_dtype so the prepare step dispatches to + # _mxfp8_e4m3_quantize rather than standard FP8 block quantization. + # Non-swizzled layout is required since the TRTLLM kernel expects + # scales in (num_tokens, hidden_dim // 32) format. + if block_shape == [1, 32]: + return FusedMoEQuantConfig.make( + "mxfp8", + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_shape, + is_nvfp4_scale_swizzled=False, + ) + # All other backends use normal config. return fp8_w8a8_moe_quant_config( w1_scale=w1_scale, diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py index 49406ba93..ed3af4b5a 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py @@ -1,44 +1,87 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from enum import Enum +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig +from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( + Fp8MoeBackend, + backend_to_kernel_cls, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kMxfp8Dynamic, + kMxfp8Static, +) logger = init_logger(__name__) +_SUPPORTED_BACKENDS: frozenset[Fp8MoeBackend] = frozenset( + { + Fp8MoeBackend.FLASHINFER_TRTLLM, + } +) -class MxFp8MoeBackend(Enum): - FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM" +_BACKEND_NAME_MAP: dict[str, Fp8MoeBackend] = { + "flashinfer_trtllm": Fp8MoeBackend.FLASHINFER_TRTLLM, +} + + +def _select_kernel_cls( + backend: Fp8MoeBackend, + config: FusedMoEConfig, +) -> type[mk.FusedMoEExperts]: + """Select the first supported expert class for the MXFP8 config.""" + activation_format = ( + mk.FusedMoEActivationFormat.BatchedExperts + if config.moe_parallel_config.use_batched_activation_format + else mk.FusedMoEActivationFormat.Standard + ) + last_reason: str | None = None + for cls in backend_to_kernel_cls(backend): + supported, reason = cls.is_supported_config( + cls, + config, + kMxfp8Static, + kMxfp8Dynamic, + activation_format, + ) + if supported: + return cls + last_reason = reason + raise ValueError( + f"No supported MXFP8 expert class for {backend.value}: {last_reason}" + ) def select_mxfp8_moe_backend( config: FusedMoEConfig, -) -> MxFp8MoeBackend: +) -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts]]: + """Select the MXFP8 MoE backend and the best expert class. + + Returns: + A tuple of (fp8_backend, experts_cls). + """ if config.is_lora_enabled: raise NotImplementedError("LoRA is not supported for MXFP8 MoE.") - AVAILABLE_BACKENDS = [ - MxFp8MoeBackend.FLASHINFER_TRTLLM, - ] - runner_backend = config.moe_backend if runner_backend != "auto": - mapping = { - "flashinfer_trtllm": MxFp8MoeBackend.FLASHINFER_TRTLLM, - } - if backend := mapping.get(runner_backend): - logger.info_once( - "Using '%s' MxFp8 MoE backend (user-requested).", - backend.value, + backend = _BACKEND_NAME_MAP.get(runner_backend) + if backend is None: + raise ValueError( + f"moe_backend='{runner_backend}' is not supported for " + f"MXFP8 MoE. Expected one of " + f"{list(_BACKEND_NAME_MAP.keys())}." ) - return backend - raise ValueError( - f"moe_backend='{runner_backend}' is not supported for MXFP8 MoE. " - f"Expected one of {list(mapping.keys())}." + logger.info_once( + "Using '%s' MxFp8 MoE backend (user-requested).", + backend.value, ) + return backend, _select_kernel_cls(backend, config) + + # Auto-select: pick the first supported backend. + for backend in _SUPPORTED_BACKENDS: + logger.info_once("Using '%s' MxFp8 MoE backend.", backend.value) + return backend, _select_kernel_cls(backend, config) - # Auto-select: only one backend available for now. - backend = AVAILABLE_BACKENDS[0] - logger.info_once("Using '%s' MxFp8 MoE backend.", backend.value) - return backend + raise ValueError("No MXFP8 MoE backends available.") diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 019e408c1..4adb7f1cf 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -199,7 +199,7 @@ def _mxfp8_e4m3_quantize( ) -> tuple[torch.Tensor, torch.Tensor]: assert A_scale is None assert not per_act_token_quant - assert block_shape is None + assert block_shape is None or block_shape == [1, 32] return mxfp8_e4m3_quantize(A, is_sf_swizzled_layout) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 2fb54e775..e08a6456a 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -31,6 +31,7 @@ QuantizationMethods = Literal[ "torchao", "inc", "mxfp4", + "mxfp8", "petit_nvfp4", "cpu_awq", ] @@ -129,6 +130,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: ) from .moe_wna16 import MoeWNA16Config from .mxfp4 import Mxfp4Config + from .mxfp8 import Mxfp8Config from .petit import PetitNvFp4Config from .ptpc_fp8 import PTPCFp8Config from .torchao import TorchAOConfig @@ -156,6 +158,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "auto-round": INCConfig, "inc": INCConfig, "mxfp4": Mxfp4Config, + "mxfp8": Mxfp8Config, "petit_nvfp4": PetitNvFp4Config, "cpu_awq": CPUAWQConfig, } diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 640580da6..78644f74d 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -25,13 +25,13 @@ from vllm.model_executor.layers.fused_moe.layer import ( FusedMoeWeightScaleSupported, ) from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( + Fp8MoeBackend, convert_to_fp8_moe_kernel_format, make_fp8_moe_kernel, make_fp8_moe_quant_config, select_fp8_moe_backend, ) from vllm.model_executor.layers.fused_moe.oracle.mxfp8 import ( - MxFp8MoeBackend, select_mxfp8_moe_backend, ) from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( @@ -1712,8 +1712,7 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): self.quant_config = quant_config assert self.quant_config.is_checkpoint_mxfp8_serialized - # Select MXFP8 MoE backend - self.mxfp8_backend = select_mxfp8_moe_backend(self.moe) + self.mxfp8_backend, _ = select_mxfp8_moe_backend(self.moe) def create_weights( self, @@ -1943,7 +1942,7 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): @property def is_monolithic(self) -> bool: - return self.mxfp8_backend == MxFp8MoeBackend.FLASHINFER_TRTLLM + return self.mxfp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM def apply_monolithic( self, @@ -1956,7 +1955,7 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): Fp8QuantizationType, ) - assert self.mxfp8_backend == MxFp8MoeBackend.FLASHINFER_TRTLLM + assert self.mxfp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM if layer.enable_eplb: raise NotImplementedError( diff --git a/vllm/model_executor/layers/quantization/mxfp8.py b/vllm/model_executor/layers/quantization/mxfp8.py new file mode 100644 index 000000000..5b4564bea --- /dev/null +++ b/vllm/model_executor/layers/quantization/mxfp8.py @@ -0,0 +1,354 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Online MXFP8 (microscaling FP8, block-32) quantization config and methods.""" + +from typing import Any + +import torch +from torch.nn import Module + +from vllm.logger import init_logger +from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.fused_moe import ( + FusedMoE, + FusedMoEMethodBase, +) +from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod +from vllm.model_executor.layers.fused_moe.oracle.mxfp8 import ( + select_mxfp8_moe_backend, +) +from vllm.model_executor.layers.linear import ( + LinearBase, + UnquantizedLinearMethod, +) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizeMethodBase, +) +from vllm.model_executor.layers.quantization.fp8 import ( + Fp8Config, + Fp8KVCacheMethod, + Fp8OnlineLinearMethod, + Fp8OnlineMoEMethod, + _copy_missing_attrs, +) +from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( + MXFP8_BLOCK_SIZE, + Mxfp8LinearBackend, + Mxfp8LinearOp, + mxfp8_e4m3_quantize, + swizzle_mxfp8_scale, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped, +) +from vllm.model_executor.model_loader.weight_utils import ( + initialize_single_dummy_weight, +) +from vllm.model_executor.parameter import ModelWeightParameter +from vllm.model_executor.utils import replace_parameter, set_weight_attrs +from vllm.platforms import current_platform + +logger = init_logger(__name__) + + +class Mxfp8Config(Fp8Config): + """Config class for online MXFP8 MoE quantization.""" + + def __init__( + self, + activation_scheme: str = "dynamic", + ignored_layers: list[str] | None = None, + ) -> None: + if activation_scheme != "dynamic": + raise ValueError("mxfp8 only supports dynamic activation scheme.") + super().__init__( + is_checkpoint_fp8_serialized=False, + activation_scheme=activation_scheme, + ignored_layers=ignored_layers, + weight_block_size=None, + ) + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "mxfp8" + + @classmethod + def get_min_capability(cls) -> int: + return 100 + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "Mxfp8Config": + activation_scheme = cls.get_from_keys_or( + config, ["activation_scheme"], "dynamic" + ) + ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) + if not ignored_layers: + ignored_layers = cls.get_from_keys_or( + config, ["modules_to_not_convert"], None + ) + return cls( + activation_scheme=activation_scheme, + ignored_layers=ignored_layers, + ) + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> "QuantizeMethodBase | None": + if isinstance(layer, LinearBase): + if is_layer_skipped( + prefix=prefix, + ignored_layers=self.ignored_layers, + fused_mapping=self.packed_modules_mapping, + skip_with_substr=True, + ): + return UnquantizedLinearMethod() + return Mxfp8OnlineLinearMethod(self) + elif isinstance(layer, FusedMoE): + if is_layer_skipped( + prefix=prefix, + ignored_layers=self.ignored_layers, + fused_mapping=self.packed_modules_mapping, + skip_with_substr=True, + ): + return UnquantizedFusedMoEMethod(layer.moe_config) + return Mxfp8OnlineMoEMethod(self, layer) + elif isinstance(layer, Attention): + return Fp8KVCacheMethod(self) + return None + + +class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod): + """Online MXFP8 linear method. + Loads bf16/fp16 checkpoints and quantizes weights to MXFP8 (microscaling + FP8 with block-32 scales) during weight loading. + + Args: + quant_config: The MXFP8 quantization config. + """ + + uses_meta_device: bool = True + + def __init__(self, quant_config: "Mxfp8Config"): + self.quant_config = quant_config + self.out_dtype = torch.get_default_dtype() + self.mxfp8_linear = Mxfp8LinearOp(self._select_backend()) + logger.info_once( + "Using %s backend for MXFP8 GEMM", self.mxfp8_linear.backend.value + ) + + @staticmethod + def _select_backend() -> Mxfp8LinearBackend: + try: + from vllm.utils import flashinfer as fi + + _ = fi.mm_mxfp8 + return Mxfp8LinearBackend.FLASHINFER_CUTLASS + except Exception: + logger.warning( + "FlashInfer mm_mxfp8 not available, " + "falling back to MXFP8 emulation backend." + ) + return Mxfp8LinearBackend.EMULATION + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + if input_size_per_partition % MXFP8_BLOCK_SIZE != 0: + raise ValueError( + f"MXFP8 requires input_size_per_partition " + f"({input_size_per_partition}) to be divisible by " + f"{MXFP8_BLOCK_SIZE}." + ) + + super().create_weights( + layer, + input_size_per_partition, + output_partition_sizes, + input_size, + output_size, + params_dtype, + **extra_weight_attrs, + ) + + def process_weights_after_loading(self, layer: Module) -> None: + if getattr(layer, "_already_called_process_weights_after_loading", False): + return + + if layer.weight.device == torch.device("meta"): + weight = ModelWeightParameter( + data=torch.empty_like(layer.weight, device=layer._load_device), + input_dim=1, + output_dim=0, + weight_loader=layer.weight.weight_loader, + ) + _copy_missing_attrs(layer.weight, weight) + layer.register_parameter("weight", weight) + initialize_single_dummy_weight(layer.weight) + + weight_fp8, weight_scale = mxfp8_e4m3_quantize(layer.weight.contiguous()) + + if self.mxfp8_linear.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS: + N, K = layer.weight.shape[0], layer.weight.shape[1] + weight_scale = swizzle_mxfp8_scale(weight_scale, N, K) + + layer.input_scale = None + replace_parameter(layer, "weight", weight_fp8.data) + replace_parameter(layer, "weight_scale", weight_scale.data) + + layer._already_called_process_weights_after_loading = True + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + return self.mxfp8_linear.apply( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + out_dtype=self.out_dtype, + bias=bias, + ) + + +class Mxfp8OnlineMoEMethod(Fp8OnlineMoEMethod): + """MoE method for online MXFP8 (block) quantization.""" + + uses_meta_device: bool = True + + def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): + FusedMoEMethodBase.__init__(self, layer.moe_config) + self.quant_config = quant_config + assert not quant_config.is_checkpoint_fp8_serialized + assert quant_config.activation_scheme == "dynamic" + + self.weight_block_size = [1, MXFP8_BLOCK_SIZE] + self.block_quant = True + self.weight_scale_name = "weight_scale" + + self.fp8_backend, self.experts_cls = select_mxfp8_moe_backend(config=self.moe) + + def create_weights( + self, + layer: Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + if ( + hidden_size % MXFP8_BLOCK_SIZE != 0 + or intermediate_size_per_partition % MXFP8_BLOCK_SIZE != 0 + ): + raise ValueError( + "Online MXFP8 MoE requires hidden/intermediate sizes divisible " + f"by {MXFP8_BLOCK_SIZE}." + ) + + super().create_weights( + layer=layer, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size_per_partition=intermediate_size_per_partition, + params_dtype=params_dtype, + **extra_weight_attrs, + ) + + w13_weight_scale = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // MXFP8_BLOCK_SIZE, + dtype=torch.uint8, + ), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size, + intermediate_size_per_partition // MXFP8_BLOCK_SIZE, + dtype=torch.uint8, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + layer.weight_block_size = [1, MXFP8_BLOCK_SIZE] + + def _quantize_mxfp8_moe_weight( + self, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """Batch quantization: bf16/fp16 weights -> MXFP8 (fp8 + uint8 scales).""" + num_batches = weight.size(0) + w_quant = [] + w_scales = [] + for i in range(num_batches): + mx_fp8_quant, mx_fp8_scale = mxfp8_e4m3_quantize( + weight[i], is_sf_swizzled_layout=False + ) + w_quant.append(mx_fp8_quant) + w_scales.append(mx_fp8_scale) + + return torch.stack(w_quant), torch.stack(w_scales) + + def process_weights_after_loading(self, layer: Module) -> None: + if getattr(layer, "_already_called_process_weights_after_loading", False): + return + + if layer.w13_weight.device == torch.device("meta"): + w13_weight = torch.nn.Parameter( + torch.empty_like(layer.w13_weight, device=layer._load_device), + requires_grad=False, + ) + set_weight_attrs( + w13_weight, {"weight_loader": layer.w13_weight.weight_loader} + ) + _copy_missing_attrs(layer.w13_weight, w13_weight) + layer.register_parameter("w13_weight", w13_weight) + initialize_single_dummy_weight(layer.w13_weight) + if layer.w2_weight.device == torch.device("meta"): + w2_weight = torch.nn.Parameter( + torch.empty_like(layer.w2_weight, device=layer._load_device), + requires_grad=False, + ) + set_weight_attrs( + w2_weight, {"weight_loader": layer.w2_weight.weight_loader} + ) + _copy_missing_attrs(layer.w2_weight, w2_weight) + layer.register_parameter("w2_weight", w2_weight) + initialize_single_dummy_weight(layer.w2_weight) + + fp8_dtype = current_platform.fp8_dtype() + w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype) + w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype) + w13_scale = layer.w13_weight_scale + w2_scale = layer.w2_weight_scale + + w13, w13_scale = self._quantize_mxfp8_moe_weight(layer.w13_weight) + w2, w2_scale = self._quantize_mxfp8_moe_weight(layer.w2_weight) + + self._setup_kernel( + layer, + w13, + w2, + w13_scale, + w2_scale, + layer.w13_input_scale, + layer.w2_input_scale, + ) + + layer._already_called_process_weights_after_loading = True diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 322b3a6e8..271bcf168 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -305,6 +305,81 @@ def align_fp8_moe_weights_for_fi( return padded_w13, padded_w2, padded_intermediate +def _shuffle_mxfp8_moe_weights( + w13: torch.Tensor, + w2: torch.Tensor, + w13_scale: torch.Tensor, + w2_scale: torch.Tensor, + is_gated: bool, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Preprocess MXFP8 weights and scales for the FlashInfer TRT-LLM kernel. + + Following flashinfer/tests/moe/test_trtllm_gen_fused_moe.py: + 1. reorder_rows_for_gated_act_gemm (interleave gate/up rows) + 2. shuffle_matrix_a (weight data layout shuffle) + 3. shuffle_matrix_sf_a (scale factor layout shuffle) + """ + from flashinfer import ( + reorder_rows_for_gated_act_gemm, + shuffle_matrix_a, + shuffle_matrix_sf_a, + ) + + epilogue_tile_m = 128 + num_experts = w13.shape[0] + intermediate_size = w13.shape[1] // 2 + hidden_size = w13.shape[2] + + w13_interleaved: list[torch.Tensor] = [] + w13_scale_interleaved: list[torch.Tensor] = [] + for i in range(num_experts): + if is_gated: + w13_interleaved.append( + reorder_rows_for_gated_act_gemm( + w13[i].reshape(2 * intermediate_size, -1) + ) + ) + w13_scale_interleaved.append( + reorder_rows_for_gated_act_gemm( + w13_scale[i].reshape(2 * intermediate_size, -1) + ) + ) + else: + w13_interleaved.append(w13[i]) + w13_scale_interleaved.append(w13_scale[i]) + + w13_shuffled: list[torch.Tensor] = [] + w2_shuffled: list[torch.Tensor] = [] + w13_scale_shuffled: list[torch.Tensor] = [] + w2_scale_shuffled: list[torch.Tensor] = [] + for i in range(num_experts): + w13_shuffled.append( + shuffle_matrix_a(w13_interleaved[i].view(torch.uint8), epilogue_tile_m) + ) + w2_shuffled.append(shuffle_matrix_a(w2[i].view(torch.uint8), epilogue_tile_m)) + w13_scale_shuffled.append( + shuffle_matrix_sf_a( + w13_scale_interleaved[i] + .view(torch.uint8) + .reshape(2 * intermediate_size, -1), + epilogue_tile_m, + ) + ) + w2_scale_shuffled.append( + shuffle_matrix_sf_a( + w2_scale[i].view(torch.uint8).reshape(hidden_size, -1), + epilogue_tile_m, + ) + ) + + w13_out = torch.stack(w13_shuffled).view(torch.float8_e4m3fn) + w2_out = torch.stack(w2_shuffled).view(torch.float8_e4m3fn) + w13_scale_out = torch.stack(w13_scale_shuffled).reshape(w13_scale.shape) + w2_scale_out = torch.stack(w2_scale_shuffled).reshape(w2_scale.shape) + + return w13_out, w2_out, w13_scale_out, w2_scale_out + + def prepare_fp8_moe_layer_for_fi( layer: torch.nn.Module, w13: torch.Tensor, @@ -314,7 +389,7 @@ def prepare_fp8_moe_layer_for_fi( w2_scale: torch.Tensor, w2_input_scale: torch.Tensor | None, is_trtllm: bool = False, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Convert Fp8 MoE weights to flashinfer kernel format @@ -329,10 +404,33 @@ def prepare_fp8_moe_layer_for_fi( block_quant = ( hasattr(layer, "weight_block_size") and layer.weight_block_size is not None ) + is_mxfp8 = block_quant and w13_scale.dtype == torch.uint8 + is_gated = layer.activation.is_gated + + # MXFP8 TRT-LLM requires W31 swap + reorder + shuffle. + if is_mxfp8 and is_trtllm: + # FlashInfer TRT-LLM SwiGLU expects [up; gate] but vLLM stores + # [gate; up]. Swap both weights and scales before interleaving. + if layer.moe_config.is_act_and_mul: + w13 = swap_w13_to_w31(w13) + # Scales may be 2D [E, flat] from _quantize_mxfp8_moe_weight; + # reshape to 3D so swap_w13_to_w31 can flip the two halves, + # then flatten back. + if w13_scale.ndim == 2: + num_rows = w13.shape[1] # 2 * intermediate_size + w13_scale = w13_scale.reshape(w13_scale.shape[0], num_rows, -1) + w13_scale = swap_w13_to_w31(w13_scale) + w13_scale = w13_scale.reshape(w13_scale.shape[0], -1) + else: + w13_scale = swap_w13_to_w31(w13_scale) + + w13, w2, w13_scale, w2_scale = _shuffle_mxfp8_moe_weights( + w13, w2, w13_scale, w2_scale, is_gated + ) + return w13, w2, w13_scale, w2_scale # Some FI MoE kernels require internal alignment of 16 # for the gate-up proj. Pad the weights to respect this. - is_gated = layer.activation.is_gated if not block_quant: min_alignment = 16 if is_gated else 128 w13, w2, new_intermediate = align_fp8_moe_weights_for_fi( @@ -369,4 +467,4 @@ def prepare_fp8_moe_layer_for_fi( w13_scale.clamp_(min=_FI_CUTLASS_MIN_BLOCK_SCALE) w2_scale.clamp_(min=_FI_CUTLASS_MIN_BLOCK_SCALE) - return w13, w2, w13_scale + return w13, w2, w13_scale, w2_scale diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 12a1799d1..1170a2d3a 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -149,6 +149,12 @@ kFp8Dynamic128Sym = QuantKey(FP8_DTYPE, kDynamic128Scale, symmetric=True) kStatic128BlockScale = ScaleDesc(torch.float32, True, GroupShape(128, 128)) kFp8Static128BlockSym = QuantKey(FP8_DTYPE, kStatic128BlockScale, symmetric=True) +kMxfp8StaticScale = ScaleDesc(torch.uint8, True, GroupShape(1, 32)) +kMxfp8Static = QuantKey(FP8_DTYPE, kMxfp8StaticScale, symmetric=True) + +kMxfp8DynamicScale = ScaleDesc(torch.uint8, False, GroupShape(1, 32)) +kMxfp8Dynamic = QuantKey(FP8_DTYPE, kMxfp8DynamicScale, symmetric=True) + kDynamic64Scale = ScaleDesc(torch.float32, False, GroupShape(1, 64)) kFp8Dynamic64Sym = QuantKey(FP8_DTYPE, kDynamic64Scale, symmetric=True) -- GitLab From a3a51d20e7d040542118f04f5089c57a27bc7aca Mon Sep 17 00:00:00 2001 From: Wei Zhao <51183510+wzhao18@users.noreply.github.com> Date: Mon, 16 Mar 2026 18:22:40 -0400 Subject: [PATCH 005/223] [Benchmark] Improvements to attention benchmark script (#37115) Signed-off-by: wzhao18 --- benchmarks/attention_benchmarks/benchmark.py | 70 ++++++-- benchmarks/attention_benchmarks/common.py | 5 + .../configs/mla_mixed_batch.yaml | 6 +- .../configs/mla_sparse_decode.yaml | 58 ++++++ benchmarks/attention_benchmarks/mla_runner.py | 165 ++++++++++++++---- benchmarks/attention_benchmarks/runner.py | 75 ++++++-- 6 files changed, 311 insertions(+), 68 deletions(-) create mode 100644 benchmarks/attention_benchmarks/configs/mla_sparse_decode.yaml diff --git a/benchmarks/attention_benchmarks/benchmark.py b/benchmarks/attention_benchmarks/benchmark.py index 0329d1102..a8b1c5478 100644 --- a/benchmarks/attention_benchmarks/benchmark.py +++ b/benchmarks/attention_benchmarks/benchmark.py @@ -47,6 +47,8 @@ from common import ( is_mla_backend, ) +from vllm.v1.worker.workspace import init_workspace_manager + def run_standard_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult: """Run standard attention benchmark (Flash/Triton/FlashInfer).""" @@ -462,7 +464,7 @@ def main(): parser.add_argument( "--batch-specs", nargs="+", - default=["q2k", "8q1s1k"], + default=None, help="Batch specifications using extended grammar", ) @@ -478,6 +480,21 @@ def main(): parser.add_argument("--repeats", type=int, default=1, help="Repetitions") parser.add_argument("--warmup-iters", type=int, default=3, help="Warmup iterations") parser.add_argument("--profile-memory", action="store_true", help="Profile memory") + parser.add_argument( + "--kv-cache-dtype", + default="auto", + choices=["auto", "fp8"], + help="KV cache dtype: auto or fp8", + ) + parser.add_argument( + "--cuda-graphs", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "Launch kernels with CUDA graphs to eliminate CPU overhead" + "in measurements (default: True)" + ), + ) # Parameter sweep (use YAML config for advanced sweeps) parser.add_argument( @@ -536,21 +553,24 @@ def main(): # Batch specs and sizes # Support both explicit batch_specs and generated batch_spec_ranges - if "batch_spec_ranges" in yaml_config: - # Generate batch specs from ranges - generated_specs = generate_batch_specs_from_ranges( - yaml_config["batch_spec_ranges"] - ) - # Combine with any explicit batch_specs - if "batch_specs" in yaml_config: - args.batch_specs = yaml_config["batch_specs"] + generated_specs - else: - args.batch_specs = generated_specs - console.print( - f"[dim]Generated {len(generated_specs)} batch specs from ranges[/]" - ) - elif "batch_specs" in yaml_config: - args.batch_specs = yaml_config["batch_specs"] + # CLI --batch-specs takes precedence over YAML when provided. + cli_batch_specs_provided = args.batch_specs is not None + if not cli_batch_specs_provided: + if "batch_spec_ranges" in yaml_config: + # Generate batch specs from ranges + generated_specs = generate_batch_specs_from_ranges( + yaml_config["batch_spec_ranges"] + ) + # Combine with any explicit batch_specs + if "batch_specs" in yaml_config: + args.batch_specs = yaml_config["batch_specs"] + generated_specs + else: + args.batch_specs = generated_specs + console.print( + f"[dim]Generated {len(generated_specs)} batch specs from ranges[/]" + ) + elif "batch_specs" in yaml_config: + args.batch_specs = yaml_config["batch_specs"] if "batch_sizes" in yaml_config: args.batch_sizes = yaml_config["batch_sizes"] @@ -575,6 +595,10 @@ def main(): args.warmup_iters = yaml_config["warmup_iters"] if "profile_memory" in yaml_config: args.profile_memory = yaml_config["profile_memory"] + if "kv_cache_dtype" in yaml_config: + args.kv_cache_dtype = yaml_config["kv_cache_dtype"] + if "cuda_graphs" in yaml_config: + args.cuda_graphs = yaml_config["cuda_graphs"] # Parameter sweep configuration if "parameter_sweep" in yaml_config: @@ -629,12 +653,18 @@ def main(): # Determine backends backends = args.backends or ([args.backend] if args.backend else ["flash"]) prefill_backends = getattr(args, "prefill_backends", None) + if not args.batch_specs: + args.batch_specs = ["q2k", "8q1s1k"] console.print(f"Backends: {', '.join(backends)}") if prefill_backends: console.print(f"Prefill backends: {', '.join(prefill_backends)}") console.print(f"Batch specs: {', '.join(args.batch_specs)}") + console.print(f"KV cache dtype: {args.kv_cache_dtype}") + console.print(f"CUDA graphs: {args.cuda_graphs}") console.print() + init_workspace_manager(args.device) + # Run benchmarks all_results = [] @@ -687,6 +717,8 @@ def main(): repeats=args.repeats, warmup_iters=args.warmup_iters, profile_memory=args.profile_memory, + kv_cache_dtype=args.kv_cache_dtype, + use_cuda_graphs=args.cuda_graphs, ) # Add decode pipeline config @@ -839,6 +871,8 @@ def main(): "repeats": args.repeats, "warmup_iters": args.warmup_iters, "profile_memory": args.profile_memory, + "kv_cache_dtype": args.kv_cache_dtype, + "use_cuda_graphs": args.cuda_graphs, } all_results = run_model_parameter_sweep( backends, @@ -861,6 +895,8 @@ def main(): "repeats": args.repeats, "warmup_iters": args.warmup_iters, "profile_memory": args.profile_memory, + "kv_cache_dtype": args.kv_cache_dtype, + "use_cuda_graphs": args.cuda_graphs, } all_results = run_parameter_sweep( backends, args.batch_specs, base_config_args, args.parameter_sweep, console @@ -891,6 +927,8 @@ def main(): repeats=args.repeats, warmup_iters=args.warmup_iters, profile_memory=args.profile_memory, + kv_cache_dtype=args.kv_cache_dtype, + use_cuda_graphs=args.cuda_graphs, ) result = run_benchmark(config) diff --git a/benchmarks/attention_benchmarks/common.py b/benchmarks/attention_benchmarks/common.py index 208d6273c..74d9e2397 100644 --- a/benchmarks/attention_benchmarks/common.py +++ b/benchmarks/attention_benchmarks/common.py @@ -213,6 +213,9 @@ class BenchmarkConfig: profile_memory: bool = False use_cuda_graphs: bool = False + # "auto" or "fp8" + kv_cache_dtype: str = "auto" + # MLA-specific prefill_backend: str | None = None kv_lora_rank: int | None = None @@ -369,6 +372,7 @@ class ResultsFormatter: "backend", "batch_spec", "num_layers", + "kv_cache_dtype", "mean_time", "std_time", "throughput", @@ -382,6 +386,7 @@ class ResultsFormatter: "backend": r.config.backend, "batch_spec": r.config.batch_spec, "num_layers": r.config.num_layers, + "kv_cache_dtype": r.config.kv_cache_dtype, "mean_time": r.mean_time, "std_time": r.std_time, "throughput": r.throughput_tokens_per_sec or 0, diff --git a/benchmarks/attention_benchmarks/configs/mla_mixed_batch.yaml b/benchmarks/attention_benchmarks/configs/mla_mixed_batch.yaml index b555d90cb..c342e9fb8 100644 --- a/benchmarks/attention_benchmarks/configs/mla_mixed_batch.yaml +++ b/benchmarks/attention_benchmarks/configs/mla_mixed_batch.yaml @@ -30,9 +30,9 @@ batch_specs: - "2q16k_32q1s4k" # 2 very large prefill + 32 decode # Context extension + decode - - "2q1kkv2k_16q1s1k" # 2 extend + 16 decode - - "4q2kkv4k_32q1s2k" # 4 extend + 32 decode - - "2q1kkv8k_32q1s2k" # 2 large extend + 32 decode + - "2q1ks2k_16q1s1k" # 2 extend + 16 decode + - "4q2ks4k_32q1s2k" # 4 extend + 32 decode + - "2q1ks8k_32q1s2k" # 2 large extend + 32 decode # Explicitly chunked prefill - "q8k" # 8k prefill with chunking hint diff --git a/benchmarks/attention_benchmarks/configs/mla_sparse_decode.yaml b/benchmarks/attention_benchmarks/configs/mla_sparse_decode.yaml new file mode 100644 index 000000000..689c9f3c3 --- /dev/null +++ b/benchmarks/attention_benchmarks/configs/mla_sparse_decode.yaml @@ -0,0 +1,58 @@ +# MLA decode-only benchmark configuration + +model: + name: "deepseek-v3" + num_layers: 60 + num_q_heads: 128 # Base value, can be swept for TP simulation + num_kv_heads: 1 # MLA uses single latent KV + head_dim: 576 + kv_lora_rank: 512 + qk_nope_head_dim: 128 + qk_rope_head_dim: 64 + v_head_dim: 128 + block_size: 128 # CUTLASS MLA and FlashAttn MLA use 128 + +# Model parameter sweep: simulate tensor parallelism by varying num_q_heads +# TP=1: 128 heads, TP=2: 64 heads, TP=4: 32 heads, TP=8: 16 heads +model_parameter_sweep: + param_name: "num_q_heads" + values: [128, 64, 32, 16] + label_format: "{backend}_{value}h" + +batch_specs: + # Small batches, varying sequence lengths + - "16q1s512" # 16 requests, 512 KV cache + - "16q1s1k" # 16 requests, 1k KV cache + - "16q1s2k" # 16 requests, 2k KV cache + - "16q1s4k" # 16 requests, 4k KV cache + + # Medium batches + - "32q1s1k" # 32 requests, 1k KV cache + - "32q1s2k" # 32 requests, 2k KV cache + - "32q1s4k" # 32 requests, 4k KV cache + - "32q1s8k" # 32 requests, 8k KV cache + + # Large batches + - "64q1s1k" # 64 requests, 1k KV cache + - "64q1s2k" # 64 requests, 2k KV cache + - "64q1s4k" # 64 requests, 4k KV cache + - "64q1s8k" # 64 requests, 8k KV cache + + # Very large batches + - "128q1s1k" # 128 requests, 1k KV cache + - "128q1s2k" # 128 requests, 2k KV cache + - "128q1s4k" # 128 requests, 4k KV cache + - "128q1s8k" # 128 requests, 8k KV cache + + # Long context + - "32q1s16k" # 32 requests, 16k KV cache + - "32q1s32k" # 32 requests, 32k KV cache + +backends: + - FLASHMLA_SPARSE + - FLASHINFER_MLA_SPARSE + +device: "cuda:0" +repeats: 100 +warmup_iters: 10 +profile_memory: true diff --git a/benchmarks/attention_benchmarks/mla_runner.py b/benchmarks/attention_benchmarks/mla_runner.py index 0d612e374..f8bc7b4a1 100644 --- a/benchmarks/attention_benchmarks/mla_runner.py +++ b/benchmarks/attention_benchmarks/mla_runner.py @@ -60,9 +60,11 @@ def create_minimal_vllm_config( model_name: str = "deepseek-v3", block_size: int = 128, max_num_seqs: int = 256, + max_num_batched_tokens: int = 8192, mla_dims: dict | None = None, index_topk: int | None = None, prefill_backend: str | None = None, + kv_cache_dtype: str = "auto", ) -> VllmConfig: """ Create minimal VllmConfig for MLA benchmarks. @@ -149,13 +151,13 @@ def create_minimal_vllm_config( cache_config = CacheConfig( block_size=block_size, gpu_memory_utilization=0.9, - cache_dtype="auto", + cache_dtype=kv_cache_dtype, enable_prefix_caching=False, ) scheduler_config = SchedulerConfig( max_num_seqs=max_num_seqs, - max_num_batched_tokens=8192, + max_num_batched_tokens=max(max_num_batched_tokens, max_num_seqs), max_model_len=32768, is_encoder_decoder=False, enable_chunked_prefill=True, @@ -535,6 +537,7 @@ def _create_backend_impl( device: torch.device, max_num_tokens: int = 8192, index_topk: int | None = None, + kv_cache_dtype: str = "auto", ): """ Create backend implementation instance. @@ -583,7 +586,7 @@ def _create_backend_impl( "num_kv_heads": mla_dims["num_kv_heads"], "alibi_slopes": None, "sliding_window": None, - "kv_cache_dtype": "auto", + "kv_cache_dtype": kv_cache_dtype, "logits_soft_cap": None, "attn_type": "decoder", "kv_sharing_target_layer_name": None, @@ -701,6 +704,7 @@ def _run_single_benchmark( mla_dims: dict, device: torch.device, indexer=None, + kv_cache_dtype: str | None = None, ) -> BenchmarkResult: """ Run a single benchmark iteration. @@ -734,49 +738,124 @@ def _run_single_benchmark( ) # Create KV cache - kv_cache = torch.zeros( - num_blocks, - block_size, - mla_dims["kv_lora_rank"] + mla_dims["qk_rope_head_dim"], - device=device, - dtype=torch.bfloat16, - ) + if kv_cache_dtype is None: + kv_cache_dtype = getattr(config, "kv_cache_dtype", "auto") + head_size = mla_dims["kv_lora_rank"] + mla_dims["qk_rope_head_dim"] + if kv_cache_dtype == "fp8_ds_mla": + # FlashMLA sparse custom format: 656 bytes per token, stored as uint8. + # Layout: kv_lora_rank fp8 bytes + 4 float32 tile scales + # + 2*rope_dim bf16 bytes + # = 512 + 16 + 128 = 656 bytes for DeepSeek dims. + kv_cache = torch.zeros( + num_blocks, + block_size, + 656, + device=device, + dtype=torch.uint8, + ) + elif kv_cache_dtype == "fp8": + from vllm.platforms import current_platform - # Create input tensors for both decode and prefill modes - decode_inputs, prefill_inputs = _create_input_tensors( - total_q, - mla_dims, - backend_cfg["query_format"], - device, - torch.bfloat16, - ) + kv_cache = torch.zeros( + num_blocks, + block_size, + head_size, + device=device, + dtype=torch.uint8, + ).view(current_platform.fp8_dtype()) + else: + kv_cache = torch.zeros( + num_blocks, + block_size, + head_size, + device=device, + dtype=torch.bfloat16, + ) # Fill indexer with random indices for sparse backends is_sparse = backend_cfg.get("is_sparse", False) if is_sparse and indexer is not None: indexer.fill_random_indices(total_q, max_kv_len) - # Determine which forward method to use based on metadata - if metadata.decode is not None: - forward_fn = lambda: impl.forward_mqa(decode_inputs, kv_cache, metadata, layer) - elif metadata.prefill is not None: - forward_fn = lambda: impl.forward_mha( - prefill_inputs["q"], - prefill_inputs["k_c_normed"], - prefill_inputs["k_pe"], - kv_cache, - metadata, - prefill_inputs["k_scale"], - prefill_inputs["output"], - ) - else: + # Determine which forward methods to use based on metadata. + # Sparse MLA backends always use forward_mqa + has_decode = is_sparse or getattr(metadata, "decode", None) is not None + has_prefill = not is_sparse and getattr(metadata, "prefill", None) is not None + if not has_decode and not has_prefill: raise RuntimeError("Metadata has neither decode nor prefill metadata") + num_decode = ( + metadata.num_decode_tokens + if (has_decode and has_prefill) + else total_q + if has_decode + else 0 + ) + num_prefill = total_q - num_decode + + # Some backends requires fp8 queries when using fp8 KV cache. + is_fp8_kvcache = kv_cache_dtype.startswith("fp8") + quantize_query = is_fp8_kvcache and getattr( + impl, "supports_quant_query_input", False + ) + + # quantize_query forces concat format + query_fmt = "concat" if quantize_query else backend_cfg["query_format"] + + # Create decode query tensors + if has_decode: + decode_inputs, _ = _create_input_tensors( + num_decode, mla_dims, query_fmt, device, torch.bfloat16 + ) + # Cast decode query to fp8 if the backend supports it + if quantize_query: + from vllm.platforms import current_platform + + if isinstance(decode_inputs, tuple): + decode_inputs = torch.cat(list(decode_inputs), dim=-1) + decode_inputs = decode_inputs.to(current_platform.fp8_dtype()) + + # Create prefill input tensors + if has_prefill: + _, prefill_inputs = _create_input_tensors( + num_prefill, mla_dims, query_fmt, device, torch.bfloat16 + ) + + # Build forward function + def forward_fn(): + results = [] + if has_decode: + results.append(impl.forward_mqa(decode_inputs, kv_cache, metadata, layer)) + if has_prefill: + results.append( + impl.forward_mha( + prefill_inputs["q"], + prefill_inputs["k_c_normed"], + prefill_inputs["k_pe"], + kv_cache, + metadata, + prefill_inputs["k_scale"], + prefill_inputs["output"], + ) + ) + return results[0] if len(results) == 1 else tuple(results) + # Warmup for _ in range(config.warmup_iters): forward_fn() torch.accelerator.synchronize() + # Optionally capture a CUDA graph after warmup. + # Graph replay eliminates CPU launch overhead so timings reflect pure + # kernel time. + if config.use_cuda_graphs: + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + forward_fn() + benchmark_fn = graph.replay + else: + benchmark_fn = forward_fn + # Benchmark times = [] for _ in range(config.repeats): @@ -785,7 +864,7 @@ def _run_single_benchmark( start.record() for _ in range(config.num_layers): - forward_fn() + benchmark_fn() end.record() torch.accelerator.synchronize() @@ -852,13 +931,30 @@ def _run_mla_benchmark_batched( # Determine if this is a sparse backend is_sparse = backend_cfg.get("is_sparse", False) + # Extract kv_cache_dtype from the first config + kv_cache_dtype = getattr(first_config, "kv_cache_dtype", "auto") + + # FlashMLA sparse only supports "fp8_ds_mla" internally (not generic "fp8"). + # Remap here so the user can pass --kv-cache-dtype fp8 regardless of backend. + if backend.upper() == "FLASHMLA_SPARSE" and kv_cache_dtype == "fp8": + kv_cache_dtype = "fp8_ds_mla" + + # Compute max total_q across all configs so the metadata builder buffer + # and scheduler config are large enough for all batch specs. + max_total_q = max( + sum(r.q_len for r in parse_batch_spec(cfg.batch_spec)) + for cfg, *_ in configs_with_params + ) + # Create and set vLLM config for MLA (reused across all benchmarks) vllm_config = create_minimal_vllm_config( model_name="deepseek-v3", # Used only for model path block_size=block_size, + max_num_batched_tokens=max_total_q, mla_dims=mla_dims, # Use custom dims from config or default index_topk=index_topk if is_sparse else None, prefill_backend=prefill_backend, + kv_cache_dtype=kv_cache_dtype, ) results = [] @@ -883,7 +979,9 @@ def _run_mla_benchmark_batched( mla_dims, vllm_config, device, + max_num_tokens=max_total_q, index_topk=index_topk if is_sparse else None, + kv_cache_dtype=kv_cache_dtype, ) # Verify the actual prefill backend matches what was requested @@ -942,6 +1040,7 @@ def _run_mla_benchmark_batched( mla_dims, device, indexer=indexer, + kv_cache_dtype=kv_cache_dtype, ) results.append(result) diff --git a/benchmarks/attention_benchmarks/runner.py b/benchmarks/attention_benchmarks/runner.py index 6af56e0e9..aa636cd9c 100644 --- a/benchmarks/attention_benchmarks/runner.py +++ b/benchmarks/attention_benchmarks/runner.py @@ -140,7 +140,7 @@ def _create_vllm_config( cache_config = CacheConfig( block_size=config.block_size, - cache_dtype="auto", + cache_dtype=config.kv_cache_dtype, ) cache_config.num_gpu_blocks = max_num_blocks cache_config.num_cpu_blocks = 0 @@ -215,7 +215,7 @@ def _create_backend_impl( num_kv_heads=config.num_kv_heads, alibi_slopes=None, sliding_window=None, - kv_cache_dtype="auto", + kv_cache_dtype=config.kv_cache_dtype, ) kv_cache_spec = FullAttentionSpec( @@ -288,12 +288,22 @@ def _create_input_tensors( total_q: int, device: torch.device, dtype: torch.dtype, + quantize_query: bool = False, ) -> tuple: - """Create Q, K, V input tensors for all layers.""" + """Create Q, K, V input tensors for all layers. + + When quantize_query is True, queries are cast to fp8 to match backends + that require query/key/value dtype consistency. + """ + q_dtype = dtype + if quantize_query: + from vllm.platforms import current_platform + + q_dtype = current_platform.fp8_dtype() q_list = [ torch.randn( total_q, config.num_q_heads, config.head_dim, device=device, dtype=dtype - ) + ).to(q_dtype) for _ in range(config.num_layers) ] k_list = [ @@ -344,10 +354,17 @@ def _create_kv_cache( # Compute inverse permutation to get back to logical view inv_order = [stride_order.index(i) for i in range(len(stride_order))] + # Use fp8 dtype for cache when requested. + cache_dtype = dtype + if config.kv_cache_dtype == "fp8": + from vllm.platforms import current_platform + + cache_dtype = current_platform.fp8_dtype() + cache_list = [] for _ in range(config.num_layers): # Allocate in physical layout order (contiguous in memory) - cache = torch.zeros(*physical_shape, device=device, dtype=dtype) + cache = torch.zeros(*physical_shape, device=device, dtype=cache_dtype) # Permute to logical view cache = cache.permute(*inv_order) cache_list.append(cache) @@ -392,6 +409,37 @@ def _run_single_benchmark( ) torch.accelerator.synchronize() + # Optionally capture a CUDA graph after warmup. + # Graph replay eliminates CPU launch overhead so timings reflect pure + # kernel time. + if config.use_cuda_graphs: + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + for i in range(config.num_layers): + impl.forward( + layer, + q_list[i], + k_list[i], + v_list[i], + cache_list[i], + attn_metadata, + output=out, + ) + benchmark_fn = graph.replay + else: + + def benchmark_fn(): + for i in range(config.num_layers): + impl.forward( + layer, + q_list[i], + k_list[i], + v_list[i], + cache_list[i], + attn_metadata, + output=out, + ) + # Benchmark times = [] for _ in range(config.repeats): @@ -399,16 +447,7 @@ def _run_single_benchmark( end = torch.cuda.Event(enable_timing=True) start.record() - for i in range(config.num_layers): - impl.forward( - layer, - q_list[i], - k_list[i], - v_list[i], - cache_list[i], - attn_metadata, - output=out, - ) + benchmark_fn() end.record() torch.accelerator.synchronize() @@ -502,8 +541,12 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult: common_attn_metadata=common_metadata, ) + # Only quantize queries when the impl supports it + quantize_query = config.kv_cache_dtype.startswith("fp8") and getattr( + impl, "supports_quant_query_input", False + ) q_list, k_list, v_list = _create_input_tensors( - config, total_q, device, dtype + config, total_q, device, dtype, quantize_query=quantize_query ) cache_list = _create_kv_cache( -- GitLab From 31a458c0913e2c498da004e16ba2ac922bcebe96 Mon Sep 17 00:00:00 2001 From: Yuchen Fama Date: Mon, 16 Mar 2026 18:27:42 -0400 Subject: [PATCH 006/223] [Doc] Clarify schema enforcement behavior for tool_choice modes (#37064) Signed-off-by: yfama --- docs/features/tool_calling.md | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/docs/features/tool_calling.md b/docs/features/tool_calling.md index b590b33e9..cea117541 100644 --- a/docs/features/tool_calling.md +++ b/docs/features/tool_calling.md @@ -107,6 +107,27 @@ vLLM supports the `tool_choice='none'` option in the chat completion API. When t !!! note When tools are specified in the request, vLLM includes tool definitions in the prompt by default, regardless of the `tool_choice` setting. To exclude tool definitions when `tool_choice='none'`, use the `--exclude-tools-when-tool-choice-none` option. +## Constrained Decoding Behavior + +Whether vLLM enforces the tool parameter schema during generation depends on the `tool_choice` mode: + +| `tool_choice` value | Schema-constrained decoding | Behavior | +| --- | --- | --- | +| Named function | Yes (via structured outputs backend) | Arguments are guaranteed to be valid JSON conforming to the function's parameter schema. | +| `"required"` | Yes (via structured outputs backend) | Same as named function. The model must produce at least one tool call. | +| `"auto"` | No | The model generates freely. A tool-call parser extracts tool calls from the raw text. Arguments may be malformed or not match the schema. | +| `"none"` | N/A | No tool calls are produced. | + +When schema conformance matters, prefer `tool_choice="required"` or named function calling over `"auto"`. + +### Strict Mode (`strict` parameter) + +The [OpenAI API](https://platform.openai.com/docs/guides/function-calling#strict-mode) supports a `strict` field on function definitions. When set to `true`, OpenAI uses constrained decoding to guarantee that tool-call arguments match the function schema, even in `tool_choice="auto"` mode. + +vLLM **does not implement** `strict` mode today. The `strict` field is accepted in requests (to avoid breaking clients that set it), but it has no effect on decoding behavior. In auto mode, argument validity depends entirely on the model's output quality and the parser's extraction logic. + +Tracking issues: [#15526](https://github.com/vllm-project/vllm/issues/15526), [#16313](https://github.com/vllm-project/vllm/issues/16313). + ## Automatic Function Calling To enable this feature, you should set the following flags: @@ -124,6 +145,9 @@ from HuggingFace; and you can find an example of this in a `tokenizer_config.jso If your favorite tool-calling model is not supported, please feel free to contribute a parser & tool use chat template! +!!! note + With `tool_choice="auto"`, tool-call arguments are extracted from the model's raw text output by the selected parser. No schema-level constraint is applied during decoding, so arguments may occasionally be malformed or violate the function's parameter schema. See [Constrained Decoding Behavior](#constrained-decoding-behavior) for details. + ### Hermes Models (`hermes`) All Nous Research Hermes-series models newer than Hermes 2 Pro should be supported. -- GitLab From 4f9b14c21cd4eb4b56c972b3280be41d341056d1 Mon Sep 17 00:00:00 2001 From: Andreas Karatzas Date: Mon, 16 Mar 2026 17:40:23 -0500 Subject: [PATCH 007/223] [CI] Stabilize multinode DP internal LB completion tests (#36356) Signed-off-by: Andreas Karatzas --- tests/v1/distributed/test_internal_lb_dp.py | 183 ++++++++++---------- 1 file changed, 89 insertions(+), 94 deletions(-) diff --git a/tests/v1/distributed/test_internal_lb_dp.py b/tests/v1/distributed/test_internal_lb_dp.py index 8f7459e95..efd9fc607 100644 --- a/tests/v1/distributed/test_internal_lb_dp.py +++ b/tests/v1/distributed/test_internal_lb_dp.py @@ -12,7 +12,7 @@ import pytest import pytest_asyncio import requests -from tests.utils import RemoteOpenAIServer +from tests.utils import ROCM_ENV_OVERRIDES, RemoteOpenAIServer from tests.v1.utils import check_request_balancing from vllm.platforms import current_platform @@ -27,6 +27,84 @@ TP_SIZE = int(os.getenv("TP_SIZE", "1")) NUM_NODES = 2 +async def _make_completion_request( + client: openai.AsyncOpenAI, + model_name: str, +) -> openai.types.Completion: + """Make a single completion request and validate the response. + + Uses temperature=1.0 to ensure diverse outputs across concurrent + requests for realistic load balancer testing. + """ + completion = await client.completions.create( + model=model_name, + prompt="Hello, my name is", + max_tokens=5, + temperature=1.0, + ) + + assert completion.id is not None, ( + f"Expected non-None completion id. usage={completion.usage!r}" + ) + assert completion.choices is not None and len(completion.choices) == 1, ( + f"Expected 1 choice, got " + f"{len(completion.choices) if completion.choices else 'None'}" + ) + + choice = completion.choices[0] + # With temperature=1.0, the model may emit a stop token immediately, + # producing empty text with finish_reason='stop'. This is valid + # model behavior - the test's purpose is load balancing, not output + # quality. + assert choice.finish_reason in ("length", "stop"), ( + f"Expected finish_reason 'length' or 'stop', " + f"got {choice.finish_reason!r}. text={choice.text!r}" + ) + if choice.finish_reason == "length": + assert len(choice.text) >= 1, ( + f"Expected non-empty text with finish_reason='length', got {choice.text!r}" + ) + + assert completion.usage.prompt_tokens > 0, ( + f"Expected positive prompt_tokens, got {completion.usage.prompt_tokens}" + ) + assert completion.usage.total_tokens > 0, ( + f"Expected positive total_tokens, got {completion.usage.total_tokens}" + ) + return completion + + +async def _run_request_bursts( + client: openai.AsyncOpenAI, + model_name: str, + num_requests: int = 200, + num_bursts: int = 2, +): + """Send multiple bursts of completion requests and validate all succeed.""" + for burst in range(num_bursts): + all_tasks = [] + for _ in range(num_requests): + all_tasks.append( + asyncio.create_task(_make_completion_request(client, model_name)) + ) + await asyncio.sleep(0.01) + + results = await asyncio.gather(*all_tasks, return_exceptions=True) + assert len(results) == num_requests, ( + f"Burst {burst}: expected {num_requests} results, got {len(results)}" + ) + + for result in results: + if isinstance(result, BaseException): + raise result + + assert all(completion is not None for completion in results), ( + f"Burst {burst}: some completions were None" + ) + + await asyncio.sleep(0.5) + + class MultinodeInternalLBServerManager: """Manages multi-node data parallel vLLM server instances for internal load balancer testing using --headless mode.""" @@ -108,6 +186,7 @@ class MultinodeInternalLBServerManager: auto_port=False, env_dict={ "VLLM_SERVER_DEV_MODE": "1", + **ROCM_ENV_OVERRIDES, current_platform.device_control_env_var: ",".join( str(current_platform.device_id_to_physical_device_id(i)) for i in range(r, r + gpus_per_node) @@ -229,6 +308,7 @@ class APIOnlyServerManager: auto_port=False, env_dict={ "VLLM_SERVER_DEV_MODE": "1", + **ROCM_ENV_OVERRIDES, # No GPUs needed for API-only server }, ) @@ -249,10 +329,11 @@ class APIOnlyServerManager: engines_server_args, auto_port=False, env_dict={ + **ROCM_ENV_OVERRIDES, current_platform.device_control_env_var: ",".join( str(current_platform.device_id_to_physical_device_id(i)) for i in range(self.dp_size * self.tp_size) - ) + ), }, ) server.__enter__() @@ -395,58 +476,15 @@ async def test_multinode_dp_completion( servers: list[tuple[RemoteOpenAIServer, list[str]]], model_name: str, ) -> None: - async def make_request(): - completion = await client.completions.create( - model=model_name, prompt="Hello, my name is", max_tokens=5, temperature=1.0 - ) - - assert completion.id is not None - assert completion.choices is not None and len(completion.choices) == 1 - - choice = completion.choices[0] - # The exact number of tokens can vary slightly with temperature=1.0, - # so we check for a reasonable minimum length. - assert len(choice.text) >= 1 - # Finish reason might not always be 'length' if the model finishes early - # or due to other reasons, especially with high temperature. - # So, we'll accept 'length' or 'stop'. - assert choice.finish_reason in ("length", "stop") - - # Token counts can also vary, so we check they are positive. - assert completion.usage.completion_tokens > 0 - assert completion.usage.prompt_tokens > 0 - assert completion.usage.total_tokens > 0 - return completion - # Test single request - result = await make_request() + result = await _make_completion_request(client, model_name) assert result is not None print("Multi-node internal LB handled single completion request successfully") await asyncio.sleep(0.5) - # Send multiple requests - internal LB should distribute across DP ranks - num_requests = 200 - all_tasks = [] - for _ in range(num_requests): - all_tasks.append(asyncio.create_task(make_request())) - await asyncio.sleep(0.01) - - results = await asyncio.gather(*all_tasks) - assert len(results) == num_requests - assert all(completion is not None for completion in results) - - await asyncio.sleep(0.5) - - # Second burst of requests - all_tasks = [] - for _ in range(num_requests): - all_tasks.append(asyncio.create_task(make_request())) - await asyncio.sleep(0.01) - - results = await asyncio.gather(*all_tasks) - assert len(results) == num_requests - assert all(completion is not None for completion in results) + # Send multiple bursts - internal LB should distribute across DP ranks + await _run_request_bursts(client, model_name) _, server_args = servers[0] api_server_count = ( @@ -570,59 +608,16 @@ async def test_api_only_multinode_dp_completion( ) -> None: """Test API-only server with all engines on separate headless server.""" - async def make_request(): - completion = await api_only_client.completions.create( - model=model_name, prompt="Hello, my name is", max_tokens=5, temperature=1.0 - ) - - assert completion.id is not None - assert completion.choices is not None and len(completion.choices) == 1 - - choice = completion.choices[0] - # The exact number of tokens can vary slightly with temperature=1.0, - # so we check for a reasonable minimum length. - assert len(choice.text) >= 1 - # Finish reason might not always be 'length' if the model finishes - # early or due to other reasons, especially with high temperature. - # So, we'll accept 'length' or 'stop'. - assert choice.finish_reason in ("length", "stop") - - # Token counts can also vary, so we check they are positive. - assert completion.usage.completion_tokens > 0 - assert completion.usage.prompt_tokens > 0 - assert completion.usage.total_tokens > 0 - return completion - # Test single request - result = await make_request() + result = await _make_completion_request(api_only_client, model_name) assert result is not None print("API-only server handled single completion request successfully") await asyncio.sleep(0.5) - # Send multiple requests - should be distributed across engines on + # Send multiple bursts - should be distributed across engines on # headless server - num_requests = 200 - all_tasks = [] - for _ in range(num_requests): - all_tasks.append(asyncio.create_task(make_request())) - await asyncio.sleep(0.01) - - results = await asyncio.gather(*all_tasks) - assert len(results) == num_requests - assert all(completion is not None for completion in results) - - await asyncio.sleep(0.5) - - # Second burst of requests - all_tasks = [] - for _ in range(num_requests): - all_tasks.append(asyncio.create_task(make_request())) - await asyncio.sleep(0.01) - - results = await asyncio.gather(*all_tasks) - assert len(results) == num_requests - assert all(completion is not None for completion in results) + await _run_request_bursts(api_only_client, model_name) api_server, api_server_args = api_only_servers[0] api_server_count = ( -- GitLab From 7961486a9b749b1b60d8b6fd5fb7d61596a9b041 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Mon, 16 Mar 2026 23:41:00 +0100 Subject: [PATCH 008/223] Fix EagleMistralLarge3Model initialization (#37232) Signed-off-by: juliendenize --- vllm/model_executor/models/mistral_large_3_eagle.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/mistral_large_3_eagle.py b/vllm/model_executor/models/mistral_large_3_eagle.py index 4567f24fd..3fcc048f9 100644 --- a/vllm/model_executor/models/mistral_large_3_eagle.py +++ b/vllm/model_executor/models/mistral_large_3_eagle.py @@ -74,6 +74,7 @@ class EagleMistralLarge3Model(DeepseekV2Model): prefix=maybe_prefix(prefix, "fc"), ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.aux_hidden_state_layers: tuple[int, ...] = () self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size ) -- GitLab From 3e6a1e1686958dcd7eff1438bc5418b8d56daa30 Mon Sep 17 00:00:00 2001 From: Terry Gao <32590313+tianrengao@users.noreply.github.com> Date: Mon, 16 Mar 2026 15:51:46 -0700 Subject: [PATCH 009/223] [Custom Ops] Add functional + out variant for scaled_fp4_quant (#34389) Signed-off-by: tianrengao --- csrc/ops.h | 12 +- csrc/quantization/fp4/nvfp4_quant_entry.cu | 37 +++++- csrc/quantization/fp4/nvfp4_utils.cuh | 13 +++ csrc/torch_bindings.cpp | 19 +++- .../distributed/test_fusion_all_reduce.py | 2 +- .../kernels/quantization/test_nvfp4_quant.py | 46 ++++++++ vllm/_custom_ops.py | 106 ++++++++++++++---- .../passes/fusion/act_quant_fusion.py | 4 +- .../passes/fusion/allreduce_rms_fusion.py | 10 +- .../passes/fusion/attn_quant_fusion.py | 4 +- .../passes/fusion/matcher_utils.py | 2 +- .../passes/fusion/rms_quant_fusion.py | 2 +- 12 files changed, 213 insertions(+), 44 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index 921d6484d..299650be7 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -295,10 +295,14 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a, std::vector cutlass_sparse_compress(torch::Tensor const& a); -void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, - torch::Tensor& output_scale, - torch::Tensor const& input_scale, - bool is_sf_swizzled_layout); +std::tuple scaled_fp4_quant_func( + torch::Tensor const& input, torch::Tensor const& input_scale, + bool is_sf_swizzled_layout); + +void scaled_fp4_quant_out(torch::Tensor const& input, + torch::Tensor const& input_scale, + bool is_sf_swizzled_layout, torch::Tensor& output, + torch::Tensor& output_scale); void scaled_fp4_experts_quant( torch::Tensor& output, torch::Tensor& output_scale, diff --git a/csrc/quantization/fp4/nvfp4_quant_entry.cu b/csrc/quantization/fp4/nvfp4_quant_entry.cu index 650b9da8a..8b5a1fd22 100644 --- a/csrc/quantization/fp4/nvfp4_quant_entry.cu +++ b/csrc/quantization/fp4/nvfp4_quant_entry.cu @@ -16,6 +16,8 @@ #include +#include "nvfp4_utils.cuh" + #if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, @@ -51,9 +53,10 @@ void silu_and_mul_scaled_fp4_experts_quant_sm1xxa( torch::Tensor const& output_scale_offset_by_experts); #endif -void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, - torch::Tensor& output_sf, torch::Tensor const& input_sf, - bool is_sf_swizzled_layout) { +void scaled_fp4_quant_out(torch::Tensor const& input, + torch::Tensor const& input_sf, + bool is_sf_swizzled_layout, torch::Tensor& output, + torch::Tensor& output_sf) { #if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf, @@ -62,6 +65,34 @@ void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization kernel"); } +std::tuple scaled_fp4_quant_func( + torch::Tensor const& input, torch::Tensor const& input_sf, + bool is_sf_swizzled_layout) { + int64_t n = input.size(-1); + int64_t m = input.numel() / n; + auto device = input.device(); + + // Two fp4 values packed into a uint8 + auto output = torch::empty( + {m, n / 2}, torch::TensorOptions().device(device).dtype(torch::kUInt8)); + + torch::Tensor output_sf; + if (is_sf_swizzled_layout) { + auto [sf_m, sf_n] = vllm::computeSwizzledSFShape(m, n); + output_sf = torch::empty( + {sf_m, sf_n}, + torch::TensorOptions().device(device).dtype(torch::kInt32)); + } else { + output_sf = torch::empty( + {m, n / CVT_FP4_SF_VEC_SIZE}, + torch::TensorOptions().device(device).dtype(torch::kUInt8)); + } + + scaled_fp4_quant_out(input, input_sf, is_sf_swizzled_layout, output, + output_sf); + return {output, output_sf}; +} + void scaled_fp4_experts_quant( torch::Tensor& output, torch::Tensor& output_scale, torch::Tensor const& input, torch::Tensor const& input_global_scale, diff --git a/csrc/quantization/fp4/nvfp4_utils.cuh b/csrc/quantization/fp4/nvfp4_utils.cuh index c1df1860c..0c04f0108 100644 --- a/csrc/quantization/fp4/nvfp4_utils.cuh +++ b/csrc/quantization/fp4/nvfp4_utils.cuh @@ -18,6 +18,7 @@ #include #include +#include #include "../../cuda_vec_utils.cuh" @@ -54,6 +55,18 @@ inline int computeEffectiveRows(int m) { return round_up(m, ROW_TILE); } +// Compute the shape of the swizzled SF output tensor. +// Returns (rounded_m, rounded_n / 4) where: +// rounded_m = round_up(m, 128) +// rounded_n = round_up(n / CVT_FP4_SF_VEC_SIZE, 4) +inline std::pair computeSwizzledSFShape(int64_t m, + int64_t n) { + int64_t rounded_m = round_up(m, static_cast(128)); + int64_t scale_n = n / CVT_FP4_SF_VEC_SIZE; + int64_t rounded_n = round_up(scale_n, static_cast(4)); + return {rounded_m, rounded_n / 4}; +} + // Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). inline __device__ uint32_t fp32_vec8_to_e2m1(float (&array)[8]) { uint32_t val; diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index d98e987d9..aadc9fe33 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -564,10 +564,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Compute NVFP4 block quantized tensor. ops.def( - "scaled_fp4_quant(Tensor! output, Tensor input," - " Tensor! output_scale, Tensor input_scale, bool " - "is_sf_swizzled_layout) -> ()"); - ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant); + "scaled_fp4_quant(Tensor input," + " Tensor input_scale, bool " + "is_sf_swizzled_layout) -> (Tensor, Tensor)"); + ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant_func); + + // Out variant + // TODO: Add {at::Tag::out_variant} tag and update all call sites + // to use the functional variant once vLLM upgrades PyTorch. + // See pytorch/pytorch#176117. + ops.def( + "scaled_fp4_quant.out(Tensor input," + " Tensor input_scale, bool " + "is_sf_swizzled_layout, *, Tensor(a!) output, Tensor(b!) output_scale) " + "-> ()"); + ops.impl("scaled_fp4_quant.out", torch::kCUDA, &scaled_fp4_quant_out); // Compute NVFP4 experts quantization. ops.def( diff --git a/tests/compile/passes/distributed/test_fusion_all_reduce.py b/tests/compile/passes/distributed/test_fusion_all_reduce.py index fe50081e5..92e7402c0 100644 --- a/tests/compile/passes/distributed/test_fusion_all_reduce.py +++ b/tests/compile/passes/distributed/test_fusion_all_reduce.py @@ -179,7 +179,7 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module): def ops_in_model_before(self): return [ torch.ops.vllm.all_reduce.default, - torch.ops._C.scaled_fp4_quant.default, + torch.ops._C.scaled_fp4_quant.out, ] diff --git a/tests/kernels/quantization/test_nvfp4_quant.py b/tests/kernels/quantization/test_nvfp4_quant.py index 1d2f9d413..e2db59758 100644 --- a/tests/kernels/quantization/test_nvfp4_quant.py +++ b/tests/kernels/quantization/test_nvfp4_quant.py @@ -159,6 +159,52 @@ def test_quantize_to_fp4( torch.testing.assert_close(scale_ans, scale_ref) +@pytest.mark.parametrize( + "shape", + [(32, 4096), (128, 4096), (1, 64), (127, 1024), (256, 16384)], +) +@pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) +@torch.inference_mode() +def test_python_util_matches_cpp_allocation( + shape: tuple[int, int], + is_sf_swizzled_layout: bool, +) -> None: + """ + Verify that the Python utility (create_fp4_output_tensors) allocates + tensors with the same shapes and dtypes as the C++ functional variant + (scaled_fp4_quant_func). + """ + from vllm._custom_ops import create_fp4_output_tensors + + torch.set_default_device("cuda:0") + m, n = shape + input_tensor = torch.randn((m, n), dtype=torch.bfloat16) + input_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda:0") + + # C++ functional variant allocates internally + cpp_out, cpp_scale = torch.ops._C.scaled_fp4_quant( + input_tensor, input_scale, is_sf_swizzled_layout + ) + + # Python utility + py_out, py_scale = create_fp4_output_tensors( + m, n, torch.device("cuda:0"), is_sf_swizzled_layout + ) + + assert py_out.shape == cpp_out.shape, ( + f"Output shape mismatch: Python {py_out.shape} vs C++ {cpp_out.shape}" + ) + assert py_out.dtype == cpp_out.dtype, ( + f"Output dtype mismatch: Python {py_out.dtype} vs C++ {cpp_out.dtype}" + ) + assert py_scale.shape == cpp_scale.shape, ( + f"Scale shape mismatch: Python {py_scale.shape} vs C++ {cpp_scale.shape}" + ) + assert py_scale.dtype == cpp_scale.dtype, ( + f"Scale dtype mismatch: Python {py_scale.dtype} vs C++ {cpp_scale.dtype}" + ) + + @pytest.mark.parametrize("pad_shape", PAD_SHAPES) @torch.inference_mode() def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None: diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index fdc468d3b..63f347d89 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -29,6 +29,81 @@ else: from torch.library import impl_abstract as register_fake +# scaled_fp4_quant functional + out variant for torch.compile buffer management + + +def create_fp4_scale_tensor( + m: int, + n: int, + device: torch.device, + is_sf_swizzled_layout: bool, +) -> torch.Tensor: + """ + Allocate the output scale tensor for scaled_fp4_quant. + + When is_sf_swizzled_layout=True, we use rounded values to store the + swizzled scales. Due to the requirement of the Tensor Core, the minimum + tile is 128x4 for the scales. So, we first pad the scales to multiples + of 128 (rows) and 4 (cols). Then, the scales (in float8_e4m3fn) are + packed into an int32 for every 4 values. More: + https://docs.nvidia.com/cuda/parallel-thread-execution/ + #tcgen05-mma-scale-factor-b-layout-4x + """ + from vllm.utils.math_utils import round_up + + block_size = 16 + if is_sf_swizzled_layout: + rounded_m = round_up(m, 128) + scale_n = n // block_size + rounded_n = round_up(scale_n, 4) + return torch.empty( + (rounded_m, rounded_n // 4), device=device, dtype=torch.int32 + ) + else: + return torch.empty((m, n // block_size), device=device, dtype=torch.uint8) + + +def create_fp4_output_tensors( + m: int, + n: int, + device: torch.device, + is_sf_swizzled_layout: bool, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Allocate both output tensors for scaled_fp4_quant: + (quantized_output, output_scale). + + Must match the C++ scaled_fp4_quant_func allocation exactly. + """ + output = torch.empty((m, n // 2), device=device, dtype=torch.uint8) + output_scale = create_fp4_scale_tensor(m, n, device, is_sf_swizzled_layout) + return output, output_scale + + +if hasattr(torch.ops, "_C") and hasattr(torch.ops._C, "scaled_fp4_quant"): + + @register_fake("_C::scaled_fp4_quant") + def _scaled_fp4_quant_fake( + input: torch.Tensor, + input_scale: torch.Tensor, + is_sf_swizzled_layout: bool, + ) -> tuple[torch.Tensor, torch.Tensor]: + n = input.shape[-1] + m = input.numel() // n + return create_fp4_output_tensors(m, n, input.device, is_sf_swizzled_layout) + + @register_fake("_C::scaled_fp4_quant.out") + def _scaled_fp4_quant_out_fake( + input: torch.Tensor, + input_scale: torch.Tensor, + is_sf_swizzled_layout: bool, + *, + output: torch.Tensor, + output_scale: torch.Tensor, + ) -> None: + return None + + # page attention ops def paged_attention_v1( out: torch.Tensor, @@ -1644,7 +1719,6 @@ def scaled_fp4_quant( input = input.reshape(other_dims, input.shape[-1]) m, n = input.shape block_size = 16 - device = input.device assert n % block_size == 0, f"last dim has to be multiple of 16, but got {n}." assert input.dtype in (torch.float16, torch.bfloat16), ( @@ -1658,26 +1732,16 @@ def scaled_fp4_quant( input, input_global_scale ) else: - # Two fp4 values will be packed into an uint8. - output = torch.empty((m, n // 2), device=device, dtype=torch.uint8) - if is_sf_swizzled_layout: - # We use the rounded values to store the swizzled values. Due to the - # requirement of the Tensor Core, the minimum tile is 128x4 for the scales. - # So, we first pad the scales to multiples of 128 and 4. Then, the scales - # (in float8_e4m3fn) are packed into an int32 for every 4 values. More: - # https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x - round_up = lambda x, y: (x + y - 1) // y * y - rounded_m = round_up(m, 128) - scale_n = n // block_size - rounded_n = round_up(scale_n, 4) - output_scale = torch.empty( - (rounded_m, rounded_n // 4), device=device, dtype=torch.int32 - ) - else: - output_scale = torch.empty((m, n // 16), device=device, dtype=torch.uint8) - - torch.ops._C.scaled_fp4_quant( - output, input, output_scale, input_global_scale, is_sf_swizzled_layout + # Pre-allocate and call .out variant (same behavior as old in-place API) + output, output_scale = create_fp4_output_tensors( + m, n, input.device, is_sf_swizzled_layout + ) + torch.ops._C.scaled_fp4_quant.out( + input, + input_global_scale, + is_sf_swizzled_layout, + output=output, + output_scale=output_scale, ) output_scale = output_scale.view(torch.float8_e4m3fn) diff --git a/vllm/compilation/passes/fusion/act_quant_fusion.py b/vllm/compilation/passes/fusion/act_quant_fusion.py index e14100384..911775f69 100644 --- a/vllm/compilation/passes/fusion/act_quant_fusion.py +++ b/vllm/compilation/passes/fusion/act_quant_fusion.py @@ -148,11 +148,11 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern): result_silu_mul = self.silu_and_mul_matcher(input) at = auto_functionalized( self.QUANT_OP, - output=result, input=result_silu_mul, - output_scale=output_scale, input_scale=scale, is_sf_swizzled_layout=True, + output=result, + output_scale=output_scale, ) return at[1], at[2] diff --git a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py index 44dc3d67b..f141a7c17 100644 --- a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py +++ b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py @@ -47,7 +47,7 @@ if find_spec("flashinfer"): pass if hasattr(torch.ops._C, "scaled_fp4_quant"): - STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default + STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.out # Max size of the input tensor per world size per device capability # to use flashinfer fused allreduce @@ -562,11 +562,11 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern): rms = self.rmsnorm_matcher(all_reduce, weight) quant_out_tuple = auto_functionalized( STATIC_FP4_QUANT_OP, - output=quant_result, input=rms, - output_scale=output_scale, input_scale=input_global_scale, is_sf_swizzled_layout=True, + output=quant_result, + output_scale=output_scale, ) # quant_out, allreduce_output, output_scale @@ -660,11 +660,11 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern): rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual) quant_out_tuple = auto_functionalized( STATIC_FP4_QUANT_OP, - output=quant_result, input=rms, - output_scale=output_scale, input_scale=input_global_scale, is_sf_swizzled_layout=True, + output=quant_result, + output_scale=output_scale, ) # quant_out, allreduce_output, output_scale diff --git a/vllm/compilation/passes/fusion/attn_quant_fusion.py b/vllm/compilation/passes/fusion/attn_quant_fusion.py index 5e6bf28c0..0e1b846af 100644 --- a/vllm/compilation/passes/fusion/attn_quant_fusion.py +++ b/vllm/compilation/passes/fusion/attn_quant_fusion.py @@ -250,11 +250,11 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): ) at2 = auto_functionalized( self.QUANT_OP, - output=output_quant, input=attn_out_view, - output_scale=output_scale, input_scale=input_scale, is_sf_swizzled_layout=True, + output=output_quant, + output_scale=output_scale, ) output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE) return at2[1], output_scale_view diff --git a/vllm/compilation/passes/fusion/matcher_utils.py b/vllm/compilation/passes/fusion/matcher_utils.py index 03f680552..ec36c12d1 100644 --- a/vllm/compilation/passes/fusion/matcher_utils.py +++ b/vllm/compilation/passes/fusion/matcher_utils.py @@ -38,7 +38,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = { } if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): - QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 + QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out # noqa: E501 if current_platform.is_cuda(): QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 diff --git a/vllm/compilation/passes/fusion/rms_quant_fusion.py b/vllm/compilation/passes/fusion/rms_quant_fusion.py index 2d084783d..95ce7b22e 100644 --- a/vllm/compilation/passes/fusion/rms_quant_fusion.py +++ b/vllm/compilation/passes/fusion/rms_quant_fusion.py @@ -63,7 +63,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = { kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 } if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): - QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default + QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out if current_platform.is_cuda(): QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 -- GitLab From 7a49742b8867e7d310abfd85c944e54d090e9301 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Mon, 16 Mar 2026 19:46:20 -0400 Subject: [PATCH 010/223] [CI/Build] Add common tool call parser test suite (#27599) Signed-off-by: Ben Browning --- .../test_gigachat3_tool_parser.py | 2 +- .../test_hunyuan_a13b_tool_parser.py | 2 +- .../test_llama4_pythonic_tool_parser.py | 2 +- .../tool_parsers/test_olmo3_tool_parser.py | 2 +- .../tool_parsers/test_pythonic_tool_parser.py | 2 +- tests/tool_parsers/common_tests.py | 378 ++++++++++++++++++ tests/tool_parsers/conftest.py | 12 + .../test_deepseekv3_tool_parser.py | 92 +++++ .../test_granite_20b_fc_tool_parser.py | 76 ++++ .../tool_parsers/test_granite_tool_parser.py | 118 ++++++ .../test_internlm2_tool_parser.py | 122 ++++++ .../tool_parsers/test_longcat_tool_parser.py | 101 +++++ .../tool_parsers/test_phi4mini_tool_parser.py | 110 +++++ .../tool_parsers/test_qwen3xml_tool_parser.py | 75 ++++ tests/tool_parsers/test_step3_tool_parser.py | 112 ++++++ .../openai => }/tool_parsers/utils.py | 0 16 files changed, 1201 insertions(+), 5 deletions(-) create mode 100644 tests/tool_parsers/common_tests.py create mode 100644 tests/tool_parsers/conftest.py create mode 100644 tests/tool_parsers/test_deepseekv3_tool_parser.py create mode 100644 tests/tool_parsers/test_granite_20b_fc_tool_parser.py create mode 100644 tests/tool_parsers/test_granite_tool_parser.py create mode 100644 tests/tool_parsers/test_internlm2_tool_parser.py create mode 100644 tests/tool_parsers/test_longcat_tool_parser.py create mode 100644 tests/tool_parsers/test_phi4mini_tool_parser.py create mode 100644 tests/tool_parsers/test_qwen3xml_tool_parser.py create mode 100644 tests/tool_parsers/test_step3_tool_parser.py rename tests/{entrypoints/openai => }/tool_parsers/utils.py (100%) diff --git a/tests/entrypoints/openai/tool_parsers/test_gigachat3_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_gigachat3_tool_parser.py index 634ec421f..99ab1e497 100644 --- a/tests/entrypoints/openai/tool_parsers/test_gigachat3_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_gigachat3_tool_parser.py @@ -5,7 +5,7 @@ import json import pytest -from tests.entrypoints.openai.tool_parsers.utils import ( +from tests.tool_parsers.utils import ( run_tool_extraction, run_tool_extraction_streaming, ) diff --git a/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py index 89c91c2ec..90f08bb82 100644 --- a/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py @@ -7,7 +7,7 @@ from unittest.mock import MagicMock import pytest -from tests.entrypoints.openai.tool_parsers.utils import ( +from tests.tool_parsers.utils import ( run_tool_extraction, run_tool_extraction_streaming, ) diff --git a/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py index 914348153..1328d0571 100644 --- a/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py @@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch import pytest -from tests.entrypoints.openai.tool_parsers.utils import ( +from tests.tool_parsers.utils import ( run_tool_extraction, run_tool_extraction_streaming, ) diff --git a/tests/entrypoints/openai/tool_parsers/test_olmo3_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_olmo3_tool_parser.py index dbd7e1d48..4c418ba11 100644 --- a/tests/entrypoints/openai/tool_parsers/test_olmo3_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_olmo3_tool_parser.py @@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch import pytest -from tests.entrypoints.openai.tool_parsers.utils import ( +from tests.tool_parsers.utils import ( run_tool_extraction, run_tool_extraction_streaming, ) diff --git a/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py index 8ab4c5a5a..9d97c7f58 100644 --- a/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py @@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch import pytest -from tests.entrypoints.openai.tool_parsers.utils import ( +from tests.tool_parsers.utils import ( run_tool_extraction, run_tool_extraction_streaming, ) diff --git a/tests/tool_parsers/common_tests.py b/tests/tool_parsers/common_tests.py new file mode 100644 index 000000000..925506aa7 --- /dev/null +++ b/tests/tool_parsers/common_tests.py @@ -0,0 +1,378 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +from dataclasses import dataclass, field +from types import NoneType +from typing import Any + +import pytest + +from tests.tool_parsers.utils import run_tool_extraction +from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers import ToolParserManager + + +@dataclass +class ToolParserTestConfig: + """Configuration for a tool parser's common tests. + + This dataclass contains all the test data and expected results needed + to run the common test suite for a parser. Each parser test file + creates one instance of this config with parser-specific values. + + Attributes: + parser_name: Name used with ToolParserManager (e.g., "mistral") + + Test data (model outputs): + no_tool_calls_output: Plain text without any tool syntax + single_tool_call_output: One tool call with simple arguments + parallel_tool_calls_output: Multiple tool calls in one response + various_data_types_output: Tool with various data types + empty_arguments_output: Tool call with no parameters + surrounding_text_output: Tool call mixed with regular text + escaped_strings_output: Tool call with escaped chars + malformed_input_outputs: List of invalid inputs + + Expected results: + single_tool_call_expected_name: Expected function name + single_tool_call_expected_args: Expected arguments dict + parallel_tool_calls_count: Number of tools in parallel test + parallel_tool_calls_names: Function names in order + single_tool_call_expected_content: Content field when tool called + parallel_tool_calls_expected_content: Content for parallel test + + xfail markers: + xfail_streaming: Mapping test name to xfail reason (streaming only) + xfail_nonstreaming: Mapping test name to xfail reason (non-streaming) + + Special flags: + allow_empty_or_json_empty_args: True if "" or "{}" both valid for empty args + supports_typed_arguments: True if the parser supports typed function arguments + """ + + # Parser identification + parser_name: str + + # Test data - model outputs for each common test + no_tool_calls_output: str + single_tool_call_output: str + parallel_tool_calls_output: str + various_data_types_output: str + empty_arguments_output: str + surrounding_text_output: str + escaped_strings_output: str + malformed_input_outputs: list[str] + + # Expected results for specific tests (optional overrides) + single_tool_call_expected_name: str = "get_weather" + single_tool_call_expected_args: dict[str, Any] = field( + default_factory=lambda: {"city": "Tokyo"} + ) + parallel_tool_calls_count: int = 2 + parallel_tool_calls_names: list[str] = field( + default_factory=lambda: ["get_weather", "get_time"] + ) + + # xfail configuration - maps test name to xfail reason + xfail_streaming: dict[str, str] = field(default_factory=dict) + xfail_nonstreaming: dict[str, str] = field(default_factory=dict) + + # Content expectations (some parsers strip content, others don't) + single_tool_call_expected_content: str | None = None + parallel_tool_calls_expected_content: str | None = None + + # Special assertions for edge cases + allow_empty_or_json_empty_args: bool = True # "{}" or "" for empty args + supports_typed_arguments: bool = True + + +class ToolParserTests: + """Mixin class providing common test suite for tool parsers. + + To use this mixin in a parser test file: + + 1. Create a test_config fixture that returns a ToolParserTestConfig instance + 2. Inherit from this class + 3. Add parser-specific tests as additional methods + + Example: + class TestMistralToolParser(ToolParserTests): + @pytest.fixture + def test_config(self) -> ToolParserTestConfig: + return ToolParserTestConfig( + parser_name="mistral", + no_tool_calls_output="Plain text...", + # ... other config ... + ) + + # Parser-specific tests + def test_mistral_specific_feature(self, tool_parser): + # Custom test logic + pass + """ + + @pytest.fixture + def test_config(self) -> ToolParserTestConfig: + """Override this to provide parser-specific configuration.""" + raise NotImplementedError( + "Subclass must provide test_config fixture returning ToolParserTestConfig" + ) + + @pytest.fixture + def tokenizer(self, default_tokenizer: TokenizerLike) -> TokenizerLike: + """Override this to provide parser-specific tokenizer.""" + return default_tokenizer + + @pytest.fixture + def tool_parser(self, test_config: ToolParserTestConfig, tokenizer: TokenizerLike): + return ToolParserManager.get_tool_parser(test_config.parser_name)(tokenizer) + + @pytest.fixture(params=[True, False]) + def streaming(self, request: pytest.FixtureRequest) -> bool: + return request.param + + def test_no_tool_calls( + self, + request: pytest.FixtureRequest, + tool_parser: Any, + test_config: ToolParserTestConfig, + streaming: bool, + ): + """Verify parser handles plain text without tool syntax.""" + # Apply xfail markers if configured + test_name = "test_no_tool_calls" + self.apply_xfail_mark(request, test_config, test_name, streaming) + + content, tool_calls = run_tool_extraction( + tool_parser, test_config.no_tool_calls_output, streaming=streaming + ) + assert content == test_config.no_tool_calls_output, ( + f"Expected content to match input, got {content}" + ) + assert len(tool_calls) == 0, f"Expected no tool calls, got {len(tool_calls)}" + + def test_single_tool_call_simple_args( + self, + request: pytest.FixtureRequest, + tool_parser: Any, + test_config: ToolParserTestConfig, + streaming: bool, + ): + """Verify parser extracts one tool with simple arguments.""" + # Apply xfail markers if configured + test_name = "test_single_tool_call_simple_args" + self.apply_xfail_mark(request, test_config, test_name, streaming) + + content, tool_calls = run_tool_extraction( + tool_parser, test_config.single_tool_call_output, streaming=streaming + ) + + # Content check (some parsers strip it) + if test_config.single_tool_call_expected_content is not None: + assert content == test_config.single_tool_call_expected_content + + assert len(tool_calls) == 1, f"Expected 1 tool call, got {len(tool_calls)}" + assert tool_calls[0].type == "function" + assert tool_calls[0].function.name == test_config.single_tool_call_expected_name + + args = json.loads(tool_calls[0].function.arguments) + for key, value in test_config.single_tool_call_expected_args.items(): + assert args.get(key) == value, ( + f"Expected {key}={value}, got {args.get(key)}" + ) + + def test_parallel_tool_calls( + self, + request: pytest.FixtureRequest, + tool_parser: Any, + test_config: ToolParserTestConfig, + streaming: bool, + ): + """Verify parser handles multiple tools in one response.""" + # Apply xfail markers if configured + test_name = "test_parallel_tool_calls" + self.apply_xfail_mark(request, test_config, test_name, streaming) + + content, tool_calls = run_tool_extraction( + tool_parser, + test_config.parallel_tool_calls_output, + streaming=streaming, + ) + + assert len(tool_calls) == test_config.parallel_tool_calls_count, ( + f"Expected {test_config.parallel_tool_calls_count} " + f"tool calls, got {len(tool_calls)}" + ) + + # Verify tool names match expected + for i, expected_name in enumerate(test_config.parallel_tool_calls_names): + assert tool_calls[i].type == "function" + assert tool_calls[i].function.name == expected_name + + # Verify unique IDs + ids = [tc.id for tc in tool_calls] + assert len(ids) == len(set(ids)), "Tool call IDs should be unique" + + def test_various_data_types( + self, + request: pytest.FixtureRequest, + tool_parser: Any, + test_config: ToolParserTestConfig, + streaming: bool, + ): + """Verify parser handles all JSON types in arguments.""" + # Apply xfail markers if configured + test_name = "test_various_data_types" + self.apply_xfail_mark(request, test_config, test_name, streaming) + + content, tool_calls = run_tool_extraction( + tool_parser, + test_config.various_data_types_output, + streaming=streaming, + ) + assert len(tool_calls) == 1, f"Expected 1 tool call, got {len(tool_calls)}" + + args = json.loads(tool_calls[0].function.arguments) + # Verify all expected fields present + required_fields_types = { + "string_field": str, + "int_field": int, + "float_field": float, + "bool_field": bool, + "null_field": NoneType, + "array_field": list, + "object_field": dict, + } + for required_field, expected_type in required_fields_types.items(): + assert required_field in args, ( + f"Expected field '{required_field}' in arguments" + ) + if test_config.supports_typed_arguments: + found_type = type(args[required_field]) + assert found_type is expected_type, ( + f"Expected field '{required_field}' to have type {expected_type}, " + f"got {found_type}" + ) + + def test_empty_arguments( + self, + request: pytest.FixtureRequest, + tool_parser: Any, + test_config: ToolParserTestConfig, + streaming: bool, + ): + """Verify parser handles parameterless tool calls.""" + # Apply xfail markers if configured + test_name = "test_empty_arguments" + self.apply_xfail_mark(request, test_config, test_name, streaming) + + content, tool_calls = run_tool_extraction( + tool_parser, test_config.empty_arguments_output, streaming=streaming + ) + assert len(tool_calls) == 1, f"Expected 1 tool call, got {len(tool_calls)}" + + args = tool_calls[0].function.arguments + if test_config.allow_empty_or_json_empty_args: + assert args in ["{}", ""], f"Expected empty args, got {args}" + else: + assert args == "{}", f"Expected {{}}, got {args}" + + def test_surrounding_text( + self, + request: pytest.FixtureRequest, + tool_parser: Any, + test_config: ToolParserTestConfig, + streaming: bool, + ): + """Verify parser extracts tools from mixed content.""" + # Apply xfail markers if configured + test_name = "test_surrounding_text" + self.apply_xfail_mark(request, test_config, test_name, streaming) + + content, tool_calls = run_tool_extraction( + tool_parser, test_config.surrounding_text_output, streaming=streaming + ) + assert len(tool_calls) >= 1, ( + f"Expected at least 1 tool call, got {len(tool_calls)}" + ) + + def test_escaped_strings( + self, + request: pytest.FixtureRequest, + tool_parser: Any, + test_config: ToolParserTestConfig, + streaming: bool, + ): + """Verify parser handles escaped characters in arguments.""" + # Apply xfail markers if configured + test_name = "test_escaped_strings" + self.apply_xfail_mark(request, test_config, test_name, streaming) + + content, tool_calls = run_tool_extraction( + tool_parser, test_config.escaped_strings_output, streaming=streaming + ) + assert len(tool_calls) == 1, f"Expected 1 tool call, got {len(tool_calls)}" + + args = json.loads(tool_calls[0].function.arguments) + # At minimum, verify we can parse and have expected fields + # Exact escaping behavior varies by parser + assert len(args) > 0, "Expected some arguments with escaped strings" + + def test_malformed_input( + self, + request: pytest.FixtureRequest, + tool_parser: Any, + test_config: ToolParserTestConfig, + streaming: bool, + ): + """Verify parser gracefully handles invalid syntax.""" + # Apply xfail markers if configured + test_name = "test_malformed_input" + self.apply_xfail_mark(request, test_config, test_name, streaming) + + for malformed_input in test_config.malformed_input_outputs: + # Should not raise exception + content, tool_calls = run_tool_extraction( + tool_parser, malformed_input, streaming=streaming + ) + # Parser should handle gracefully (exact behavior varies) + + def test_streaming_reconstruction( + self, + request: pytest.FixtureRequest, + tool_parser: Any, + test_config: ToolParserTestConfig, + ): + """Verify streaming produces same result as non-streaming.""" + test_name = "test_streaming_reconstruction" + self.apply_xfail_mark(request, test_config, test_name, True) + + test_output = test_config.single_tool_call_output + + # Non-streaming result + content_non, tools_non = run_tool_extraction( + tool_parser, test_output, streaming=False + ) + + # Streaming result + content_stream, tools_stream = run_tool_extraction( + tool_parser, test_output, streaming=True + ) + + # Compare results + assert content_non == content_stream, "Content should match between modes" + assert len(tools_non) == len(tools_stream), "Tool count should match" + if len(tools_non) > 0: + assert tools_non[0].function.name == tools_stream[0].function.name + assert tools_non[0].function.arguments == tools_stream[0].function.arguments + + def apply_xfail_mark(self, request, test_config, test_name, streaming): + reason = None + if streaming and test_name in test_config.xfail_streaming: + reason = test_config.xfail_streaming[test_name] + elif not streaming and test_name in test_config.xfail_nonstreaming: + reason = test_config.xfail_nonstreaming[test_name] + if reason is not None: + mark = pytest.mark.xfail(reason=reason, strict=True) + request.node.add_marker(mark) diff --git a/tests/tool_parsers/conftest.py b/tests/tool_parsers/conftest.py new file mode 100644 index 000000000..89609b257 --- /dev/null +++ b/tests/tool_parsers/conftest.py @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +from transformers import AutoTokenizer + +from vllm.tokenizers import TokenizerLike + + +@pytest.fixture(scope="module") +def default_tokenizer() -> TokenizerLike: + return AutoTokenizer.from_pretrained("gpt2") diff --git a/tests/tool_parsers/test_deepseekv3_tool_parser.py b/tests/tool_parsers/test_deepseekv3_tool_parser.py new file mode 100644 index 000000000..27fbae092 --- /dev/null +++ b/tests/tool_parsers/test_deepseekv3_tool_parser.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import pytest + +from tests.tool_parsers.common_tests import ( + ToolParserTestConfig, + ToolParserTests, +) +from vllm.tokenizers import TokenizerLike, get_tokenizer + + +class TestDeepSeekV3ToolParser(ToolParserTests): + @pytest.fixture(scope="class") + def tokenizer(self) -> TokenizerLike: + return get_tokenizer("deepseek-ai/DeepSeek-V3") + + @pytest.fixture + def test_config(self) -> ToolParserTestConfig: + return ToolParserTestConfig( + parser_name="deepseek_v3", + # Test data + no_tool_calls_output=( + "How can I help you today? I can check weather for you." + ), + single_tool_call_output="""<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather +```json +{"city": "Tokyo", "unit": "celsius"} +```<|tool▁call▁end|><|tool▁calls▁end|>""", + parallel_tool_calls_output="""<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather +```json +{"city": "Tokyo", "unit": "celsius"} +```<|tool▁call▁end|><|tool▁call▁begin|>function<|tool▁sep|>search_hotels +```json +{"location": "Tokyo", "check_in": "2025-01-15"} +```<|tool▁call▁end|><|tool▁calls▁end|>""", + various_data_types_output=( + """<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>test_function +```json +""" + """{"string_field": "hello", "int_field": 42, "float_field": 3.14, """ + """"bool_field": true, "null_field": null, """ + """"array_field": ["a", "b", "c"], """ + """"object_field": {"nested": "value"}, """ + """"empty_array": [], "empty_object": {}} +```<|tool▁call▁end|><|tool▁calls▁end|>""" + ), + empty_arguments_output="""<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_time +```json +{} +```<|tool▁call▁end|><|tool▁calls▁end|>""", + surrounding_text_output=( + """Let me check the weather for you.""" + """<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather +```json +{"city": "Paris"} +```<|tool▁call▁end|><|tool▁calls▁end|>""" + ), + escaped_strings_output=( + """<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>send_message +```json +""" + """{"text": "He said \\"hello\\"", "path": "C:\\\\Users\\\\file", """ + """"newline": "line1\\nline2"} +```<|tool▁call▁end|><|tool▁calls▁end|>""" + ), + malformed_input_outputs=[ + """<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather +```json +{"city": "Tokyo" +```<|tool▁call▁end|><|tool▁calls▁end|>""", + """<|tool▁calls▁begin|>function<|tool▁sep|>get_weather +```json +{"city": "Tokyo"} +```<|tool▁calls▁end|>""", + ], + # Expected results + single_tool_call_expected_name="get_weather", + single_tool_call_expected_args={"city": "Tokyo", "unit": "celsius"}, + single_tool_call_expected_content=None, + parallel_tool_calls_count=2, + parallel_tool_calls_names=["get_weather", "search_hotels"], + # xfail markers + xfail_streaming={}, + xfail_nonstreaming={ + "test_malformed_input": ( + "Parser sets tools_called=True even when tool_calls is " + "empty (detects start token but fails to parse)" + ), + }, + ) diff --git a/tests/tool_parsers/test_granite_20b_fc_tool_parser.py b/tests/tool_parsers/test_granite_20b_fc_tool_parser.py new file mode 100644 index 000000000..857c5a5bf --- /dev/null +++ b/tests/tool_parsers/test_granite_20b_fc_tool_parser.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from tests.tool_parsers.common_tests import ( + ToolParserTestConfig, + ToolParserTests, +) + + +class TestGranite20bFcToolParser(ToolParserTests): + @pytest.fixture + def test_config(self) -> ToolParserTestConfig: + return ToolParserTestConfig( + parser_name="granite-20b-fc", + # Test data + no_tool_calls_output="This is a regular response without any tool calls.", + single_tool_call_output=( + ' {"name": "get_weather", ' + '"arguments": {"city": "Tokyo"}}' + ), + parallel_tool_calls_output=( + ' {"name": "get_weather", ' + '"arguments": {"city": "Tokyo"}}\n' + ' {"name": "get_time", ' + '"arguments": {"timezone": "Asia/Tokyo"}}' + ), + various_data_types_output=""" { + "name": "test_function", + "arguments": { + "string_field": "hello", + "int_field": 42, + "float_field": 3.14, + "bool_field": true, + "null_field": null, + "array_field": ["a", "b", "c"], + "object_field": {"nested": "value"}, + "empty_array": [], + "empty_object": {} + } +}""", + empty_arguments_output=( + ' {"name": "refresh", "arguments": {}}' + ), + surrounding_text_output="""Let me check the weather for you. + {"name": "get_weather", "arguments": {"city": "Tokyo"}}""", + escaped_strings_output=""" { + "name": "test_function", + "arguments": { + "quoted": "He said \\"hello\\"", + "path": "C:\\\\Users\\\\file.txt", + "newline": "line1\\nline2", + "unicode": "emoji: 🎉" + } +}""", + malformed_input_outputs=[ + ' {"name": "func", "arguments": {', + ' [{"name": "func", "arguments": {}}]', + '{"name": "func", "arguments": {}}', + ' {"name": 123}', + ], + # Expected results + single_tool_call_expected_name="get_weather", + single_tool_call_expected_args={"city": "Tokyo"}, + single_tool_call_expected_content=None, + parallel_tool_calls_count=2, + parallel_tool_calls_names=["get_weather", "get_time"], + # xfail markers + xfail_streaming={ + "test_surrounding_text": ( + "Granite 20B FC streaming requires at start" + ), + }, + xfail_nonstreaming={}, + ) diff --git a/tests/tool_parsers/test_granite_tool_parser.py b/tests/tool_parsers/test_granite_tool_parser.py new file mode 100644 index 000000000..2046c11c5 --- /dev/null +++ b/tests/tool_parsers/test_granite_tool_parser.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import pytest + +from tests.tool_parsers.common_tests import ( + ToolParserTestConfig, + ToolParserTests, +) +from tests.tool_parsers.utils import run_tool_extraction + + +class TestGraniteToolParser(ToolParserTests): + @pytest.fixture + def test_config(self) -> ToolParserTestConfig: + return ToolParserTestConfig( + parser_name="granite", + # Test data + no_tool_calls_output="This is a regular response without any tool calls.", + single_tool_call_output=( + '<|tool_call|> [{"name": "get_weather", ' + '"arguments": {"city": "Tokyo"}}]' + ), + parallel_tool_calls_output="""<|tool_call|> [ + {"name": "get_weather", "arguments": {"city": "Tokyo"}}, + {"name": "get_time", "arguments": {"timezone": "Asia/Tokyo"}} +]""", + various_data_types_output=""" [{ + "name": "test_function", + "arguments": { + "string_field": "hello", + "int_field": 42, + "float_field": 3.14, + "bool_field": true, + "null_field": null, + "array_field": ["a", "b", "c"], + "object_field": {"nested": "value"}, + "empty_array": [], + "empty_object": {} + } +}]""", + empty_arguments_output=( + '<|tool_call|> [{"name": "refresh", "arguments": {}}]' + ), + surrounding_text_output="""Let me check the weather for you. +<|tool_call|> [{"name": "get_weather", "arguments": {"city": "Tokyo"}}] +I'll get that information.""", + escaped_strings_output=""" [{ + "name": "test_function", + "arguments": { + "quoted": "He said \\"hello\\"", + "path": "C:\\\\Users\\\\file.txt", + "newline": "line1\\nline2", + "unicode": "emoji: 🎉" + } +}]""", + malformed_input_outputs=[ + '<|tool_call|> [{"name": "func", "arguments": {', + '<|tool_call|> {"name": "func", "arguments": {}}', # Not an array + '[{"name": "func", "arguments": "not a dict"}]', + 'Some text [{"name": "func"}]', # JSON but not tool call format + ], + # Expected results + single_tool_call_expected_name="get_weather", + single_tool_call_expected_args={"city": "Tokyo"}, + # Granite strips content when tool calls present + single_tool_call_expected_content=None, + parallel_tool_calls_count=2, + parallel_tool_calls_names=["get_weather", "get_time"], + # xfail markers + xfail_streaming={ + "test_malformed_input": ( + "Streaming mode incorrectly creates tool call from malformed JSON" + ), + "test_surrounding_text": ( + "Parser doesn't handle surrounding text correctly in streaming" + ), + "test_streaming_reconstruction": ( + "Streaming mode doesn't strip <|tool_call|> marker from content" + ), + }, + xfail_nonstreaming={ + "test_surrounding_text": ( + "Parser doesn't handle surrounding text correctly in non-streaming" + ), + }, + ) + + # Granite-Specific Tests + + @pytest.mark.parametrize("streaming", [True, False]) + def test_granite_token_prefix_format(self, tool_parser, streaming): + """Verify parser handles Granite 3.0 <|tool_call|> token format.""" + single_tool_call_token = ( + '<|tool_call|> [{"name": "get_weather", "arguments": {"city": "Tokyo"}}]' + ) + content, tool_calls = run_tool_extraction( + tool_parser, single_tool_call_token, streaming=streaming + ) + assert len(tool_calls) == 1, ( + f"Expected 1 tool call from token format, got {len(tool_calls)}" + ) + assert tool_calls[0].function.name == "get_weather" + + @pytest.mark.parametrize("streaming", [True, False]) + def test_granite_string_prefix_format(self, tool_parser, streaming): + """Verify parser handles Granite 3.1 string format.""" + single_tool_call_string = ( + ' [{"name": "get_weather", "arguments": {"city": "Tokyo"}}]' + ) + content, tool_calls = run_tool_extraction( + tool_parser, single_tool_call_string, streaming=streaming + ) + assert len(tool_calls) == 1, ( + f"Expected 1 tool call from string format, got {len(tool_calls)}" + ) + assert tool_calls[0].function.name == "get_weather" diff --git a/tests/tool_parsers/test_internlm2_tool_parser.py b/tests/tool_parsers/test_internlm2_tool_parser.py new file mode 100644 index 000000000..2e5069dbe --- /dev/null +++ b/tests/tool_parsers/test_internlm2_tool_parser.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from unittest.mock import MagicMock + +import pytest + +from tests.tool_parsers.common_tests import ( + ToolParserTestConfig, + ToolParserTests, +) +from vllm.tokenizers import TokenizerLike + + +class TestInternLM2ToolParser(ToolParserTests): + @pytest.fixture + def tokenizer(self, default_tokenizer: TokenizerLike) -> TokenizerLike: + """Add some internlm2 specific tokens to the default vocab.""" + + tokenizer_vocab = default_tokenizer.get_vocab() + default_tokenizer.get_vocab = MagicMock() + tokenizer_vocab.update( + { + "<|action_start|>": 92540, + "<|plugin|>": 92541, + "<|action_end|>": 92542, + } + ) + default_tokenizer.get_vocab.return_value = tokenizer_vocab + return default_tokenizer + + @pytest.fixture + def test_config(self) -> ToolParserTestConfig: + return ToolParserTestConfig( + parser_name="internlm", + # Test data + no_tool_calls_output="This is a regular response without any tool calls.", + single_tool_call_output=( + '<|action_start|><|plugin|>{"name": "get_weather", ' + '"parameters": {"city": "Tokyo"}}<|action_end|>' + ), + # InternLM2 doesn't support parallel calls + parallel_tool_calls_output=( + '<|action_start|><|plugin|>{"name": "get_weather", ' + '"parameters": {"city": "Tokyo"}}<|action_end|>' + ), + various_data_types_output="""<|action_start|><|plugin|>{ + "name": "test_function", + "parameters": { + "string_field": "hello", + "int_field": 42, + "float_field": 3.14, + "bool_field": true, + "null_field": null, + "array_field": ["a", "b", "c"], + "object_field": {"nested": "value"}, + "empty_array": [], + "empty_object": {} + } +}<|action_end|>""", + empty_arguments_output=( + '<|action_start|><|plugin|>{"name": "refresh", ' + '"parameters": {}}<|action_end|>' + ), + surrounding_text_output=( + "Let me check the weather for you. " + '<|action_start|><|plugin|>{"name": "get_weather", ' + '"parameters": {"city": "Tokyo"}}<|action_end|>' + ), + escaped_strings_output="""<|action_start|><|plugin|>{ + "name": "test_function", + "parameters": { + "quoted": "He said \\"hello\\"", + "path": "C:\\\\Users\\\\file.txt", + "newline": "line1\\nline2", + "unicode": "emoji: 🎉" + } +}<|action_end|>""", + malformed_input_outputs=[ + '<|action_start|><|plugin|>{"name": "func", "parameters": {', + ( + '<|action_start|><|plugin|>{"name": "func", ' + '"parameters": "not a dict"}<|action_end|>' + ), + "<|action_start|><|plugin|>not json<|action_end|>", + "<|action_start|><|plugin|>", + '<|action_start|>{"name": "func"}', + ], + # Expected results + single_tool_call_expected_name="get_weather", + single_tool_call_expected_args={"city": "Tokyo"}, + single_tool_call_expected_content=None, + parallel_tool_calls_count=1, # InternLM2 only supports single tool calls + parallel_tool_calls_names=["get_weather"], + # Parser-specific settings + allow_empty_or_json_empty_args=True, + # xfail markers + xfail_streaming={ + "test_single_tool_call_simple_args": ( + "InternLM2 streaming not fully implemented" + ), + "test_parallel_tool_calls": ( + "InternLM2 streaming not fully implemented" + ), + "test_various_data_types": ( + "InternLM2 streaming not fully implemented" + ), + "test_empty_arguments": ("InternLM2 streaming not fully implemented"), + "test_surrounding_text": ("InternLM2 streaming not fully implemented"), + "test_escaped_strings": ("InternLM2 streaming not fully implemented"), + "test_streaming_reconstruction": ( + "InternLM2 streaming parser returns '<|action_start|' as " + "content instead of None - streaming/non-streaming inconsistency" + ), + }, + xfail_nonstreaming={ + "test_malformed_input": ( + "InternLM2 parser raises JSONDecodeError on malformed JSON " + "instead of gracefully handling it" + ), + }, + ) diff --git a/tests/tool_parsers/test_longcat_tool_parser.py b/tests/tool_parsers/test_longcat_tool_parser.py new file mode 100644 index 000000000..e2fad4341 --- /dev/null +++ b/tests/tool_parsers/test_longcat_tool_parser.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from unittest.mock import MagicMock + +import pytest + +from tests.tool_parsers.common_tests import ( + ToolParserTestConfig, + ToolParserTests, +) +from vllm.tokenizers import TokenizerLike + + +class TestLongCatToolParser(ToolParserTests): + @pytest.fixture + def tokenizer(self, default_tokenizer: TokenizerLike) -> TokenizerLike: + """Add some longcat specific tokens to the default vocab.""" + tokenizer = default_tokenizer + tokenizer_vocab = tokenizer.get_vocab() + tokenizer.get_vocab = MagicMock() + tokenizer_vocab.update( + { + "": 32000, + "": 32001, + } + ) + tokenizer.get_vocab.return_value = tokenizer_vocab + return tokenizer + + @pytest.fixture + def test_config(self) -> ToolParserTestConfig: + return ToolParserTestConfig( + parser_name="longcat", + # Test data + no_tool_calls_output="This is a regular response without any tool calls.", + single_tool_call_output=( + '{"name": "get_weather", ' + '"arguments": {"city": "Tokyo"}}' + ), + parallel_tool_calls_output=( + '{"name": "get_weather", ' + '"arguments": {"city": "Tokyo"}}\n' + '{"name": "get_time", ' + '"arguments": {"timezone": "Asia/Tokyo"}}' + ), + various_data_types_output="""{ + "name": "test_function", + "arguments": { + "string_field": "hello", + "int_field": 42, + "float_field": 3.14, + "bool_field": true, + "null_field": null, + "array_field": ["a", "b", "c"], + "object_field": {"nested": "value"}, + "empty_array": [], + "empty_object": {} + } +}""", + empty_arguments_output=( + '{"name": "refresh", "arguments": {}}' + "" + ), + surrounding_text_output=( + "Let me check the weather for you.\n" + '{"name": "get_weather", ' + '"arguments": {"city": "Tokyo"}}\n' + "Here is the result." + ), + escaped_strings_output="""{ + "name": "test_function", + "arguments": { + "quoted": "He said \\"hello\\"", + "path": "C:\\\\Users\\\\file.txt", + "newline": "line1\\nline2", + "unicode": "emoji: 🎉" + } +}""", + malformed_input_outputs=[ + '{"name": "func", "arguments": {', + ( + '{"name": "func", ' + '"arguments": "not a dict"}' + ), + "Some text with invalid json", + ], + # Expected results + single_tool_call_expected_name="get_weather", + single_tool_call_expected_args={"city": "Tokyo"}, + single_tool_call_expected_content=None, + parallel_tool_calls_count=2, + parallel_tool_calls_names=["get_weather", "get_time"], + # xfail markers + xfail_streaming={ + "test_malformed_input": "Streaming has complex buffering behavior", + }, + xfail_nonstreaming={}, + # Configuration + allow_empty_or_json_empty_args=True, + ) diff --git a/tests/tool_parsers/test_phi4mini_tool_parser.py b/tests/tool_parsers/test_phi4mini_tool_parser.py new file mode 100644 index 000000000..eff9fa9bb --- /dev/null +++ b/tests/tool_parsers/test_phi4mini_tool_parser.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from unittest.mock import MagicMock + +import pytest + +from tests.tool_parsers.common_tests import ( + ToolParserTestConfig, + ToolParserTests, +) +from vllm.tokenizers import TokenizerLike + + +class TestPhi4MiniToolParser(ToolParserTests): + @pytest.fixture + def tokenizer(self, default_tokenizer: TokenizerLike) -> TokenizerLike: + """Add some phi4mini specific tokens to the default vocab.""" + + tokenizer = default_tokenizer + tokenizer_vocab = tokenizer.get_vocab() + tokenizer.get_vocab = MagicMock() + tokenizer_vocab.update( + { + "functools": 32000, + } + ) + tokenizer.get_vocab.return_value = tokenizer_vocab + return tokenizer + + @pytest.fixture + def test_config(self) -> ToolParserTestConfig: + return ToolParserTestConfig( + parser_name="phi4_mini_json", + # Test data + no_tool_calls_output="This is a regular response without any tool calls.", + single_tool_call_output=( + 'functools[{"name": "get_weather", "arguments": {"city": "Tokyo"}}]' + ), + parallel_tool_calls_output="""functools[ + {"name": "get_weather", "arguments": {"city": "Tokyo"}}, + {"name": "get_time", "arguments": {"timezone": "Asia/Tokyo"}} +]""", + various_data_types_output="""functools[{ + "name": "test_function", + "arguments": { + "string_field": "hello", + "int_field": 42, + "float_field": 3.14, + "bool_field": true, + "null_field": null, + "array_field": ["a", "b", "c"], + "object_field": {"nested": "value"}, + "empty_array": [], + "empty_object": {} + } +}]""", + empty_arguments_output='functools[{"name": "refresh", "arguments": {}}]', + surrounding_text_output="""Let me check the weather for you. +functools[{"name": "get_weather", "arguments": {"city": "Tokyo"}}] +Would you like to know more?""", + escaped_strings_output="""functools[{ + "name": "test_function", + "arguments": { + "quoted": "He said \\"hello\\"", + "path": "C:\\\\Users\\\\file.txt", + "newline": "line1\\nline2", + "unicode": "emoji: 🎉" + } +}]""", + malformed_input_outputs=[ + 'functools[{"name": "func", "arguments": {', + 'functools[{"name": "func", "arguments": "not a dict"}]', + 'functools{"name": "func"}', # Missing brackets + 'functools[{"name": "func"}]', # Missing arguments/parameters + "functools[] This is just text", # Empty functools + "functools[ This is just text ]", # functools with invalid JSON + ], + # Expected results + single_tool_call_expected_name="get_weather", + single_tool_call_expected_args={"city": "Tokyo"}, + # Phi-4 Mini strips content when tool calls present + single_tool_call_expected_content=None, + parallel_tool_calls_count=2, + parallel_tool_calls_names=["get_weather", "get_time"], + parallel_tool_calls_expected_content=None, + # xfail markers + xfail_streaming={ + "test_no_tool_calls": "Phi4 Mini streaming not implemented", + "test_single_tool_call_simple_args": ( + "Phi4 Mini streaming not implemented" + ), + "test_parallel_tool_calls": "Phi4 Mini streaming not implemented", + "test_various_data_types": "Phi4 Mini streaming not implemented", + "test_empty_arguments": "Phi4 Mini streaming not implemented", + "test_surrounding_text": "Phi4 Mini streaming not implemented", + "test_escaped_strings": "Phi4 Mini streaming not implemented", + "test_streaming_reconstruction": "Phi4 Mini streaming not implemented", + }, + xfail_nonstreaming={ + "test_various_data_types": ( + "Phi4MiniJsonToolParser regex has nesting limitations " + "with nested objects" + ), + "test_malformed_input": ( + "Phi4MiniJsonToolParser incorrectly sets " + "tools_called=True on empty array" + ), + }, + ) diff --git a/tests/tool_parsers/test_qwen3xml_tool_parser.py b/tests/tool_parsers/test_qwen3xml_tool_parser.py new file mode 100644 index 000000000..3771b8afd --- /dev/null +++ b/tests/tool_parsers/test_qwen3xml_tool_parser.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import pytest + +from tests.tool_parsers.common_tests import ( + ToolParserTestConfig, + ToolParserTests, +) + + +class TestQwen3xmlToolParser(ToolParserTests): + @pytest.fixture + def test_config(self) -> ToolParserTestConfig: + return ToolParserTestConfig( + parser_name="qwen3_xml", + # Test data + no_tool_calls_output="This is a regular response without any tool calls.", + single_tool_call_output="\n\nTokyo\n\n", + parallel_tool_calls_output="\n\nTokyo\n\n\n\nAsia/Tokyo\n\n", + various_data_types_output=( + "\n\n" + "hello\n" + "42\n" + "3.14\n" + "true\n" + "null\n" + '["a", "b", "c"]\n' + '{"nested": "value"}\n' + "\n" + ), + empty_arguments_output="\n\n\n", + surrounding_text_output=( + "Let me check the weather for you.\n\n" + "\n\n" + "Tokyo\n" + "\n\n\n" + "I will get that information." + ), + escaped_strings_output=( + "\n\n" + 'He said "hello"\n' + "C:\\Users\\file.txt\n" + "line1\nline2\n" + "\n" + ), + malformed_input_outputs=[ + "", + "", + ], + # Expected results + single_tool_call_expected_name="get_weather", + single_tool_call_expected_args={"city": "Tokyo"}, + parallel_tool_calls_count=2, + parallel_tool_calls_names=["get_weather", "get_time"], + # xfail markers - Qwen3XML has systematic streaming issues + xfail_streaming={ + "test_single_tool_call_simple_args": ( + "Qwen3XML streaming has systematic issues" + ), + "test_parallel_tool_calls": "Qwen3XML streaming has systematic issues", + "test_various_data_types": "Qwen3XML streaming has systematic issues", + "test_empty_arguments": "Qwen3XML streaming has systematic issues", + "test_surrounding_text": "Qwen3XML streaming has systematic issues", + "test_escaped_strings": "Qwen3XML streaming has systematic issues", + "test_malformed_input": ( + "Qwen3XML parser is lenient with malformed input" + ), + "test_streaming_reconstruction": ( + "Qwen3XML streaming reconstruction has known issues" + ), + }, + supports_typed_arguments=False, + ) diff --git a/tests/tool_parsers/test_step3_tool_parser.py b/tests/tool_parsers/test_step3_tool_parser.py new file mode 100644 index 000000000..9ea17d65a --- /dev/null +++ b/tests/tool_parsers/test_step3_tool_parser.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import pytest + +from tests.tool_parsers.common_tests import ( + ToolParserTestConfig, + ToolParserTests, +) +from vllm.tokenizers import TokenizerLike, get_tokenizer + + +class TestStep3ToolParser(ToolParserTests): + @pytest.fixture(scope="class") + def tokenizer(self) -> TokenizerLike: + return get_tokenizer("stepfun-ai/step3") + + @pytest.fixture + def test_config(self) -> ToolParserTestConfig: + return ToolParserTestConfig( + parser_name="step3", + # Test data + no_tool_calls_output="This is a regular response without any tool calls.", + single_tool_call_output=( + "<|tool_calls_begin|><|tool_call_begin|>" + '' + 'Tokyo' + "<|tool_call_end|><|tool_calls_end|>" + ), + parallel_tool_calls_output=( + "<|tool_calls_begin|><|tool_call_begin|>" + '' + 'Tokyo' + "<|tool_call_end|><|tool_sep|>" + '<|tool_call_begin|>' + 'Asia/Tokyo' + "<|tool_call_end|><|tool_calls_end|>" + ), + various_data_types_output=( + "<|tool_calls_begin|><|tool_call_begin|>" + '' + 'hello' + '42' + '3.14' + 'true' + 'null' + '' + '["a", "b", "c"]' + '' + '{"nested": "value"}' + "<|tool_call_end|><|tool_calls_end|>" + ), + empty_arguments_output=( + "<|tool_calls_begin|><|tool_call_begin|>" + '' + "<|tool_call_end|><|tool_calls_end|>" + ), + surrounding_text_output=( + "Let me check the weather for you.\n\n" + "<|tool_calls_begin|><|tool_call_begin|>" + '' + 'Tokyo' + "<|tool_call_end|><|tool_calls_end|>\n\n" + "I'll get that information." + ), + escaped_strings_output=( + "<|tool_calls_begin|><|tool_call_begin|>" + '' + 'He said "hello"' + 'C:\\Users\\file.txt' + 'line1\nline2' + "<|tool_call_end|><|tool_calls_end|>" + ), + malformed_input_outputs=[ + ( + "<|tool_calls_begin|><|tool_call_begin|>" + '' + ), + ( + '<|tool_call_begin|>' + "<|tool_call_end|>" + ), + ], + # Expected results + single_tool_call_expected_name="get_weather", + single_tool_call_expected_args={"city": "Tokyo"}, + parallel_tool_calls_count=2, + parallel_tool_calls_names=["get_weather", "get_time"], + # xfail markers + xfail_nonstreaming={ + "test_single_tool_call_simple_args": ( + "Step3 parser non-streaming has bugs" + ), + "test_parallel_tool_calls": ("Step3 parser non-streaming has bugs"), + "test_various_data_types": "Step3 parser non-streaming has bugs", + "test_empty_arguments": "Step3 parser non-streaming has bugs", + "test_surrounding_text": "Step3 parser non-streaming has bugs", + "test_escaped_strings": "Step3 parser non-streaming has bugs", + }, + xfail_streaming={ + "test_parallel_tool_calls": ( + "Step3 parser has significant bugs in both streaming " + "and non-streaming" + ), + "test_streaming_reconstruction": ( + "Step3 parser non-streaming has bugs, so streaming " + "doesn't match non-streaming" + ), + }, + supports_typed_arguments=False, + ) diff --git a/tests/entrypoints/openai/tool_parsers/utils.py b/tests/tool_parsers/utils.py similarity index 100% rename from tests/entrypoints/openai/tool_parsers/utils.py rename to tests/tool_parsers/utils.py -- GitLab From 061980c36a7b78e5d8ea96893b79fd0b9c11a20e Mon Sep 17 00:00:00 2001 From: Walter Beller-Morales Date: Mon, 16 Mar 2026 19:55:53 -0400 Subject: [PATCH 011/223] [Feature][Frontend] add support for Cohere Embed v2 API (#37074) Signed-off-by: walterbm --- docs/serving/openai_compatible_server.md | 134 ++++++++ .../pooling/embed/test_cohere_online.py | 310 +++++++++++++++++ .../embed/test_cohere_online_vision.py | 135 ++++++++ .../embed/test_cohere_openai_parity.py | 102 ++++++ .../pooling/embed/test_io_processor.py | 208 ++++++++++++ .../pooling/embed/test_protocol.py | 129 +++++++ vllm/entrypoints/pooling/base/protocol.py | 10 +- vllm/entrypoints/pooling/classify/protocol.py | 2 + vllm/entrypoints/pooling/embed/api_router.py | 31 +- .../entrypoints/pooling/embed/io_processor.py | 319 +++++++++++++++++- vllm/entrypoints/pooling/embed/protocol.py | 170 +++++++++- vllm/entrypoints/pooling/embed/serving.py | 64 +++- vllm/entrypoints/pooling/pooling/protocol.py | 3 + vllm/entrypoints/pooling/score/protocol.py | 2 + vllm/entrypoints/pooling/typing.py | 2 + vllm/renderers/params.py | 26 +- 16 files changed, 1608 insertions(+), 39 deletions(-) create mode 100644 tests/entrypoints/pooling/embed/test_cohere_online.py create mode 100644 tests/entrypoints/pooling/embed/test_cohere_online_vision.py create mode 100644 tests/entrypoints/pooling/embed/test_cohere_openai_parity.py create mode 100644 tests/entrypoints/pooling/embed/test_io_processor.py create mode 100644 tests/entrypoints/pooling/embed/test_protocol.py diff --git a/docs/serving/openai_compatible_server.md b/docs/serving/openai_compatible_server.md index 45af2b693..cf44a1bfe 100644 --- a/docs/serving/openai_compatible_server.md +++ b/docs/serving/openai_compatible_server.md @@ -72,6 +72,9 @@ In addition, we have the following custom APIs: - Only applicable to [classification models](../models/pooling_models.md). - [Score API](#score-api) (`/score`) - Applicable to [embedding models and cross-encoder models](../models/pooling_models.md). +- [Cohere Embed API](#cohere-embed-api) (`/v2/embed`) + - Compatible with [Cohere's Embed API](https://docs.cohere.com/reference/embed) + - Works with any [embedding model](../models/pooling_models.md), including multimodal models. - [Re-rank API](#re-rank-api) (`/rerank`, `/v1/rerank`, `/v2/rerank`) - Implements [Jina AI's v1 re-rank API](https://jina.ai/reranker/) - Also compatible with [Cohere's v1 & v2 re-rank APIs](https://docs.cohere.com/v2/reference/rerank) @@ -429,6 +432,137 @@ these extra parameters are supported instead: --8<-- "vllm/entrypoints/pooling/base/protocol.py:embed-extra-params" ``` +### Cohere Embed API + +Our API is also compatible with [Cohere's Embed v2 API](https://docs.cohere.com/reference/embed) which adds support for some modern embedding feature such as truncation, output dimensions, embedding types, and input types. This endpoint works with any embedding model (including multimodal models). + +#### Cohere Embed API request parameters + +| Parameter | Type | Required | Description | +| --------- | ---- | -------- | ----------- | +| `model` | string | Yes | Model name | +| `input_type` | string | No | Prompt prefix key (model-dependent, see below) | +| `texts` | list[string] | No | Text inputs (use one of `texts`, `images`, or `inputs`) | +| `images` | list[string] | No | Base64 data URI images | +| `inputs` | list[object] | No | Mixed text and image content objects | +| `embedding_types` | list[string] | No | Output types (default: `["float"]`) | +| `output_dimension` | int | No | Truncate embeddings to this dimension (Matryoshka) | +| `truncate` | string | No | `END`, `START`, or `NONE` (default: `END`) | + +#### Text embedding + +```bash +curl -X POST "http://localhost:8000/v2/embed" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Snowflake/snowflake-arctic-embed-m-v1.5", + "input_type": "query", + "texts": ["Hello world", "How are you?"], + "embedding_types": ["float"] + }' +``` + +??? console "Response" + + ```json + { + "id": "embd-...", + "embeddings": { + "float": [ + [0.012, -0.034, ...], + [0.056, 0.078, ...] + ] + }, + "texts": ["Hello world", "How are you?"], + "meta": { + "api_version": {"version": "2"}, + "billed_units": {"input_tokens": 12} + } + } + ``` + +#### Mixed text and image inputs + +For multimodal models, you can embed images by passing base64 data URIs. The `inputs` field accepts a list of objects with mixed text and image content: + +```bash +curl -X POST "http://localhost:8000/v2/embed" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "google/siglip-so400m-patch14-384", + "inputs": [ + { + "content": [ + {"type": "text", "text": "A photo of a cat"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBOR..."}} + ] + } + ], + "embedding_types": ["float"] + }' +``` + +#### Embedding types + +The `embedding_types` parameter controls the output format. Multiple types can be requested in a single call: + +| Type | Description | +| ---- | ----------- | +| `float` | Raw float32 embeddings (default) | +| `binary` | Bit-packed signed binary | +| `ubinary` | Bit-packed unsigned binary | +| `base64` | Little-endian float32 encoded as base64 | + +```bash +curl -X POST "http://localhost:8000/v2/embed" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Snowflake/snowflake-arctic-embed-m-v1.5", + "input_type": "query", + "texts": ["What is machine learning?"], + "embedding_types": ["float", "binary"] + }' +``` + +??? console "Response" + + ```json + { + "id": "embd-...", + "embeddings": { + "float": [[0.012, -0.034, ...]], + "binary": [[42, -117, ...]] + }, + "texts": ["What is machine learning?"], + "meta": { + "api_version": {"version": "2"}, + "billed_units": {"input_tokens": 8} + } + } + ``` + +#### Truncation + +The `truncate` parameter controls how inputs exceeding the model's maximum sequence length are handled: + +| Value | Behavior | +| ----- | --------- | +| `END` (default) | Keep the first tokens, drop the end | +| `START` | Keep the last tokens, drop the beginning | +| `NONE` | Return an error if the input is too long | + +#### Input type and prompt prefixes + +The `input_type` field selects a prompt prefix to prepend to each text input. The available values +depend on the model: + +- **Models with `task_instructions` in `config.json`**: The keys from the `task_instructions` dict are + the valid `input_type` values and the corresponding value is prepended to each text. +- **Models with `config_sentence_transformers.json` prompts**: The keys from the `prompts` dict are + the valid `input_type` values. For example, `Snowflake/snowflake-arctic-embed-xs` defines `"query"`, + so setting `input_type: "query"` prepends `"Represent this sentence for searching relevant passages: "`. +- **Other models**: `input_type` is not accepted and will raise a validation error if passed. + ### Transcriptions API Our Transcriptions API is compatible with [OpenAI's Transcriptions API](https://platform.openai.com/docs/api-reference/audio/createTranscription); diff --git a/tests/entrypoints/pooling/embed/test_cohere_online.py b/tests/entrypoints/pooling/embed/test_cohere_online.py new file mode 100644 index 000000000..fc313819f --- /dev/null +++ b/tests/entrypoints/pooling/embed/test_cohere_online.py @@ -0,0 +1,310 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for the Cohere /v2/embed API with generic (non-Cohere) models. + +Validates that the Cohere v2 embed endpoint works correctly with standard +embedding models, covering text embedding, embedding type conversions, +response structure, batching, normalisation, and semantic similarity. +""" + +import base64 +import struct + +import numpy as np +import pytest +import requests + +from tests.utils import RemoteOpenAIServer + +DTYPE = "bfloat16" + +MODELS: list[tuple[str, list[str]]] = [ + ("intfloat/multilingual-e5-small", []), + ( + "Snowflake/snowflake-arctic-embed-m-v1.5", + [ + "--trust_remote_code", + "--hf_overrides", + '{"matryoshka_dimensions":[256]}', + ], + ), +] + + +@pytest.fixture(scope="module", params=MODELS, ids=lambda m: m[0]) +def model_config(request): + return request.param + + +@pytest.fixture(scope="module") +def model_name(model_config): + return model_config[0] + + +@pytest.fixture(scope="module") +def server(model_config): + name, extra_args = model_config + args = [ + "--runner", + "pooling", + "--dtype", + DTYPE, + "--enforce-eager", + "--max-model-len", + "512", + "--gpu-memory-utilization", + "0.02", + ] + extra_args + with RemoteOpenAIServer(name, args) as remote_server: + yield remote_server + + +def _cohere_embed( + server: RemoteOpenAIServer, + model_name: str, + texts: list[str] | None = None, + images: list[str] | None = None, + input_type: str | None = None, + embedding_types: list[str] | None = None, +) -> dict: + body: dict = {"model": model_name} + if input_type is not None: + body["input_type"] = input_type + if texts is not None: + body["texts"] = texts + if images is not None: + body["images"] = images + if embedding_types is not None: + body["embedding_types"] = embedding_types + resp = requests.post(server.url_for("/v2/embed"), json=body) + resp.raise_for_status() + return resp.json() + + +def _openai_embed( + server: RemoteOpenAIServer, model_name: str, texts: list[str] +) -> dict: + body = {"model": model_name, "input": texts, "encoding_format": "float"} + resp = requests.post(server.url_for("/v1/embeddings"), json=body) + resp.raise_for_status() + return resp.json() + + +def _cosine_sim(a: list[float], b: list[float]) -> float: + va, vb = np.array(a), np.array(b) + return float(np.dot(va, vb) / (np.linalg.norm(va) * np.linalg.norm(vb))) + + +# ----------------------------------------------------------- +# Text embedding tests +# ----------------------------------------------------------- + + +def test_basic_embed(server: RemoteOpenAIServer, model_name: str): + r = _cohere_embed( + server, model_name, texts=["hello world"], embedding_types=["float"] + ) + assert "embeddings" in r + assert len(r["embeddings"]["float"]) == 1 + assert len(r["embeddings"]["float"][0]) > 0 + + +def test_unsupported_input_type_rejected(server: RemoteOpenAIServer, model_name: str): + """An input_type not defined in the model's prompt config should be + rejected with a 400 error.""" + body = { + "model": model_name, + "input_type": "nonexistent_type", + "texts": ["hello world"], + "embedding_types": ["float"], + } + resp = requests.post(server.url_for("/v2/embed"), json=body) + assert resp.status_code == 400 + assert "Unsupported input_type" in resp.json()["error"]["message"] + + +def test_omitted_input_type_accepted(server: RemoteOpenAIServer, model_name: str): + """Omitting input_type should always work (no prompt prefix applied).""" + body = { + "model": model_name, + "texts": ["hello world"], + "embedding_types": ["float"], + } + resp = requests.post(server.url_for("/v2/embed"), json=body) + assert resp.status_code == 200 + data = resp.json() + assert len(data["embeddings"]["float"]) == 1 + + +def test_v1_v2_parity(server: RemoteOpenAIServer, model_name: str): + """v1 (OpenAI) and v2 (Cohere) endpoints should produce the same + float embeddings for a generic model.""" + texts = ["hello world"] + v2 = _cohere_embed(server, model_name, texts=texts, embedding_types=["float"]) + v1 = _openai_embed(server, model_name, texts) + cos = _cosine_sim(v2["embeddings"]["float"][0], v1["data"][0]["embedding"]) + assert cos > 0.9999, f"v1/v2 parity failed, cosine={cos}" + + +def test_embedding_types(server: RemoteOpenAIServer, model_name: str): + r = _cohere_embed( + server, + model_name, + texts=["test"], + embedding_types=["float", "binary", "ubinary"], + ) + dim = len(r["embeddings"]["float"][0]) + assert len(r["embeddings"]["binary"][0]) == dim // 8 + assert len(r["embeddings"]["ubinary"][0]) == dim // 8 + + +def test_response_structure(server: RemoteOpenAIServer, model_name: str): + r = _cohere_embed(server, model_name, texts=["test"], embedding_types=["float"]) + assert "id" in r + assert "embeddings" in r + assert "texts" in r + assert r["texts"] == ["test"] + assert "meta" in r + assert r["meta"]["api_version"]["version"] == "2" + assert "billed_units" in r["meta"] + assert r["meta"]["billed_units"]["input_tokens"] > 0 + assert r["meta"]["billed_units"]["image_tokens"] == 0 + + +def test_batch(server: RemoteOpenAIServer, model_name: str): + texts = ["apple", "banana", "cherry"] + r = _cohere_embed(server, model_name, texts=texts, embedding_types=["float"]) + assert len(r["embeddings"]["float"]) == 3 + dim = len(r["embeddings"]["float"][0]) + for emb in r["embeddings"]["float"]: + assert len(emb) == dim + + +def test_l2_normalized(server: RemoteOpenAIServer, model_name: str): + r = _cohere_embed( + server, model_name, texts=["hello world"], embedding_types=["float"] + ) + emb = np.array(r["embeddings"]["float"][0]) + assert abs(float(np.linalg.norm(emb)) - 1.0) < 0.01 + + +def test_semantic_similarity(server: RemoteOpenAIServer, model_name: str): + r = _cohere_embed( + server, + model_name, + texts=["machine learning", "deep learning", "chocolate cake recipe"], + embedding_types=["float"], + ) + embs = r["embeddings"]["float"] + cos_related = _cosine_sim(embs[0], embs[1]) + cos_unrelated = _cosine_sim(embs[0], embs[2]) + assert cos_related > cos_unrelated + + +def test_missing_input_returns_error(server: RemoteOpenAIServer, model_name: str): + body = {"model": model_name} + resp = requests.post(server.url_for("/v2/embed"), json=body) + assert resp.status_code == 400 + + +def test_base64_embedding_type(server: RemoteOpenAIServer, model_name: str): + r = _cohere_embed( + server, + model_name, + texts=["test encoding"], + embedding_types=["float", "base64"], + ) + float_emb = r["embeddings"]["float"][0] + b64_str = r["embeddings"]["base64"][0] + decoded = struct.unpack(f"<{len(float_emb)}f", base64.b64decode(b64_str)) + np.testing.assert_allclose(float_emb, decoded, rtol=1e-5) + + +# ----------------------------------------------------------- +# Truncation tests +# ----------------------------------------------------------- + + +def _cohere_embed_raw( + server: RemoteOpenAIServer, + body: dict, +) -> requests.Response: + return requests.post(server.url_for("/v2/embed"), json=body) + + +def test_truncate_end_succeeds(server: RemoteOpenAIServer, model_name: str): + """truncate=END should silently truncate long input.""" + long_text = " ".join(["word"] * 2000) + body = { + "model": model_name, + "texts": [long_text], + "embedding_types": ["float"], + "truncate": "END", + } + resp = _cohere_embed_raw(server, body) + assert resp.status_code == 200 + data = resp.json() + assert len(data["embeddings"]["float"]) == 1 + + +def test_truncate_start_succeeds(server: RemoteOpenAIServer, model_name: str): + """truncate=START should silently truncate long input from the start.""" + long_text = " ".join(["word"] * 2000) + body = { + "model": model_name, + "texts": [long_text], + "embedding_types": ["float"], + "truncate": "START", + } + resp = _cohere_embed_raw(server, body) + assert resp.status_code == 200 + data = resp.json() + assert len(data["embeddings"]["float"]) == 1 + + +def test_truncate_none_rejects_long_input(server: RemoteOpenAIServer, model_name: str): + """truncate=NONE should error when input exceeds model context.""" + long_text = " ".join(["word"] * 2000) + body = { + "model": model_name, + "texts": [long_text], + "embedding_types": ["float"], + "truncate": "NONE", + } + resp = _cohere_embed_raw(server, body) + assert resp.status_code == 400 + + +def test_truncate_start_vs_end_differ(server: RemoteOpenAIServer, model_name: str): + """START and END truncation should produce different embeddings + when the input is long enough to actually be truncated. + + We construct input with distinct tokens at the start vs end + so that keeping different halves produces different embeddings. + """ + start_words = " ".join([f"alpha{i}" for i in range(300)]) + end_words = " ".join([f"omega{i}" for i in range(300)]) + long_text = start_words + " " + end_words + + body_end = { + "model": model_name, + "texts": [long_text], + "embedding_types": ["float"], + "truncate": "END", + } + body_start = { + "model": model_name, + "texts": [long_text], + "embedding_types": ["float"], + "truncate": "START", + } + r_end = _cohere_embed_raw(server, body_end).json() + r_start = _cohere_embed_raw(server, body_start).json() + + emb_end = r_end["embeddings"]["float"][0] + emb_start = r_start["embeddings"]["float"][0] + cos = _cosine_sim(emb_end, emb_start) + assert cos < 0.99, ( + f"START and END truncation should produce different embeddings " + f"for long input, but cosine similarity was {cos}" + ) diff --git a/tests/entrypoints/pooling/embed/test_cohere_online_vision.py b/tests/entrypoints/pooling/embed/test_cohere_online_vision.py new file mode 100644 index 000000000..ab874e4e2 --- /dev/null +++ b/tests/entrypoints/pooling/embed/test_cohere_online_vision.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for the Cohere /v2/embed API with a multimodal model (SigLIP). + +Validates image embedding, batching, normalisation, and embedding type +conversions through the /v2/embed endpoint. +""" + +import base64 +import struct +import zlib + +import numpy as np +import pytest +import requests + +from tests.utils import RemoteOpenAIServer + +MODEL_NAME = "google/siglip-so400m-patch14-384" +DTYPE = "bfloat16" + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--runner", + "pooling", + "--dtype", + DTYPE, + "--enforce-eager", + "--max-model-len", + "64", + "--gpu-memory-utilization", + "0.3", + ] + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +def _make_tiny_png(r: int, g: int, b: int, w: int = 2, h: int = 2) -> str: + raw = b"" + for _ in range(h): + raw += b"\x00" + bytes([r, g, b]) * w + compressed = zlib.compress(raw) + + def chunk(ctype: bytes, cdata: bytes) -> bytes: + c = ctype + cdata + return ( + struct.pack(">I", len(cdata)) + + c + + struct.pack(">I", zlib.crc32(c) & 0xFFFFFFFF) + ) + + ihdr = struct.pack(">IIBBBBB", w, h, 8, 2, 0, 0, 0) + png = ( + b"\x89PNG\r\n\x1a\n" + + chunk(b"IHDR", ihdr) + + chunk(b"IDAT", compressed) + + chunk(b"IEND", b"") + ) + return "data:image/png;base64," + base64.b64encode(png).decode() + + +def _cohere_embed( + server: RemoteOpenAIServer, + texts: list[str] | None = None, + images: list[str] | None = None, + embedding_types: list[str] | None = None, +) -> dict: + body: dict = {"model": MODEL_NAME} + if texts is not None: + body["texts"] = texts + if images is not None: + body["images"] = images + if embedding_types is not None: + body["embedding_types"] = embedding_types + resp = requests.post(server.url_for("/v2/embed"), json=body) + resp.raise_for_status() + return resp.json() + + +def test_image_embed(server: RemoteOpenAIServer): + img_uri = _make_tiny_png(255, 0, 0) + r = _cohere_embed( + server, + images=[img_uri], + embedding_types=["float"], + ) + assert "embeddings" in r + assert len(r["embeddings"]["float"]) == 1 + assert len(r["embeddings"]["float"][0]) > 0 + assert r["meta"]["billed_units"]["image_tokens"] > 0 + assert r["meta"]["billed_units"]["input_tokens"] == 0 + + +def test_image_batch(server: RemoteOpenAIServer): + red = _make_tiny_png(255, 0, 0) + blue = _make_tiny_png(0, 0, 255) + r = _cohere_embed( + server, + images=[red, blue], + embedding_types=["float"], + ) + assert len(r["embeddings"]["float"]) == 2 + + +def test_image_l2_normalized(server: RemoteOpenAIServer): + img_uri = _make_tiny_png(0, 255, 0) + r = _cohere_embed( + server, + images=[img_uri], + embedding_types=["float"], + ) + emb = np.array(r["embeddings"]["float"][0]) + assert abs(float(np.linalg.norm(emb)) - 1.0) < 0.01 + + +def test_image_embedding_types(server: RemoteOpenAIServer): + img_uri = _make_tiny_png(128, 128, 128) + r = _cohere_embed( + server, + images=[img_uri], + embedding_types=["float", "binary", "ubinary"], + ) + dim = len(r["embeddings"]["float"][0]) + assert len(r["embeddings"]["binary"][0]) == dim // 8 + assert len(r["embeddings"]["ubinary"][0]) == dim // 8 + + +def test_text_embed_on_multimodal(server: RemoteOpenAIServer): + """SigLIP also supports text-only embedding via /v2/embed.""" + r = _cohere_embed(server, texts=["hello world"], embedding_types=["float"]) + assert "embeddings" in r + assert len(r["embeddings"]["float"]) == 1 + assert len(r["embeddings"]["float"][0]) > 0 diff --git a/tests/entrypoints/pooling/embed/test_cohere_openai_parity.py b/tests/entrypoints/pooling/embed/test_cohere_openai_parity.py new file mode 100644 index 000000000..d23e1461b --- /dev/null +++ b/tests/entrypoints/pooling/embed/test_cohere_openai_parity.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Parity test between Cohere /v2/embed and OpenAI /v1/embeddings. + +Verifies that both endpoints produce identical float embeddings when +no prompt prefix is applied (input_type omitted for Cohere /v2/embed). +""" + +import numpy as np +import pytest +import requests + +from tests.utils import RemoteOpenAIServer + +MODEL_NAME = "BAAI/bge-base-en-v1.5" +DTYPE = "bfloat16" + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--runner", + "pooling", + "--dtype", + DTYPE, + "--enforce-eager", + "--max-model-len", + "512", + "--gpu-memory-utilization", + "0.02", + ] + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +def _cohere_embed( + server: RemoteOpenAIServer, + texts: list[str], +) -> list[list[float]]: + body = { + "model": MODEL_NAME, + "texts": texts, + "embedding_types": ["float"], + } + resp = requests.post(server.url_for("/v2/embed"), json=body) + resp.raise_for_status() + return resp.json()["embeddings"]["float"] + + +def _openai_embed( + server: RemoteOpenAIServer, + texts: list[str], +) -> list[list[float]]: + body = {"model": MODEL_NAME, "input": texts, "encoding_format": "float"} + resp = requests.post(server.url_for("/v1/embeddings"), json=body) + resp.raise_for_status() + return [item["embedding"] for item in resp.json()["data"]] + + +def test_single_text_parity(server: RemoteOpenAIServer): + """A single text should produce identical embeddings via both APIs.""" + texts = ["the quick brown fox jumps over the lazy dog"] + v2 = _cohere_embed(server, texts) + v1 = _openai_embed(server, texts) + np.testing.assert_allclose(v2[0], v1[0], rtol=1e-5) + + +def test_batch_parity(server: RemoteOpenAIServer): + """A batch of texts should produce identical embeddings via both APIs, + in the same order.""" + texts = [ + "machine learning", + "deep learning", + "natural language processing", + ] + v2 = _cohere_embed(server, texts) + v1 = _openai_embed(server, texts) + assert len(v2) == len(v1) == 3 + for i in range(3): + np.testing.assert_allclose(v2[i], v1[i], rtol=1e-5, err_msg=f"index {i}") + + +def test_token_count_parity(server: RemoteOpenAIServer): + """Both APIs should report the same prompt token count.""" + texts = ["hello world"] + v2_resp = requests.post( + server.url_for("/v2/embed"), + json={ + "model": MODEL_NAME, + "texts": texts, + "embedding_types": ["float"], + }, + ) + v1_resp = requests.post( + server.url_for("/v1/embeddings"), + json={"model": MODEL_NAME, "input": texts, "encoding_format": "float"}, + ) + v2_resp.raise_for_status() + v1_resp.raise_for_status() + v2_tokens = v2_resp.json()["meta"]["billed_units"]["input_tokens"] + v1_tokens = v1_resp.json()["usage"]["prompt_tokens"] + assert v2_tokens == v1_tokens diff --git a/tests/entrypoints/pooling/embed/test_io_processor.py b/tests/entrypoints/pooling/embed/test_io_processor.py new file mode 100644 index 000000000..e7db0df1e --- /dev/null +++ b/tests/entrypoints/pooling/embed/test_io_processor.py @@ -0,0 +1,208 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for EmbedIOProcessor.""" + +import pytest + +from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor +from vllm.entrypoints.pooling.embed.protocol import ( + CohereEmbedRequest, +) + + +class TestResolveTruncation: + """Unit tests for EmbedIOProcessor._resolve_cohere_truncation.""" + + @staticmethod + def _make_request(**kwargs) -> CohereEmbedRequest: + defaults = { + "model": "test", + "input_type": "search_document", + "texts": ["hello"], + } + return CohereEmbedRequest(**(defaults | kwargs)) + + def test_truncate_end_default(self): + req = self._make_request() + tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req) + assert tokens == -1 + assert side is None + + def test_truncate_end_explicit(self): + req = self._make_request(truncate="END") + tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req) + assert tokens == -1 + assert side is None + + def test_truncate_end_with_max_tokens(self): + req = self._make_request(truncate="END", max_tokens=128) + tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req) + assert tokens == 128 + assert side is None + + def test_truncate_none(self): + req = self._make_request(truncate="NONE") + tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req) + assert tokens is None + assert side is None + + def test_truncate_none_with_max_tokens(self): + """truncate=NONE should NOT set truncate_prompt_tokens; the + max_tokens limit is enforced separately via _check_max_tokens.""" + req = self._make_request(truncate="NONE", max_tokens=10) + tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req) + assert tokens is None + assert side is None + + def test_truncate_start(self): + req = self._make_request(truncate="START") + tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req) + assert tokens == -1 + assert side == "left" + + def test_truncate_start_with_max_tokens(self): + req = self._make_request(truncate="START", max_tokens=64) + tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req) + assert tokens == 64 + assert side == "left" + + +class TestApplyStPrompt: + """Unit tests for EmbedIOProcessor._apply_task_instruction.""" + + @staticmethod + def _make_handler(task_instructions: dict[str, str] | None): + handler = object.__new__(EmbedIOProcessor) + handler.task_instructions = task_instructions + return handler + + def test_no_prompts_configured(self): + handler = self._make_handler(None) + texts = ["hello", "world"] + assert handler._apply_task_instruction(texts, "query") is texts + + def test_matching_input_type(self): + handler = self._make_handler({"query": "search_query: "}) + result = handler._apply_task_instruction(["hello"], "query") + assert result == ["search_query: hello"] + + def test_non_matching_input_type(self): + handler = self._make_handler({"query": "search_query: "}) + texts = ["hello"] + assert handler._apply_task_instruction(texts, "document") is texts + + def test_multiple_texts(self): + handler = self._make_handler( + {"query": "Represent this sentence for searching: "} + ) + result = handler._apply_task_instruction(["a", "b", "c"], "query") + assert result == [ + "Represent this sentence for searching: a", + "Represent this sentence for searching: b", + "Represent this sentence for searching: c", + ] + + def test_empty_prefix_returns_unchanged(self): + handler = self._make_handler({"passage": ""}) + texts = ["hello"] + assert handler._apply_task_instruction(texts, "passage") is texts + + +class TestLoadTaskInstructions: + """Unit tests for EmbedIOProcessor._load_task_instructions.""" + + def test_no_attribute(self): + class FakeConfig: + pass + + assert EmbedIOProcessor._load_task_instructions(FakeConfig()) is None + + def test_with_task_instructions(self): + class FakeConfig: + task_instructions = { + "retrieval.query": "Represent the query: ", + "retrieval.passage": "", + } + + result = EmbedIOProcessor._load_task_instructions(FakeConfig()) + assert result == { + "retrieval.query": "Represent the query: ", + "retrieval.passage": "", + } + + def test_empty_dict(self): + class FakeConfig: + task_instructions = {} + + assert EmbedIOProcessor._load_task_instructions(FakeConfig()) is None + + def test_non_dict(self): + class FakeConfig: + task_instructions = "not a dict" + + assert EmbedIOProcessor._load_task_instructions(FakeConfig()) is None + + +class TestCheckMaxTokens: + """Unit tests for EmbedIOProcessor._check_cohere_max_tokens.""" + + @staticmethod + def _fake_output(n_tokens: int): + class _Out: + def __init__(self, n: int): + self.prompt_token_ids = list(range(n)) + + return _Out(n_tokens) + + def test_none_check_is_noop(self): + outs = [self._fake_output(100)] + EmbedIOProcessor._check_cohere_max_tokens(outs, None) + + def test_within_limit(self): + outs = [self._fake_output(5), self._fake_output(3)] + EmbedIOProcessor._check_cohere_max_tokens(outs, 5) + + def test_exceeds_limit(self): + outs = [self._fake_output(3), self._fake_output(10)] + with pytest.raises(ValueError, match="exceeds max_tokens=5"): + EmbedIOProcessor._check_cohere_max_tokens(outs, 5) + + def test_exact_limit(self): + outs = [self._fake_output(5)] + EmbedIOProcessor._check_cohere_max_tokens(outs, 5) + + +class TestValidateInputType: + """Unit tests for EmbedIOProcessor._validate_input_type.""" + + @staticmethod + def _make_handler(task_instructions: dict[str, str] | None): + handler = object.__new__(EmbedIOProcessor) + handler.task_instructions = task_instructions + return handler + + def test_none_input_type_always_accepted(self): + handler = self._make_handler(None) + handler._validate_input_type(None) + handler_with = self._make_handler({"query": "q: "}) + handler_with._validate_input_type(None) + + def test_no_prompts_rejects(self): + handler = self._make_handler(None) + with pytest.raises(ValueError, match="does not define any input_type"): + handler._validate_input_type("anything") + + def test_known_type_accepted(self): + handler = self._make_handler({"query": "q: ", "document": "d: "}) + handler._validate_input_type("query") + handler._validate_input_type("document") + + def test_unknown_type_rejected(self): + handler = self._make_handler({"query": "q: ", "document": "d: "}) + with pytest.raises(ValueError, match="Unsupported input_type 'other'"): + handler._validate_input_type("other") + + def test_error_lists_supported(self): + handler = self._make_handler({"a": "", "b": ""}) + with pytest.raises(ValueError, match="Supported values: a, b"): + handler._validate_input_type("z") diff --git a/tests/entrypoints/pooling/embed/test_protocol.py b/tests/entrypoints/pooling/embed/test_protocol.py new file mode 100644 index 000000000..f2bd5d2cc --- /dev/null +++ b/tests/entrypoints/pooling/embed/test_protocol.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for Cohere embed protocol: build_typed_embeddings and its +underlying packing helpers, plus Cohere-specific serving helpers.""" + +import base64 +import struct + +import numpy as np +import pytest + +from vllm.entrypoints.pooling.embed.protocol import ( + build_typed_embeddings, +) + + +@pytest.fixture +def sample_embeddings() -> list[list[float]]: + return [ + [0.1, -0.2, 0.3, -0.4, 0.5, -0.6, 0.7, -0.8], + [-0.05, 0.15, -0.25, 0.35, -0.45, 0.55, -0.65, 0.75], + ] + + +class TestBuildTypedEmbeddingsFloat: + def test_float_passthrough(self, sample_embeddings: list[list[float]]): + result = build_typed_embeddings(sample_embeddings, ["float"]) + assert result.float == sample_embeddings + assert result.binary is None + + def test_empty_input(self): + result = build_typed_embeddings([], ["float"]) + assert result.float == [] + + +class TestBuildTypedEmbeddingsBinary: + def test_binary_packing(self): + # 8 values: positive->1, negative->0 => bits: 10101010 = 0xAA = 170 + # signed: 170 - 128 = 42 + embs = [[1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0]] + result = build_typed_embeddings(embs, ["binary"]) + assert result.binary is not None + assert result.binary[0] == [42] + + def test_ubinary_packing(self): + embs = [[1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0]] + result = build_typed_embeddings(embs, ["ubinary"]) + assert result.ubinary is not None + assert result.ubinary[0] == [170] # 0b10101010 + + def test_binary_all_positive(self): + embs = [[0.1] * 8] + result = build_typed_embeddings(embs, ["binary"]) + assert result.binary is not None + # all bits = 1 => 0xFF = 255, signed: 255 - 128 = 127 + assert result.binary[0] == [127] + + def test_binary_all_negative(self): + embs = [[-0.1] * 8] + result = build_typed_embeddings(embs, ["binary"]) + assert result.binary is not None + # all bits = 0, signed: 0 - 128 = -128 + assert result.binary[0] == [-128] + + def test_binary_dimension_is_eighth(self, sample_embeddings: list[list[float]]): + result = build_typed_embeddings(sample_embeddings, ["binary"]) + assert result.binary is not None + for orig, packed in zip(sample_embeddings, result.binary): + assert len(packed) == len(orig) // 8 + + def test_zero_treated_as_positive(self): + embs = [[0.0] * 8] + result = build_typed_embeddings(embs, ["binary"]) + assert result.binary is not None + # 0.0 >= 0 is True, so bit=1 for all => 127 (signed) + assert result.binary[0] == [127] + + def test_non_multiple_of_8_raises(self): + embs = [[0.1] * 7] + with pytest.raises(ValueError, match="multiple of 8"): + build_typed_embeddings(embs, ["binary"]) + + def test_ubinary_non_multiple_of_8_raises(self): + embs = [[0.1] * 10] + with pytest.raises(ValueError, match="multiple of 8"): + build_typed_embeddings(embs, ["ubinary"]) + + +class TestBuildTypedEmbeddingsBase64: + def test_base64_roundtrip(self, sample_embeddings: list[list[float]]): + result = build_typed_embeddings(sample_embeddings, ["base64"]) + assert result.base64 is not None + assert len(result.base64) == 2 + + for orig, b64_str in zip(sample_embeddings, result.base64): + decoded = base64.b64decode(b64_str) + n = len(orig) + values = struct.unpack(f"<{n}f", decoded) + np.testing.assert_allclose(orig, values, rtol=1e-5) + + def test_base64_byte_length(self): + embs = [[0.1, 0.2, 0.3]] + result = build_typed_embeddings(embs, ["base64"]) + assert result.base64 is not None + raw = base64.b64decode(result.base64[0]) + assert len(raw) == 3 * 4 # 3 floats * 4 bytes each + + +class TestBuildTypedEmbeddingsMultiple: + def test_all_types_at_once(self, sample_embeddings: list[list[float]]): + result = build_typed_embeddings( + sample_embeddings, + ["float", "binary", "ubinary", "base64"], + ) + assert result.float is not None + assert result.binary is not None + assert result.ubinary is not None + assert result.base64 is not None + + def test_subset_types(self, sample_embeddings: list[list[float]]): + result = build_typed_embeddings(sample_embeddings, ["float", "binary"]) + assert result.float is not None + assert result.binary is not None + assert result.ubinary is None + assert result.base64 is None + + def test_unknown_type_ignored(self, sample_embeddings: list[list[float]]): + result = build_typed_embeddings(sample_embeddings, ["float", "unknown_type"]) + assert result.float is not None diff --git a/vllm/entrypoints/pooling/base/protocol.py b/vllm/entrypoints/pooling/base/protocol.py index 50be58374..2f547df8d 100644 --- a/vllm/entrypoints/pooling/base/protocol.py +++ b/vllm/entrypoints/pooling/base/protocol.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Annotated, Any +from typing import Annotated, Any, Literal from pydantic import Field, model_validator @@ -24,6 +24,14 @@ class PoolingBasicRequestMixin(OpenAIBaseModel): # --8<-- [start:pooling-common-extra-params] truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None + truncation_side: Literal["left", "right"] | None = Field( + default=None, + description=( + "Which side to truncate from when truncate_prompt_tokens is active. " + "'right' keeps the first N tokens. " + "'left' keeps the last N tokens." + ), + ) request_id: str = Field( default_factory=random_uuid, description=( diff --git a/vllm/entrypoints/pooling/classify/protocol.py b/vllm/entrypoints/pooling/classify/protocol.py index bfc38ebef..fe8c898e0 100644 --- a/vllm/entrypoints/pooling/classify/protocol.py +++ b/vllm/entrypoints/pooling/classify/protocol.py @@ -32,6 +32,7 @@ class ClassificationCompletionRequest( max_total_tokens=model_config.max_model_len, max_output_tokens=0, truncate_prompt_tokens=self.truncate_prompt_tokens, + truncation_side=self.truncation_side, do_lower_case=encoder_config.get("do_lower_case", False), add_special_tokens=self.add_special_tokens, max_total_tokens_param="max_model_len", @@ -54,6 +55,7 @@ class ClassificationChatRequest( max_total_tokens=model_config.max_model_len, max_output_tokens=0, truncate_prompt_tokens=self.truncate_prompt_tokens, + truncation_side=self.truncation_side, do_lower_case=encoder_config.get("do_lower_case", False), add_special_tokens=self.add_special_tokens, max_total_tokens_param="max_model_len", diff --git a/vllm/entrypoints/pooling/embed/api_router.py b/vllm/entrypoints/pooling/embed/api_router.py index f88999468..390efc6a1 100644 --- a/vllm/entrypoints/pooling/embed/api_router.py +++ b/vllm/entrypoints/pooling/embed/api_router.py @@ -7,12 +7,12 @@ from fastapi import APIRouter, Depends, Request from vllm.entrypoints.openai.engine.protocol import ErrorResponse from vllm.entrypoints.openai.utils import validate_json_request -from vllm.entrypoints.pooling.embed.protocol import EmbeddingRequest -from vllm.entrypoints.pooling.embed.serving import ServingEmbedding -from vllm.entrypoints.utils import ( - load_aware_call, - with_cancellation, +from vllm.entrypoints.pooling.embed.protocol import ( + CohereEmbedRequest, + EmbeddingRequest, ) +from vllm.entrypoints.pooling.embed.serving import ServingEmbedding +from vllm.entrypoints.utils import load_aware_call, with_cancellation router = APIRouter() @@ -40,3 +40,24 @@ async def create_embedding( raise NotImplementedError("The model does not support Embeddings API") return await handler(request, raw_request) + + +@router.post( + "/v2/embed", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) +@with_cancellation +@load_aware_call +async def create_cohere_embedding( + request: CohereEmbedRequest, + raw_request: Request, +): + handler = embedding(raw_request) + if handler is None: + raise NotImplementedError("The model does not support Embeddings API") + + return await handler(request, raw_request) diff --git a/vllm/entrypoints/pooling/embed/io_processor.py b/vllm/entrypoints/pooling/embed/io_processor.py index 22ece7542..9342013bf 100644 --- a/vllm/entrypoints/pooling/embed/io_processor.py +++ b/vllm/entrypoints/pooling/embed/io_processor.py @@ -1,14 +1,37 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, cast +from collections.abc import Sequence +from typing import Any, Literal, cast import torch - +from openai.types.chat import ( + ChatCompletionContentPartImageParam, + ChatCompletionContentPartTextParam, +) +from openai.types.chat.chat_completion_content_part_image_param import ImageURL + +from vllm import PoolingParams +from vllm.entrypoints.chat_utils import ( + ChatCompletionContentPartParam, + ChatCompletionMessageParam, + CustomChatCompletionMessageParam, +) from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor +from vllm.entrypoints.pooling.embed.protocol import ( + CohereEmbedInput, + CohereEmbedRequest, + EmbeddingChatRequest, + EmbeddingCompletionRequest, +) from vllm.entrypoints.pooling.typing import PoolingServeContext from vllm.inputs.data import ProcessorInputs, token_inputs +from vllm.logger import init_logger from vllm.outputs import PoolingOutput, PoolingRequestOutput +from vllm.renderers import merge_kwargs from vllm.utils.collection_utils import chunk_list +from vllm.utils.mistral import is_mistral_tokenizer + +logger = init_logger(__name__) class EmbedIOProcessor(PoolingIOProcessor): @@ -21,16 +44,45 @@ class EmbedIOProcessor(PoolingIOProcessor): self.pooler_config = self.model_config.pooler_config self.enable_chunked_processing = self.pooler_config.enable_chunked_processing - ################################################################# - # Long Text Embedding with Chunked Processing - # PTAL: examples/pooling/embed/openai_embedding_long_text + # Load task instructions from HF config or sentence-transformers config + self.task_instructions: dict[str, str] | None = self._load_task_instructions( + self.model_config.hf_config + ) or self._load_st_prompts(self.model_config.model, self.model_config.revision) + if self.task_instructions: + logger.info( + "Loaded prompt prefixes for input_type: %s", + list(self.task_instructions.keys()), + ) def pre_process_online(self, ctx: PoolingServeContext): - super().pre_process_online(ctx) + if isinstance(ctx.request, CohereEmbedRequest): + self._pre_process_cohere_online(ctx) + else: + super().pre_process_online(ctx) + + if self.enable_chunked_processing: + self._pre_process_chunked(ctx) + + def post_process_online( + self, + ctx: PoolingServeContext, + ): + if ctx.final_res_batch is None: + raise ValueError("Final response batch not available") if not self.enable_chunked_processing: - return None + self._enforce_cohere_max_tokens(ctx) + return super().post_process_online(ctx) + self._post_process_chunked(ctx) + self._enforce_cohere_max_tokens(ctx) + + ################################################################# + # Long Text Embedding with Chunked Processing + # PTAL: examples/pooling/embed/openai_embedding_long_text + ################################################################# + + def _pre_process_chunked(self, ctx: PoolingServeContext) -> None: if ctx.engine_prompts is None: raise ValueError("Engine prompts not available") @@ -61,18 +113,10 @@ class EmbedIOProcessor(PoolingIOProcessor): ctx.engine_prompts = chunked_engine_prompts ctx.prompt_request_ids = prompt_request_ids - return None - def post_process_online( - self, - ctx: PoolingServeContext, - ): - if ctx.final_res_batch is None: - raise ValueError("Final response batch not available") - - if not self.enable_chunked_processing: - return super().post_process_online(ctx) + return None + def _post_process_chunked(self, ctx: PoolingServeContext) -> None: # Online aggregation for chunked requests to # minimize memory usage # Track aggregation state for each prompt @@ -195,4 +239,245 @@ class EmbedIOProcessor(PoolingIOProcessor): raise ValueError(f"Result not found for prompt {prompt_idx}") ctx.final_res_batch = final_res_batch + return None + + ################################################################# + # Cohere Request Preprocessing & Postprocessing + ################################################################# + + @staticmethod + def _load_task_instructions(hf_config: Any) -> dict[str, str] | None: + """Extract ``task_instructions`` from the HF model config.""" + ti = getattr(hf_config, "task_instructions", None) + if not isinstance(ti, dict) or not ti: + return None + return {k: v for k, v in ti.items() if isinstance(v, str)} + + @staticmethod + def _load_st_prompts( + model: str | Any, + revision: str | None, + ) -> dict[str, str] | None: + """Load ``task_instructions`` from ``config_sentence_transformers.json``.""" + from vllm.transformers_utils.repo_utils import get_hf_file_to_dict + + try: + cfg = get_hf_file_to_dict( + "config_sentence_transformers.json", str(model), revision + ) + except (ValueError, OSError): + return None + + if cfg is None: + return None + prompts = cfg.get("prompts") + if not isinstance(prompts, dict) or not prompts: + return None + return {k: v for k, v in prompts.items() if isinstance(v, str)} + + @staticmethod + def _mixed_input_to_messages( + inp: CohereEmbedInput, + *, + task_prefix: str | None = None, + ) -> list[ChatCompletionMessageParam]: + """Build chat messages from a mixed text+image input. + + When *task_prefix* is given, it is prepended to each text part. + """ + parts: list[ChatCompletionContentPartParam] = [] + for item in inp.content: + if item.type == "text" and item.text is not None: + text = task_prefix + item.text if task_prefix else item.text + parts.append(ChatCompletionContentPartTextParam(type="text", text=text)) + elif item.type == "image_url" and item.image_url is not None: + parts.append( + ChatCompletionContentPartImageParam( + type="image_url", + image_url=ImageURL(url=item.image_url["url"]), + ) + ) + return [CustomChatCompletionMessageParam(role="user", content=parts)] + + @staticmethod + def _check_cohere_max_tokens( + outputs: list[PoolingRequestOutput], + max_tokens_check: int | None, + ) -> None: + """Raise if any output exceeds *max_tokens_check* tokens. + + Used to enforce ``truncate=NONE`` with an explicit ``max_tokens``: + the pipeline runs without truncation and we reject afterwards. + """ + if max_tokens_check is None: + return + for out in outputs: + n = len(out.prompt_token_ids) + if n > max_tokens_check: + raise ValueError( + f"Input of {n} tokens exceeds max_tokens={max_tokens_check} " + "with truncate=NONE. Set truncate to END or START to " + "allow truncation." + ) + + @staticmethod + def _resolve_cohere_truncation( + request: CohereEmbedRequest, + ) -> tuple[int | None, Literal["left", "right"] | None]: + """Return ``(truncate_prompt_tokens, truncation_side)``.""" + if request.truncate == "NONE": + return None, None + if request.truncate == "START": + tokens = request.max_tokens if request.max_tokens is not None else -1 + return tokens, "left" + if request.max_tokens is not None: + return request.max_tokens, None + return -1, None + + def create_pooling_params(self, request): + if isinstance(request, CohereEmbedRequest): + return PoolingParams( + task="embed", + dimensions=request.output_dimension, + ) + return super().create_pooling_params(request) + + def _pre_process_cohere_online(self, ctx: PoolingServeContext) -> None: + """Convert a ``CohereEmbedRequest`` into engine prompts. + + For texts, a single batched completion request path is used. + For images and mixed inputs, conversations are batch-rendered + through the chat template in one ``render_chat`` call. + """ + request = ctx.request + assert isinstance(request, CohereEmbedRequest) + + if request.texts is None and request.images is None and request.inputs is None: + raise ValueError("One of texts, images, or inputs must be provided") + + truncate_prompt_tokens, truncation_side = self._resolve_cohere_truncation( + request + ) + input_type = request.input_type + self._validate_input_type(input_type) + + if request.images is not None: + all_messages: list[list[ChatCompletionMessageParam]] = [ + [ + CustomChatCompletionMessageParam( + role="user", + content=[{"type": "image_url", "image_url": {"url": uri}}], + ) + ] + for uri in request.images + ] + ctx.engine_prompts = self._batch_render_chat( + request, all_messages, truncate_prompt_tokens, truncation_side + ) + + elif request.inputs is not None: + task_prefix = self._get_task_instruction_prefix(input_type) + all_messages = [ + self._mixed_input_to_messages(inp, task_prefix=task_prefix) + for inp in request.inputs + ] + ctx.engine_prompts = self._batch_render_chat( + request, all_messages, truncate_prompt_tokens, truncation_side + ) + + else: + prefixed = self._apply_task_instruction(request.texts or [], input_type) + proxy = EmbeddingCompletionRequest( + model=request.model, + input=prefixed, + dimensions=request.output_dimension, + encoding_format="float", + truncate_prompt_tokens=truncate_prompt_tokens, + truncation_side=truncation_side, + ) + ctx.engine_prompts = self._preprocess_completion_online( + proxy, prompt_input=proxy.input, prompt_embeds=None + ) + + def _batch_render_chat( + self, + request: CohereEmbedRequest, + all_messages: Sequence[list[ChatCompletionMessageParam]], + truncate_prompt_tokens: int | None, + truncation_side: Literal["left", "right"] | None, + ) -> list[ProcessorInputs]: + """Batch-render multiple conversations through the chat template.""" + if not all_messages: + return [] + + proxy = EmbeddingChatRequest( + model=request.model, + messages=list(all_messages[0]), + dimensions=request.output_dimension, + encoding_format="float", + truncate_prompt_tokens=truncate_prompt_tokens, + truncation_side=truncation_side, + ) + + renderer = self.renderer + mm_config = self.model_config.multimodal_config + + tok_params = proxy.build_tok_params(self.model_config) + chat_params = proxy.build_chat_params( + self.chat_template, + self.chat_template_content_format, + ).with_defaults( + merge_kwargs( + None, + dict( + tools=None, + tokenize=is_mistral_tokenizer(renderer.tokenizer), + ), + ), + default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None), + ) + + _, engine_prompts = renderer.render_chat(all_messages, chat_params, tok_params) + return engine_prompts + + def _validate_input_type(self, input_type: str | None) -> None: + """Raise if *input_type* is not supported by this model.""" + if input_type is None: + return + if self.task_instructions is None: + raise ValueError( + f"Unsupported input_type {input_type!r}. " + "This model does not define any input_type task instructions." + ) + if input_type not in self.task_instructions: + supported = ", ".join(sorted(self.task_instructions)) + raise ValueError( + f"Unsupported input_type {input_type!r}. Supported values: {supported}" + ) + + def _apply_task_instruction( + self, + texts: list[str], + input_type: str | None, + ) -> list[str]: + """Prepend the task-instruction prefix for *input_type*. + + Returns *texts* unchanged when no matching prefix is configured. + """ + prefix = self._get_task_instruction_prefix(input_type) + if not prefix: + return texts + return [prefix + t for t in texts] + + def _get_task_instruction_prefix(self, input_type: str | None) -> str | None: + """Return the task-instruction prefix for *input_type*, or ``None``.""" + if not self.task_instructions or input_type is None: + return None + return self.task_instructions.get(input_type) or None + + def _enforce_cohere_max_tokens(self, ctx: PoolingServeContext) -> None: + if isinstance(ctx.request, CohereEmbedRequest): + request = ctx.request + if request.truncate == "NONE" and request.max_tokens is not None: + self._check_cohere_max_tokens(ctx.final_res_batch, request.max_tokens) diff --git a/vllm/entrypoints/pooling/embed/protocol.py b/vllm/entrypoints/pooling/embed/protocol.py index 4b47c6522..b02f91dfa 100644 --- a/vllm/entrypoints/pooling/embed/protocol.py +++ b/vllm/entrypoints/pooling/embed/protocol.py @@ -1,9 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Embedding API protocol models for OpenAI and Cohere formats. + +OpenAI: https://platform.openai.com/docs/api-reference/embeddings +Cohere: https://docs.cohere.com/reference/embed +""" + +import base64 +import builtins +import struct import time -from typing import TypeAlias +from collections.abc import Sequence +from typing import Literal, TypeAlias -from pydantic import Field +from pydantic import BaseModel, Field from vllm import PoolingParams from vllm.config import ModelConfig @@ -17,6 +27,10 @@ from vllm.entrypoints.pooling.base.protocol import ( from vllm.renderers import TokenizeParams from vllm.utils import random_uuid +# --------------------------------------------------------------------------- +# OpenAI /v1/embeddings — request models +# --------------------------------------------------------------------------- + def _get_max_total_output_tokens( model_config: ModelConfig, @@ -50,6 +64,7 @@ class EmbeddingCompletionRequest( max_total_tokens=max_total_tokens, max_output_tokens=max_output_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens, + truncation_side=self.truncation_side, do_lower_case=encoder_config.get("do_lower_case", False), add_special_tokens=self.add_special_tokens, max_total_tokens_param="max_model_len", @@ -79,6 +94,7 @@ class EmbeddingChatRequest( max_total_tokens=max_total_tokens, max_output_tokens=max_output_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens, + truncation_side=self.truncation_side, do_lower_case=encoder_config.get("do_lower_case", False), add_special_tokens=self.add_special_tokens, max_total_tokens_param="max_model_len", @@ -96,6 +112,11 @@ class EmbeddingChatRequest( EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest +# --------------------------------------------------------------------------- +# OpenAI /v1/embeddings — response models +# --------------------------------------------------------------------------- + + class EmbeddingResponseData(OpenAIBaseModel): index: int object: str = "embedding" @@ -106,7 +127,7 @@ class EmbeddingResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"embd-{random_uuid()}") object: str = "list" created: int = Field(default_factory=lambda: int(time.time())) - model: str + model: str | None = None data: list[EmbeddingResponseData] usage: UsageInfo @@ -115,3 +136,146 @@ class EmbeddingBytesResponse(OpenAIBaseModel): content: list[bytes] headers: dict[str, str] | None = None media_type: str = "application/octet-stream" + + +# --------------------------------------------------------------------------- +# Cohere /v2/embed — request models +# --------------------------------------------------------------------------- + +CohereEmbeddingType = Literal[ + "float", + "binary", + "ubinary", + "base64", +] +CohereTruncate = Literal["NONE", "START", "END"] + + +class CohereEmbedContent(BaseModel): + type: Literal["text", "image_url"] + text: str | None = None + image_url: dict[str, str] | None = None + + +class CohereEmbedInput(BaseModel): + content: list[CohereEmbedContent] + + +class CohereEmbedRequest(BaseModel): + model: str | None = None + input_type: str | None = None + texts: list[str] | None = None + images: list[str] | None = None + inputs: list[CohereEmbedInput] | None = None + output_dimension: int | None = None + embedding_types: list[CohereEmbeddingType] | None = None + truncate: CohereTruncate = "END" + max_tokens: int | None = None + priority: int = 0 + + +# --------------------------------------------------------------------------- +# Cohere /v2/embed — response models +# --------------------------------------------------------------------------- + + +class CohereApiVersion(BaseModel): + version: str = "2" + + +class CohereBilledUnits(BaseModel): + input_tokens: int | None = None + image_tokens: int | None = None + + +class CohereMeta(BaseModel): + api_version: CohereApiVersion = Field(default_factory=CohereApiVersion) + billed_units: CohereBilledUnits | None = None + + +class CohereEmbedByTypeEmbeddings(BaseModel): + # The field name ``float`` shadows the builtin type, so the annotation + # must use ``builtins.float`` to avoid a self-referential type error. + float: list[list[builtins.float]] | None = None + binary: list[list[int]] | None = None + ubinary: list[list[int]] | None = None + base64: list[str] | None = None + + +class CohereEmbedResponse(BaseModel): + id: str = Field(default_factory=lambda: f"embd-{random_uuid()}") + embeddings: CohereEmbedByTypeEmbeddings + texts: list[str] | None = None + meta: CohereMeta | None = None + response_type: Literal["embeddings_by_type"] = "embeddings_by_type" + + +# --------------------------------------------------------------------------- +# Cohere embedding type conversion helpers +# --------------------------------------------------------------------------- + +_UNSIGNED_TO_SIGNED_DIFF = 1 << 7 # 128 + + +def _pack_binary_embeddings( + float_embeddings: list[list[float]], + signed: bool, +) -> list[list[int]]: + """Bit-pack float embeddings: positive -> 1, negative -> 0. + + Each bit is shifted left by ``7 - idx%8``, and every 8 bits are packed + into one byte. + """ + result: list[list[int]] = [] + for embedding in float_embeddings: + dim = len(embedding) + if dim % 8 != 0: + raise ValueError( + "Embedding dimension must be a multiple of 8 for binary " + f"embedding types, but got {dim}." + ) + packed_len = dim // 8 + packed: list[int] = [] + byte_val = 0 + for idx, value in enumerate(embedding): + bit = 1 if value >= 0 else 0 + byte_val += bit << (7 - idx % 8) + if (idx + 1) % 8 == 0: + if signed: + byte_val -= _UNSIGNED_TO_SIGNED_DIFF + packed.append(byte_val) + byte_val = 0 + assert len(packed) == packed_len + result.append(packed) + return result + + +def _encode_base64_embeddings( + float_embeddings: list[list[float]], +) -> list[str]: + """Encode float embeddings as base64 (little-endian float32).""" + result: list[str] = [] + for embedding in float_embeddings: + buf = struct.pack(f"<{len(embedding)}f", *embedding) + result.append(base64.b64encode(buf).decode("utf-8")) + return result + + +def build_typed_embeddings( + float_embeddings: list[list[float]], + embedding_types: Sequence[str], +) -> CohereEmbedByTypeEmbeddings: + """Convert float embeddings to all requested Cohere embedding types.""" + result = CohereEmbedByTypeEmbeddings() + + for emb_type in embedding_types: + if emb_type == "float": + result.float = float_embeddings + elif emb_type == "binary": + result.binary = _pack_binary_embeddings(float_embeddings, signed=True) + elif emb_type == "ubinary": + result.ubinary = _pack_binary_embeddings(float_embeddings, signed=False) + elif emb_type == "base64": + result.base64 = _encode_base64_embeddings(float_embeddings) + + return result diff --git a/vllm/entrypoints/pooling/embed/serving.py b/vllm/entrypoints/pooling/embed/serving.py index c4ecf2683..f0c331645 100644 --- a/vllm/entrypoints/pooling/embed/serving.py +++ b/vllm/entrypoints/pooling/embed/serving.py @@ -5,7 +5,7 @@ from collections.abc import Callable from functools import partial from typing import Literal, TypeAlias, cast -from fastapi.responses import JSONResponse, StreamingResponse +from fastapi.responses import JSONResponse, Response, StreamingResponse from typing_extensions import assert_never from vllm.config import ModelConfig @@ -14,10 +14,15 @@ from vllm.entrypoints.openai.engine.protocol import UsageInfo from vllm.entrypoints.pooling.base.serving import PoolingServing from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor from vllm.entrypoints.pooling.embed.protocol import ( + CohereBilledUnits, + CohereEmbedRequest, + CohereEmbedResponse, + CohereMeta, EmbeddingBytesResponse, EmbeddingRequest, EmbeddingResponse, EmbeddingResponseData, + build_typed_embeddings, ) from vllm.entrypoints.pooling.typing import PoolingServeContext from vllm.entrypoints.pooling.utils import ( @@ -26,24 +31,23 @@ from vllm.entrypoints.pooling.utils import ( encode_pooling_output_float, get_json_response_cls, ) +from vllm.logger import init_logger from vllm.outputs import PoolingRequestOutput from vllm.renderers import BaseRenderer from vllm.utils.serial_utils import EmbedDType, Endianness +logger = init_logger(__name__) + JSONResponseCLS = get_json_response_cls() EmbeddingServeContext: TypeAlias = PoolingServeContext[EmbeddingRequest] class ServingEmbedding(PoolingServing): - """ - Embedding API similar to OpenAI's API. - - See https://platform.openai.com/docs/api-reference/embeddings/create - for the API specification. This API mimics the OpenAI Embedding API. - """ + """Embedding API supporting both OpenAI and Cohere formats.""" request_id_prefix = "embd" + io_processor: EmbedIOProcessor def init_io_processor( self, @@ -58,6 +62,14 @@ class ServingEmbedding(PoolingServing): ) async def _build_response( + self, + ctx: PoolingServeContext, + ) -> Response: + if isinstance(ctx.request, CohereEmbedRequest): + return self._build_cohere_response_from_ctx(ctx) + return await self._build_openai_response(ctx) + + async def _build_openai_response( self, ctx: EmbeddingServeContext, ) -> JSONResponse | StreamingResponse: @@ -66,7 +78,7 @@ class ServingEmbedding(PoolingServing): endianness = ctx.request.endianness if encoding_format == "float" or encoding_format == "base64": - return self._request_output_to_embed_json_response( + return self._openai_json_response( ctx.final_res_batch, ctx.request_id, ctx.created_time, @@ -77,7 +89,7 @@ class ServingEmbedding(PoolingServing): ) if encoding_format == "bytes" or encoding_format == "bytes_only": - return self._request_output_to_to_embed_bytes_response( + return self._openai_bytes_response( ctx.final_res_batch, ctx.request_id, ctx.created_time, @@ -89,7 +101,7 @@ class ServingEmbedding(PoolingServing): assert_never(encoding_format) - def _request_output_to_embed_json_response( + def _openai_json_response( self, final_res_batch: list[PoolingRequestOutput], request_id: str, @@ -139,7 +151,7 @@ class ServingEmbedding(PoolingServing): ) return JSONResponseCLS(content=response.model_dump()) - def _request_output_to_to_embed_bytes_response( + def _openai_bytes_response( self, final_res_batch: list[PoolingRequestOutput], request_id: str, @@ -177,3 +189,33 @@ class ServingEmbedding(PoolingServing): headers=response.headers, media_type=response.media_type, ) + + @staticmethod + def _build_cohere_response_from_ctx( + ctx: PoolingServeContext, + ) -> JSONResponse: + request = ctx.request + assert isinstance(request, CohereEmbedRequest) + + all_floats = [encode_pooling_output_float(out) for out in ctx.final_res_batch] + total_tokens = sum(len(out.prompt_token_ids) for out in ctx.final_res_batch) + + image_tokens = total_tokens if request.images is not None else 0 + texts_echo = request.texts + + embedding_types = request.embedding_types or ["float"] + embeddings_obj = build_typed_embeddings(all_floats, embedding_types) + + input_tokens = total_tokens - image_tokens + response = CohereEmbedResponse( + id=ctx.request_id, + embeddings=embeddings_obj, + texts=texts_echo, + meta=CohereMeta( + billed_units=CohereBilledUnits( + input_tokens=input_tokens, + image_tokens=image_tokens, + ), + ), + ) + return JSONResponse(content=response.model_dump(exclude_none=True)) diff --git a/vllm/entrypoints/pooling/pooling/protocol.py b/vllm/entrypoints/pooling/pooling/protocol.py index b99f98959..098690db2 100644 --- a/vllm/entrypoints/pooling/pooling/protocol.py +++ b/vllm/entrypoints/pooling/pooling/protocol.py @@ -36,6 +36,7 @@ class PoolingCompletionRequest( max_total_tokens=model_config.max_model_len, max_output_tokens=0, truncate_prompt_tokens=self.truncate_prompt_tokens, + truncation_side=self.truncation_side, do_lower_case=encoder_config.get("do_lower_case", False), add_special_tokens=self.add_special_tokens, max_total_tokens_param="max_model_len", @@ -61,6 +62,7 @@ class PoolingChatRequest( max_total_tokens=model_config.max_model_len, max_output_tokens=0, truncate_prompt_tokens=self.truncate_prompt_tokens, + truncation_side=self.truncation_side, do_lower_case=encoder_config.get("do_lower_case", False), add_special_tokens=self.add_special_tokens, max_total_tokens_param="max_model_len", @@ -88,6 +90,7 @@ class IOProcessorRequest(PoolingBasicRequestMixin, EncodingRequestMixin, Generic max_total_tokens=model_config.max_model_len, max_output_tokens=0, truncate_prompt_tokens=self.truncate_prompt_tokens, + truncation_side=self.truncation_side, do_lower_case=encoder_config.get("do_lower_case", False), add_special_tokens=not model_config.is_encoder_decoder, max_total_tokens_param="max_model_len", diff --git a/vllm/entrypoints/pooling/score/protocol.py b/vllm/entrypoints/pooling/score/protocol.py index 643eeed36..2aea1bd7b 100644 --- a/vllm/entrypoints/pooling/score/protocol.py +++ b/vllm/entrypoints/pooling/score/protocol.py @@ -30,6 +30,7 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin): max_total_tokens=model_config.max_model_len, max_output_tokens=0, truncate_prompt_tokens=self.truncate_prompt_tokens, + truncation_side=self.truncation_side, do_lower_case=encoder_config.get("do_lower_case", False), max_total_tokens_param="max_model_len", ) @@ -105,6 +106,7 @@ class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin): max_total_tokens=model_config.max_model_len, max_output_tokens=0, truncate_prompt_tokens=self.truncate_prompt_tokens, + truncation_side=self.truncation_side, do_lower_case=encoder_config.get("do_lower_case", False), max_total_tokens_param="max_model_len", ) diff --git a/vllm/entrypoints/pooling/typing.py b/vllm/entrypoints/pooling/typing.py index 74ed9b50c..f9f361824 100644 --- a/vllm/entrypoints/pooling/typing.py +++ b/vllm/entrypoints/pooling/typing.py @@ -15,6 +15,7 @@ from vllm.entrypoints.pooling.classify.protocol import ( ClassificationResponse, ) from vllm.entrypoints.pooling.embed.protocol import ( + CohereEmbedRequest, EmbeddingBytesResponse, EmbeddingChatRequest, EmbeddingCompletionRequest, @@ -50,6 +51,7 @@ AnyPoolingRequest: TypeAlias = ( | IOProcessorRequest | RerankRequest | ScoreRequest + | CohereEmbedRequest ) AnyPoolingResponse: TypeAlias = ( diff --git a/vllm/renderers/params.py b/vllm/renderers/params.py index 54da0f3b5..a2c95690c 100644 --- a/vllm/renderers/params.py +++ b/vllm/renderers/params.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, TypeVar +from typing import TYPE_CHECKING, Any, Literal, TypeVar from vllm.exceptions import VLLMValidationError from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt @@ -153,6 +153,14 @@ class TokenizeParams: - `-1` maps to `max_input_tokens`. """ + truncation_side: Literal["left", "right"] | None = None + """ + Which side to truncate from when ``truncate_prompt_tokens`` is active: + - ``"right"`` keeps the first N tokens (truncate from the end). + - ``"left"`` keeps the last N tokens (truncate from the start). + - ``None`` falls back to the tokenizer default. + """ + do_lower_case: bool = False """Whether to normalize text to lower case before tokenization.""" @@ -271,6 +279,7 @@ class TokenizeParams: ), pad_prompt_tokens=pad_prompt_tokens, truncate_prompt_tokens=truncate_prompt_tokens, + truncation_side=self.truncation_side, do_lower_case=do_lower_case, add_special_tokens=add_special_tokens, needs_detokenization=needs_detokenization, @@ -286,6 +295,16 @@ class TokenizeParams: # while still failing `self._token_len_check` as expected by users max_length = self.max_input_tokens + 1 + # Left-side truncation requires the full token sequence so we can + # slice from the end in _token_truncation. Disable HF-level + # truncation (which would incorrectly truncate from the right for + # pooling models) and let _token_truncation handle it. + if self.truncation_side == "left": + return dict( + truncation=False, + add_special_tokens=self.add_special_tokens, + ) + return dict( truncation=max_length is not None, max_length=max_length, @@ -375,7 +394,10 @@ class TokenizeParams: if max_length == 0: return tokens[:0] - if getattr(tokenizer, "truncation_side", "left") == "left": + side = self.truncation_side or ( + tokenizer.truncation_side if tokenizer is not None else None + ) + if side == "left": return tokens[-max_length:] return tokens[:max_length] -- GitLab From 5db91f0aaf3566b1d9f8b0720065eb9009296d98 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Tue, 17 Mar 2026 01:08:56 +0100 Subject: [PATCH 012/223] Fix some Mistral parser issues (#37209) Signed-off-by: juliendenize --- .../openai/chat_completion/serving.py | 13 +++-- vllm/tokenizers/mistral.py | 53 ++++++++++--------- vllm/tool_parsers/mistral_tool_parser.py | 10 ++-- 3 files changed, 42 insertions(+), 34 deletions(-) diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index 2eb550c3e..ad7982b61 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -310,11 +310,14 @@ class OpenAIServingChat(OpenAIServing): trace_headers=trace_headers, ) else: - reasoning_ended = ( - reasoning_parser.is_reasoning_end(prompt_token_ids or []) - if reasoning_parser - else None - ) + if not request.include_reasoning: + reasoning_ended = True + elif reasoning_parser: + reasoning_ended = reasoning_parser.is_reasoning_end( + prompt_token_ids or [] + ) + else: + reasoning_ended = None generator = self.engine_client.generate( engine_prompt, diff --git a/vllm/tokenizers/mistral.py b/vllm/tokenizers/mistral.py index ca61edeb8..e20f1edd4 100644 --- a/vllm/tokenizers/mistral.py +++ b/vllm/tokenizers/mistral.py @@ -15,8 +15,15 @@ from mistral_common.protocol.instruct.validator import ValidationMode from mistral_common.tokens.tokenizers.base import ( SpecialTokenPolicy, SpecialTokens, + Tokenizer, +) +from mistral_common.tokens.tokenizers.instruct import ( + InstructTokenizerBase, + InstructTokenizerV13, +) +from mistral_common.tokens.tokenizers.mistral import ( + MistralTokenizer as MistralCommonTokenizer, ) -from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13 from mistral_common.tokens.tokenizers.sentencepiece import ( SentencePieceTokenizer, ) @@ -26,21 +33,20 @@ from pydantic import ValidationError from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest from vllm.logger import init_logger +from vllm.tokenizers.protocol import TokenizerLike -from .protocol import TokenizerLike +try: + # Transformers v5 + from transformers.tokenization_mistral_common import MistralCommonBackend +except ImportError: + # Transformers v4 + from transformers.tokenization_mistral_common import ( + MistralCommonTokenizer as MistralCommonBackend, + ) if TYPE_CHECKING: from transformers import BatchEncoding - try: - # Transformers v5 - from transformers.tokenization_mistral_common import MistralCommonBackend - except ImportError: - # Transformers v4 - from transformers.tokenization_mistral_common import ( - MistralCommonTokenizer as MistralCommonBackend, - ) - logger = init_logger(__name__) @@ -235,15 +241,6 @@ class MistralTokenizer(TokenizerLike): download_dir: str | None = None, **kwargs, ) -> "MistralTokenizer": - try: - # Transformers v5 - from transformers.tokenization_mistral_common import MistralCommonBackend - except ImportError: - # Transformers v4 - from transformers.tokenization_mistral_common import ( - MistralCommonTokenizer as MistralCommonBackend, - ) - tokenizer = MistralCommonBackend.from_pretrained( path_or_repo_id, *args, @@ -255,13 +252,13 @@ class MistralTokenizer(TokenizerLike): return cls(tokenizer) - def __init__(self, tokenizer: "MistralCommonBackend") -> None: + def __init__(self, tokenizer: MistralCommonBackend) -> None: super().__init__() - self.transformers_tokenizer = tokenizer - self.mistral = tokenizer.tokenizer - self.instruct = self.mistral.instruct_tokenizer - self.tokenizer = self.instruct.tokenizer + self.transformers_tokenizer: MistralCommonBackend = tokenizer + self.mistral: MistralCommonTokenizer = tokenizer.tokenizer + self.instruct: InstructTokenizerBase = self.mistral.instruct_tokenizer + self.tokenizer: Tokenizer = self.instruct.tokenizer mode = self.mistral._chat_completion_request_validator._mode if mode != ValidationMode.test: @@ -483,7 +480,11 @@ class MistralTokenizer(TokenizerLike): return self.transformers_tokenizer.convert_tokens_to_ids(tokens) def convert_tokens_to_string(self, tokens: list[str]) -> str: - to_decode_special_tokens = {SpecialTokens.tool_calls} + to_decode_special_tokens = { + SpecialTokens.tool_calls, + SpecialTokens.begin_think, + SpecialTokens.end_think, + } if self.is_tekken: assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer) tokens = [ diff --git a/vllm/tool_parsers/mistral_tool_parser.py b/vllm/tool_parsers/mistral_tool_parser.py index baab4ade0..56ba245ce 100644 --- a/vllm/tool_parsers/mistral_tool_parser.py +++ b/vllm/tool_parsers/mistral_tool_parser.py @@ -241,7 +241,10 @@ class MistralToolParser(ToolParser): delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> DeltaMessage | None: - if self.bot_token_id not in current_token_ids: + has_bot_token = ( + self.bot_token_id in current_token_ids or self.bot_token in current_text + ) + if not has_bot_token: # if the tool call token is not in the tokens generated so far, # append output to contents since it's not a tool return DeltaMessage(content=delta_text) @@ -275,7 +278,8 @@ class MistralToolParser(ToolParser): additional_content: str = "" if self.streaming_state == StreamingState.WAITING_FOR_TOOL_START: # this is the first tool call - assert self.bot_token_id in delta_token_ids + if self.bot_token not in delta_text: + return DeltaMessage(content=delta_text) if not delta_text.startswith(self.bot_token): additional_content += delta_text.split(self.bot_token)[0] delta_text = self.bot_token + "".join( @@ -411,7 +415,7 @@ class MistralToolParser(ToolParser): index=self.current_tool_id, type="function" ) current_tool_call_modified = False - if self.bot_token_id in delta_token_ids: + if self.bot_token_id in delta_token_ids or self.bot_token in delta_text: # this is the first tool call if not delta_text.startswith(self.bot_token): content = delta_text.split(self.bot_token)[0] -- GitLab From 45f526d65237d9073a5f3be166b306580687f210 Mon Sep 17 00:00:00 2001 From: Harry Huang Date: Tue, 17 Mar 2026 08:38:52 +0800 Subject: [PATCH 013/223] [BugFix] Correct max memory usage for multiple KV-cache groups (#36030) Signed-off-by: huanghaoyan.hhy --- tests/v1/core/test_kv_cache_utils.py | 41 ++++++++++++++++++++++++++++ vllm/v1/core/kv_cache_utils.py | 6 ++-- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 08463a280..8153fed69 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -43,6 +43,7 @@ from vllm.v1.kv_cache_interface import ( KVCacheGroupSpec, KVCacheSpec, KVCacheTensor, + MambaSpec, MLAAttentionSpec, SlidingWindowSpec, UniformTypeKVCacheSpecs, @@ -157,6 +158,24 @@ def new_chunked_local_attention_spec( ) +def new_mamba_spec( + block_size=16, + shapes=((2, 512), (3, 32, 32)), + dtypes=(torch.float32, torch.float32), + num_speculative_blocks=2, + mamba_cache_mode="none", + page_size_padded=None, +): + return MambaSpec( + block_size=block_size, + shapes=shapes, + dtypes=dtypes, + page_size_padded=page_size_padded, + mamba_cache_mode=mamba_cache_mode, + num_speculative_blocks=num_speculative_blocks, + ) + + @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_none_hash(monkeypatch, hash_fn): import vllm.v1.core.kv_cache_utils @@ -2010,6 +2029,28 @@ def test_auto_fit_max_model_len(): assert vllm_config.model_config.max_model_len > 0 +def test_auto_fit_max_model_len_with_hybrid(): + """Test that auto-fit works with hybrid KV cache specs.""" + # Create config with original_max_model_len=-1 to trigger auto-fit + model_config = ModelConfig(max_model_len=8192) + # Simulate the user passing -1 by setting original_max_model_len + model_config.original_max_model_len = -1 + vllm_config = VllmConfig(model_config=model_config) + + mem_per_block_per_layer = 16 * 2 * 64 * 4 * 2 # 16KB per block per layer + gamma = 2 + kv_cache_specs = { + "layer_1": new_mamba_spec(num_speculative_blocks=gamma), + "layer_2": new_kv_cache_spec(), + } + + available_memory = mem_per_block_per_layer * (1024 // 16 + 1 + gamma) + _kv_cache_configs = get_kv_cache_configs( + vllm_config, [kv_cache_specs], [available_memory] + ) + assert vllm_config.model_config.max_model_len == 1024 + + def test_auto_fit_max_model_len_not_triggered(): """Test that auto-fit is not triggered when original_max_model_len is not -1.""" model_config = ModelConfig(max_model_len=16) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 3da3d7e7b..83ada0530 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1356,8 +1356,10 @@ def _max_memory_usage_bytes_from_groups( page_size = get_uniform_page_size( [group.kv_cache_spec for group in kv_cache_groups] ) - any_spec = kv_cache_groups[0].kv_cache_spec - blocks_needed = cdiv(any_spec.max_memory_usage_bytes(vllm_config), page_size) + blocks_needed = sum( + cdiv(group.kv_cache_spec.max_memory_usage_bytes(vllm_config), page_size) + for group in kv_cache_groups + ) return group_size * page_size * blocks_needed -- GitLab From 6c1cfbad325067c4afa12c87992f45a58ce0614b Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Date: Tue, 17 Mar 2026 04:48:42 +0400 Subject: [PATCH 014/223] Support non-contiguous KV cache in TRTLLM fp8 dequant kernel (#36867) Signed-off-by: Vadim Gimpelson Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Co-authored-by: Pavani Majety --- .../attention/test_trtllm_kvfp8_dequant.py | 434 ++++++++++++++++++ vllm/v1/attention/backends/flashinfer.py | 83 ++-- 2 files changed, 491 insertions(+), 26 deletions(-) create mode 100644 tests/kernels/attention/test_trtllm_kvfp8_dequant.py diff --git a/tests/kernels/attention/test_trtllm_kvfp8_dequant.py b/tests/kernels/attention/test_trtllm_kvfp8_dequant.py new file mode 100644 index 000000000..a2ea372c0 --- /dev/null +++ b/tests/kernels/attention/test_trtllm_kvfp8_dequant.py @@ -0,0 +1,434 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Standalone unit tests for trtllm_prefill_attn_kvfp8_dequant. + +Tests both contiguous and non-contiguous (cross-layer unified) KV cache +layouts against a pure-PyTorch reference implementation. +""" + +import pytest +import torch + +from vllm.platforms import current_platform + +FP8_DTYPE = current_platform.fp8_dtype() + +NUM_BLOCKS = 128 + + +def to_float8(x, dtype=None): + if dtype is None: + dtype = FP8_DTYPE + finfo = torch.finfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax * 0.1 + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype), scale.float().reciprocal() + + +def make_contiguous_kv_cache(num_blocks, num_kv_heads, block_size, head_size): + """Create a standard contiguous fp8 KV cache (HND layout).""" + raw = torch.randn( + num_blocks, + 2, + num_kv_heads, + block_size, + head_size, + dtype=torch.bfloat16, + device="cuda", + ) + kv_cache, scale = to_float8(raw) + return kv_cache, scale + + +def make_cross_layer_kv_cache( + num_blocks, + num_kv_heads, + block_size, + head_size, + num_layers=4, +): + """ + Create a non-contiguous per-layer view mimicking cross-layer allocation. + + Physical layout: (num_blocks, 2, num_kv_heads, num_layers, block_size, head_size) + Returned view: (num_blocks, 2, num_kv_heads, block_size, head_size) + with non-contiguous strides on dims 0, 1, 2 (they skip over num_layers). + """ + raw = torch.randn( + num_blocks, + 2, + num_kv_heads, + num_layers, + block_size, + head_size, + dtype=torch.bfloat16, + device="cuda", + ) + fp8_full, scale = to_float8(raw) + layer_view = fp8_full[:, :, :, 0, :, :] + assert not layer_view.is_contiguous(), ( + f"Expected non-contiguous view, got strides {layer_view.stride()}" + ) + return layer_view, scale + + +def ref_dequant(kv_cache, block_tables, k_scale, v_scale, dequant_dtype): + """Pure PyTorch reference: gather pages and dequantize fp8 -> dequant_dtype.""" + batch_size, num_pages_per_seq = block_tables.shape + s = kv_cache.shape + out = torch.zeros( + batch_size * num_pages_per_seq + 1, + s[1], + s[2], + s[3], + s[4], + dtype=dequant_dtype, + device=kv_cache.device, + ) + for b in range(batch_size): + for p in range(num_pages_per_seq): + page_idx = block_tables[b, p].item() + if page_idx <= 0: + continue + mock_idx = b * num_pages_per_seq + p + 1 + out[mock_idx, 0] = (kv_cache[page_idx, 0].float() * k_scale.item()).to( + dequant_dtype + ) + out[mock_idx, 1] = (kv_cache[page_idx, 1].float() * v_scale.item()).to( + dequant_dtype + ) + return out + + +@pytest.mark.parametrize("num_kv_heads", [1, 8]) +@pytest.mark.parametrize("head_size", [64, 128]) +@pytest.mark.parametrize("block_size", [16, 32]) +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("num_pages_per_seq", [3, 8]) +@pytest.mark.parametrize("contiguous", [True, False]) +@torch.inference_mode() +def test_trtllm_kvfp8_dequant( + num_kv_heads: int, + head_size: int, + block_size: int, + batch_size: int, + num_pages_per_seq: int, + contiguous: bool, +): + from vllm.v1.attention.backends.flashinfer import ( + trtllm_prefill_attn_kvfp8_dequant, + ) + + torch.set_default_device("cuda") + + if contiguous: + kv_cache, scale = make_contiguous_kv_cache( + NUM_BLOCKS, + num_kv_heads, + block_size, + head_size, + ) + else: + kv_cache, scale = make_cross_layer_kv_cache( + NUM_BLOCKS, + num_kv_heads, + block_size, + head_size, + ) + + k_scale = scale.clone() + v_scale = scale.clone() + + block_tables = torch.randint( + 1, + NUM_BLOCKS, + (batch_size, num_pages_per_seq), + dtype=torch.int32, + ) + + mock_kv_cache, mock_block_table = trtllm_prefill_attn_kvfp8_dequant( + kv_cache, + block_tables, + k_scale, + v_scale, + torch.bfloat16, + ) + + ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16) + + expected_bt = torch.arange( + 1, + batch_size * num_pages_per_seq + 1, + dtype=torch.int32, + device="cuda", + ).reshape(batch_size, num_pages_per_seq) + torch.testing.assert_close(mock_block_table, expected_bt) + + # Page 0 is padding (never written), compare only pages 1+ + torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3) + + +@torch.inference_mode() +def test_block_tables_with_zero_pages(): + """Pages with index <= 0 must be skipped (early return in kernel).""" + from vllm.v1.attention.backends.flashinfer import ( + trtllm_prefill_attn_kvfp8_dequant, + ) + + torch.set_default_device("cuda") + num_kv_heads, block_size, head_size = 8, 16, 64 + + kv_cache, scale = make_contiguous_kv_cache( + NUM_BLOCKS, + num_kv_heads, + block_size, + head_size, + ) + k_scale = v_scale = scale.clone() + + # Mix of valid pages and zeros (padding) + block_tables = torch.tensor( + [[5, 0, 10], [0, 0, 0], [3, 7, 0]], + dtype=torch.int32, + device="cuda", + ) + + mock_kv_cache, _ = trtllm_prefill_attn_kvfp8_dequant( + kv_cache, + block_tables, + k_scale, + v_scale, + torch.bfloat16, + ) + ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16) + + # Only compare pages that were actually written (non-zero page indices) + for b in range(block_tables.shape[0]): + for p in range(block_tables.shape[1]): + if block_tables[b, p].item() > 0: + idx = b * block_tables.shape[1] + p + 1 + torch.testing.assert_close( + mock_kv_cache[idx], + ref[idx], + atol=1e-3, + rtol=1e-3, + ) + + +@torch.inference_mode() +def test_all_zero_block_tables(): + """All-zero block_tables: kernel should write nothing.""" + from vllm.v1.attention.backends.flashinfer import ( + trtllm_prefill_attn_kvfp8_dequant, + ) + + torch.set_default_device("cuda") + num_kv_heads, block_size, head_size = 4, 16, 64 + + kv_cache, scale = make_contiguous_kv_cache( + NUM_BLOCKS, + num_kv_heads, + block_size, + head_size, + ) + k_scale = v_scale = scale.clone() + + block_tables = torch.zeros(2, 4, dtype=torch.int32, device="cuda") + + # Should not crash even though no pages are valid + mock_kv_cache, mock_block_table = trtllm_prefill_attn_kvfp8_dequant( + kv_cache, + block_tables, + k_scale, + v_scale, + torch.bfloat16, + ) + assert mock_kv_cache.shape[0] == 2 * 4 + 1 + assert mock_block_table.shape == (2, 4) + + +@torch.inference_mode() +def test_different_k_v_scales(): + """Verify K and V are dequantized with independent scales.""" + from vllm.v1.attention.backends.flashinfer import ( + trtllm_prefill_attn_kvfp8_dequant, + ) + + torch.set_default_device("cuda") + num_kv_heads, block_size, head_size = 8, 16, 64 + + kv_cache, _ = make_contiguous_kv_cache( + NUM_BLOCKS, + num_kv_heads, + block_size, + head_size, + ) + k_scale = torch.tensor([0.5], dtype=torch.float32, device="cuda") + v_scale = torch.tensor([2.0], dtype=torch.float32, device="cuda") + + block_tables = torch.tensor([[1, 2]], dtype=torch.int32, device="cuda") + + mock_kv_cache, _ = trtllm_prefill_attn_kvfp8_dequant( + kv_cache, + block_tables, + k_scale, + v_scale, + torch.bfloat16, + ) + ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16) + + torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3) + + +@torch.inference_mode() +def test_single_page_per_seq(): + """Minimum grid dim 1 = 1 page per sequence.""" + from vllm.v1.attention.backends.flashinfer import ( + trtllm_prefill_attn_kvfp8_dequant, + ) + + torch.set_default_device("cuda") + num_kv_heads, block_size, head_size = 8, 16, 128 + + kv_cache, scale = make_contiguous_kv_cache( + NUM_BLOCKS, + num_kv_heads, + block_size, + head_size, + ) + k_scale = v_scale = scale.clone() + + block_tables = torch.tensor([[5], [10], [20]], dtype=torch.int32, device="cuda") + + mock_kv_cache, _ = trtllm_prefill_attn_kvfp8_dequant( + kv_cache, + block_tables, + k_scale, + v_scale, + torch.bfloat16, + ) + ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16) + + torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3) + + +@torch.inference_mode() +def test_large_page_indices(): + """Page indices near the top of the buffer stress offset arithmetic.""" + from vllm.v1.attention.backends.flashinfer import ( + trtllm_prefill_attn_kvfp8_dequant, + ) + + torch.set_default_device("cuda") + num_kv_heads, block_size, head_size = 8, 16, 128 + large_num_blocks = 32768 + + kv_cache, scale = make_contiguous_kv_cache( + large_num_blocks, + num_kv_heads, + block_size, + head_size, + ) + k_scale = v_scale = scale.clone() + + # Use page indices near the top of the buffer + block_tables = torch.tensor( + [[large_num_blocks - 1, large_num_blocks - 2, 1]], + dtype=torch.int32, + device="cuda", + ) + + mock_kv_cache, _ = trtllm_prefill_attn_kvfp8_dequant( + kv_cache, + block_tables, + k_scale, + v_scale, + torch.bfloat16, + ) + ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16) + + torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3) + + +@torch.inference_mode() +def test_large_block_size(): + """block_size=64 -> HEAD_STRIDE=8192, large tl.arange per thread block.""" + from vllm.v1.attention.backends.flashinfer import ( + trtllm_prefill_attn_kvfp8_dequant, + ) + + torch.set_default_device("cuda") + num_kv_heads, block_size, head_size = 4, 64, 128 + + kv_cache, scale = make_contiguous_kv_cache( + NUM_BLOCKS, + num_kv_heads, + block_size, + head_size, + ) + k_scale = v_scale = scale.clone() + + block_tables = torch.randint( + 1, + NUM_BLOCKS, + (2, 4), + dtype=torch.int32, + device="cuda", + ) + + mock_kv_cache, _ = trtllm_prefill_attn_kvfp8_dequant( + kv_cache, + block_tables, + k_scale, + v_scale, + torch.bfloat16, + ) + ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16) + + torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3) + + +@torch.inference_mode() +def test_cross_layer_many_layers(): + """ + Non-contiguous with 36 layers -- matches real gpt-oss-120b. + Strides are far from contiguous (factor of 36 in the gaps). + """ + from vllm.v1.attention.backends.flashinfer import ( + trtllm_prefill_attn_kvfp8_dequant, + ) + + torch.set_default_device("cuda") + num_kv_heads, block_size, head_size = 8, 16, 64 + num_layers = 36 + + kv_cache, scale = make_cross_layer_kv_cache( + NUM_BLOCKS, + num_kv_heads, + block_size, + head_size, + num_layers=num_layers, + ) + k_scale = v_scale = scale.clone() + + block_tables = torch.randint( + 1, + NUM_BLOCKS, + (4, 6), + dtype=torch.int32, + device="cuda", + ) + + mock_kv_cache, _ = trtllm_prefill_attn_kvfp8_dequant( + kv_cache, + block_tables, + k_scale, + v_scale, + torch.bfloat16, + ) + ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16) + + torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 595f4ffa5..411ec746c 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -96,8 +96,13 @@ def _trtllm_prefill_attn_kvfp8_dequant( mock_kv_cache_ptr, k_scale_ptr, v_scale_ptr, - K_CACHE_STRIDE: tl.constexpr, - KV_CACHE_STRIDE: tl.constexpr, + src_stride_page, + src_stride_kv, + src_stride_head, + DST_K_CACHE_STRIDE: tl.constexpr, + DST_KV_CACHE_STRIDE: tl.constexpr, + HEAD_STRIDE: tl.constexpr, + NUM_KV_HEADS: tl.constexpr, ): batch_idx = tl.program_id(0).to(tl.int64) mock_block_table_idx = tl.program_id(1).to(tl.int64) @@ -108,31 +113,42 @@ def _trtllm_prefill_attn_kvfp8_dequant( return dequant_dtype = mock_kv_cache_ptr.dtype.element_ty - # Dequantize K k_scale_val = tl.load(k_scale_ptr) - offset = orig_page_num * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE) - fp8_vals = tl.load(kv_cache_ptr + offset) - dequantized_vals = fp8_vals.to(tl.float32) * k_scale_val - mock_cache_offset = ( - batch_idx * block_table_stride + mock_block_table_idx + 1 - ) * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE) - dequantized_vals = dequantized_vals.to(dequant_dtype) - tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals) - - # Dequantize V v_scale_val = tl.load(v_scale_ptr) - offset = ( - orig_page_num * KV_CACHE_STRIDE + K_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE) - ) - fp8_vals = tl.load(kv_cache_ptr + offset) - dequantized_vals = fp8_vals.to(tl.float32) * v_scale_val - mock_cache_offset = ( - (batch_idx * block_table_stride + mock_block_table_idx + 1) * KV_CACHE_STRIDE - + K_CACHE_STRIDE - + tl.arange(0, K_CACHE_STRIDE) - ) - dequantized_vals = dequantized_vals.to(dequant_dtype) - tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals) + + mock_page_idx = batch_idx * block_table_stride + mock_block_table_idx + 1 + head_offsets = tl.arange(0, HEAD_STRIDE) + + for h in range(NUM_KV_HEADS): + h_off = tl.cast(h, tl.int64) + + # Read K from source (supports non-contiguous page/kv/head strides) + src_k = orig_page_num * src_stride_page + h_off * src_stride_head + head_offsets + fp8_k = tl.load(kv_cache_ptr + src_k) + dequant_k = (fp8_k.to(tl.float32) * k_scale_val).to(dequant_dtype) + + # Write K to contiguous mock cache + dst_k = mock_page_idx * DST_KV_CACHE_STRIDE + h * HEAD_STRIDE + head_offsets + tl.store(mock_kv_cache_ptr + dst_k, dequant_k) + + # Read V from source (offset by src_stride_kv for the V half) + src_v = ( + orig_page_num * src_stride_page + + src_stride_kv + + h_off * src_stride_head + + head_offsets + ) + fp8_v = tl.load(kv_cache_ptr + src_v) + dequant_v = (fp8_v.to(tl.float32) * v_scale_val).to(dequant_dtype) + + # Write V to contiguous mock cache + dst_v = ( + mock_page_idx * DST_KV_CACHE_STRIDE + + DST_K_CACHE_STRIDE + + h * HEAD_STRIDE + + head_offsets + ) + tl.store(mock_kv_cache_ptr + dst_v, dequant_v) def trtllm_prefill_attn_kvfp8_dequant( @@ -146,8 +162,18 @@ def trtllm_prefill_attn_kvfp8_dequant( s = kv_cache.shape assert s[1] == 2 assert dequant_dtype in (torch.bfloat16, torch.float16) - k_cache_stride = s[2] * s[3] * s[4] + + num_kv_heads, block_size, head_size = s[2], s[3], s[4] + head_stride = block_size * head_size + k_cache_stride = num_kv_heads * head_stride kv_cache_stride = k_cache_stride * s[1] + + strides = kv_cache.stride() + assert strides[3] == head_size and strides[4] == 1, ( + "For kv cache layouts, (block_size, head_size) " + f"dimensions must be contiguous, got strides {strides}" + ) + new_s = (batch_size * num_of_page_per_token + 1, s[1], s[2], s[3], s[4]) # mock kv cache contains just the pages needed by this prefill mock_kv_cache = torch.empty(new_s, dtype=dequant_dtype, device=kv_cache.device) @@ -166,8 +192,13 @@ def trtllm_prefill_attn_kvfp8_dequant( mock_kv_cache, k_scale, v_scale, + strides[0], + strides[1], + strides[2], k_cache_stride, kv_cache_stride, + head_stride, + num_kv_heads, ) return mock_kv_cache, mock_block_table -- GitLab From 0a0a1a198be88e1782b52fa31738896468200a76 Mon Sep 17 00:00:00 2001 From: Kyuyeun Kim <62023335+kyuyeunk@users.noreply.github.com> Date: Mon, 16 Mar 2026 18:04:15 -0700 Subject: [PATCH 015/223] Add ability to replace oot ops when using lora (#37181) Signed-off-by: Kyuyeun Kim --- vllm/lora/layers/column_parallel_linear.py | 7 ++++--- vllm/lora/layers/replicated_linear.py | 3 ++- vllm/lora/layers/row_parallel_linear.py | 3 ++- vllm/lora/layers/vocal_parallel_embedding.py | 3 ++- vllm/model_executor/custom_op.py | 5 +++-- .../layers/attention/mm_encoder_attention.py | 6 +++--- 6 files changed, 16 insertions(+), 11 deletions(-) diff --git a/vllm/lora/layers/column_parallel_linear.py b/vllm/lora/layers/column_parallel_linear.py index eaed6e226..f49a3fcbb 100644 --- a/vllm/lora/layers/column_parallel_linear.py +++ b/vllm/lora/layers/column_parallel_linear.py @@ -9,6 +9,7 @@ from transformers import PretrainedConfig from vllm.config.lora import LoRAConfig from vllm.distributed import tensor_model_parallel_all_gather from vllm.distributed.utils import divide +from vllm.model_executor.custom_op import maybe_get_oot_by_class from vllm.model_executor.layers.linear import ( ColumnParallelLinear, MergedColumnParallelLinear, @@ -155,9 +156,9 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): packed_modules_list: list, model_config: PretrainedConfig | None = None, ) -> bool: - if type(source_layer) is ColumnParallelLinear: + if type(source_layer) is maybe_get_oot_by_class(ColumnParallelLinear): return True - if type(source_layer) is MergedColumnParallelLinear: + if type(source_layer) is maybe_get_oot_by_class(MergedColumnParallelLinear): if len(packed_modules_list) != 1: return False # Exclude layers with 3+ output sizes - those are handled by @@ -606,7 +607,7 @@ class MergedColumnParallelLinearVariableSliceWithLoRA( ) -> bool: # Support MergedColumnParallelLinear with 3 or more slices # (2 slices are handled by MergedColumnParallelLinearWithLoRA) - if type(source_layer) is not MergedColumnParallelLinear: + if type(source_layer) is not maybe_get_oot_by_class(MergedColumnParallelLinear): return False # If packed_modules_list has 3+ items, use this class diff --git a/vllm/lora/layers/replicated_linear.py b/vllm/lora/layers/replicated_linear.py index 62bac546c..f1f499b84 100644 --- a/vllm/lora/layers/replicated_linear.py +++ b/vllm/lora/layers/replicated_linear.py @@ -7,6 +7,7 @@ import torch.nn as nn from transformers import PretrainedConfig from vllm.config.lora import LoRAConfig +from vllm.model_executor.custom_op import maybe_get_oot_by_class from vllm.model_executor.layers.linear import ReplicatedLinear from .base_linear import BaseLinearLayerWithLoRA @@ -55,7 +56,7 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): packed_modules_list: list, model_config: PretrainedConfig | None = None, ) -> bool: - return type(source_layer) is ReplicatedLinear + return type(source_layer) is maybe_get_oot_by_class(ReplicatedLinear) def slice_lora_a( self, lora_a: torch.Tensor | list[torch.Tensor | None] diff --git a/vllm/lora/layers/row_parallel_linear.py b/vllm/lora/layers/row_parallel_linear.py index 8de5822db..9460b687f 100644 --- a/vllm/lora/layers/row_parallel_linear.py +++ b/vllm/lora/layers/row_parallel_linear.py @@ -11,6 +11,7 @@ from vllm.distributed import ( split_tensor_along_last_dim, tensor_model_parallel_all_reduce, ) +from vllm.model_executor.custom_op import maybe_get_oot_by_class from vllm.model_executor.layers.linear import RowParallelLinear from vllm.platforms import current_platform @@ -89,7 +90,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): packed_modules_list: list, model_config: PretrainedConfig | None = None, ) -> bool: - return type(source_layer) is RowParallelLinear + return type(source_layer) is maybe_get_oot_by_class(RowParallelLinear) # The following layer is based on the tensor parallelism strategy given in diff --git a/vllm/lora/layers/vocal_parallel_embedding.py b/vllm/lora/layers/vocal_parallel_embedding.py index efc5a1771..05e7cfa06 100644 --- a/vllm/lora/layers/vocal_parallel_embedding.py +++ b/vllm/lora/layers/vocal_parallel_embedding.py @@ -7,6 +7,7 @@ import torch.nn.functional as F from transformers import PretrainedConfig from vllm.config.lora import LoRAConfig +from vllm.model_executor.custom_op import maybe_get_oot_by_class from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.platforms import current_platform @@ -132,7 +133,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): packed_modules_list: list, model_config: PretrainedConfig | None = None, ) -> bool: - return type(source_layer) is VocabParallelEmbedding + return type(source_layer) is maybe_get_oot_by_class(VocabParallelEmbedding) @property def weight(self): diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index b8e372e88..a1514c920 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -22,10 +22,11 @@ op_registry: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {} op_registry_oot: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {} -def get_oot_class_by_name(class_name: str) -> type | None: +def maybe_get_oot_by_class(class_type: type) -> type: + class_name = class_type.__name__ if class_name in op_registry_oot: return op_registry_oot[class_name] - return None + return class_type class PluggableLayer(nn.Module): diff --git a/vllm/model_executor/layers/attention/mm_encoder_attention.py b/vllm/model_executor/layers/attention/mm_encoder_attention.py index bc0687ed2..46d461c38 100644 --- a/vllm/model_executor/layers/attention/mm_encoder_attention.py +++ b/vllm/model_executor/layers/attention/mm_encoder_attention.py @@ -6,7 +6,7 @@ import numpy as np import torch from vllm.logger import init_logger -from vllm.model_executor.custom_op import CustomOp, get_oot_class_by_name +from vllm.model_executor.custom_op import CustomOp, maybe_get_oot_by_class from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.utils.math_utils import round_up from vllm.v1.attention.backends.fa_utils import get_flash_attn_version @@ -125,7 +125,7 @@ class MMEncoderAttention(CustomOp): cu_seqlens: np.ndarray, device: torch.device, ) -> torch.Tensor | None: - if (oot_class := get_oot_class_by_name(cls.__name__)) is not None: + if (oot_class := maybe_get_oot_by_class(cls)) is not cls: return oot_class.maybe_compute_seq_lens(attn_backend, cu_seqlens, device) # type: ignore[attr-defined] if attn_backend != AttentionBackendEnum.FLASHINFER: @@ -149,7 +149,7 @@ class MMEncoderAttention(CustomOp): tp_size: int, device: torch.device, ) -> torch.Tensor: - if (oot_class := get_oot_class_by_name(cls.__name__)) is not None: + if (oot_class := maybe_get_oot_by_class(cls)) is not cls: return oot_class.maybe_recompute_cu_seqlens( # type: ignore[attr-defined] attn_backend, cu_seqlens, hidden_size, tp_size, device ) -- GitLab From f04d5226f837ae76daf442a2a3f2b161c4287242 Mon Sep 17 00:00:00 2001 From: Flora Feng <4florafeng@gmail.com> Date: Mon, 16 Mar 2026 23:24:34 -0400 Subject: [PATCH 016/223] [CI] Fix flaky tool_use chat completion tests with deterministic seed (#37027) Signed-off-by: sfeng33 <4florafeng@gmail.com> --- tests/tool_use/test_chat_completions.py | 5 +++++ tests/tool_use/test_parallel_tool_calls.py | 7 +++++++ tests/tool_use/test_tool_calls.py | 5 +++++ tests/tool_use/utils.py | 2 ++ 4 files changed, 19 insertions(+) diff --git a/tests/tool_use/test_chat_completions.py b/tests/tool_use/test_chat_completions.py index 07b7933f6..e5bb47587 100644 --- a/tests/tool_use/test_chat_completions.py +++ b/tests/tool_use/test_chat_completions.py @@ -6,6 +6,7 @@ import pytest from .utils import ( MESSAGES_WITHOUT_TOOLS, + SEED, WEATHER_TOOL, ServerConfig, ensure_system_prompt, @@ -27,6 +28,7 @@ async def test_chat_completion_without_tools( max_completion_tokens=150, model=model_name, logprobs=False, + seed=SEED, ) choice = chat_completion.choices[0] stop_reason = chat_completion.choices[0].finish_reason @@ -47,6 +49,7 @@ async def test_chat_completion_without_tools( max_completion_tokens=150, model=model_name, logprobs=False, + seed=SEED, stream=True, ) chunks: list[str] = [] @@ -97,6 +100,7 @@ async def test_chat_completion_with_tools( model=model_name, tools=[WEATHER_TOOL], logprobs=False, + seed=SEED, ) choice = chat_completion.choices[0] stop_reason = chat_completion.choices[0].finish_reason @@ -118,6 +122,7 @@ async def test_chat_completion_with_tools( model=model_name, logprobs=False, tools=[WEATHER_TOOL], + seed=SEED, stream=True, ) diff --git a/tests/tool_use/test_parallel_tool_calls.py b/tests/tool_use/test_parallel_tool_calls.py index 77084ec2d..ed8c80d36 100644 --- a/tests/tool_use/test_parallel_tool_calls.py +++ b/tests/tool_use/test_parallel_tool_calls.py @@ -10,6 +10,7 @@ from .utils import ( MESSAGES_ASKING_FOR_PARALLEL_TOOLS, MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, SEARCH_TOOL, + SEED, WEATHER_TOOL, ServerConfig, ) @@ -39,6 +40,7 @@ async def test_parallel_tool_calls( model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, + seed=SEED, ) choice = chat_completion.choices[0] @@ -76,6 +78,7 @@ async def test_parallel_tool_calls( max_completion_tokens=200, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, + seed=SEED, stream=True, ) @@ -166,6 +169,7 @@ async def test_parallel_tool_calls_with_results( model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, + seed=SEED, ) choice = chat_completion.choices[0] @@ -184,6 +188,7 @@ async def test_parallel_tool_calls_with_results( model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, + seed=SEED, stream=True, ) @@ -229,6 +234,7 @@ async def test_parallel_tool_calls_false(client: openai.AsyncOpenAI): model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, + seed=SEED, parallel_tool_calls=False, ) @@ -247,6 +253,7 @@ async def test_parallel_tool_calls_false(client: openai.AsyncOpenAI): max_completion_tokens=200, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, + seed=SEED, parallel_tool_calls=False, stream=True, ) diff --git a/tests/tool_use/test_tool_calls.py b/tests/tool_use/test_tool_calls.py index 6614b6415..f719a886c 100644 --- a/tests/tool_use/test_tool_calls.py +++ b/tests/tool_use/test_tool_calls.py @@ -10,6 +10,7 @@ from .utils import ( MESSAGES_ASKING_FOR_TOOLS, MESSAGES_WITH_TOOL_RESPONSE, SEARCH_TOOL, + SEED, WEATHER_TOOL, ) @@ -27,6 +28,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI): model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, + seed=SEED, ) choice = chat_completion.choices[0] @@ -71,6 +73,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI): max_completion_tokens=100, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, + seed=SEED, stream=True, ) @@ -154,6 +157,7 @@ async def test_tool_call_with_results(client: openai.AsyncOpenAI): model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, + seed=SEED, ) choice = chat_completion.choices[0] @@ -171,6 +175,7 @@ async def test_tool_call_with_results(client: openai.AsyncOpenAI): model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, + seed=SEED, stream=True, ) diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index de7284a30..5a03f53ec 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -42,6 +42,8 @@ def ensure_system_prompt( # universal args for all models go here. also good if you need to test locally # and change type or KV cache quantization or something. +SEED = 42 + ARGS: list[str] = [ "--enable-auto-tool-choice", "--max-model-len", -- GitLab From 384dc7f77b61ba98555df11c122fae759d6ef97e Mon Sep 17 00:00:00 2001 From: Flora Feng <4florafeng@gmail.com> Date: Mon, 16 Mar 2026 23:31:23 -0400 Subject: [PATCH 017/223] [Refactor] Relocate completion and chat completion tests (#37125) Signed-off-by: sfeng33 <4florafeng@gmail.com> --- .../scripts/hardware_ci/run-amd-test.sh | 8 +++---- .buildkite/test-amd.yaml | 24 +++++++++---------- .buildkite/test_areas/entrypoints.yaml | 2 +- .buildkite/test_areas/model_executor.yaml | 4 ++-- .buildkite/test_areas/plugins.yaml | 2 +- .github/mergify.yml | 2 +- requirements/rocm-test.txt | 2 +- tests/distributed/test_distributed_oot.py | 4 +++- tests/entrypoints/llm/test_chat.py | 3 +-- tests/entrypoints/llm/test_mm_cache_stats.py | 3 +-- .../{ => chat_completion}/test_audio.py | 3 +-- .../test_audio_in_video.py | 4 ++-- .../test_default_mm_loras.py | 4 ++-- .../test_oot_registration.py | 2 +- .../{ => chat_completion}/test_root_path.py | 2 +- .../{ => chat_completion}/test_video.py | 3 +-- .../{ => chat_completion}/test_vision.py | 3 +-- .../test_vision_embeds.py | 3 +-- .../entrypoints/openai/completion/__init__.py | 0 .../{ => completion}/test_completion_error.py | 0 .../test_completion_with_prompt_embeds.py | 2 +- .../{ => completion}/test_lora_resolvers.py | 0 .../test_prompt_validation.py | 3 +-- .../openai/{ => completion}/test_shutdown.py | 0 .../test_tensorizer_entrypoint.py | 3 +-- .../test_token_in_token_out.py | 3 +-- 26 files changed, 41 insertions(+), 48 deletions(-) rename tests/entrypoints/openai/{ => chat_completion}/test_audio.py (99%) rename tests/entrypoints/openai/{ => chat_completion}/test_audio_in_video.py (98%) rename tests/entrypoints/openai/{ => chat_completion}/test_default_mm_loras.py (97%) rename tests/entrypoints/openai/{ => chat_completion}/test_oot_registration.py (96%) rename tests/entrypoints/openai/{ => chat_completion}/test_root_path.py (98%) rename tests/entrypoints/openai/{ => chat_completion}/test_video.py (99%) rename tests/entrypoints/openai/{ => chat_completion}/test_vision.py (99%) rename tests/entrypoints/openai/{ => chat_completion}/test_vision_embeds.py (99%) create mode 100644 tests/entrypoints/openai/completion/__init__.py rename tests/entrypoints/openai/{ => completion}/test_completion_error.py (100%) rename tests/entrypoints/openai/{ => completion}/test_completion_with_prompt_embeds.py (99%) rename tests/entrypoints/openai/{ => completion}/test_lora_resolvers.py (100%) rename tests/entrypoints/openai/{ => completion}/test_prompt_validation.py (98%) rename tests/entrypoints/openai/{ => completion}/test_shutdown.py (100%) rename tests/entrypoints/openai/{ => completion}/test_tensorizer_entrypoint.py (98%) rename tests/entrypoints/openai/{ => completion}/test_token_in_token_out.py (98%) diff --git a/.buildkite/scripts/hardware_ci/run-amd-test.sh b/.buildkite/scripts/hardware_ci/run-amd-test.sh index 1c43c404d..407e3c5a6 100755 --- a/.buildkite/scripts/hardware_ci/run-amd-test.sh +++ b/.buildkite/scripts/hardware_ci/run-amd-test.sh @@ -333,15 +333,15 @@ apply_rocm_test_overrides() { # --- Entrypoint ignores --- if [[ $cmds == *" entrypoints/openai "* ]]; then cmds=${cmds//" entrypoints/openai "/" entrypoints/openai \ - --ignore=entrypoints/openai/test_audio.py \ - --ignore=entrypoints/openai/test_shutdown.py \ + --ignore=entrypoints/openai/chat_completion/test_audio.py \ + --ignore=entrypoints/openai/completion/test_shutdown.py \ --ignore=entrypoints/openai/test_completion.py \ --ignore=entrypoints/openai/test_models.py \ --ignore=entrypoints/openai/test_lora_adapters.py \ --ignore=entrypoints/openai/test_return_tokens_as_ids.py \ - --ignore=entrypoints/openai/test_root_path.py \ + --ignore=entrypoints/openai/chat_completion/test_root_path.py \ --ignore=entrypoints/openai/test_tokenization.py \ - --ignore=entrypoints/openai/test_prompt_validation.py "} + --ignore=entrypoints/openai/completion/test_prompt_validation.py "} fi if [[ $cmds == *" entrypoints/llm "* ]]; then diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index 7f8020540..eb331aaf9 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -162,7 +162,7 @@ steps: - tests/entrypoints/test_chat_utils commands: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/tool_parsers/ --ignore=entrypoints/openai/responses + - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/chat_completion/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/chat_completion/test_oot_registration.py --ignore=entrypoints/openai/completion/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/tool_parsers/ --ignore=entrypoints/openai/responses - pytest -v -s entrypoints/test_chat_utils.py - label: Entrypoints Integration Test (API Server 2) @@ -674,12 +674,12 @@ steps: - vllm/config/model.py - vllm/model_executor - tests/model_executor - - tests/entrypoints/openai/test_tensorizer_entrypoint.py + - tests/entrypoints/openai/completion/test_tensorizer_entrypoint.py commands: - apt-get update && apt-get install -y curl libsodium23 - export VLLM_WORKER_MULTIPROC_METHOD=spawn - pytest -v -s model_executor - - pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py + - pytest -v -s entrypoints/openai/completion/test_tensorizer_entrypoint.py - label: Benchmarks # 11min timeout_in_minutes: 20 @@ -1143,7 +1143,7 @@ steps: - pytest -v -s plugins_tests/test_scheduler_plugins.py - pip install -e ./plugins/vllm_add_dummy_model - pytest -v -s distributed/test_distributed_oot.py - - pytest -v -s entrypoints/openai/test_oot_registration.py + - pytest -v -s entrypoints/openai/chat_completion/test_oot_registration.py - pytest -v -s models/test_oot_registration.py - pytest -v -s plugins/lora_resolvers @@ -1502,7 +1502,7 @@ steps: - tests/entrypoints/test_chat_utils commands: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/tool_parsers/ --ignore=entrypoints/openai/responses + - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/chat_completion/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/chat_completion/test_oot_registration.py --ignore=entrypoints/openai/completion/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/tool_parsers/ --ignore=entrypoints/openai/responses - pytest -v -s entrypoints/test_chat_utils.py - label: Entrypoints Integration Test (API Server 2) @@ -2133,12 +2133,12 @@ steps: - vllm/config/model.py - vllm/model_executor - tests/model_executor - - tests/entrypoints/openai/test_tensorizer_entrypoint.py + - tests/entrypoints/openai/completion/test_tensorizer_entrypoint.py commands: - apt-get update && apt-get install -y curl libsodium23 - export VLLM_WORKER_MULTIPROC_METHOD=spawn - pytest -v -s model_executor - - pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py + - pytest -v -s entrypoints/openai/completion/test_tensorizer_entrypoint.py - label: Benchmarks # 11min timeout_in_minutes: 20 @@ -2735,7 +2735,7 @@ steps: - pytest -v -s plugins_tests/test_scheduler_plugins.py - pip install -e ./plugins/vllm_add_dummy_model - pytest -v -s distributed/test_distributed_oot.py - - pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process + - pytest -v -s entrypoints/openai/chat_completion/test_oot_registration.py # it needs a clean process - pytest -v -s models/test_oot_registration.py # it needs a clean process - pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins @@ -3257,7 +3257,7 @@ steps: - tests/entrypoints/test_chat_utils commands: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/tool_parsers/ --ignore=entrypoints/openai/responses + - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/chat_completion/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/chat_completion/test_oot_registration.py --ignore=entrypoints/openai/completion/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/tool_parsers/ --ignore=entrypoints/openai/responses - pytest -v -s entrypoints/test_chat_utils.py - label: Entrypoints Integration Test (API Server 2) @@ -3872,12 +3872,12 @@ steps: - vllm/config/model.py - vllm/model_executor - tests/model_executor - - tests/entrypoints/openai/test_tensorizer_entrypoint.py + - tests/entrypoints/openai/completion/test_tensorizer_entrypoint.py commands: - apt-get update && apt-get install -y curl libsodium23 - export VLLM_WORKER_MULTIPROC_METHOD=spawn - pytest -v -s model_executor - - pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py + - pytest -v -s entrypoints/openai/completion/test_tensorizer_entrypoint.py - label: Benchmarks # 11min timeout_in_minutes: 20 @@ -4508,7 +4508,7 @@ steps: - pytest -v -s plugins_tests/test_scheduler_plugins.py - pip install -e ./plugins/vllm_add_dummy_model - pytest -v -s distributed/test_distributed_oot.py - - pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process + - pytest -v -s entrypoints/openai/chat_completion/test_oot_registration.py # it needs a clean process - pytest -v -s models/test_oot_registration.py # it needs a clean process - pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins diff --git a/.buildkite/test_areas/entrypoints.yaml b/.buildkite/test_areas/entrypoints.yaml index 9de9c3fd2..ac6be8e14 100644 --- a/.buildkite/test_areas/entrypoints.yaml +++ b/.buildkite/test_areas/entrypoints.yaml @@ -34,7 +34,7 @@ steps: - tests/entrypoints/test_chat_utils commands: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/chat_completion/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/tool_parsers/ --ignore=entrypoints/openai/responses + - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/chat_completion/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/chat_completion/test_oot_registration.py --ignore=entrypoints/openai/completion/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/tool_parsers/ --ignore=entrypoints/openai/responses - pytest -v -s entrypoints/test_chat_utils.py mirror: amd: diff --git a/.buildkite/test_areas/model_executor.yaml b/.buildkite/test_areas/model_executor.yaml index 996c8bb8b..496ecca39 100644 --- a/.buildkite/test_areas/model_executor.yaml +++ b/.buildkite/test_areas/model_executor.yaml @@ -9,9 +9,9 @@ steps: - vllm/config/model.py - vllm/model_executor - tests/model_executor - - tests/entrypoints/openai/test_tensorizer_entrypoint.py + - tests/entrypoints/openai/completion/test_tensorizer_entrypoint.py commands: - apt-get update && apt-get install -y curl libsodium23 - export VLLM_WORKER_MULTIPROC_METHOD=spawn - pytest -v -s model_executor - - pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py + - pytest -v -s entrypoints/openai/completion/test_tensorizer_entrypoint.py diff --git a/.buildkite/test_areas/plugins.yaml b/.buildkite/test_areas/plugins.yaml index 7e7727fce..8e0eb0284 100644 --- a/.buildkite/test_areas/plugins.yaml +++ b/.buildkite/test_areas/plugins.yaml @@ -36,6 +36,6 @@ steps: - pytest -v -s plugins_tests/test_scheduler_plugins.py - pip install -e ./plugins/vllm_add_dummy_model - pytest -v -s distributed/test_distributed_oot.py - - pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process + - pytest -v -s entrypoints/openai/chat_completion/test_oot_registration.py # it needs a clean process - pytest -v -s models/test_oot_registration.py # it needs a clean process - pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins diff --git a/.github/mergify.yml b/.github/mergify.yml index c6d1f1fed..8e9cb790b 100644 --- a/.github/mergify.yml +++ b/.github/mergify.yml @@ -381,7 +381,7 @@ pull_request_rules: - or: - files~=^vllm/model_executor/model_loader/tensorizer.py - files~=^vllm/model_executor/model_loader/tensorizer_loader.py - - files~=^tests/entrypoints/openai/test_tensorizer_entrypoint.py + - files~=^tests/entrypoints/openai/completion/test_tensorizer_entrypoint.py - files~=^tests/model_executor/model_loader/tensorizer_loader/ actions: assign: diff --git a/requirements/rocm-test.txt b/requirements/rocm-test.txt index 9014ab1ea..9a7bd9f59 100644 --- a/requirements/rocm-test.txt +++ b/requirements/rocm-test.txt @@ -50,7 +50,7 @@ av==16.1.0 blobfile==3.0.0 # Multi-Modal Models Test decord==0.6.0 - # video processing, required by entrypoints/openai/test_video.py + # video processing, required by entrypoints/openai/chat_completion/test_video.py rapidfuzz==3.12.1 # OpenAI compatibility and testing diff --git a/tests/distributed/test_distributed_oot.py b/tests/distributed/test_distributed_oot.py index ea7a88abd..9bd7603e7 100644 --- a/tests/distributed/test_distributed_oot.py +++ b/tests/distributed/test_distributed_oot.py @@ -1,7 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from ..entrypoints.openai.test_oot_registration import run_and_test_dummy_opt_api_server +from tests.entrypoints.openai.chat_completion.test_oot_registration import ( + run_and_test_dummy_opt_api_server, +) def test_distributed_oot(dummy_opt_path: str): diff --git a/tests/entrypoints/llm/test_chat.py b/tests/entrypoints/llm/test_chat.py index 20ed73e26..7d8a09852 100644 --- a/tests/entrypoints/llm/test_chat.py +++ b/tests/entrypoints/llm/test_chat.py @@ -4,12 +4,11 @@ import weakref import pytest +from tests.entrypoints.openai.chat_completion.test_vision import TEST_IMAGE_ASSETS from vllm import LLM from vllm.distributed import cleanup_dist_env_and_memory from vllm.sampling_params import SamplingParams -from ..openai.test_vision import TEST_IMAGE_ASSETS - @pytest.fixture(scope="function") def text_llm(): diff --git a/tests/entrypoints/llm/test_mm_cache_stats.py b/tests/entrypoints/llm/test_mm_cache_stats.py index e5ee99124..62c6aa9f7 100644 --- a/tests/entrypoints/llm/test_mm_cache_stats.py +++ b/tests/entrypoints/llm/test_mm_cache_stats.py @@ -6,13 +6,12 @@ import logging import pytest import regex as re +from tests.entrypoints.openai.chat_completion.test_vision import TEST_IMAGE_ASSETS from vllm import LLM from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.v1.metrics import loggers as stat_loggers from vllm.v1.metrics.reader import Counter, Metric -from ..openai.test_vision import TEST_IMAGE_ASSETS - def _make_messages(image_url: str) -> list[ChatCompletionMessageParam]: return [ diff --git a/tests/entrypoints/openai/test_audio.py b/tests/entrypoints/openai/chat_completion/test_audio.py similarity index 99% rename from tests/entrypoints/openai/test_audio.py rename to tests/entrypoints/openai/chat_completion/test_audio.py index 9fe1d906d..fa0f141af 100644 --- a/tests/entrypoints/openai/test_audio.py +++ b/tests/entrypoints/openai/chat_completion/test_audio.py @@ -7,11 +7,10 @@ import openai import pytest import pytest_asyncio +from tests.utils import RemoteOpenAIServer from vllm.assets.audio import AudioAsset from vllm.multimodal.utils import encode_audio_base64, encode_audio_url, fetch_audio -from ...utils import RemoteOpenAIServer - MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b" TEST_AUDIO_URLS = [ AudioAsset("winning_call").url, diff --git a/tests/entrypoints/openai/test_audio_in_video.py b/tests/entrypoints/openai/chat_completion/test_audio_in_video.py similarity index 98% rename from tests/entrypoints/openai/test_audio_in_video.py rename to tests/entrypoints/openai/chat_completion/test_audio_in_video.py index 334d9a71e..769390309 100644 --- a/tests/entrypoints/openai/test_audio_in_video.py +++ b/tests/entrypoints/openai/chat_completion/test_audio_in_video.py @@ -8,8 +8,8 @@ import openai import pytest import pytest_asyncio -from ...conftest import VideoTestAssets -from ...utils import RemoteOpenAIServer +from tests.conftest import VideoTestAssets +from tests.utils import RemoteOpenAIServer MODEL_NAME = "Qwen/Qwen2.5-Omni-3B" diff --git a/tests/entrypoints/openai/test_default_mm_loras.py b/tests/entrypoints/openai/chat_completion/test_default_mm_loras.py similarity index 97% rename from tests/entrypoints/openai/test_default_mm_loras.py rename to tests/entrypoints/openai/chat_completion/test_default_mm_loras.py index dd8f9d67d..e285c8d31 100644 --- a/tests/entrypoints/openai/test_default_mm_loras.py +++ b/tests/entrypoints/openai/chat_completion/test_default_mm_loras.py @@ -8,8 +8,8 @@ import pytest import pytest_asyncio from huggingface_hub import snapshot_download -from ...conftest import AudioTestAssets -from ...utils import RemoteOpenAIServer +from tests.conftest import AudioTestAssets +from tests.utils import RemoteOpenAIServer # NOTE - the tests in this module are currently analogous to test_chat, but are # separated to avoid OOM killing due to module-scoped servers, since we diff --git a/tests/entrypoints/openai/test_oot_registration.py b/tests/entrypoints/openai/chat_completion/test_oot_registration.py similarity index 96% rename from tests/entrypoints/openai/test_oot_registration.py rename to tests/entrypoints/openai/chat_completion/test_oot_registration.py index ba463be1d..151373d82 100644 --- a/tests/entrypoints/openai/test_oot_registration.py +++ b/tests/entrypoints/openai/chat_completion/test_oot_registration.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from ...utils import VLLM_PATH, RemoteOpenAIServer +from tests.utils import VLLM_PATH, RemoteOpenAIServer chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" assert chatml_jinja_path.exists() diff --git a/tests/entrypoints/openai/test_root_path.py b/tests/entrypoints/openai/chat_completion/test_root_path.py similarity index 98% rename from tests/entrypoints/openai/test_root_path.py rename to tests/entrypoints/openai/chat_completion/test_root_path.py index 6bcb80878..9b3f30255 100644 --- a/tests/entrypoints/openai/test_root_path.py +++ b/tests/entrypoints/openai/chat_completion/test_root_path.py @@ -8,7 +8,7 @@ from typing import Any, NamedTuple import openai # use the official client for correctness check import pytest -from ...utils import RemoteOpenAIServer +from tests.utils import RemoteOpenAIServer # # any model with a chat template should work here MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct" diff --git a/tests/entrypoints/openai/test_video.py b/tests/entrypoints/openai/chat_completion/test_video.py similarity index 99% rename from tests/entrypoints/openai/test_video.py rename to tests/entrypoints/openai/chat_completion/test_video.py index 47450c30b..a5827c9f9 100644 --- a/tests/entrypoints/openai/test_video.py +++ b/tests/entrypoints/openai/chat_completion/test_video.py @@ -7,11 +7,10 @@ import openai import pytest import pytest_asyncio +from tests.utils import RemoteOpenAIServer from vllm.multimodal.utils import encode_video_url, fetch_video from vllm.platforms import current_platform -from ...utils import RemoteOpenAIServer - MODEL_NAME = "llava-hf/llava-onevision-qwen2-0.5b-ov-hf" MAXIMUM_VIDEOS = 3 diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/chat_completion/test_vision.py similarity index 99% rename from tests/entrypoints/openai/test_vision.py rename to tests/entrypoints/openai/chat_completion/test_vision.py index c0d8b0532..6cb843342 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/chat_completion/test_vision.py @@ -8,12 +8,11 @@ import pytest import pytest_asyncio from transformers import AutoProcessor +from tests.utils import ROCM_ENV_OVERRIDES, ROCM_EXTRA_ARGS, RemoteOpenAIServer from vllm.multimodal.media import MediaWithBytes from vllm.multimodal.utils import encode_image_url, fetch_image from vllm.platforms import current_platform -from ...utils import ROCM_ENV_OVERRIDES, ROCM_EXTRA_ARGS, RemoteOpenAIServer - MODEL_NAME = "microsoft/Phi-3.5-vision-instruct" MAXIMUM_IMAGES = 2 diff --git a/tests/entrypoints/openai/test_vision_embeds.py b/tests/entrypoints/openai/chat_completion/test_vision_embeds.py similarity index 99% rename from tests/entrypoints/openai/test_vision_embeds.py rename to tests/entrypoints/openai/chat_completion/test_vision_embeds.py index b3da30102..82cb84bcc 100644 --- a/tests/entrypoints/openai/test_vision_embeds.py +++ b/tests/entrypoints/openai/chat_completion/test_vision_embeds.py @@ -8,10 +8,9 @@ import pytest import requests import torch +from tests.utils import RemoteOpenAIServer from vllm.utils.serial_utils import tensor2base64 -from ...utils import RemoteOpenAIServer - @pytest.mark.parametrize( "model_name", ["ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"] diff --git a/tests/entrypoints/openai/completion/__init__.py b/tests/entrypoints/openai/completion/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/entrypoints/openai/test_completion_error.py b/tests/entrypoints/openai/completion/test_completion_error.py similarity index 100% rename from tests/entrypoints/openai/test_completion_error.py rename to tests/entrypoints/openai/completion/test_completion_error.py diff --git a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py b/tests/entrypoints/openai/completion/test_completion_with_prompt_embeds.py similarity index 99% rename from tests/entrypoints/openai/test_completion_with_prompt_embeds.py rename to tests/entrypoints/openai/completion/test_completion_with_prompt_embeds.py index f8a19e40b..374e77245 100644 --- a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py +++ b/tests/entrypoints/openai/completion/test_completion_with_prompt_embeds.py @@ -14,7 +14,7 @@ import torch from openai import BadRequestError from transformers import AutoConfig -from ...utils import RemoteOpenAIServer +from tests.utils import RemoteOpenAIServer # any model with a chat template should work here MODEL_NAME = "facebook/opt-125m" diff --git a/tests/entrypoints/openai/test_lora_resolvers.py b/tests/entrypoints/openai/completion/test_lora_resolvers.py similarity index 100% rename from tests/entrypoints/openai/test_lora_resolvers.py rename to tests/entrypoints/openai/completion/test_lora_resolvers.py diff --git a/tests/entrypoints/openai/test_prompt_validation.py b/tests/entrypoints/openai/completion/test_prompt_validation.py similarity index 98% rename from tests/entrypoints/openai/test_prompt_validation.py rename to tests/entrypoints/openai/completion/test_prompt_validation.py index 5aff3b3c7..f44d13c55 100644 --- a/tests/entrypoints/openai/test_prompt_validation.py +++ b/tests/entrypoints/openai/completion/test_prompt_validation.py @@ -11,11 +11,10 @@ import pytest import regex as re import torch +from tests.utils import RemoteOpenAIServer from vllm.config import ModelConfig from vllm.renderers.embed_utils import safe_load_prompt_embeds -from ...utils import RemoteOpenAIServer - @pytest.mark.asyncio async def test_empty_prompt(): diff --git a/tests/entrypoints/openai/test_shutdown.py b/tests/entrypoints/openai/completion/test_shutdown.py similarity index 100% rename from tests/entrypoints/openai/test_shutdown.py rename to tests/entrypoints/openai/completion/test_shutdown.py diff --git a/tests/entrypoints/openai/test_tensorizer_entrypoint.py b/tests/entrypoints/openai/completion/test_tensorizer_entrypoint.py similarity index 98% rename from tests/entrypoints/openai/test_tensorizer_entrypoint.py rename to tests/entrypoints/openai/completion/test_tensorizer_entrypoint.py index 9ac9106db..29c0c2dc8 100644 --- a/tests/entrypoints/openai/test_tensorizer_entrypoint.py +++ b/tests/entrypoints/openai/completion/test_tensorizer_entrypoint.py @@ -9,6 +9,7 @@ import pytest import pytest_asyncio import torch.cuda +from tests.utils import RemoteOpenAIServer from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.model_loader.tensorizer import ( TensorizerConfig, @@ -17,8 +18,6 @@ from vllm.model_executor.model_loader.tensorizer import ( ) from vllm.platforms import current_platform -from ...utils import RemoteOpenAIServer - MODEL_NAME = "unsloth/llama-3.2-1b-Instruct" LORA_PATH = "davzoku/finqa_adapter_1b" diff --git a/tests/entrypoints/openai/test_token_in_token_out.py b/tests/entrypoints/openai/completion/test_token_in_token_out.py similarity index 98% rename from tests/entrypoints/openai/test_token_in_token_out.py rename to tests/entrypoints/openai/completion/test_token_in_token_out.py index c7f8abe27..8882ae624 100644 --- a/tests/entrypoints/openai/test_token_in_token_out.py +++ b/tests/entrypoints/openai/completion/test_token_in_token_out.py @@ -6,11 +6,10 @@ import tempfile import pytest +from tests.utils import RemoteOpenAIServer from vllm.model_executor.model_loader.weight_utils import download_weights_from_hf from vllm.tokenizers import get_tokenizer -from ...utils import RemoteOpenAIServer - MODEL_NAME = "Qwen/Qwen3-0.6B" MODEL_PATH = os.path.join(tempfile.gettempdir(), "qwen3_06b") -- GitLab From 54a62a79f70982742a227c845b96148e6401d0e7 Mon Sep 17 00:00:00 2001 From: Andreas Karatzas Date: Mon, 16 Mar 2026 22:34:49 -0500 Subject: [PATCH 018/223] [ROCm] Fix AttributeError for torch.compiler.skip_all_guards_unsafe on older PyTorch (#37219) Signed-off-by: Andreas Karatzas --- vllm/compilation/wrapper.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index ce85bae53..f5e62402a 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -112,7 +112,12 @@ class TorchCompileWithNoGuardsWrapper: entry.guard_type == "SHAPE_ENV" for entry in x ] else: - options["guard_filter_fn"] = torch.compiler.skip_all_guards_unsafe + if hasattr(torch.compiler, "skip_all_guards_unsafe"): + # Torch 2.10+ provides skip_all_guards_unsafe + options["guard_filter_fn"] = torch.compiler.skip_all_guards_unsafe + else: + # Equivalent fallback for older PyTorch: skip all guards + options["guard_filter_fn"] = lambda x: [False for _ in x] compiled_ptr: Any = self.forward # Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False -- GitLab From 3e3d320c1b367264f654204da42aeaf478cf3972 Mon Sep 17 00:00:00 2001 From: Flora Feng <4florafeng@gmail.com> Date: Tue, 17 Mar 2026 01:14:52 -0400 Subject: [PATCH 019/223] [Refactor] Relocate responses API tests (#37241) Signed-off-by: sfeng33 <4florafeng@gmail.com> --- .../entrypoints/openai/responses/conftest.py | 38 ++++++++++++++++ .../openai/responses}/test_basic.py | 0 .../openai/responses}/test_function_call.py | 0 .../openai/responses/test_harmony.py | 3 +- .../openai/responses}/test_image.py | 0 .../openai/responses/test_mcp_tools.py | 2 +- .../openai/responses/test_parsable_context.py | 3 +- .../openai/responses/test_simple.py | 3 +- .../openai/responses}/test_stateful.py | 0 .../responses}/test_structured_output.py | 0 .../openai/serving_responses/__init__.py | 0 .../openai/serving_responses/conftest.py | 44 ------------------- 12 files changed, 45 insertions(+), 48 deletions(-) rename tests/{v1/entrypoints/openai/serving_responses => entrypoints/openai/responses}/test_basic.py (100%) rename tests/{v1/entrypoints/openai/serving_responses => entrypoints/openai/responses}/test_function_call.py (100%) rename tests/{v1/entrypoints/openai/serving_responses => entrypoints/openai/responses}/test_image.py (100%) rename tests/{v1/entrypoints/openai/serving_responses => entrypoints/openai/responses}/test_stateful.py (100%) rename tests/{v1/entrypoints/openai/serving_responses => entrypoints/openai/responses}/test_structured_output.py (100%) delete mode 100644 tests/v1/entrypoints/openai/serving_responses/__init__.py delete mode 100644 tests/v1/entrypoints/openai/serving_responses/conftest.py diff --git a/tests/entrypoints/openai/responses/conftest.py b/tests/entrypoints/openai/responses/conftest.py index 3d300849e..68fdbbba3 100644 --- a/tests/entrypoints/openai/responses/conftest.py +++ b/tests/entrypoints/openai/responses/conftest.py @@ -8,6 +8,9 @@ from collections.abc import Callable from typing import Any import pytest +import pytest_asyncio + +from tests.utils import RemoteOpenAIServer logger = logging.getLogger(__name__) @@ -361,3 +364,38 @@ def log_response_diagnostics( ) return diagnostics + + +@pytest.fixture(scope="module") +def default_server_args(): + return [ + "--max-model-len", + "8192", + "--enforce-eager", # For faster startup. + "--enable-auto-tool-choice", + "--structured-outputs-config.backend", + "xgrammar", + "--tool-call-parser", + "hermes", + "--reasoning-parser", + "qwen3", + ] + + +@pytest.fixture(scope="module") +def server_with_store(default_server_args): + with RemoteOpenAIServer( + "Qwen/Qwen3-1.7B", + default_server_args, + env_dict={ + "VLLM_ENABLE_RESPONSES_API_STORE": "1", + "VLLM_SERVER_DEV_MODE": "1", + }, + ) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server_with_store): + async with server_with_store.get_async_client() as async_client: + yield async_client diff --git a/tests/v1/entrypoints/openai/serving_responses/test_basic.py b/tests/entrypoints/openai/responses/test_basic.py similarity index 100% rename from tests/v1/entrypoints/openai/serving_responses/test_basic.py rename to tests/entrypoints/openai/responses/test_basic.py diff --git a/tests/v1/entrypoints/openai/serving_responses/test_function_call.py b/tests/entrypoints/openai/responses/test_function_call.py similarity index 100% rename from tests/v1/entrypoints/openai/serving_responses/test_function_call.py rename to tests/entrypoints/openai/responses/test_function_call.py diff --git a/tests/entrypoints/openai/responses/test_harmony.py b/tests/entrypoints/openai/responses/test_harmony.py index 3bc041ba4..74f3360df 100644 --- a/tests/entrypoints/openai/responses/test_harmony.py +++ b/tests/entrypoints/openai/responses/test_harmony.py @@ -16,7 +16,8 @@ import requests from openai import InternalServerError, NotFoundError, OpenAI from openai_harmony import Message -from ....utils import RemoteOpenAIServer +from tests.utils import RemoteOpenAIServer + from .conftest import ( BASE_TEST_ENV, events_contain_type, diff --git a/tests/v1/entrypoints/openai/serving_responses/test_image.py b/tests/entrypoints/openai/responses/test_image.py similarity index 100% rename from tests/v1/entrypoints/openai/serving_responses/test_image.py rename to tests/entrypoints/openai/responses/test_image.py diff --git a/tests/entrypoints/openai/responses/test_mcp_tools.py b/tests/entrypoints/openai/responses/test_mcp_tools.py index 55445f188..eb3c5becc 100644 --- a/tests/entrypoints/openai/responses/test_mcp_tools.py +++ b/tests/entrypoints/openai/responses/test_mcp_tools.py @@ -9,9 +9,9 @@ import pytest_asyncio from openai import OpenAI from openai_harmony import ToolDescription, ToolNamespaceConfig +from tests.utils import RemoteOpenAIServer from vllm.entrypoints.mcp.tool_server import MCPToolServer -from ....utils import RemoteOpenAIServer from .conftest import ( BASE_TEST_ENV, events_contain_type, diff --git a/tests/entrypoints/openai/responses/test_parsable_context.py b/tests/entrypoints/openai/responses/test_parsable_context.py index 280bacf47..292edda9a 100644 --- a/tests/entrypoints/openai/responses/test_parsable_context.py +++ b/tests/entrypoints/openai/responses/test_parsable_context.py @@ -9,7 +9,8 @@ import pytest import pytest_asyncio from openai import OpenAI -from ....utils import RemoteOpenAIServer +from tests.utils import RemoteOpenAIServer + from .conftest import ( BASE_TEST_ENV, has_output_type, diff --git a/tests/entrypoints/openai/responses/test_simple.py b/tests/entrypoints/openai/responses/test_simple.py index 744aa068a..1f382f61b 100644 --- a/tests/entrypoints/openai/responses/test_simple.py +++ b/tests/entrypoints/openai/responses/test_simple.py @@ -5,7 +5,8 @@ import pytest import pytest_asyncio from openai import OpenAI -from ....utils import RemoteOpenAIServer +from tests.utils import RemoteOpenAIServer + from .conftest import validate_streaming_event_stack MODEL_NAME = "Qwen/Qwen3-8B" diff --git a/tests/v1/entrypoints/openai/serving_responses/test_stateful.py b/tests/entrypoints/openai/responses/test_stateful.py similarity index 100% rename from tests/v1/entrypoints/openai/serving_responses/test_stateful.py rename to tests/entrypoints/openai/responses/test_stateful.py diff --git a/tests/v1/entrypoints/openai/serving_responses/test_structured_output.py b/tests/entrypoints/openai/responses/test_structured_output.py similarity index 100% rename from tests/v1/entrypoints/openai/serving_responses/test_structured_output.py rename to tests/entrypoints/openai/responses/test_structured_output.py diff --git a/tests/v1/entrypoints/openai/serving_responses/__init__.py b/tests/v1/entrypoints/openai/serving_responses/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/v1/entrypoints/openai/serving_responses/conftest.py b/tests/v1/entrypoints/openai/serving_responses/conftest.py deleted file mode 100644 index b948b6d05..000000000 --- a/tests/v1/entrypoints/openai/serving_responses/conftest.py +++ /dev/null @@ -1,44 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest -import pytest_asyncio - -from tests.utils import RemoteOpenAIServer - -# Use a small reasoning model to test the responses API. -MODEL_NAME = "Qwen/Qwen3-1.7B" - - -@pytest.fixture(scope="module") -def default_server_args(): - return [ - "--max-model-len", - "8192", - "--enforce-eager", # For faster startup. - "--enable-auto-tool-choice", - "--structured-outputs-config.backend", - "xgrammar", - "--tool-call-parser", - "hermes", - "--reasoning-parser", - "qwen3", - ] - - -@pytest.fixture(scope="module") -def server_with_store(default_server_args): - with RemoteOpenAIServer( - MODEL_NAME, - default_server_args, - env_dict={ - "VLLM_ENABLE_RESPONSES_API_STORE": "1", - "VLLM_SERVER_DEV_MODE": "1", - }, - ) as remote_server: - yield remote_server - - -@pytest_asyncio.fixture -async def client(server_with_store): - async with server_with_store.get_async_client() as async_client: - yield async_client -- GitLab From 17c1bdf3719d9d8fdf4f13cb1468e5ed5f70d021 Mon Sep 17 00:00:00 2001 From: PatchyTIS <58251192+PatchouliTIS@users.noreply.github.com> Date: Tue, 17 Mar 2026 13:19:55 +0800 Subject: [PATCH 020/223] [Bugfix] dtype mismatch in ngram gpu propose (#37246) Signed-off-by: PatchouliTaisa Co-authored-by: PatchouliTaisa --- vllm/v1/spec_decode/ngram_proposer_gpu.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/v1/spec_decode/ngram_proposer_gpu.py b/vllm/v1/spec_decode/ngram_proposer_gpu.py index 3ff841804..eb24a9c93 100644 --- a/vllm/v1/spec_decode/ngram_proposer_gpu.py +++ b/vllm/v1/spec_decode/ngram_proposer_gpu.py @@ -364,7 +364,9 @@ class NgramProposerGPU: ) token_ids_gpu.scatter_(1, write_positions_long, tokens_to_scatter) - num_tokens_tmp = num_tokens_no_spec + valid_sampled_tokens_count + num_tokens_tmp = (num_tokens_no_spec + valid_sampled_tokens_count).to( + torch.int32 + ) # Compute validity masks. sampled_flags = valid_sampled_tokens_count > 0 @@ -437,7 +439,7 @@ class NgramProposerGPU: ) # Count valid tokens per request. - valid_sampled_tokens_count = valid_mask.sum(dim=1) + valid_sampled_tokens_count = valid_mask.sum(dim=1).to(torch.int32) # Rightmost valid index per row. last_valid_indices = valid_sampled_tokens_count - 1 -- GitLab From 20b14095a4e64e0cba71a40b264d0bc96ffb9c07 Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Tue, 17 Mar 2026 01:24:40 -0400 Subject: [PATCH 021/223] [Bugfix] Fix loading Music Flamingo (#35535) Signed-off-by: Nick Cao --- vllm/model_executor/models/audioflamingo3.py | 6 ------ vllm/model_executor/models/musicflamingo.py | 11 ++++++++++- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/audioflamingo3.py b/vllm/model_executor/models/audioflamingo3.py index e56997fb7..1a25dca2d 100644 --- a/vllm/model_executor/models/audioflamingo3.py +++ b/vllm/model_executor/models/audioflamingo3.py @@ -128,12 +128,6 @@ class AudioFlamingo3Encoder(Qwen2AudioEncoder): super().__init__(config) self.avg_pooler = nn.AvgPool1d(kernel_size=2, stride=2) # self.layer_norm is already initialized in super().__init__ - # Keep a dummy freqs parameter for MusicFlamingo checkpoints. - self.pos_emb = nn.Module() - freqs = torch.empty(getattr(config, "num_mel_bins", 128)) - self.pos_emb.register_parameter( - "freqs", nn.Parameter(freqs, requires_grad=False) - ) def forward( self, diff --git a/vllm/model_executor/models/musicflamingo.py b/vllm/model_executor/models/musicflamingo.py index 161de4e24..84328d4cd 100644 --- a/vllm/model_executor/models/musicflamingo.py +++ b/vllm/model_executor/models/musicflamingo.py @@ -21,6 +21,7 @@ from vllm.multimodal.processing import BaseProcessingInfo from .audioflamingo3 import ( AudioFlamingo3DummyInputsBuilder, AudioFlamingo3ForConditionalGeneration, + AudioFlamingo3MultiModalDataParser, AudioFlamingo3MultiModalProcessor, ) @@ -53,8 +54,16 @@ class MusicFlamingoProcessingInfo(BaseProcessingInfo): hf_processor = self.get_hf_processor(**kwargs) return hf_processor.feature_extractor + def get_data_parser(self): + feature_extractor = self.get_feature_extractor() + + return AudioFlamingo3MultiModalDataParser( + target_sr=feature_extractor.sampling_rate, + expected_hidden_size=self._get_expected_hidden_size(), + ) + def get_supported_mm_limits(self) -> Mapping[str, int | None]: - return {"audio": None} + return {"audio": 1} class MusicFlamingoDummyInputsBuilder(AudioFlamingo3DummyInputsBuilder): -- GitLab From 8a680463fab3bc9e6760417cd5c0a6aa58283065 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Tue, 17 Mar 2026 02:07:33 -0400 Subject: [PATCH 022/223] [Bugfix] Fix NemotronH MTP + Chunked Prefill (#35447) --- tests/v1/e2e/test_hybrid_chunked_prefill.py | 104 ++++++++++++++++++ .../layers/mamba/ops/mamba_ssm.py | 6 +- vllm/v1/attention/backends/mamba_attn.py | 10 +- vllm/v1/worker/gpu_model_runner.py | 27 ++++- vllm/v1/worker/mamba_utils.py | 42 +++++++ 5 files changed, 181 insertions(+), 8 deletions(-) create mode 100644 tests/v1/e2e/test_hybrid_chunked_prefill.py diff --git a/tests/v1/e2e/test_hybrid_chunked_prefill.py b/tests/v1/e2e/test_hybrid_chunked_prefill.py new file mode 100644 index 000000000..030081a38 --- /dev/null +++ b/tests/v1/e2e/test_hybrid_chunked_prefill.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm import SamplingParams +from vllm.platforms import current_platform + +from ...utils import large_gpu_mark, multi_gpu_marks + +# A trivial request with a short prompt to ensure we run a mixed batch +SMALL_MESSAGE = [ + { + "role": "user", + "content": "The secret beta value is 64. What is the secret beta?", + } +] + +# Sample prompt with a bunch of filler in between the critical fact and the request. +# Both parts need to be processed properly for the model to generate the correct answer +MESSAGES = [ + { + "role": "user", + "content": ( + "Important: The secret number is 42. " + "The sky is green in this hypothetical world. " + "Apples grow on trees in the forest. " + "Rivers flow through the valleys and mountains. " + "Birds sing songs in the early morning light. " + "The weather today is sunny with clear skies ahead. " + "Flowers bloom in the garden during spring season. " + "Now answer with ONLY the number and nothing else: " + "What is the secret number plus one?" + ), + } +] + + +@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available") +@pytest.mark.parametrize( + "model_name", + [ + pytest.param("Qwen/Qwen3.5-4B", marks=[large_gpu_mark(min_gb=40)]), + pytest.param( + "nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-FP8", + marks=[large_gpu_mark(min_gb=80)] + multi_gpu_marks(num_gpus=2), + ), + ], +) +@pytest.mark.parametrize("enable_prefix_caching", [False, True]) +def test_mtp_speculative_mixed_batch_short_prefill( + vllm_runner, model_name, enable_prefix_caching +): + """Test to ensure MTP speculative decoding correctly handles + short prefill chunks that fall below the reorder_batch_threshold.""" + + # Set so large that both prefills will be classified as decodes in a mixed batch + # note, with prefix caching we require chunk_size >= mamba_block_size + chunk_size = 256 if not enable_prefix_caching else 16384 + num_draft_tokens = 100 + + with vllm_runner( + model_name, + speculative_config={ + "method": "mtp", + "num_speculative_tokens": num_draft_tokens, + }, + max_num_batched_tokens=chunk_size, + max_model_len=512, + enforce_eager=True, + tensor_parallel_size=2, + trust_remote_code=True, + enable_chunked_prefill=True, + enable_prefix_caching=enable_prefix_caching, + mamba_cache_mode="align" if enable_prefix_caching else "none", + ) as llm: + sampling_params = SamplingParams( + temperature=0.0, + max_tokens=128, + ) + + # First small message gets prefilled first, under normal conditions since the + # batch is not yet mixed. Then the second prefill arrives as a mixed batch, but + # is shorter than num_speculative_tokens, so it gets misclassified as a decode + # and processed with the wrong state management logic, causing the critical + # fact from the first chunk to be lost and the model to generate nonsense. + outputs = llm.get_llm().chat( + [SMALL_MESSAGE, MESSAGES], + sampling_params, + chat_template_kwargs={"enable_thinking": False}, + ) + + responses = [] + for output in outputs: + generated_text = output.outputs[0].text + print(f"Generated text: {generated_text!r}") + responses.append(generated_text) + + assert "64" in responses[0], ( + "The first response should contain the correct value of 64." + ) + assert "43" in responses[1], ( + "The second response should contain the correct value of 42+1=43." + ) diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 22a99596a..1cd077758 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -334,13 +334,13 @@ def selective_state_update( dt_bias = dt_bias.unsqueeze(0) if out.dim() == 2: out = out.unsqueeze(1) - if num_accepted_tokens is not None: - assert state_batch_indices is not None and state_batch_indices.dim() == 2 - assert dst_state_batch_indices is None or dst_state_batch_indices.dim() == 2 if state_batch_indices is not None and state_batch_indices.dim() == 1: state_batch_indices = state_batch_indices.unsqueeze(1) if dst_state_batch_indices is not None and dst_state_batch_indices.dim() == 1: dst_state_batch_indices = dst_state_batch_indices.unsqueeze(1) + if num_accepted_tokens is not None: + assert state_batch_indices is not None and state_batch_indices.dim() == 2 + assert dst_state_batch_indices is None or dst_state_batch_indices.dim() == 2 _, nheads, dim, dstate = state.shape batch = x.shape[0] diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 0364d6aee..bdb820eac 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -414,8 +414,11 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): ] state_indices_tensor_p = state_indices_tensor_p[:, 0] - if num_decodes > 0 and self.use_spec_decode: - assert num_accepted_tokens is not None + # Sometimes even with specdec enabled we get single-token prefill chunks that + # should be treated as decodes but don't have num_accepted_tokens set. + # These should be fine to process as non-spec decodes since there's only + # one token, so no risk of placing accepted tokens in the wrong slot. + if num_decodes > 0 and self.use_spec_decode and num_accepted_tokens is not None: query_start_loc_d = common_attn_metadata.query_start_loc[: num_decodes + 1] num_accepted_tokens = num_accepted_tokens[:num_decodes] @@ -501,9 +504,8 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): state_indices_tensor_d = self.state_indices_tensor_d[:padded_bs] state_indices_tensor_d[metadata.num_decodes :] = PAD_SLOT_ID - if self.use_spec_decode: + if self.use_spec_decode and num_accepted_tokens is not None: assert query_start_loc_d is not None - assert num_accepted_tokens is not None query_start_loc_d = query_start_loc_d[: padded_bs + 1] self.decode_num_accepted_tokens[: metadata.num_decodes].copy_( num_accepted_tokens, non_blocking=True diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 98e1dab36..22459bc49 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -739,6 +739,19 @@ class GPUModelRunner( self.uniform_decode_query_len = 1 + self.num_spec_tokens + # When spec decode is active, the mamba backend classifies requests + # with query_len <= reorder_batch_threshold as "decodes". Prefill + # chunks that fall under this threshold get processed via the decode + # path, which stores intermediate states at sequential slots. We must + # set num_accepted_tokens to the chunk's query_len for those requests + # so the next iteration reads from the correct final-state slot. + # Prefills that went through the actual prefill path should keep the + # default value of 1 (the prefill path stores state at slot 0 only). + self.needs_prefill_as_decode_slots: bool = False + self.prefill_as_decode_num_tokens = self._make_buffer( + self.max_num_reqs, dtype=torch.int32 + ) + # Cudagraph dispatcher for runtime cudagraph dispatching. self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) @@ -1355,12 +1368,22 @@ class GPUModelRunner( .int() .argmax(-1) ) + spec_decode_active = bool(scheduler_output.scheduled_spec_decode_tokens) + if self.needs_prefill_as_decode_slots and spec_decode_active: + mamba_utils.update_accepted_tokens_for_prefill_as_decode( + self.input_batch, + self.prefill_as_decode_num_tokens, + self.num_accepted_tokens.gpu, + scheduler_output, + self.reorder_batch_threshold, + num_reqs, + ) + if self.cache_config.mamba_cache_mode == "align": for i, num_tokens in enumerate( self.num_accepted_tokens.gpu[:num_reqs].cpu().numpy() ): self.input_batch.num_accepted_tokens_cpu[i] = num_tokens - mamba_utils.postprocess_mamba( scheduler_output, self.kv_cache_config, @@ -2024,6 +2047,8 @@ class GPUModelRunner( else 0 ) + if isinstance(builder, Mamba2AttentionMetadataBuilder): + self.needs_prefill_as_decode_slots = True extra_attn_metadata_args = {} if use_spec_decode and isinstance( builder, (Mamba2AttentionMetadataBuilder, GDNAttentionMetadataBuilder) diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py index 2bd5d2b3f..68172133e 100644 --- a/vllm/v1/worker/mamba_utils.py +++ b/vllm/v1/worker/mamba_utils.py @@ -266,3 +266,45 @@ def postprocess_mamba( if src_block_idx == dest_block_idx: num_accepted_tokens_cpu[i] = 1 do_mamba_copy_block(copy_bufs) + + +def update_accepted_tokens_for_prefill_as_decode( + input_batch: GPUInputBatch, + prefill_as_decode_num_tokens: CpuGpuBuffer, + num_accepted_tokens_gpu: torch.Tensor, + scheduler_output: SchedulerOutput, + decode_qlen_threshold: int | None, + num_reqs: int, +): + """ + Adjusts num_accepted_tokens for prefill chunks processed via the decode path. + This ensures subsequent iterations read from the correct sequential state slot + instead of the default prefill slot 0. Not used by GDN attention, which manually + separates short prefills and short decodes when building the attention metadata. + """ + any_is_prefill = False + for i in range(num_reqs): + num_computed = input_batch.num_computed_tokens_cpu[i] + num_prompt = input_batch.num_prompt_tokens[i] + is_prefill = num_computed < num_prompt + req_id = input_batch.req_ids[i] + query_len = scheduler_output.num_scheduled_tokens[req_id] + + if is_prefill: + classified_as_decode = ( + decode_qlen_threshold is not None and query_len <= decode_qlen_threshold + ) + num_tokens = query_len if classified_as_decode else 1 + any_is_prefill = True + else: + num_tokens = -1 + prefill_as_decode_num_tokens.np[i] = num_tokens + + # We can skip the GPU transfer if there aren't any values to update + if any_is_prefill: + prefill_as_decode_num_tokens.copy_to_gpu(num_reqs) + num_accepted_tokens_gpu[:num_reqs] = torch.where( + prefill_as_decode_num_tokens.gpu[:num_reqs] != -1, + prefill_as_decode_num_tokens.gpu[:num_reqs], + num_accepted_tokens_gpu[:num_reqs], + ) -- GitLab From 24b4272a8ca6a793b80568486060547b5b392433 Mon Sep 17 00:00:00 2001 From: xiao-llm Date: Tue, 17 Mar 2026 03:19:15 -0400 Subject: [PATCH 023/223] Fix infinite recursive search issue in quark.py (#32779) Signed-off-by: Yanwen Lin Signed-off-by: Xiao Yu Signed-off-by: kimheesu Co-authored-by: Yanwen Lin Co-authored-by: Kim Hee Su --- .../layers/quantization/quark/quark.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index 1ca28fbf0..78c64bac6 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -467,10 +467,17 @@ class QuarkConfig(QuantizationConfig): layer_name.replace(proj_name, shard_proj_name) for shard_proj_name in shard_proj_names ] - shard_configs = [ - self._find_matched_config(shard_name, module) - for shard_name in shard_names - ] + + shard_configs = [] + for shard_name in shard_names: + if shard_name == layer_name: + config = cast( + dict[str, Any], self.quant_config.get("global_quant_config") + ) + else: + config = self._find_matched_config(shard_name, module) + shard_configs.append(config) + if not all( deep_compare(q_config, shard_configs[0]) for q_config in shard_configs ): -- GitLab From 132bfd45b691fedc45a8d9851a25c7776144d9e0 Mon Sep 17 00:00:00 2001 From: Chauncey Date: Tue, 17 Mar 2026 16:54:52 +0800 Subject: [PATCH 024/223] [Bugfix][ResponsesAPI] Fix crash when tool_choice=required exceeds max_output_tokens (#37258) Signed-off-by: chaunceyjiang --- .../openai/responses/test_function_call.py | 28 +++++++++++++++++++ vllm/parser/abstract_parser.py | 23 +++++++++------ 2 files changed, 42 insertions(+), 9 deletions(-) diff --git a/tests/entrypoints/openai/responses/test_function_call.py b/tests/entrypoints/openai/responses/test_function_call.py index 0b8a2e649..36627f92d 100644 --- a/tests/entrypoints/openai/responses/test_function_call.py +++ b/tests/entrypoints/openai/responses/test_function_call.py @@ -134,6 +134,34 @@ async def test_function_tool_use( assert reasoning.type == "reasoning" +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_max_tokens_with_tool_choice_required( + client: openai.AsyncOpenAI, model_name: str +): + prompt = [ + { + "role": "user", + "content": "Can you tell me what the current weather is in Berlin and the " + "forecast for the next 5 days, in fahrenheit?", + }, + ] + response = await client.responses.create( + model=model_name, + input=prompt, + tools=tools, + tool_choice="required", + max_output_tokens=10, + ) + assert len(response.output) >= 1 + for out in response.output: + # When `tool_choice="required"` and the tokens of `tools` + # exceed `max_output_tokens`,`function_call` should be empty. + # This behavior should be consistent with OpenAI + assert out.type != "function_call" + assert response.incomplete_details.reason == "max_output_tokens" + + @pytest.mark.asyncio async def test_named_tool_use(client: openai.AsyncOpenAI): def get_weather(latitude: float, longitude: float) -> str: diff --git a/vllm/parser/abstract_parser.py b/vllm/parser/abstract_parser.py index aa145bab2..0c1dda17b 100644 --- a/vllm/parser/abstract_parser.py +++ b/vllm/parser/abstract_parser.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib import json from abc import abstractmethod from collections.abc import Sequence @@ -18,7 +19,7 @@ from openai.types.responses.response_output_text import Logprob from openai.types.responses.response_reasoning_item import ( Content as ResponseReasoningTextContent, ) -from pydantic import TypeAdapter +from pydantic import TypeAdapter, ValidationError from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.chat_completion.protocol import ( @@ -422,15 +423,19 @@ class DelegatingParser(Parser): if request.tool_choice == "required": # Required tool calls - parse JSON - assert content is not None - tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content) - function_calls.extend( - FunctionCall( - name=tool_call.name, - arguments=json.dumps(tool_call.parameters, ensure_ascii=False), + tool_calls = [] + with contextlib.suppress(ValidationError): + content = content or "" + tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json( + content + ) + for tool_call in tool_calls: + function_calls.append( + FunctionCall( + name=tool_call.name, + arguments=json.dumps(tool_call.parameters, ensure_ascii=False), + ) ) - for tool_call in tool_calls - ) return function_calls, None # Clear content since tool is called. if ( -- GitLab From 9c7cab5ebb0f8a15e632e7ea2cfeebcca1d3628f Mon Sep 17 00:00:00 2001 From: Augusto Yao Date: Tue, 17 Mar 2026 17:05:42 +0800 Subject: [PATCH 025/223] [Feature]: Support for multiple embedding types in a single inference call (#35829) Signed-off-by: augusto.yjh --- .../sparse_embeddings_processor.py | 124 +++++++++++++++--- .../bge_m3_sparse_processor/types.py | 35 ++++- ...test_bge_m3_sparse_io_processor_plugins.py | 25 +++- vllm/model_executor/layers/pooler/special.py | 40 +++++- vllm/model_executor/models/roberta.py | 26 ++-- vllm/pooling_params.py | 4 + vllm/tasks.py | 8 +- 7 files changed, 226 insertions(+), 36 deletions(-) diff --git a/tests/plugins/bge_m3_sparse_plugin/bge_m3_sparse_processor/sparse_embeddings_processor.py b/tests/plugins/bge_m3_sparse_plugin/bge_m3_sparse_processor/sparse_embeddings_processor.py index 4749d3e81..b97f7de13 100644 --- a/tests/plugins/bge_m3_sparse_plugin/bge_m3_sparse_processor/sparse_embeddings_processor.py +++ b/tests/plugins/bge_m3_sparse_plugin/bge_m3_sparse_processor/sparse_embeddings_processor.py @@ -3,10 +3,10 @@ from collections.abc import Sequence -from vllm.config import VllmConfig +from vllm.config import ModelConfig, PoolerConfig, VllmConfig from vllm.entrypoints.openai.engine.protocol import UsageInfo +from vllm.entrypoints.pooling.base.protocol import EmbedRequestMixin from vllm.inputs.data import PromptType -from vllm.logger import init_logger from vllm.outputs import PoolingRequestOutput from vllm.plugins.io_processors.interface import ( IOProcessor, @@ -16,14 +16,13 @@ from vllm.renderers import BaseRenderer from vllm.tokenizers.detokenizer_utils import convert_ids_list_to_tokens from .types import ( + EMBED_TASKS, SparseEmbeddingCompletionRequestMixin, SparseEmbeddingResponse, SparseEmbeddingResponseData, SparseEmbeddingTokenWeight, ) -logger = init_logger(__name__) - class BgeM3SparseEmbeddingsProcessor( IOProcessor[SparseEmbeddingCompletionRequestMixin, SparseEmbeddingResponse] @@ -33,6 +32,22 @@ class BgeM3SparseEmbeddingsProcessor( self.offline_requests: list[SparseEmbeddingCompletionRequestMixin] = [] self.online_requests: dict[str, SparseEmbeddingCompletionRequestMixin] = {} self.renderer: BaseRenderer = renderer + self.default_pooling_params = {} + pooler_config: PoolerConfig = vllm_config.model_config.pooler_config + if pooler_config is not None: + for param in ["use_activation", "dimensions"]: + if getattr(pooler_config, param, None) is None: + continue + self.default_pooling_params[param] = getattr(pooler_config, param) + self.embed_dimensions = vllm_config.model_config.embedding_size + self.embed_request_queue: list[EmbedRequestMixin] = [] + + def __repr__(self) -> str: + return ( + f"BgeM3SparseEmbeddingsProcessor(" + f"embed_dimensions={self.embed_dimensions}, " + f"default_pooling_params={self.default_pooling_params})" + ) def merge_pooling_params( self, @@ -41,7 +56,57 @@ class BgeM3SparseEmbeddingsProcessor( if params is None: params = PoolingParams() # refer to PoolingCompletionRequest.to_pooling_params - params.task = "token_classify" + # set and verify pooling params + params.skip_reading_prefix_cache = True + + raw_embed_request = self.embed_request_queue.pop(0) + if raw_embed_request.embed_task not in EMBED_TASKS: + raise ValueError( + f"Unsupported task {raw_embed_request}, " + f"Supported tasks are {EMBED_TASKS}" + ) + has_dense_embed = True + if raw_embed_request.embed_task == "dense": + params.task = "embed" + params.skip_reading_prefix_cache = False + elif raw_embed_request.embed_task == "sparse": + params.task = "token_classify" + has_dense_embed = False + else: + params.task = "embed&token_classify" + params.use_activation = raw_embed_request.use_activation + if params.use_activation is None: + params.use_activation = True + if not has_dense_embed: + params.dimensions = None + return params + + params.dimensions = raw_embed_request.dimensions + + model_config: ModelConfig = self.vllm_config.model_config + for param in self.default_pooling_params: + if getattr(params, param, None) is None: + setattr(params, param, self.default_pooling_params[param]) + + if params.dimensions is not None: + if not model_config.is_matryoshka: + raise ValueError( + f'Model "{model_config.served_model_name}" does not ' + f"support matryoshka representation, " + f"changing output dimensions will lead to poor results." + ) + + mds = model_config.matryoshka_dimensions + if mds is not None: + if params.dimensions not in mds: + raise ValueError( + f"Model {model_config.served_model_name!r} " + f"only supports {str(mds)} matryoshka dimensions, " + f"use other output dimensions will " + f"lead to poor results." + ) + elif params.dimensions < 1: + raise ValueError("Dimensions must be greater than 0") return params def parse_request( @@ -61,14 +126,16 @@ class BgeM3SparseEmbeddingsProcessor( if request_id is not None: assert request_id not in self.online_requests, "request_id duplicated" self.online_requests[request_id] = prompt + self.embed_request_queue.extend(prompt.to_embed_requests_online()) else: self.offline_requests.append(prompt) + self.embed_request_queue.extend(prompt.to_embed_requests_offline()) return prompt.input def _get_sparse_embedding_request(self, request_id: str | None = None): if request_id: return self.online_requests.pop(request_id, None) - return self.offline_requests.pop() + return self.offline_requests.pop(0) def _build_sparse_embedding_token_weights( self, @@ -100,26 +167,45 @@ class BgeM3SparseEmbeddingsProcessor( ) -> SparseEmbeddingResponse: num_prompt_tokens = 0 response_data = [] - return_tokens = self._get_sparse_embedding_request(request_id).return_tokens + raw_request = self._get_sparse_embedding_request(request_id) + has_dense_embed = raw_request.embed_task in ["dense", "dense&sparse"] + has_sparse_embed = raw_request.embed_task in ["sparse", "dense&sparse"] + embed_dimensions = 0 + if has_dense_embed: + embed_dimensions = ( + self.embed_dimensions + if raw_request.dimensions is None + else raw_request.dimensions + ) for idx in range(len(model_output)): mo = model_output[idx] - sparse_embedding: dict[int, float] = {} + sparse_embedding_dict: dict[int, float] = {} num_prompt_tokens += len(mo.prompt_token_ids) - if len(mo.prompt_token_ids) != len(mo.outputs.data): - # this is the case that add_special_tokens is True, - # which means first token and last token are special tokens - mo.prompt_token_ids = mo.prompt_token_ids[1:] - for token_id, weight in zip(mo.prompt_token_ids, mo.outputs.data.tolist()): - sparse_embedding[token_id] = max( - weight, sparse_embedding.get(token_id, 0.0) + dense_embedding: list[float] | None = None + sparse_embedding: list[SparseEmbeddingTokenWeight] | None = None + if has_dense_embed: + dense_embedding = mo.outputs.data[:embed_dimensions].tolist() + if has_sparse_embed: + sparse_weights = mo.outputs.data[embed_dimensions:].tolist() + if len(mo.prompt_token_ids) != len(sparse_weights): + # this is the case that add_special_tokens is True, + # which means first token and last token are special tokens + mo.prompt_token_ids = mo.prompt_token_ids[1:] + for token_id, weight in zip(mo.prompt_token_ids, sparse_weights): + sparse_embedding_dict[token_id] = max( + weight, sparse_embedding_dict.get(token_id, 0.0) + ) + sparse_embedding = self._build_sparse_embedding_token_weights( + sparse_embedding_dict, + raw_request.return_tokens, ) + response_data.append( SparseEmbeddingResponseData( index=idx, - sparse_embedding=self._build_sparse_embedding_token_weights( - sparse_embedding, - return_tokens, - ), + object=raw_request.embed_task, + sparse_embedding=sparse_embedding, + dense_embedding=dense_embedding, ) ) diff --git a/tests/plugins/bge_m3_sparse_plugin/bge_m3_sparse_processor/types.py b/tests/plugins/bge_m3_sparse_plugin/bge_m3_sparse_processor/types.py index 1dcf30a05..ba69932f4 100644 --- a/tests/plugins/bge_m3_sparse_plugin/bge_m3_sparse_processor/types.py +++ b/tests/plugins/bge_m3_sparse_plugin/bge_m3_sparse_processor/types.py @@ -1,18 +1,44 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Literal, get_args + from pydantic import BaseModel, Field from vllm.entrypoints.openai.engine.protocol import UsageInfo -from vllm.entrypoints.pooling.base.protocol import CompletionRequestMixin +from vllm.entrypoints.pooling.base.protocol import ( + CompletionRequestMixin, + EmbedRequestMixin, +) + +EmbedTask = Literal[ + "sparse", + "dense", + "dense&sparse", +] + +EMBED_TASKS: tuple[EmbedTask, ...] = get_args(EmbedTask) -class SparseEmbeddingCompletionRequestMixin(CompletionRequestMixin): +class SparseEmbeddingCompletionRequestMixin(CompletionRequestMixin, EmbedRequestMixin): return_tokens: bool | None = Field( default=None, description="Whether to return dict shows the mapping of token_id to text." "`None` or False means not return.", ) + embed_task: EmbedTask = Field( + default="dense&sparse", + description="embed task, can be one of 'sparse', 'dense' , 'dense&sparse', " + "default to 'dense&sparse'", + ) + + def to_embed_requests_offline(self) -> list[EmbedRequestMixin]: + if isinstance(self.input, list): + return [self] * len(self.input) + return [self] + + def to_embed_requests_online(self) -> list[EmbedRequestMixin]: + return [self] class SparseEmbeddingTokenWeight(BaseModel): @@ -23,8 +49,9 @@ class SparseEmbeddingTokenWeight(BaseModel): class SparseEmbeddingResponseData(BaseModel): index: int - object: str = "sparse-embedding" - sparse_embedding: list[SparseEmbeddingTokenWeight] + object: str = "dense&sparse" + sparse_embedding: list[SparseEmbeddingTokenWeight] | None + dense_embedding: list[float] | None class SparseEmbeddingResponse(BaseModel): diff --git a/tests/plugins_tests/test_bge_m3_sparse_io_processor_plugins.py b/tests/plugins_tests/test_bge_m3_sparse_io_processor_plugins.py index 20c400e59..85293e55c 100644 --- a/tests/plugins_tests/test_bge_m3_sparse_io_processor_plugins.py +++ b/tests/plugins_tests/test_bge_m3_sparse_io_processor_plugins.py @@ -19,6 +19,12 @@ model_config = { ), } +dense_embedding_sum = [ + -0.7214539647102356, # "What is the capital of France?" + -0.6926871538162231, # "What is the capital of Germany?" + -0.7129564881324768, # "What is the capital of Spain?" +] + def _float_close(expected: object, result: object): assert isinstance(expected, float) and isinstance(result, float), ( @@ -33,6 +39,12 @@ def _get_attr_or_val(obj: object | dict, key: str): return getattr(obj, key, None) +def _check_dense_embedding(data, index=0): + assert _float_close(sum(data), dense_embedding_sum[index]), ( + "dense-embedding result not match" + ) + + def _check_sparse_embedding(data, check_tokens=False): expected_weights = [ {"token_id": 32, "weight": 0.0552978515625, "token": "?"}, @@ -109,7 +121,7 @@ async def test_bge_m3_sparse_plugin_online( assert len(_get_attr_or_val(parsed_response, "data")) > 0 data_entry = _get_attr_or_val(parsed_response, "data")[0] - assert _get_attr_or_val(data_entry, "object") == "sparse-embedding" + assert _get_attr_or_val(data_entry, "object") == "dense&sparse" assert _get_attr_or_val(data_entry, "sparse_embedding") # Verify sparse embedding format @@ -117,6 +129,11 @@ async def test_bge_m3_sparse_plugin_online( assert isinstance(sparse_embedding, list) _check_sparse_embedding(sparse_embedding, return_tokens) + # Verify dense embedding format + dense_embedding = _get_attr_or_val(data_entry, "dense_embedding") + assert isinstance(dense_embedding, list) + _check_dense_embedding(dense_embedding) + # Verify usage information usage = _get_attr_or_val(parsed_response, "usage") assert usage, f"usage not found for {parsed_response}" @@ -164,6 +181,9 @@ def test_bge_m3_sparse_plugin_offline(vllm_runner, return_tokens: bool): sparse_embedding = output.sparse_embedding assert isinstance(sparse_embedding, list) _check_sparse_embedding(sparse_embedding, return_tokens) + dense_embedding = output.dense_embedding + assert isinstance(dense_embedding, list) + _check_dense_embedding(dense_embedding) # Verify usage assert response.usage.prompt_tokens > 0 @@ -206,6 +226,9 @@ def test_bge_m3_sparse_plugin_offline_multiple_inputs(vllm_runner): # Each output should have sparse embeddings sparse_embedding = output.sparse_embedding assert isinstance(sparse_embedding, list) + dense_embedding = output.dense_embedding + assert isinstance(dense_embedding, list) + _check_dense_embedding(dense_embedding, i) # Verify usage assert response.usage.prompt_tokens > 0 diff --git a/vllm/model_executor/layers/pooler/special.py b/vllm/model_executor/layers/pooler/special.py index bafa191db..5e0f9ec75 100644 --- a/vllm/model_executor/layers/pooler/special.py +++ b/vllm/model_executor/layers/pooler/special.py @@ -170,4 +170,42 @@ class BOSEOSFilter(Pooler): return pooled_outputs -__all__ = ["BOSEOSFilter", "DispatchPooler", "IdentityPooler"] +class BgeM3Pooler(Pooler): + def __init__(self, token_classify_pooler: Pooler, embed_pooler: Pooler) -> None: + super().__init__() + self.token_classify_pooler = token_classify_pooler + self.embed_pooler = embed_pooler + + def forward( + self, hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata + ) -> PoolerOutput: + embed_outputs = self.embed_pooler(hidden_states, pooling_metadata) + token_classify_outputs = self.token_classify_pooler( + hidden_states, pooling_metadata + ) + pooler_outputs: list[torch.Tensor] = [] + for embed_output, token_classify_output in zip( + embed_outputs, token_classify_outputs + ): + pooler_outputs.append( + torch.cat( + [embed_output.view(-1), token_classify_output.view(-1)], dim=-1 + ) + ) + + return pooler_outputs + + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"embed&token_classify"} + + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + return self.embed_pooler.get_pooling_updates( + "embed" + ) | self.token_classify_pooler.get_pooling_updates("token_classify") + + def extra_repr(self) -> str: + s = f"supported_task={self.get_supported_tasks()}" + return s + + +__all__ = ["BOSEOSFilter", "DispatchPooler", "IdentityPooler", "BgeM3Pooler"] diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 5faa64654..46211e6ed 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -10,6 +10,7 @@ from transformers import RobertaConfig from vllm.config import ModelConfig, PoolerConfig, VllmConfig from vllm.model_executor.layers.pooler import ( + BgeM3Pooler, BOSEOSFilter, DispatchPooler, Pooler, @@ -216,24 +217,29 @@ class BgeM3EmbeddingModel(RobertaEmbeddingModel): self.colbert_linear = nn.Linear( self.hidden_size, self.hidden_size, dtype=self.head_dtype ) + embed_pooler = pooler_for_embed(pooler_config) + token_classify_pooler = BOSEOSFilter( + pooler_for_token_classify( + pooler_config, + pooling=AllPool(), + classifier=self.sparse_linear, + act_fn=torch.relu, + ), + self.bos_token_id, + self.eos_token_id, + ) return DispatchPooler( { - "embed": pooler_for_embed(pooler_config), + "embed": embed_pooler, "token_embed": BOSEOSFilter( pooler_for_token_embed(pooler_config, self.colbert_linear), self.bos_token_id, # for some reason m3 only filters the bos for colbert vectors ), - "token_classify": BOSEOSFilter( - pooler_for_token_classify( - pooler_config, - pooling=AllPool(), - classifier=self.sparse_linear, - act_fn=torch.relu, - ), - self.bos_token_id, - self.eos_token_id, + "token_classify": token_classify_pooler, + "embed&token_classify": BgeM3Pooler( + token_classify_pooler, embed_pooler ), } ) diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 6b85506ab..e5e993b75 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -96,6 +96,10 @@ class PoolingParams( self.skip_reading_prefix_cache = True return + # skipping verify, let plugins configure and validate pooling params + if self.task not in self.valid_parameters: + return + # NOTE: Task validation needs to done against the model instance, # which is not available in model config. So, it's not included # in this method diff --git a/vllm/tasks.py b/vllm/tasks.py index 950993279..83dd7f85e 100644 --- a/vllm/tasks.py +++ b/vllm/tasks.py @@ -6,7 +6,13 @@ GenerationTask = Literal["generate", "transcription", "realtime"] GENERATION_TASKS: tuple[GenerationTask, ...] = get_args(GenerationTask) PoolingTask = Literal[ - "embed", "classify", "score", "token_embed", "token_classify", "plugin" + "embed", + "classify", + "score", + "token_embed", + "token_classify", + "plugin", + "embed&token_classify", ] POOLING_TASKS: tuple[PoolingTask, ...] = get_args(PoolingTask) -- GitLab From 4af9ed21cba9e4bb85cd7cc124aa6f23cd0ae9a5 Mon Sep 17 00:00:00 2001 From: "zhao, zhenhui" Date: Tue, 17 Mar 2026 19:14:07 +0800 Subject: [PATCH 026/223] =?UTF-8?q?[Bugfix](xpu):=20prevent=20=E2=80=9Csel?= =?UTF-8?q?ected=20index=20k=20out=20of=20range=E2=80=9D=20in=20TP=20decod?= =?UTF-8?q?e=20path=20(#37259)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: zhenzhao --- vllm/_xpu_ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/_xpu_ops.py b/vllm/_xpu_ops.py index 91f5e0290..a2eb5ff3a 100644 --- a/vllm/_xpu_ops.py +++ b/vllm/_xpu_ops.py @@ -426,7 +426,8 @@ class xpu_ops: mask = positions <= index_end_pos # mask: [B * N, L] logits = logits.masked_fill(~mask, float("-inf")) - topk_indices = logits.topk(topk_tokens, dim=-1)[1].to(torch.int32) # [B * N, K] + real_topk = min(topk_tokens, logits.shape[-1]) + topk_indices = logits.topk(real_topk, dim=-1)[1].to(torch.int32) # [B * N, K] # ensure we don't set indices for the top k # that is out of range(masked already) # this will happen if context length is shorter than K -- GitLab From 00f8e0d2113098b5fd37c8c24ba594fa4268ccc3 Mon Sep 17 00:00:00 2001 From: Sage <80211083+sagearc@users.noreply.github.com> Date: Tue, 17 Mar 2026 13:22:54 +0200 Subject: [PATCH 027/223] [Frontend] Delegate tokenization serving preprocessing to OpenAIServingRender (#37266) Signed-off-by: Sage Ahrac --- .../openai/chat_completion/test_chat_error.py | 2 +- vllm/entrypoints/openai/api_server.py | 19 +++++++++++++++++ .../entrypoints/openai/generate/api_router.py | 21 +------------------ vllm/entrypoints/serve/render/serving.py | 12 +++++------ vllm/entrypoints/serve/tokenize/serving.py | 9 +++++--- 5 files changed, 33 insertions(+), 30 deletions(-) diff --git a/tests/entrypoints/openai/chat_completion/test_chat_error.py b/tests/entrypoints/openai/chat_completion/test_chat_error.py index 073976563..5fd7bc09c 100644 --- a/tests/entrypoints/openai/chat_completion/test_chat_error.py +++ b/tests/entrypoints/openai/chat_completion/test_chat_error.py @@ -111,7 +111,7 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat: [{"prompt_token_ids": [1, 2, 3]}], ) - serving_chat.openai_serving_render._preprocess_chat = AsyncMock( + serving_chat.openai_serving_render.preprocess_chat = AsyncMock( side_effect=_fake_preprocess_chat ) return serving_chat diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 126e2b402..39e9076a7 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -46,6 +46,7 @@ from vllm.entrypoints.sagemaker.api_router import sagemaker_standards_bootstrap from vllm.entrypoints.serve.elastic_ep.middleware import ( ScalingMiddleware, ) +from vllm.entrypoints.serve.render.serving import OpenAIServingRender from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization from vllm.entrypoints.utils import ( cli_env_setup, @@ -365,9 +366,27 @@ async def init_app_state( lora_modules=lora_modules, ) await state.openai_serving_models.init_static_loras() + + state.openai_serving_render = OpenAIServingRender( + model_config=engine_client.model_config, + renderer=engine_client.renderer, + io_processor=engine_client.io_processor, + model_registry=state.openai_serving_models.registry, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + trust_request_chat_template=args.trust_request_chat_template, + enable_auto_tools=args.enable_auto_tool_choice, + exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none, + tool_parser=args.tool_call_parser, + default_chat_template_kwargs=args.default_chat_template_kwargs, + log_error_stack=args.log_error_stack, + ) + state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, state.openai_serving_models, + state.openai_serving_render, request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, diff --git a/vllm/entrypoints/openai/generate/api_router.py b/vllm/entrypoints/openai/generate/api_router.py index 88a059661..bda83fbe0 100644 --- a/vllm/entrypoints/openai/generate/api_router.py +++ b/vllm/entrypoints/openai/generate/api_router.py @@ -74,26 +74,7 @@ async def init_generate_state( # Render endpoints are always backed by OpenAIServingRender so that # /v1/chat/completions/render and /v1/completions/render work on both - # generate-mode and render-only servers. - # It is created first so that OpenAIServingChat and OpenAIServingCompletion - # can delegate their preprocessing logic to it. - from vllm.entrypoints.serve.render.serving import OpenAIServingRender - - state.openai_serving_render = OpenAIServingRender( - model_config=engine_client.model_config, - renderer=engine_client.renderer, - io_processor=engine_client.io_processor, - model_registry=state.openai_serving_models.registry, - request_logger=request_logger, - chat_template=resolved_chat_template, - chat_template_content_format=args.chat_template_content_format, - trust_request_chat_template=args.trust_request_chat_template, - enable_auto_tools=args.enable_auto_tool_choice, - exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none, - tool_parser=args.tool_call_parser, - default_chat_template_kwargs=args.default_chat_template_kwargs, - log_error_stack=args.log_error_stack, - ) + # generate-mode and render-only servers. Created in init_app_state. state.openai_serving_responses = ( OpenAIServingResponses( diff --git a/vllm/entrypoints/serve/render/serving.py b/vllm/entrypoints/serve/render/serving.py index 9dc410c9e..c54852fca 100644 --- a/vllm/entrypoints/serve/render/serving.py +++ b/vllm/entrypoints/serve/render/serving.py @@ -226,7 +226,7 @@ class OpenAIServingRender: if not self.use_harmony: # Common case. - error_check_ret = self._validate_chat_template( + error_check_ret = self.validate_chat_template( request_chat_template=request.chat_template, chat_template_kwargs=request.chat_template_kwargs, trust_request_chat_template=self.trust_request_chat_template, @@ -234,7 +234,7 @@ class OpenAIServingRender: if error_check_ret is not None: return error_check_ret - conversation, engine_prompts = await self._preprocess_chat( + conversation, engine_prompts = await self.preprocess_chat( request, request.messages, default_template=self.chat_template, @@ -328,7 +328,7 @@ class OpenAIServingRender: "prompt_logprobs is not compatible with prompt embeds." ) - engine_prompts = await self._preprocess_completion( + engine_prompts = await self.preprocess_completion( request, prompt_input=request.prompt, prompt_embeds=request.prompt_embeds, @@ -426,7 +426,7 @@ class OpenAIServingRender: ) -> ErrorResponse | None: return await self.model_registry.check_model(request.model) - def _validate_chat_template( + def validate_chat_template( self, request_chat_template: str | None, chat_template_kwargs: dict[str, Any] | None, @@ -447,7 +447,7 @@ class OpenAIServingRender: ) return None - async def _preprocess_completion( + async def preprocess_completion( self, request: Any, prompt_input: str | list[str] | list[int] | list[list[int]] | None, @@ -490,7 +490,7 @@ class OpenAIServingRender: }, ) - async def _preprocess_chat( + async def preprocess_chat( self, request: Any, messages: list[Any], diff --git a/vllm/entrypoints/serve/tokenize/serving.py b/vllm/entrypoints/serve/tokenize/serving.py index 233674aff..d68651da8 100644 --- a/vllm/entrypoints/serve/tokenize/serving.py +++ b/vllm/entrypoints/serve/tokenize/serving.py @@ -11,6 +11,7 @@ from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.engine.protocol import ErrorResponse from vllm.entrypoints.openai.engine.serving import OpenAIServing from vllm.entrypoints.openai.models.serving import OpenAIServingModels +from vllm.entrypoints.serve.render.serving import OpenAIServingRender from vllm.entrypoints.serve.tokenize.protocol import ( DetokenizeRequest, DetokenizeResponse, @@ -31,6 +32,7 @@ class OpenAIServingTokenization(OpenAIServing): self, engine_client: EngineClient, models: OpenAIServingModels, + openai_serving_render: OpenAIServingRender, *, request_logger: RequestLogger | None, chat_template: str | None, @@ -44,6 +46,7 @@ class OpenAIServingTokenization(OpenAIServing): request_logger=request_logger, ) + self.openai_serving_render = openai_serving_render self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format self.default_chat_template_kwargs = default_chat_template_kwargs or {} @@ -68,7 +71,7 @@ class OpenAIServingTokenization(OpenAIServing): if request.tools is None else [tool.model_dump() for tool in request.tools] ) - error_check_ret = self._validate_chat_template( + error_check_ret = self.openai_serving_render.validate_chat_template( request_chat_template=request.chat_template, chat_template_kwargs=request.chat_template_kwargs, trust_request_chat_template=self.trust_request_chat_template, @@ -76,7 +79,7 @@ class OpenAIServingTokenization(OpenAIServing): if error_check_ret is not None: return error_check_ret - _, engine_prompts = await self._preprocess_chat( + _, engine_prompts = await self.openai_serving_render.preprocess_chat( request, request.messages, default_template=self.chat_template, @@ -85,7 +88,7 @@ class OpenAIServingTokenization(OpenAIServing): tool_dicts=tool_dicts, ) else: - engine_prompts = await self._preprocess_completion( + engine_prompts = await self.openai_serving_render.preprocess_completion( request, prompt_input=request.prompt, prompt_embeds=None, -- GitLab From 0fb142a454757ec2055000ca8a2607e797af3e71 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 17 Mar 2026 19:59:35 +0800 Subject: [PATCH 028/223] [perf][connector] optimize build_connector_meta when host buffer transfer is not used (#37165) Signed-off-by: youkaichao --- .../kv_connector/v1/nixl_connector.py | 34 ++++++++++++------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 7651bf988..9001e3181 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -815,20 +815,12 @@ class NixlConnectorScheduler: # Only trigger 1 KV transfer per request. params["do_remote_prefill"] = False - def build_connector_meta( + def _build_save_meta( self, + meta: NixlConnectorMetadata, scheduler_output: SchedulerOutput, - ) -> KVConnectorMetadata: - meta = NixlConnectorMetadata() - - # Loop through scheduled reqs and convert to ReqMeta. - for req_id, (req, block_ids) in self._reqs_need_recv.items(): - assert req.kv_transfer_params is not None - meta.add_new_req_to_recv( - request_id=req_id, - local_block_ids=block_ids, - kv_transfer_params=req.kv_transfer_params, - ) + ) -> None: + # only called when use_host_buffer is True to build the save metadata # NOTE: For the prefill side, there might be a chance that an early added # request is a chunked prefill, so we need to check if new blocks are added @@ -858,6 +850,24 @@ class NixlConnectorScheduler: # Therefore, only pop if `not is_partial`. self._reqs_need_save.pop(req_id) + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + meta = NixlConnectorMetadata() + + # Loop through scheduled reqs and convert to ReqMeta. + for req_id, (req, block_ids) in self._reqs_need_recv.items(): + assert req.kv_transfer_params is not None + meta.add_new_req_to_recv( + request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params, + ) + + if self.use_host_buffer: + self._build_save_meta(meta, scheduler_output) + meta.reqs_to_send = self._reqs_need_send meta.reqs_in_batch = self._reqs_in_batch meta.reqs_not_processed = self._reqs_not_processed -- GitLab From 293f036e6d83ba05236d948e9800bc6d4d58a727 Mon Sep 17 00:00:00 2001 From: Viacheslav Date: Tue, 17 Mar 2026 15:03:20 +0300 Subject: [PATCH 029/223] Add gigachat 3.1 tool parser + fix gigachat3 tool parser (#36664) Signed-off-by: Viacheslav Barinov --- .../test_gigachat3_tool_parser.py | 219 +++++++++++++++--- vllm/tool_parsers/gigachat3_tool_parser.py | 143 +++++++----- 2 files changed, 274 insertions(+), 88 deletions(-) diff --git a/tests/entrypoints/openai/tool_parsers/test_gigachat3_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_gigachat3_tool_parser.py index 99ab1e497..f29f79f72 100644 --- a/tests/entrypoints/openai/tool_parsers/test_gigachat3_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_gigachat3_tool_parser.py @@ -13,6 +13,13 @@ from vllm.entrypoints.openai.engine.protocol import FunctionCall from vllm.tokenizers import TokenizerLike from vllm.tool_parsers import ToolParser, ToolParserManager +MSG_SEP_TOKEN = "<|message_sep|>\n\n" +ROLE_SEP_TOKEN = "<|role_sep|>\n" +EOS_TOKEN = "" +TOOL_HEADER_GIGACHAT3 = f"function call{ROLE_SEP_TOKEN}" +TOOL_HEADER_GIGACHAT31 = "<|function_call|>" + + SIMPLE_ARGS_DICT = { "action": "create", "id": "preferences", @@ -24,7 +31,10 @@ SIMPLE_FUNCTION_JSON = json.dumps( }, ensure_ascii=False, ) -SIMPLE_FUNCTION_OUTPUT = "function call" + SIMPLE_FUNCTION_JSON +SIMPLE_FUNCTION_OUTPUT_GIGACHAT3 = ( + f"{MSG_SEP_TOKEN}{TOOL_HEADER_GIGACHAT3}{SIMPLE_FUNCTION_JSON}" +) +SIMPLE_FUNCTION_OUTPUT_GIGACHAT31 = f"{TOOL_HEADER_GIGACHAT31}{SIMPLE_FUNCTION_JSON}" SIMPLE_FUNCTION_CALL = FunctionCall( name="manage_user_memory", arguments=json.dumps(SIMPLE_ARGS_DICT, ensure_ascii=False), @@ -38,7 +48,12 @@ PARAMETERLESS_FUNCTION_JSON = json.dumps( }, ensure_ascii=False, ) -PARAMETERLESS_FUNCTION_OUTPUT = "function call" + PARAMETERLESS_FUNCTION_JSON +PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT3 = ( + f"{MSG_SEP_TOKEN}{TOOL_HEADER_GIGACHAT3}{PARAMETERLESS_FUNCTION_JSON}" +) +PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT31 = ( + f"{TOOL_HEADER_GIGACHAT31}{PARAMETERLESS_FUNCTION_JSON}" +) PARAMETERLESS_FUNCTION_CALL = FunctionCall( name="manage_user_memory", arguments=json.dumps({}, ensure_ascii=False), @@ -62,17 +77,38 @@ COMPLEX_FUNCTION_JSON = json.dumps( }, ensure_ascii=False, ) -COMPLEX_FUNCTION_OUTPUT = "function call" + COMPLEX_FUNCTION_JSON +COMPLEX_FUNCTION_OUTPUT_GIGACHAT3 = ( + f"{MSG_SEP_TOKEN}{TOOL_HEADER_GIGACHAT3}{COMPLEX_FUNCTION_JSON}" +) +COMPLEX_FUNCTION_OUTPUT_GIGACHAT31 = f"{TOOL_HEADER_GIGACHAT31}{COMPLEX_FUNCTION_JSON}" COMPLEX_FUNCTION_CALL = FunctionCall( name="manage_user_memory", arguments=json.dumps(COMPLEX_ARGS_DICT, ensure_ascii=False), ) +CONTENT_TEXT = "I'll check that for you." +MIXED_OUTPUT_GIGACHAT3 = f"{CONTENT_TEXT}{SIMPLE_FUNCTION_OUTPUT_GIGACHAT3}" +MIXED_OUTPUT_GIGACHAT31 = f"{CONTENT_TEXT}{SIMPLE_FUNCTION_OUTPUT_GIGACHAT31}" + + +@pytest.fixture(name="gigachat_tokenizer") +def fixture_gigachat_tokenizer(default_tokenizer: TokenizerLike): + default_tokenizer.add_tokens( + [ + MSG_SEP_TOKEN, + ROLE_SEP_TOKEN, + TOOL_HEADER_GIGACHAT31, + EOS_TOKEN, + ] + ) + return default_tokenizer + + @pytest.mark.parametrize("streaming", [True, False]) -def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike): +def test_no_tool_call(streaming: bool, gigachat_tokenizer: TokenizerLike): tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")( - default_tokenizer + gigachat_tokenizer ) model_output = "How can I help you today?" content, tool_calls = run_tool_extraction( @@ -85,45 +121,143 @@ def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike): TEST_CASES = [ pytest.param( True, - SIMPLE_FUNCTION_OUTPUT, + SIMPLE_FUNCTION_OUTPUT_GIGACHAT3, + [SIMPLE_FUNCTION_CALL], + None, + id="simple_streaming_gigachat3", + ), + pytest.param( + False, + SIMPLE_FUNCTION_OUTPUT_GIGACHAT3, + [SIMPLE_FUNCTION_CALL], + None, + id="simple_nonstreaming_gigachat3", + ), + pytest.param( + True, + PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT3, + [PARAMETERLESS_FUNCTION_CALL], + None, + id="parameterless_streaming_gigachat3", + ), + pytest.param( + False, + PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT3, + [PARAMETERLESS_FUNCTION_CALL], + None, + id="parameterless_nonstreaming_gigachat3", + ), + pytest.param( + True, + COMPLEX_FUNCTION_OUTPUT_GIGACHAT3, + [COMPLEX_FUNCTION_CALL], + None, + id="complex_streaming_gigachat3", + ), + pytest.param( + False, + COMPLEX_FUNCTION_OUTPUT_GIGACHAT3, + [COMPLEX_FUNCTION_CALL], + None, + id="complex_nonstreaming_gigachat3", + ), + pytest.param( + True, + MIXED_OUTPUT_GIGACHAT3, + [SIMPLE_FUNCTION_CALL], + CONTENT_TEXT, + id="mixed_content_streaming_gigachat3", + ), + pytest.param( + False, + MIXED_OUTPUT_GIGACHAT3, + [SIMPLE_FUNCTION_CALL], + CONTENT_TEXT, + id="mixed_content_nonstreaming_gigachat3", + ), + pytest.param( + True, + MIXED_OUTPUT_GIGACHAT3 + EOS_TOKEN, + [SIMPLE_FUNCTION_CALL], + CONTENT_TEXT, + id="mixed_content_streaming_with_eos_gigachat3", + ), + pytest.param( + False, + MIXED_OUTPUT_GIGACHAT3 + EOS_TOKEN, + [SIMPLE_FUNCTION_CALL], + CONTENT_TEXT, + id="mixed_content_nonstreaming_with_eos_gigachat3", + ), + pytest.param( + True, + SIMPLE_FUNCTION_OUTPUT_GIGACHAT31, [SIMPLE_FUNCTION_CALL], None, - id="simple_streaming", + id="simple_streaming_gigachat31", ), pytest.param( False, - SIMPLE_FUNCTION_OUTPUT, + SIMPLE_FUNCTION_OUTPUT_GIGACHAT31, [SIMPLE_FUNCTION_CALL], None, - id="simple_nonstreaming", + id="simple_nonstreaming_gigachat31", ), pytest.param( True, - PARAMETERLESS_FUNCTION_OUTPUT, + PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT31, [PARAMETERLESS_FUNCTION_CALL], None, - id="parameterless_streaming", + id="parameterless_streaming_gigachat31", ), pytest.param( False, - PARAMETERLESS_FUNCTION_OUTPUT, + PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT31, [PARAMETERLESS_FUNCTION_CALL], None, - id="parameterless_nonstreaming", + id="parameterless_nonstreaming_gigachat31", ), pytest.param( True, - COMPLEX_FUNCTION_OUTPUT, + COMPLEX_FUNCTION_OUTPUT_GIGACHAT31, [COMPLEX_FUNCTION_CALL], None, - id="complex_streaming", + id="complex_streaming_gigachat31", ), pytest.param( False, - COMPLEX_FUNCTION_OUTPUT, + COMPLEX_FUNCTION_OUTPUT_GIGACHAT31, [COMPLEX_FUNCTION_CALL], None, - id="complex_nonstreaming", + id="complex_nonstreaming_gigachat31", + ), + pytest.param( + True, + MIXED_OUTPUT_GIGACHAT31, + [SIMPLE_FUNCTION_CALL], + CONTENT_TEXT, + id="mixed_content_streaming_gigachat31", + ), + pytest.param( + False, + MIXED_OUTPUT_GIGACHAT31, + [SIMPLE_FUNCTION_CALL], + CONTENT_TEXT, + id="mixed_content_nonstreaming_gigachat31", + ), + pytest.param( + True, + MIXED_OUTPUT_GIGACHAT31 + EOS_TOKEN, + [SIMPLE_FUNCTION_CALL], + CONTENT_TEXT, + id="mixed_content_streaming_with_eos_gigachat31", + ), + pytest.param( + False, + MIXED_OUTPUT_GIGACHAT31 + EOS_TOKEN, + [SIMPLE_FUNCTION_CALL], + CONTENT_TEXT, + id="mixed_content_nonstreaming_with_eos_gigachat31", ), ] @@ -136,14 +270,16 @@ def test_tool_call( model_output: str, expected_tool_calls: list[FunctionCall], expected_content: str | None, - default_tokenizer: TokenizerLike, + gigachat_tokenizer: TokenizerLike, ): tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")( - default_tokenizer + gigachat_tokenizer ) content, tool_calls = run_tool_extraction( tool_parser, model_output, streaming=streaming ) + if content == "": + content = None assert content == expected_content assert len(tool_calls) == len(expected_tool_calls) for actual, expected in zip(tool_calls, expected_tool_calls): @@ -154,15 +290,46 @@ def test_tool_call( assert actual_args == expected_args -def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike): +@pytest.mark.parametrize( + "model_output_deltas", + [ + pytest.param( + [ + CONTENT_TEXT[:3], + CONTENT_TEXT[3:5], + CONTENT_TEXT[5:], + MSG_SEP_TOKEN, + TOOL_HEADER_GIGACHAT3, + COMPLEX_FUNCTION_JSON[:40], + COMPLEX_FUNCTION_JSON[40:-1], + COMPLEX_FUNCTION_JSON[-1], + ], + id="gigachat3", + ), + pytest.param( + [ + CONTENT_TEXT[:3], + CONTENT_TEXT[3:5], + CONTENT_TEXT[5:], + TOOL_HEADER_GIGACHAT31, + COMPLEX_FUNCTION_JSON[:40], + COMPLEX_FUNCTION_JSON[40:-1], + COMPLEX_FUNCTION_JSON[-1], + ], + id="gigachat31", + ), + ], +) +def test_streaming_tool_call_with_large_steps( + model_output_deltas: list[str], + gigachat_tokenizer: TokenizerLike, +): + """ + Test that the closing braces are streamed correctly. + """ tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")( - default_tokenizer + gigachat_tokenizer ) - model_output_deltas = [ - "function call", - COMPLEX_FUNCTION_JSON[:40], - COMPLEX_FUNCTION_JSON[40:], - ] reconstructor = run_tool_extraction_streaming( tool_parser, model_output_deltas, diff --git a/vllm/tool_parsers/gigachat3_tool_parser.py b/vllm/tool_parsers/gigachat3_tool_parser.py index 02cdad9ed..90928f9ae 100644 --- a/vllm/tool_parsers/gigachat3_tool_parser.py +++ b/vllm/tool_parsers/gigachat3_tool_parser.py @@ -25,7 +25,12 @@ from vllm.tool_parsers.abstract_tool_parser import ToolParser logger = init_logger(__name__) REGEX_FUNCTION_CALL = re.compile( - r"function call(?:<\|role_sep\|>\n)?(\{.*)", + r"(?:function call<\|role_sep\|>\n|<\|function_call\|>)(.*)", + re.DOTALL, +) + +REGEX_CONTENT_PATTERN = re.compile( + r"^(.*?)(?:<\|message_sep\|>|<\|function_call\|>)", re.DOTALL, ) @@ -47,57 +52,67 @@ class GigaChat3ToolParser(ToolParser): self.tool_name_sent: bool = False self.tool_id: str | None = None self.prev_tool_call_arr: list[dict] = [] - self.content_buffer: str = "" - self.trigger_start = "function call{" + self.end_content: bool = False + self.streamed_args_for_tool: list[str] = [] + + def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: + request = super().adjust_request(request) + if request.tools and request.tool_choice != "none": + request.skip_special_tokens = False + return request def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: - match = REGEX_FUNCTION_CALL.search(model_output) - if not match: - return ExtractedToolCallInformation( - tools_called=False, - tool_calls=[], - content=model_output, - ) - json_candidate = match.group(1).strip() - try: - data = json.loads(json_candidate) - except json.JSONDecodeError: - return ExtractedToolCallInformation( - tools_called=False, - tool_calls=[], - content=model_output, - ) - if not (isinstance(data, dict) and "name" in data and "arguments" in data): + function_call = None + content = None + if model_output.rstrip().endswith(""): + model_output = model_output[: model_output.rfind("")] + m_func = REGEX_FUNCTION_CALL.search(model_output) + if m_func: + try: + function_call = json.loads(m_func.group(1), strict=False) + if ( + isinstance(function_call, dict) + and "name" in function_call + and "arguments" in function_call + ): + if not isinstance(function_call["arguments"], dict): + function_call = None + else: + function_call = None + except json.JSONDecodeError: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output, + ) + m_content = REGEX_CONTENT_PATTERN.search(model_output) + content = m_content.group(1) if m_content else model_output + if not function_call: return ExtractedToolCallInformation( tools_called=False, tool_calls=[], - content=model_output, + content=content if content else None, ) - name = data["name"] - args = data["arguments"] + name = function_call["name"] + args = function_call["arguments"] if not isinstance(args, str): - args = json.dumps(args, ensure_ascii=False) - - tool_calls = [ - ToolCall( - type="function", - function=FunctionCall( - name=name, - arguments=args, - ), - ) - ] - prefix = model_output[: match.start()] - content = prefix.rstrip() if prefix and prefix.strip() else None - + args = json.dumps(function_call["arguments"], ensure_ascii=False) return ExtractedToolCallInformation( tools_called=True, - tool_calls=tool_calls, - content=content, + tool_calls=[ + ToolCall( + type="function", + function=FunctionCall( + name=name, + arguments=args, + ), + ) + ], + content=content if content else None, ) def extract_tool_calls_streaming( @@ -110,39 +125,37 @@ class GigaChat3ToolParser(ToolParser): delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> DeltaMessage | None: + content = None func_name = None cur_args = None + m_func = REGEX_FUNCTION_CALL.search(current_text) if not self.tool_started: - match = REGEX_FUNCTION_CALL.search(current_text) - if match: - self.tool_started = True - self.content_buffer = "" + m_content = REGEX_CONTENT_PATTERN.search(delta_text) + if m_content: + content = m_content.group(1) + self.end_content = True else: - self.content_buffer += delta_text - clean_buffer = self.content_buffer.lstrip() - is_prefix = self.trigger_start.startswith(clean_buffer) - starts_with_trigger = clean_buffer.startswith(self.trigger_start) - if is_prefix or starts_with_trigger: - return None - else: - flush_text = self.content_buffer - self.content_buffer = "" - return DeltaMessage(content=flush_text) - - match = REGEX_FUNCTION_CALL.search(current_text) - if not match: + if not self.end_content: + content = delta_text + if m_func: + self.tool_started = True + if content: + return DeltaMessage(content=content) + if not m_func: return None - json_tail = match.group(1).strip() + json_tail = m_func.group(1).strip() name_match = NAME_REGEX.search(json_tail) if name_match: func_name = name_match.group(1) args_match = ARGS_REGEX.search(json_tail) if args_match: cur_args = args_match.group(1).strip() + if cur_args.endswith(""): + cur_args = cur_args[: -len("")] if cur_args.endswith("}"): # last '}' end of json try: candidate = cur_args[:-1].strip() - json.loads(candidate) + json.loads(candidate, strict=False) cur_args = candidate except json.JSONDecodeError: pass @@ -165,11 +178,10 @@ class GigaChat3ToolParser(ToolParser): ).model_dump(exclude_none=True), ) ], - content=None, ) if cur_args is None: return None - prev_args = self.prev_tool_call_arr[0].get("arguments", "") + prev_args = self.prev_tool_call_arr[0].get("arguments_str", "") if not prev_args: delta_args = cur_args elif cur_args.startswith(prev_args): @@ -178,7 +190,15 @@ class GigaChat3ToolParser(ToolParser): return None if not delta_args: return None - self.prev_tool_call_arr[0]["arguments"] = cur_args + self.prev_tool_call_arr[0]["arguments_str"] = cur_args + try: + args_dict = json.loads(cur_args, strict=False) + self.prev_tool_call_arr[0]["arguments"] = args_dict + except json.JSONDecodeError: + self.prev_tool_call_arr[0]["arguments"] = {} + if len(self.streamed_args_for_tool) <= 0: + self.streamed_args_for_tool.append("") + self.streamed_args_for_tool[0] = cur_args return DeltaMessage( tool_calls=[ DeltaToolCall( @@ -188,5 +208,4 @@ class GigaChat3ToolParser(ToolParser): ).model_dump(exclude_none=True), ) ], - content=None, ) -- GitLab From 2660b9289c1f9e26ae65a247ceac2b9add52fa90 Mon Sep 17 00:00:00 2001 From: sfbemerk Date: Tue, 17 Mar 2026 14:22:09 +0100 Subject: [PATCH 030/223] Bugfix for offloading+prefetch for GLM-4.7-FP8 (#37178) Signed-off-by: Benjamin Merkel Co-authored-by: Benjamin Merkel --- vllm/model_executor/offloader/prefetch.py | 43 ++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/offloader/prefetch.py b/vllm/model_executor/offloader/prefetch.py index b43cb8b7d..5bdde8c3a 100644 --- a/vllm/model_executor/offloader/prefetch.py +++ b/vllm/model_executor/offloader/prefetch.py @@ -431,10 +431,32 @@ class _ModuleOffloader: Called after process_weights_after_loading to ensure _cpu_storage contains the final processed weights, not stale pre-loading data. + + Parameters whose underlying nn.Parameter was deleted by + process_weights_after_loading (e.g. transient KV-cache scale params) + are pruned from self._param_offloaders so they do not participate in + buffer-pool allocation or prefetching. """ for param_offloader in self._param_offloaders.values(): param_offloader.sync_cpu_storage() + # Remove offloaders whose parameter was deleted during + # process_weights_after_loading (e.g. k_scale / v_scale). + deleted = [ + name + for name, offloader in self._param_offloaders.items() + if getattr(offloader, "_param_deleted", False) + ] + if deleted: + logger.debug( + "Pruning %d transient offloaded param(s) that were deleted " + "by process_weights_after_loading: %s", + len(deleted), + deleted, + ) + for name in deleted: + del self._param_offloaders[name] + def get_param_infos(self) -> list[ParamInfo]: """Get parameter metadata for buffer pool allocation. @@ -590,6 +612,11 @@ class _CpuParamOffloader(_BaseParamOffloader): super().__init__(module, param_name) self._cpu_storage: torch.Tensor | None = None self._gpu_buffer: torch.Tensor | None = None # Store reference to GPU buffer + # Set to True if the underlying nn.Parameter was deleted by + # process_weights_after_loading (e.g. transient KV-cache scale params + # such as k_scale/v_scale created by BaseKVCacheMethod.create_weights + # and deleted after copying into permanent _k_scale buffers). + self._param_deleted: bool = False # Offload to CPU immediately to free GPU memory during model loading self._offload_to_cpu_internal() @@ -696,8 +723,22 @@ class _CpuParamOffloader(_BaseParamOffloader): 1. process_weights_after_loading may transform weights (quantization) 2. device_loading_context creates NEW CPU tensors when moving back 3. Our old _cpu_storage would have pre-processed or stale data + + If the parameter no longer exists on the module (e.g. transient + KV-cache scale parameters such as k_scale/v_scale that are created + by BaseKVCacheMethod.create_weights() and then deleted by + process_weights_after_loading() after copying their values into + permanent _k_scale buffers), the offloader marks itself as deleted + and skips the sync. The caller (_ModuleOffloader.sync_cpu_storage) + is responsible for removing these stale entries. """ - self._update_cpu_storage_from_param() + try: + self._update_cpu_storage_from_param() + except AttributeError: + # The parameter was deleted by process_weights_after_loading. + # Drop the now-stale CPU storage so this offloader can be pruned. + self._param_deleted = True + self._cpu_storage = None def post_init(self): """No-op: offloading done in offload_to_cpu/assign_static_buffer.""" -- GitLab From f34032433573cda9bc495cf02e783c8b0d99d20d Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 17 Mar 2026 21:50:56 +0800 Subject: [PATCH 031/223] [1/2] Move InternVL-based processors (#37260) Signed-off-by: DarkLight1337 --- .../multimodal/processing/test_h2ovl.py | 2 +- .../multimodal/processing/test_internvl.py | 2 +- .../multimodal/processing/test_nemotron_vl.py | 2 +- vllm/model_executor/models/eagle2_5_vl.py | 82 +- vllm/model_executor/models/h2ovl.py | 375 +----- vllm/model_executor/models/internvl.py | 585 +--------- .../model_executor/models/nano_nemotron_vl.py | 1033 +---------------- vllm/model_executor/models/nemotron_parse.py | 233 +--- vllm/model_executor/models/nemotron_vl.py | 408 +------ vllm/model_executor/models/nvlm_d.py | 34 +- vllm/model_executor/models/skyworkr1v.py | 379 +----- .../transformers_utils/processors/__init__.py | 18 + .../processors/eagle2_5_vl.py | 85 ++ vllm/transformers_utils/processors/h2ovl.py | 390 +++++++ .../transformers_utils/processors/internvl.py | 603 ++++++++++ .../processors/nano_nemotron_vl.py | 1032 ++++++++++++++++ .../processors/nemotron_parse.py | 245 ++++ .../processors/nemotron_vl.py | 410 +++++++ vllm/transformers_utils/processors/nvlm_d.py | 44 + .../processors/skyworkr1v.py | 389 +++++++ 20 files changed, 3252 insertions(+), 3099 deletions(-) create mode 100644 vllm/transformers_utils/processors/eagle2_5_vl.py create mode 100644 vllm/transformers_utils/processors/h2ovl.py create mode 100644 vllm/transformers_utils/processors/internvl.py create mode 100644 vllm/transformers_utils/processors/nano_nemotron_vl.py create mode 100644 vllm/transformers_utils/processors/nemotron_parse.py create mode 100644 vllm/transformers_utils/processors/nemotron_vl.py create mode 100644 vllm/transformers_utils/processors/nvlm_d.py create mode 100644 vllm/transformers_utils/processors/skyworkr1v.py diff --git a/tests/models/multimodal/processing/test_h2ovl.py b/tests/models/multimodal/processing/test_h2ovl.py index 19e4cb896..3ba256f3c 100644 --- a/tests/models/multimodal/processing/test_h2ovl.py +++ b/tests/models/multimodal/processing/test_h2ovl.py @@ -23,7 +23,7 @@ def _get_expected_num_patches( min_num: int, max_num: int, ): - from vllm.model_executor.models.h2ovl import ( + from vllm.transformers_utils.processors.h2ovl import ( calculate_h2ovl_targets, get_h2ovl_target_ratios, ) diff --git a/tests/models/multimodal/processing/test_internvl.py b/tests/models/multimodal/processing/test_internvl.py index 437c7b682..7954dd6b5 100644 --- a/tests/models/multimodal/processing/test_internvl.py +++ b/tests/models/multimodal/processing/test_internvl.py @@ -23,7 +23,7 @@ def _get_expected_num_patches( min_num: int, max_num: int, ): - from vllm.model_executor.models.internvl import ( + from vllm.transformers_utils.processors.internvl import ( calculate_internvl_targets, get_internvl_target_ratios, ) diff --git a/tests/models/multimodal/processing/test_nemotron_vl.py b/tests/models/multimodal/processing/test_nemotron_vl.py index d9e635dde..be5c222fd 100644 --- a/tests/models/multimodal/processing/test_nemotron_vl.py +++ b/tests/models/multimodal/processing/test_nemotron_vl.py @@ -23,7 +23,7 @@ def _get_expected_num_patches( min_num: int, max_num: int, ): - from vllm.model_executor.models.nemotron_vl import ( + from vllm.transformers_utils.processors.nemotron_vl import ( calculate_nemotron_vl_targets, get_nemotron_vl_target_ratios, ) diff --git a/vllm/model_executor/models/eagle2_5_vl.py b/vllm/model_executor/models/eagle2_5_vl.py index 718e8bb54..3e6182db5 100644 --- a/vllm/model_executor/models/eagle2_5_vl.py +++ b/vllm/model_executor/models/eagle2_5_vl.py @@ -15,9 +15,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.siglip import SiglipVisionModel from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.processing import PromptUpdateDetails from vllm.sequence import IntermediateTensors -from vllm.tokenizers import TokenizerLike +from vllm.transformers_utils.processors.eagle2_5_vl import Eagle2_5_VLProcessor from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import ( @@ -27,13 +26,9 @@ from .interfaces import ( SupportsPP, ) from .internvl import ( - IMG_CONTEXT, - IMG_END, - IMG_START, BaseInternVLDummyInputsBuilder, BaseInternVLMultiModalProcessor, BaseInternVLProcessingInfo, - BaseInternVLProcessor, ) from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix @@ -70,81 +65,6 @@ Eagle2_5_VLImageInputs: TypeAlias = ( ) -class Eagle2_5_VLProcessor(BaseInternVLProcessor): - """ - Custom processor for Eagle2.5-VL model. - Extends BaseInternVLProcessor with Eagle-specific token handling. - """ - - def __init__( - self, - config: PretrainedConfig, - tokenizer: TokenizerLike, - *, - min_dynamic_patch: int | None = None, - max_dynamic_patch: int | None = None, - dynamic_image_size: bool | None = None, - ) -> None: - # Skip super().__init__() to avoid config manipulation - # Directly initialize all required attributes - self.config = config - self.tokenizer = tokenizer - - # Image size with force_image_size override - image_size: int = config.vision_config.image_size - if hasattr(config, "force_image_size") and config.force_image_size: - image_size = config.force_image_size - - patch_size: int = config.vision_config.patch_size - downsample_ratio: float = getattr(config, "downsample_ratio", 0.5) - - # Compute num_image_token - self.num_image_token = int( - (image_size // patch_size) ** 2 * (downsample_ratio**2) - ) - self.image_size = image_size - - # Dynamic patch settings with defaults - self.min_dynamic_patch = ( - min_dynamic_patch - if min_dynamic_patch is not None - else getattr(config, "min_dynamic_patch", 1) - ) - self.max_dynamic_patch = ( - max_dynamic_patch - if max_dynamic_patch is not None - else getattr(config, "max_dynamic_patch", 12) - ) - self.dynamic_image_size = ( - dynamic_image_size - if dynamic_image_size is not None - else getattr(config, "dynamic_image_size", True) - ) - self.use_thumbnail: bool = getattr(config, "use_thumbnail", True) - - @property - def image_token_id(self) -> int: - """Get the image token ID from config or tokenizer.""" - if hasattr(self.config, "image_token_index"): - return self.config.image_token_index - # Fallback to tokenizer vocab - use (ID: 151667) - vocab = self.tokenizer.get_vocab() - if IMG_CONTEXT in vocab: - return vocab[IMG_CONTEXT] - raise ValueError(f"Cannot find image token '{IMG_CONTEXT}' in vocabulary") - - def get_image_repl( - self, - feature_size: int, - num_patches: int | None, - ) -> PromptUpdateDetails[str]: - """Get image replacement string for prompt.""" - repl_features = IMG_CONTEXT * feature_size - repl_full = IMG_START + repl_features + IMG_END - - return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT) - - class Eagle2_5_VLProcessingInfo(BaseInternVLProcessingInfo): """Processing info for Eagle2.5-VL model.""" diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py index 0b61bd5a2..3b01985c4 100644 --- a/vllm/model_executor/models/h2ovl.py +++ b/vllm/model_executor/models/h2ovl.py @@ -11,7 +11,6 @@ from collections.abc import Mapping, Sequence import torch -from PIL import Image from transformers import PretrainedConfig from vllm.model_executor.layers.quantization import QuantizationConfig @@ -27,391 +26,19 @@ from vllm.multimodal.processing.processor import ( ProcessorInputs, PromptReplacement, PromptUpdate, - PromptUpdateDetails, TimingContext, ) -from vllm.tokenizers import TokenizerLike +from vllm.transformers_utils.processors.h2ovl import H2OVLProcessor from .intern_vit import InternVisionModel from .internvl import ( - IMG_CONTEXT, - IMG_END, - IMG_START, BaseInternVLDummyInputsBuilder, BaseInternVLMultiModalProcessor, BaseInternVLProcessingInfo, - BaseInternVLProcessor, InternVLChatModel, - build_transform, - find_closest_aspect_ratio, - get_internvl_target_ratios, ) -def resolve_h2ovl_min_max_num( - *, - min_dynamic_patch: int, - max_dynamic_patch: int, - dynamic_image_size: bool, - use_thumbnail: bool, -) -> tuple[int, int]: - min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1 - max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1 - - if use_thumbnail and max_dynamic_patch != 1: - max_dynamic_patch += 1 - - return min_dynamic_patch, max_dynamic_patch - - -def get_h2ovl_target_ratios( - min_num: int, - max_num: int, - *, - prior_aspect_ratio: tuple[int, int] | None, -) -> list[tuple[int, int]]: - target_ratios = get_internvl_target_ratios(min_num, max_num) - - # if prior_aspect_ratio is provided, filter the target ratios - if prior_aspect_ratio is not None: - target_ratios = [ - ratio - for ratio in target_ratios - if prior_aspect_ratio[0] % ratio[0] != 0 - and prior_aspect_ratio[1] % ratio[1] != 0 - ] - - return target_ratios - - -# modified to include blocks generated in second pass -def calculate_h2ovl_targets( - *, - orig_width: int, - orig_height: int, - target_ratios: list[tuple[int, int]], - image_size: int, - use_thumbnail: bool, -) -> tuple[int, int, int, tuple[int, int]]: - aspect_ratio = orig_width / orig_height - - # find the closest aspect ratio to the target - target_aspect_ratio = find_closest_aspect_ratio( - aspect_ratio, - target_ratios, - width=orig_width, - height=orig_height, - image_size=image_size, - ) - - # calculate the target width and height - target_width = image_size * target_aspect_ratio[0] - target_height = image_size * target_aspect_ratio[1] - blocks = target_aspect_ratio[0] * target_aspect_ratio[1] - - # add thumbnail image if num_blocks != 1 - if use_thumbnail and blocks != 1: - blocks += 1 - - return blocks, target_width, target_height, target_aspect_ratio - - -# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B -# refactored to handle prior_aspect_ratio -def dynamic_preprocess_h2ovl( - image: Image.Image, - *, - target_ratios: list[tuple[int, int]], - image_size: int, - use_thumbnail: bool, -) -> tuple[list[Image.Image], tuple[int, int]]: - orig_width, orig_height = image.size - - # calculate the number of blocks without thumbnail - ( - blocks, - target_width, - target_height, - target_aspect_ratio, - ) = calculate_h2ovl_targets( - orig_width=orig_width, - orig_height=orig_height, - target_ratios=target_ratios, - image_size=image_size, - use_thumbnail=False, - ) - - # resize the image - resized_img = image.resize((target_width, target_height)) - processed_images = [] - for i in range(blocks): - box = ( - (i % (target_width // image_size)) * image_size, - (i // (target_width // image_size)) * image_size, - ((i % (target_width // image_size)) + 1) * image_size, - ((i // (target_width // image_size)) + 1) * image_size, - ) - # split the image - split_img = resized_img.crop(box) - processed_images.append(split_img) - - assert len(processed_images) == blocks - - if use_thumbnail and len(processed_images) != 1: - thumbnail_img = image.resize((image_size, image_size)) - processed_images.append(thumbnail_img) - - return processed_images, target_aspect_ratio - - -def _preprocess_image( - image: Image.Image, - *, - input_size: int, - min_num: int, - max_num: int, - use_thumbnail: bool, - prior_aspect_ratio: tuple[int, int] | None, -) -> tuple[torch.Tensor, tuple[int, int]]: - target_ratios = get_h2ovl_target_ratios( - min_num, - max_num, - prior_aspect_ratio=prior_aspect_ratio, - ) - - transform = build_transform(input_size=input_size) - images, target_aspect_ratio = dynamic_preprocess_h2ovl( - image, - image_size=input_size, - use_thumbnail=use_thumbnail, - target_ratios=target_ratios, - ) - - pixel_values = torch.stack([transform(image) for image in images]) - return pixel_values, target_aspect_ratio - - -# refactored to use the _preprocess_image function -def image_to_pixel_values_h2ovl( - image: Image.Image, - *, - input_size: int, - min_num: int, - max_num: int, - use_thumbnail: bool, - use_msac: bool, -) -> torch.Tensor: - # when MSAC is turned on, we need to process the image twice - if use_msac: - # first pass - pixel_values1, aspect_ratio1 = _preprocess_image( - image, - input_size=input_size, - min_num=1, - max_num=max_num, - use_thumbnail=True, - prior_aspect_ratio=None, - ) - # second pass - pixel_values2, _ = _preprocess_image( - image, - input_size=input_size, - min_num=3, - max_num=max_num, - use_thumbnail=True, - prior_aspect_ratio=aspect_ratio1, - ) - # combine pixel values - pixel_values = torch.cat( - [pixel_values2[:-1], pixel_values1[:-1], pixel_values2[-1:]], 0 - ) - - else: - pixel_values, _ = _preprocess_image( - image, - input_size=input_size, - min_num=min_num, - max_num=max_num, - use_thumbnail=use_thumbnail, - prior_aspect_ratio=None, - ) - - return pixel_values - - -class H2OVLProcessor(BaseInternVLProcessor): - def __init__( - self, - config: PretrainedConfig, - tokenizer: TokenizerLike, - *, - min_dynamic_patch: int | None = None, - max_dynamic_patch: int | None = None, - dynamic_image_size: bool | None = None, - use_msac: bool | None = None, - ) -> None: - super().__init__( - config, - tokenizer, - min_dynamic_patch=min_dynamic_patch, - max_dynamic_patch=max_dynamic_patch, - dynamic_image_size=dynamic_image_size, - ) - - if use_msac is None: - use_msac = config.use_msac - assert isinstance(use_msac, bool) - - self.use_msac = use_msac - - @property - def image_token_id(self) -> int: - return self.tokenizer.get_vocab()[IMG_CONTEXT] - - def get_image_repl( - self, - feature_size: int, - num_patches: int | None, - ) -> PromptUpdateDetails[str]: - repl_features = IMG_CONTEXT * feature_size - repl_full = IMG_START + repl_features + IMG_END - - return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT) - - def resolve_min_max_num( - self, - *, - min_dynamic_patch: int | None = None, - max_dynamic_patch: int | None = None, - dynamic_image_size: bool | None = None, - use_thumbnail: bool | None = None, - ) -> tuple[int, int]: - min_dynamic_patch = ( - self.min_dynamic_patch if min_dynamic_patch is None else min_dynamic_patch - ) - max_dynamic_patch = ( - self.max_dynamic_patch if max_dynamic_patch is None else max_dynamic_patch - ) - dynamic_image_size = ( - self.dynamic_image_size - if dynamic_image_size is None - else dynamic_image_size - ) - use_thumbnail = self.use_thumbnail if use_thumbnail is None else use_thumbnail - - return resolve_h2ovl_min_max_num( - min_dynamic_patch=min_dynamic_patch, - max_dynamic_patch=max_dynamic_patch, - dynamic_image_size=dynamic_image_size, - use_thumbnail=use_thumbnail, - ) - - def resolve_target_ratios( - self, - *, - min_dynamic_patch: int | None = None, - max_dynamic_patch: int | None = None, - dynamic_image_size: bool | None = None, - use_thumbnail: bool | None = None, - prior_aspect_ratio: tuple[int, int] | None = None, - override_min_num: int | None = None, - ) -> list[tuple[int, int]]: - min_num, max_num = self.resolve_min_max_num( - min_dynamic_patch=min_dynamic_patch, - max_dynamic_patch=max_dynamic_patch, - dynamic_image_size=dynamic_image_size, - use_thumbnail=use_thumbnail, - ) - if override_min_num is not None: - min_num = override_min_num - - return get_h2ovl_target_ratios( - min_num, - max_num, - prior_aspect_ratio=prior_aspect_ratio, - ) - - def get_num_image_tokens( - self, - *, - image_width: int, - image_height: int, - use_msac: bool | None = None, - ) -> int: - use_msac = self.use_msac if use_msac is None else use_msac - - use_thumbnail = self.use_thumbnail - - if use_msac: - target_ratios_1 = self.resolve_target_ratios( - use_thumbnail=False, # Applied in calculate_targets - override_min_num=1, - ) - num_patches_1, _, _, aspect_ratio_1 = calculate_h2ovl_targets( - orig_width=image_width, - orig_height=image_height, - image_size=self.image_size, - target_ratios=target_ratios_1, - use_thumbnail=True, - ) - - target_ratios_2 = self.resolve_target_ratios( - use_thumbnail=False, # Applied in calculate_targets - prior_aspect_ratio=aspect_ratio_1, - override_min_num=3, - ) - num_patches_2, _, _, _ = calculate_h2ovl_targets( - orig_width=image_width, - orig_height=image_height, - image_size=self.image_size, - target_ratios=target_ratios_2, - use_thumbnail=True, - ) - - num_patches = num_patches_1 + num_patches_2 - 1 - else: - target_ratios = self.resolve_target_ratios( - use_thumbnail=False, # Applied in calculate_targets - ) - num_patches, _, _, _ = calculate_h2ovl_targets( - orig_width=image_width, - orig_height=image_height, - image_size=self.image_size, - target_ratios=target_ratios, - use_thumbnail=use_thumbnail, - ) - - return num_patches * self.num_image_token - - def _images_to_pixel_values_lst( - self, - images: list[Image.Image], - min_dynamic_patch: int | None = None, - max_dynamic_patch: int | None = None, - dynamic_image_size: bool | None = None, - ) -> list[torch.Tensor]: - use_msac = self.use_msac if len(images) == 1 else False - - min_num, max_num = self.resolve_min_max_num( - min_dynamic_patch=min_dynamic_patch, - max_dynamic_patch=max_dynamic_patch, - dynamic_image_size=dynamic_image_size, - use_thumbnail=False, # Applied in image_to_pixel_values - ) - - return [ - image_to_pixel_values_h2ovl( - image, - input_size=self.image_size, - min_num=min_num, - max_num=max_num, - use_thumbnail=self.use_thumbnail, - use_msac=use_msac, - ) - for image in images - ] - - class H2OVLProcessingInfo(BaseInternVLProcessingInfo): def get_hf_processor(self, **kwargs: object) -> H2OVLProcessor: return self.ctx.init_processor( diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index cdaa2b093..8126391b2 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -7,16 +7,13 @@ # Copyright (c) 2023 OpenGVLab # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- -from abc import ABC, abstractmethod +from abc import abstractmethod from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Any, Literal, TypeAlias, TypeVar +from typing import Annotated, Literal, TypeAlias, TypeVar -import numpy.typing as npt import torch import torch.nn as nn -import torchvision.transforms as T -from PIL import Image -from transformers import BatchFeature, PretrainedConfig, TensorType +from transformers import BatchFeature, PretrainedConfig from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions @@ -28,7 +25,6 @@ from vllm.model_executor.models.intern_vit import ( ) from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import ( MultiModalDataDict, MultiModalFieldConfig, @@ -46,10 +42,12 @@ from vllm.multimodal.processing import ( BaseProcessingInfo, PromptReplacement, PromptUpdate, - PromptUpdateDetails, ) from vllm.sequence import IntermediateTensors -from vllm.tokenizers import TokenizerLike +from vllm.transformers_utils.processors.internvl import ( + BaseInternVLProcessor, + InternVLProcessor, +) from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import ( @@ -60,13 +58,6 @@ from .interfaces import ( ) from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix -IMG_START = "" -IMG_END = "" -IMG_CONTEXT = "" - -IMAGENET_MEAN = (0.485, 0.456, 0.406) -IMAGENET_STD = (0.229, 0.224, 0.225) - class InternVLImagePixelInputs(TensorSchema): """ @@ -128,568 +119,6 @@ class InternVLVideoEmbeddingInputs(TensorSchema): InternVLVideoInputs: TypeAlias = InternVLVideoPixelInputs | InternVLVideoEmbeddingInputs -# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B -def build_transform(input_size: int): - MEAN, STD = IMAGENET_MEAN, IMAGENET_STD - transform = T.Compose( - [ - T.Lambda(lambda img: convert_image_mode(img, "RGB")), - T.Resize( - (input_size, input_size), interpolation=T.InterpolationMode.BICUBIC - ), - T.ToTensor(), - T.Normalize(mean=MEAN, std=STD), - ] - ) - return transform - - -# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B -def find_closest_aspect_ratio( - aspect_ratio: float, - target_ratios: list[tuple[int, int]], - *, - width: int, - height: int, - image_size: int, -) -> tuple[int, int]: - best_ratio_diff = float("inf") - best_ratio = (1, 1) - area = width * height - for ratio in target_ratios: - target_aspect_ratio = ratio[0] / ratio[1] - ratio_diff = abs(aspect_ratio - target_aspect_ratio) - if ratio_diff < best_ratio_diff: - best_ratio_diff = ratio_diff - best_ratio = ratio - elif ratio_diff == best_ratio_diff: - if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: - best_ratio = ratio - return best_ratio - - -def resolve_internvl_min_max_num( - *, - min_dynamic_patch: int, - max_dynamic_patch: int, - dynamic_image_size: bool, - use_thumbnail: bool, -) -> tuple[int, int]: - min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1 - max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1 - - if use_thumbnail and max_dynamic_patch != 1: - max_dynamic_patch += 1 - - return min_dynamic_patch, max_dynamic_patch - - -def get_internvl_target_ratios( - min_num: int, - max_num: int, -) -> list[tuple[int, int]]: - target_ratios = { - (i, j) - for n in range(min_num, max_num + 1) - for i in range(1, n + 1) - for j in range(1, n + 1) - if min_num <= i * j <= max_num - } - return sorted(target_ratios, key=lambda x: x[0] * x[1]) - - -def calculate_internvl_targets( - *, - orig_width: int, - orig_height: int, - target_ratios: list[tuple[int, int]], - image_size: int, - use_thumbnail: bool, -) -> tuple[int, int, int]: - aspect_ratio = orig_width / orig_height - - # find the closest aspect ratio to the target - target_aspect_ratio = find_closest_aspect_ratio( - aspect_ratio, - target_ratios, - width=orig_width, - height=orig_height, - image_size=image_size, - ) - - # calculate the target width and height - target_width = image_size * target_aspect_ratio[0] - target_height = image_size * target_aspect_ratio[1] - blocks = target_aspect_ratio[0] * target_aspect_ratio[1] - - # add thumbnail image if num_blocks != 1 - if use_thumbnail and blocks != 1: - blocks += 1 - - return blocks, target_width, target_height - - -# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B -def dynamic_preprocess_internvl( - image: Image.Image, - *, - target_ratios: list[tuple[int, int]], - image_size: int, - use_thumbnail: bool, -) -> list[Image.Image]: - orig_width, orig_height = image.size - - # calculate the number of blocks without thumbnail - blocks, target_width, target_height = calculate_internvl_targets( - orig_width=orig_width, - orig_height=orig_height, - target_ratios=target_ratios, - image_size=image_size, - use_thumbnail=False, - ) - - # resize the image - resized_img = image.resize((target_width, target_height)) - processed_images = [] - for i in range(blocks): - box = ( - (i % (target_width // image_size)) * image_size, - (i // (target_width // image_size)) * image_size, - ((i % (target_width // image_size)) + 1) * image_size, - ((i // (target_width // image_size)) + 1) * image_size, - ) - # split the image - split_img = resized_img.crop(box) - processed_images.append(split_img) - - assert len(processed_images) == blocks - - if use_thumbnail and len(processed_images) != 1: - thumbnail_img = image.resize((image_size, image_size)) - processed_images.append(thumbnail_img) - - return processed_images - - -# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B -def image_to_pixel_values_internvl( - image: Image.Image, - *, - input_size: int, - min_num: int, - max_num: int, - use_thumbnail: bool, -) -> torch.Tensor: - target_ratios = get_internvl_target_ratios(min_num, max_num) - - transform = build_transform(input_size=input_size) - images = dynamic_preprocess_internvl( - image, - target_ratios=target_ratios, - image_size=input_size, - use_thumbnail=use_thumbnail, - ) - - pixel_values = torch.stack([transform(image) for image in images]) - return pixel_values - - -# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B -def video_to_pixel_values_internvl( - video: npt.NDArray, - *, - input_size: int, - min_num: int, - max_num: int, - use_thumbnail: bool, -) -> torch.Tensor: - target_ratios = get_internvl_target_ratios(min_num, max_num) - - transform = build_transform(input_size=input_size) - frames_list = list[Image.Image]() - for frame in video: - pil_frame = dynamic_preprocess_internvl( - Image.fromarray(frame, mode="RGB"), - target_ratios=target_ratios, - image_size=input_size, - use_thumbnail=use_thumbnail, - ) - assert len(pil_frame) == 1 - frames_list.extend(pil_frame) - - pixel_values = torch.stack([transform(image) for image in frames_list]) - return pixel_values - - -class BaseInternVLProcessor(ABC): - """ - This model doesn't define its own HF processor, - so we implement our own one here. - - The code to insert image tokens is based on: - https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py#L252 - """ - - def __init__( - self, - config: PretrainedConfig, - tokenizer: TokenizerLike, - *, - min_dynamic_patch: int | None = None, - max_dynamic_patch: int | None = None, - dynamic_image_size: bool | None = None, - ) -> None: - super().__init__() - - self.config = config - self.tokenizer = tokenizer - - image_size: int = config.vision_config.image_size - patch_size: int = config.vision_config.patch_size - - if min_dynamic_patch is None: - min_dynamic_patch = config.min_dynamic_patch - assert isinstance(min_dynamic_patch, int) - - if max_dynamic_patch is None: - max_dynamic_patch = config.max_dynamic_patch - assert isinstance(max_dynamic_patch, int) - - if dynamic_image_size is None: - dynamic_image_size = config.dynamic_image_size - assert isinstance(dynamic_image_size, bool) - - self.num_image_token = int( - (image_size // patch_size) ** 2 * (config.downsample_ratio**2) - ) - self.image_size = image_size - self.min_dynamic_patch = min_dynamic_patch - self.max_dynamic_patch = max_dynamic_patch - self.dynamic_image_size = dynamic_image_size - self.use_thumbnail: bool = config.use_thumbnail - - @property - @abstractmethod - def image_token_id(self) -> int: - raise NotImplementedError - - @abstractmethod - def get_image_repl( - self, - feature_size: int, - num_patches: int | None, - ) -> PromptUpdateDetails[str]: - raise NotImplementedError - - def resolve_min_max_num( - self, - *, - min_dynamic_patch: int | None = None, - max_dynamic_patch: int | None = None, - dynamic_image_size: bool | None = None, - use_thumbnail: bool | None = None, - ) -> tuple[int, int]: - min_dynamic_patch = ( - self.min_dynamic_patch if min_dynamic_patch is None else min_dynamic_patch - ) - max_dynamic_patch = ( - self.max_dynamic_patch if max_dynamic_patch is None else max_dynamic_patch - ) - dynamic_image_size = ( - self.dynamic_image_size - if dynamic_image_size is None - else dynamic_image_size - ) - use_thumbnail = self.use_thumbnail if use_thumbnail is None else use_thumbnail - - return resolve_internvl_min_max_num( - min_dynamic_patch=min_dynamic_patch, - max_dynamic_patch=max_dynamic_patch, - dynamic_image_size=dynamic_image_size, - use_thumbnail=use_thumbnail, - ) - - def resolve_target_ratios( - self, - *, - min_dynamic_patch: int | None = None, - max_dynamic_patch: int | None = None, - dynamic_image_size: bool | None = None, - use_thumbnail: bool | None = None, - ) -> list[tuple[int, int]]: - min_num, max_num = self.resolve_min_max_num( - min_dynamic_patch=min_dynamic_patch, - max_dynamic_patch=max_dynamic_patch, - dynamic_image_size=dynamic_image_size, - use_thumbnail=use_thumbnail, - ) - - return get_internvl_target_ratios(min_num, max_num) - - def get_num_image_tokens( - self, - *, - image_width: int, - image_height: int, - ) -> int: - target_ratios = self.resolve_target_ratios( - use_thumbnail=False, # Applied in calculate_targets - ) - - num_patches, _, _ = calculate_internvl_targets( - orig_width=image_width, - orig_height=image_height, - image_size=self.image_size, - target_ratios=target_ratios, - use_thumbnail=self.use_thumbnail, - ) - - return num_patches * self.num_image_token - - def _images_to_pixel_values_lst( - self, - images: list[Image.Image], - min_dynamic_patch: int | None = None, - max_dynamic_patch: int | None = None, - dynamic_image_size: bool | None = None, - ) -> list[torch.Tensor]: - min_num, max_num = self.resolve_min_max_num( - min_dynamic_patch=min_dynamic_patch, - max_dynamic_patch=max_dynamic_patch, - dynamic_image_size=dynamic_image_size, - use_thumbnail=False, # Applied in image_to_pixel_values - ) - - return [ - image_to_pixel_values_internvl( - image, - input_size=self.image_size, - min_num=min_num, - max_num=max_num, - use_thumbnail=self.use_thumbnail, - ) - for image in images - ] - - def _preprocess_image( - self, - text: list[str], - images: list[Image.Image], - min_dynamic_patch: int | None = None, - max_dynamic_patch: int | None = None, - dynamic_image_size: bool | None = None, - ) -> tuple[list[str], dict[str, torch.Tensor]]: - if len(images) == 0: - image_inputs = {} - else: - pixel_values_lst = self._images_to_pixel_values_lst( - images, - min_dynamic_patch=min_dynamic_patch, - max_dynamic_patch=max_dynamic_patch, - dynamic_image_size=dynamic_image_size, - ) - image_inputs = { - "pixel_values_flat": torch.cat(pixel_values_lst), - "image_num_patches": torch.tensor( - [len(item) for item in pixel_values_lst] - ), - } - - for pixel_values in pixel_values_lst: - num_patches = pixel_values.shape[0] - feature_size = num_patches * self.num_image_token - - image_repl = self.get_image_repl(feature_size, num_patches) - text = [t.replace("", image_repl.full, 1) for t in text] - return text, image_inputs - - def _make_batch_input(self, input_item: Any | list[Any] | None = None): - if input_item is None: - input_item = [] - if not isinstance(input_item, list): - input_item = [input_item] - return input_item - - def __call__( - self, - text: str | list[str] | None = None, - images: Image.Image | list[Image.Image] | None = None, - min_dynamic_patch: int | None = None, - max_dynamic_patch: int | None = None, - dynamic_image_size: bool | None = None, - return_tensors: str | TensorType | None = None, - ) -> BatchFeature: - text, images = [self._make_batch_input(x) for x in (text, images)] - - text, image_inputs = self._preprocess_image( - text=text, - images=images, - min_dynamic_patch=min_dynamic_patch, - max_dynamic_patch=max_dynamic_patch, - dynamic_image_size=dynamic_image_size, - ) - - text_inputs = self.tokenizer(text) - - combined_outputs = {**text_inputs, **image_inputs} - - return BatchFeature(combined_outputs, tensor_type=return_tensors) - - -class InternVLProcessor(BaseInternVLProcessor): - """ - HF Processor for InternVLChatModel with extended video processing logic. - - Code for video processing is adapted from video example: - https://huggingface.co/OpenGVLab/InternVL3-1B#inference-with-transformers - """ - - def __init__( - self, - config: PretrainedConfig, - tokenizer: TokenizerLike, - *, - min_dynamic_patch: int | None = None, - max_dynamic_patch: int | None = None, - dynamic_image_size: bool | None = None, - video_token: str | None = None, - ) -> None: - super().__init__( - config=config, - tokenizer=tokenizer, - min_dynamic_patch=min_dynamic_patch, - max_dynamic_patch=max_dynamic_patch, - dynamic_image_size=dynamic_image_size, - ) - # add extra video token for video processing - self.video_token = video_token - - @property - def image_token_id(self) -> int: - return self.tokenizer.get_vocab()[IMG_CONTEXT] - - @property - def video_token_id(self) -> int | None: - if self.video_token is None: - return None - return self.tokenizer.get_vocab().get(self.video_token, None) - - @property - def supports_video(self) -> bool: - return self.video_token_id is not None - - def _videos_to_pixel_values_lst( - self, - videos: list[npt.NDArray], - dynamic_image_size: bool | None = None, - ) -> list[torch.Tensor]: - min_num, max_num = self.resolve_min_max_num( - min_dynamic_patch=1, - max_dynamic_patch=1, - dynamic_image_size=dynamic_image_size, - use_thumbnail=False, # Applied in image_to_pixel_values - ) - - return [ - video_to_pixel_values_internvl( - video, - input_size=self.image_size, - min_num=min_num, - max_num=max_num, - use_thumbnail=False, - ) - for video in videos - ] - - def _preprocess_video( - self, - text: list[str], - videos: list[npt.NDArray], - dynamic_image_size: bool | None = None, - ): - if len(videos) == 0 or not self.supports_video: - video_inputs = {} - else: - pixel_values_lst_video = self._videos_to_pixel_values_lst( - videos, - dynamic_image_size=dynamic_image_size, - ) - video_inputs = { - "pixel_values_flat_video": torch.cat(pixel_values_lst_video), - "video_num_patches": torch.tensor( - [len(item) for item in pixel_values_lst_video] - ), - } - - for pixel_values in pixel_values_lst_video: - num_patches = pixel_values.shape[0] - - video_repl = self.get_video_repl( - self.num_image_token, num_patches, self.video_token - ) - text = [t.replace("