Unverified Commit b111f8a6 authored by Juan Pérez de Algaba's avatar Juan Pérez de Algaba Committed by GitHub
Browse files

fix(security): Add VLLM_MAX_N_SEQUENCES environment variable and enforce limit (#37952)


Signed-off-by: default avatarjperezde <jperezde@redhat.com>
Signed-off-by: default avatarRussell Bryant <rbryant@redhat.com>
Co-authored-by: default avatarRussell Bryant <rbryant@redhat.com>
parent 497e234d
...@@ -231,6 +231,18 @@ The most effective approach is to deploy vLLM behind a reverse proxy (such as ng ...@@ -231,6 +231,18 @@ The most effective approach is to deploy vLLM behind a reverse proxy (such as ng
- Blocks all other endpoints, including the unauthenticated inference and operational control endpoints - Blocks all other endpoints, including the unauthenticated inference and operational control endpoints
- Implements additional authentication, rate limiting, and logging at the proxy layer - Implements additional authentication, rate limiting, and logging at the proxy layer
## Request Parameter Resource Limits
Certain API request parameters can have a large impact on resource consumption and may be abused to exhaust server resources. The `n` parameter in the `/v1/completions` and `/v1/chat/completions` endpoints controls how many independent output sequences are generated per request. A very large value causes the engine to allocate memory, CPU, and GPU time proportional to `n`, which can lead to out-of-memory conditions on the host and block the server from processing other requests.
To mitigate this, vLLM enforces a configurable upper bound on the `n` parameter via the `VLLM_MAX_N_SEQUENCES` environment variable (default: **16384**). Requests exceeding this limit are rejected before reaching the engine.
### Recommendations
- **Public-facing deployments:** Consider setting `VLLM_MAX_N_SEQUENCES` to a value appropriate for your workload (e.g., `64` or `128`) to limit the blast radius of a single request.
- **Reverse proxy layer:** In addition to vLLM's built-in limit, consider enforcing request body validation and rate limiting at your reverse proxy to further constrain abusive payloads.
- **Monitoring:** Monitor per-request resource consumption to detect anomalous patterns that may indicate abuse.
## Tool Server and MCP Security ## Tool Server and MCP Security
vLLM supports connecting to external tool servers via the `--tool-server` argument. This enables models to call tools through the Responses API (`/v1/responses`). Tool server support works with all models — it is not limited to specific model architectures. vLLM supports connecting to external tool servers via the `--tool-server` argument. This enables models to call tools through the Responses API (`/v1/responses`). Tool server support works with all models — it is not limited to specific model architectures.
......
...@@ -1020,3 +1020,114 @@ def test_chat_completion_request_n_parameter_various_values(): ...@@ -1020,3 +1020,114 @@ def test_chat_completion_request_n_parameter_various_values():
assert sampling_params.n == n_value, ( assert sampling_params.n == n_value, (
f"Expected n={n_value}, got n={sampling_params.n}" f"Expected n={n_value}, got n={sampling_params.n}"
) )
def test_chat_completion_request_n_parameter_exceeds_default_limit(
monkeypatch: pytest.MonkeyPatch,
):
"""Test that n values exceeding the default limit are rejected."""
import vllm.envs as envs
monkeypatch.delenv("VLLM_MAX_N_SEQUENCES", raising=False)
if hasattr(envs.__getattr__, "cache_clear"):
envs.__getattr__.cache_clear()
max_n = envs.VLLM_MAX_N_SEQUENCES
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Test"}],
n=max_n + 1,
max_tokens=10,
)
with pytest.raises(ValueError, match="n must be at most"):
request.to_sampling_params(
max_tokens=10,
default_sampling_params={},
)
def test_chat_completion_request_n_parameter_at_limit(
monkeypatch: pytest.MonkeyPatch,
):
"""Test that n at exactly the limit is accepted."""
import vllm.envs as envs
monkeypatch.delenv("VLLM_MAX_N_SEQUENCES", raising=False)
if hasattr(envs.__getattr__, "cache_clear"):
envs.__getattr__.cache_clear()
max_n = envs.VLLM_MAX_N_SEQUENCES
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Test"}],
n=max_n,
max_tokens=10,
)
sampling_params = request.to_sampling_params(
max_tokens=10,
default_sampling_params={},
)
assert sampling_params.n == max_n
def test_chat_completion_request_n_parameter_custom_limit(
monkeypatch: pytest.MonkeyPatch,
):
"""Test that VLLM_MAX_N_SEQUENCES env var overrides the default limit."""
import vllm.envs as envs
monkeypatch.setenv("VLLM_MAX_N_SEQUENCES", "128")
if hasattr(envs.__getattr__, "cache_clear"):
envs.__getattr__.cache_clear()
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Test"}],
n=128,
max_tokens=10,
)
sampling_params = request.to_sampling_params(
max_tokens=10,
default_sampling_params={},
)
assert sampling_params.n == 128
request_over = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Test"}],
n=129,
max_tokens=10,
)
with pytest.raises(ValueError, match="n must be at most 128"):
request_over.to_sampling_params(
max_tokens=10,
default_sampling_params={},
)
def test_chat_completion_request_n_parameter_massive_value(
monkeypatch: pytest.MonkeyPatch,
):
"""Test that astronomically large n values are rejected (CVE fix)."""
import vllm.envs as envs
monkeypatch.delenv("VLLM_MAX_N_SEQUENCES", raising=False)
if hasattr(envs.__getattr__, "cache_clear"):
envs.__getattr__.cache_clear()
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Test"}],
n=100_000_000,
max_tokens=1,
)
with pytest.raises(ValueError, match="n must be at most"):
request.to_sampling_params(
max_tokens=1,
default_sampling_params={},
)
...@@ -454,3 +454,55 @@ class TestVllmConfigureLogging: ...@@ -454,3 +454,55 @@ class TestVllmConfigureLogging:
with pytest.raises(ValueError, match="invalid literal for int"): with pytest.raises(ValueError, match="invalid literal for int"):
_ = envs.VLLM_CONFIGURE_LOGGING _ = envs.VLLM_CONFIGURE_LOGGING
class TestVllmMaxNSequences:
def test_default_value(self):
"""Test that VLLM_MAX_N_SEQUENCES defaults to 64."""
with patch.dict(os.environ, {}, clear=False):
os.environ.pop("VLLM_MAX_N_SEQUENCES", None)
if hasattr(envs.__getattr__, "cache_clear"):
envs.__getattr__.cache_clear()
assert envs.VLLM_MAX_N_SEQUENCES == 16384
def test_custom_value(self, monkeypatch: pytest.MonkeyPatch):
"""Test that VLLM_MAX_N_SEQUENCES can be overridden."""
monkeypatch.setenv("VLLM_MAX_N_SEQUENCES", "128")
if hasattr(envs.__getattr__, "cache_clear"):
envs.__getattr__.cache_clear()
assert envs.VLLM_MAX_N_SEQUENCES == 128
def test_sampling_params_respects_limit(
self,
monkeypatch: pytest.MonkeyPatch,
):
"""Test that SamplingParams rejects n above the limit."""
from vllm.sampling_params import SamplingParams
monkeypatch.delenv("VLLM_MAX_N_SEQUENCES", raising=False)
if hasattr(envs.__getattr__, "cache_clear"):
envs.__getattr__.cache_clear()
max_n = envs.VLLM_MAX_N_SEQUENCES
SamplingParams(n=max_n)
with pytest.raises(ValueError, match="n must be at most"):
SamplingParams(n=max_n + 1)
def test_sampling_params_respects_custom_limit(
self,
monkeypatch: pytest.MonkeyPatch,
):
"""Test that SamplingParams uses the overridden env var limit."""
from vllm.sampling_params import SamplingParams
monkeypatch.setenv("VLLM_MAX_N_SEQUENCES", "128")
if hasattr(envs.__getattr__, "cache_clear"):
envs.__getattr__.cache_clear()
SamplingParams(n=128)
with pytest.raises(ValueError, match="n must be at most 128"):
SamplingParams(n=129)
...@@ -86,6 +86,7 @@ if TYPE_CHECKING: ...@@ -86,6 +86,7 @@ if TYPE_CHECKING:
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
VLLM_RPC_TIMEOUT: int = 10000 # ms VLLM_RPC_TIMEOUT: int = 10000 # ms
VLLM_HTTP_TIMEOUT_KEEP_ALIVE: int = 5 # seconds VLLM_HTTP_TIMEOUT_KEEP_ALIVE: int = 5 # seconds
VLLM_MAX_N_SEQUENCES: int = 16384
VLLM_PLUGINS: list[str] | None = None VLLM_PLUGINS: list[str] | None = None
VLLM_LORA_RESOLVER_CACHE_DIR: str | None = None VLLM_LORA_RESOLVER_CACHE_DIR: str | None = None
VLLM_LORA_RESOLVER_HF_REPO_LIST: str | None = None VLLM_LORA_RESOLVER_HF_REPO_LIST: str | None = None
...@@ -870,6 +871,12 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -870,6 +871,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_HTTP_TIMEOUT_KEEP_ALIVE": lambda: int( "VLLM_HTTP_TIMEOUT_KEEP_ALIVE": lambda: int(
os.environ.get("VLLM_HTTP_TIMEOUT_KEEP_ALIVE", "5") os.environ.get("VLLM_HTTP_TIMEOUT_KEEP_ALIVE", "5")
), ),
# Maximum allowed value for the `n` sampling parameter (number of output
# sequences per request). Limits resource consumption to prevent
# denial-of-service via excessively large fan-out. Default: 16384.
"VLLM_MAX_N_SEQUENCES": lambda: int(
os.environ.get("VLLM_MAX_N_SEQUENCES", "16384")
),
# a list of plugin names to load, separated by commas. # a list of plugin names to load, separated by commas.
# if this is not set, it means all plugins will be loaded # if this is not set, it means all plugins will be loaded
# if this is set to an empty string, no plugins will be loaded # if this is set to an empty string, no plugins will be loaded
......
...@@ -12,6 +12,7 @@ from typing import Any ...@@ -12,6 +12,7 @@ from typing import Any
import msgspec import msgspec
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
import vllm.envs as envs
from vllm.config import ModelConfig, SpeculativeConfig, StructuredOutputsConfig from vllm.config import ModelConfig, SpeculativeConfig, StructuredOutputsConfig
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -169,6 +170,9 @@ class SamplingParams( ...@@ -169,6 +170,9 @@ class SamplingParams(
n: int = 1 n: int = 1
"""Number of outputs to return for the given prompt request. """Number of outputs to return for the given prompt request.
The maximum allowed value is controlled by the ``VLLM_MAX_N_SEQUENCES``
environment variable (default: 16384).
NOTE: NOTE:
`AsyncLLM` streams outputs by default. When `n > 1`, all `n` outputs `AsyncLLM` streams outputs by default. When `n > 1`, all `n` outputs
are generated and streamed cumulatively per request. To see all `n` are generated and streamed cumulatively per request. To see all `n`
...@@ -425,6 +429,13 @@ class SamplingParams( ...@@ -425,6 +429,13 @@ class SamplingParams(
raise ValueError(f"n must be an int, but is of type {type(self.n)}") raise ValueError(f"n must be an int, but is of type {type(self.n)}")
if self.n < 1: if self.n < 1:
raise ValueError(f"n must be at least 1, got {self.n}.") raise ValueError(f"n must be at least 1, got {self.n}.")
max_n = envs.VLLM_MAX_N_SEQUENCES
if self.n > max_n:
raise ValueError(
f"n must be at most {max_n}, got {self.n}. "
"To increase this limit, set the VLLM_MAX_N_SEQUENCES "
"environment variable."
)
if not -2.0 <= self.presence_penalty <= 2.0: if not -2.0 <= self.presence_penalty <= 2.0:
raise ValueError( raise ValueError(
f"presence_penalty must be in [-2, 2], got {self.presence_penalty}." f"presence_penalty must be in [-2, 2], got {self.presence_penalty}."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment