Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
84e23d10
Unverified
Commit
84e23d10
authored
Dec 14, 2025
by
Wenqi Glantz
Committed by
GitHub
Dec 15, 2025
Browse files
additional protection for CVE-2025-62164 (#30649)
Signed-off-by:
Wenqi Glantz
<
wglantz@nvidia.com
>
parent
738648fb
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
510 additions
and
15 deletions
+510
-15
tests/entrypoints/openai/test_sparse_tensor_validation.py
tests/entrypoints/openai/test_sparse_tensor_validation.py
+342
-0
tests/multimodal/test_sparse_tensor_validation_unit.py
tests/multimodal/test_sparse_tensor_validation_unit.py
+134
-0
vllm/entrypoints/renderer.py
vllm/entrypoints/renderer.py
+14
-11
vllm/multimodal/audio.py
vllm/multimodal/audio.py
+10
-2
vllm/multimodal/image.py
vllm/multimodal/image.py
+10
-2
No files found.
tests/entrypoints/openai/test_sparse_tensor_validation.py
0 → 100644
View file @
84e23d10
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Sparse tensor validation in embedding APIs.
Tests verify that malicious sparse tensors are rejected before they can trigger
out-of-bounds memory writes during to_dense() operations.
"""
import
base64
import
io
import
pytest
import
torch
from
vllm.entrypoints.renderer
import
CompletionRenderer
from
vllm.multimodal.audio
import
AudioEmbeddingMediaIO
from
vllm.multimodal.image
import
ImageEmbeddingMediaIO
def
_encode_tensor
(
tensor
:
torch
.
Tensor
)
->
bytes
:
"""Helper to encode a tensor as base64 bytes."""
buffer
=
io
.
BytesIO
()
torch
.
save
(
tensor
,
buffer
)
buffer
.
seek
(
0
)
return
base64
.
b64encode
(
buffer
.
read
())
def
_create_malicious_sparse_tensor
()
->
torch
.
Tensor
:
"""
Create a malicious sparse COO tensor with out-of-bounds indices.
This tensor has indices that point beyond the declared shape, which would
cause an out-of-bounds write when converted to dense format without
validation.
"""
# Create a 3x3 sparse tensor but with indices pointing to (10, 10)
indices
=
torch
.
tensor
([[
10
],
[
10
]])
# Out of bounds for 3x3 shape
values
=
torch
.
tensor
([
1.0
])
shape
=
(
3
,
3
)
# Create sparse tensor (this will be invalid)
sparse_tensor
=
torch
.
sparse_coo_tensor
(
indices
,
values
,
shape
,
dtype
=
torch
.
float32
)
return
sparse_tensor
def
_create_valid_sparse_tensor
()
->
torch
.
Tensor
:
"""Create a valid sparse COO tensor for baseline testing."""
indices
=
torch
.
tensor
([[
0
,
1
,
2
],
[
0
,
1
,
2
]])
values
=
torch
.
tensor
([
1.0
,
2.0
,
3.0
])
shape
=
(
3
,
3
)
sparse_tensor
=
torch
.
sparse_coo_tensor
(
indices
,
values
,
shape
,
dtype
=
torch
.
float32
)
return
sparse_tensor
def
_create_valid_dense_tensor
()
->
torch
.
Tensor
:
"""Create a valid dense tensor for baseline testing."""
return
torch
.
randn
(
10
,
768
,
dtype
=
torch
.
float32
)
# (seq_len, hidden_size)
class
TestPromptEmbedsValidation
:
"""Test sparse tensor validation in prompt embeddings (Completions API)."""
def
test_valid_dense_tensor_accepted
(
self
,
model_config
):
"""Baseline: Valid dense tensors should work normally."""
renderer
=
CompletionRenderer
(
model_config
)
valid_tensor
=
_create_valid_dense_tensor
()
encoded
=
_encode_tensor
(
valid_tensor
)
# Should not raise any exception
result
=
renderer
.
load_prompt_embeds
(
encoded
)
assert
len
(
result
)
==
1
assert
result
[
0
][
"prompt_embeds"
].
shape
==
valid_tensor
.
shape
def
test_valid_sparse_tensor_accepted
(
self
):
"""Baseline: Valid sparse tensors should load successfully."""
io_handler
=
ImageEmbeddingMediaIO
()
valid_sparse
=
_create_valid_sparse_tensor
()
encoded
=
_encode_tensor
(
valid_sparse
)
# Should not raise any exception (sparse tensors remain sparse)
result
=
io_handler
.
load_base64
(
""
,
encoded
.
decode
(
"utf-8"
))
assert
result
.
shape
==
valid_sparse
.
shape
def
test_malicious_sparse_tensor_rejected
(
self
,
model_config
):
"""Security: Malicious sparse tensors should be rejected."""
renderer
=
CompletionRenderer
(
model_config
)
malicious_tensor
=
_create_malicious_sparse_tensor
()
encoded
=
_encode_tensor
(
malicious_tensor
)
# Should raise RuntimeError due to invalid sparse tensor
with
pytest
.
raises
((
RuntimeError
,
ValueError
))
as
exc_info
:
renderer
.
load_prompt_embeds
(
encoded
)
# Error should indicate sparse tensor validation failure
error_msg
=
str
(
exc_info
.
value
).
lower
()
assert
"sparse"
in
error_msg
or
"index"
in
error_msg
or
"bounds"
in
error_msg
def
test_extremely_large_indices_rejected
(
self
,
model_config
):
"""Security: Sparse tensors with extremely large indices should be rejected."""
renderer
=
CompletionRenderer
(
model_config
)
# Create tensor with indices far beyond reasonable bounds
indices
=
torch
.
tensor
([[
999999
],
[
999999
]])
values
=
torch
.
tensor
([
1.0
])
shape
=
(
10
,
10
)
malicious_tensor
=
torch
.
sparse_coo_tensor
(
indices
,
values
,
shape
,
dtype
=
torch
.
float32
)
encoded
=
_encode_tensor
(
malicious_tensor
)
with
pytest
.
raises
((
RuntimeError
,
ValueError
)):
renderer
.
load_prompt_embeds
(
encoded
)
def
test_negative_indices_rejected
(
self
,
model_config
):
"""Security: Sparse tensors with negative indices should be rejected."""
renderer
=
CompletionRenderer
(
model_config
)
# Create tensor with negative indices
indices
=
torch
.
tensor
([[
-
1
],
[
-
1
]])
values
=
torch
.
tensor
([
1.0
])
shape
=
(
10
,
10
)
malicious_tensor
=
torch
.
sparse_coo_tensor
(
indices
,
values
,
shape
,
dtype
=
torch
.
float32
)
encoded
=
_encode_tensor
(
malicious_tensor
)
with
pytest
.
raises
((
RuntimeError
,
ValueError
)):
renderer
.
load_prompt_embeds
(
encoded
)
class
TestImageEmbedsValidation
:
"""Test sparse tensor validation in image embeddings (Chat API)."""
def
test_valid_dense_tensor_accepted
(
self
):
"""Baseline: Valid dense tensors should work normally."""
io_handler
=
ImageEmbeddingMediaIO
()
valid_tensor
=
_create_valid_dense_tensor
()
encoded
=
_encode_tensor
(
valid_tensor
)
# Should not raise any exception
result
=
io_handler
.
load_base64
(
""
,
encoded
.
decode
(
"utf-8"
))
assert
result
.
shape
==
valid_tensor
.
shape
def
test_valid_sparse_tensor_accepted
(
self
):
"""Baseline: Valid sparse tensors should load successfully."""
io_handler
=
AudioEmbeddingMediaIO
()
valid_sparse
=
_create_valid_sparse_tensor
()
encoded
=
_encode_tensor
(
valid_sparse
)
# Should not raise any exception (sparse tensors remain sparse)
result
=
io_handler
.
load_base64
(
""
,
encoded
.
decode
(
"utf-8"
))
assert
result
.
shape
==
valid_sparse
.
shape
def
test_malicious_sparse_tensor_rejected
(
self
):
"""Security: Malicious sparse tensors should be rejected."""
io_handler
=
ImageEmbeddingMediaIO
()
malicious_tensor
=
_create_malicious_sparse_tensor
()
encoded
=
_encode_tensor
(
malicious_tensor
)
# Should raise RuntimeError due to invalid sparse tensor
with
pytest
.
raises
((
RuntimeError
,
ValueError
))
as
exc_info
:
io_handler
.
load_base64
(
""
,
encoded
.
decode
(
"utf-8"
))
error_msg
=
str
(
exc_info
.
value
).
lower
()
assert
"sparse"
in
error_msg
or
"index"
in
error_msg
or
"bounds"
in
error_msg
def
test_load_bytes_validates
(
self
):
"""Security: Validation should also work for load_bytes method."""
io_handler
=
ImageEmbeddingMediaIO
()
malicious_tensor
=
_create_malicious_sparse_tensor
()
buffer
=
io
.
BytesIO
()
torch
.
save
(
malicious_tensor
,
buffer
)
buffer
.
seek
(
0
)
with
pytest
.
raises
((
RuntimeError
,
ValueError
)):
io_handler
.
load_bytes
(
buffer
.
read
())
class
TestAudioEmbedsValidation
:
"""Test sparse tensor validation in audio embeddings (Chat API)."""
def
test_valid_dense_tensor_accepted
(
self
):
"""Baseline: Valid dense tensors should work normally."""
io_handler
=
AudioEmbeddingMediaIO
()
valid_tensor
=
_create_valid_dense_tensor
()
encoded
=
_encode_tensor
(
valid_tensor
)
# Should not raise any exception
result
=
io_handler
.
load_base64
(
""
,
encoded
.
decode
(
"utf-8"
))
assert
result
.
shape
==
valid_tensor
.
shape
def
test_valid_sparse_tensor_accepted
(
self
):
"""Baseline: Valid sparse tensors should be converted successfully."""
io_handler
=
AudioEmbeddingMediaIO
()
valid_sparse
=
_create_valid_sparse_tensor
()
encoded
=
_encode_tensor
(
valid_sparse
)
# Should not raise any exception
result
=
io_handler
.
load_base64
(
""
,
encoded
.
decode
(
"utf-8"
))
assert
result
.
is_sparse
is
False
def
test_malicious_sparse_tensor_rejected
(
self
):
"""Security: Malicious sparse tensors should be rejected."""
io_handler
=
AudioEmbeddingMediaIO
()
malicious_tensor
=
_create_malicious_sparse_tensor
()
encoded
=
_encode_tensor
(
malicious_tensor
)
# Should raise RuntimeError due to invalid sparse tensor
with
pytest
.
raises
((
RuntimeError
,
ValueError
))
as
exc_info
:
io_handler
.
load_base64
(
""
,
encoded
.
decode
(
"utf-8"
))
error_msg
=
str
(
exc_info
.
value
).
lower
()
assert
"sparse"
in
error_msg
or
"index"
in
error_msg
or
"bounds"
in
error_msg
def
test_load_bytes_validates
(
self
):
"""Security: Validation should also work for load_bytes method."""
io_handler
=
AudioEmbeddingMediaIO
()
malicious_tensor
=
_create_malicious_sparse_tensor
()
buffer
=
io
.
BytesIO
()
torch
.
save
(
malicious_tensor
,
buffer
)
buffer
.
seek
(
0
)
with
pytest
.
raises
((
RuntimeError
,
ValueError
)):
io_handler
.
load_bytes
(
buffer
.
read
())
class
TestSparseTensorValidationIntegration
:
"""
These tests verify the complete attack chain is blocked at all entry points.
"""
def
test_attack_scenario_completions_api
(
self
,
model_config
):
"""
Simulate a complete attack through the Completions API.
Attack scenario:
1. Attacker crafts malicious sparse tensor
2. Encodes it as base64
3. Sends to /v1/completions with prompt_embeds parameter
4. Server should reject before memory corruption occurs
"""
renderer
=
CompletionRenderer
(
model_config
)
# Step 1-2: Attacker creates malicious payload
attack_payload
=
_encode_tensor
(
_create_malicious_sparse_tensor
())
# Step 3-4: Server processes and should reject
with
pytest
.
raises
((
RuntimeError
,
ValueError
)):
renderer
.
load_prompt_embeds
(
attack_payload
)
def
test_attack_scenario_chat_api_image
(
self
):
"""
Simulate attack through Chat API with image_embeds.
Verifies the image embeddings path is protected.
"""
io_handler
=
ImageEmbeddingMediaIO
()
attack_payload
=
_encode_tensor
(
_create_malicious_sparse_tensor
())
with
pytest
.
raises
((
RuntimeError
,
ValueError
)):
io_handler
.
load_base64
(
""
,
attack_payload
.
decode
(
"utf-8"
))
def
test_attack_scenario_chat_api_audio
(
self
):
"""
Simulate attack through Chat API with audio_embeds.
Verifies the audio embeddings path is protected.
"""
io_handler
=
AudioEmbeddingMediaIO
()
attack_payload
=
_encode_tensor
(
_create_malicious_sparse_tensor
())
with
pytest
.
raises
((
RuntimeError
,
ValueError
)):
io_handler
.
load_base64
(
""
,
attack_payload
.
decode
(
"utf-8"
))
def
test_multiple_valid_embeddings_in_batch
(
self
,
model_config
):
"""
Regression test: Multiple valid embeddings should still work.
Ensures the fix doesn't break legitimate batch processing.
"""
renderer
=
CompletionRenderer
(
model_config
)
valid_tensors
=
[
_encode_tensor
(
_create_valid_dense_tensor
()),
_encode_tensor
(
_create_valid_dense_tensor
()),
_encode_tensor
(
_create_valid_dense_tensor
()),
]
# Should process all without error
result
=
renderer
.
load_prompt_embeds
(
valid_tensors
)
assert
len
(
result
)
==
3
def
test_mixed_valid_and_malicious_rejected
(
self
,
model_config
):
"""
Security: Batch with one malicious tensor should be rejected.
Even if most tensors are valid, a single malicious one should
cause rejection of the entire batch.
"""
renderer
=
CompletionRenderer
(
model_config
)
mixed_batch
=
[
_encode_tensor
(
_create_valid_dense_tensor
()),
_encode_tensor
(
_create_malicious_sparse_tensor
()),
# Malicious
_encode_tensor
(
_create_valid_dense_tensor
()),
]
# Should fail on the malicious tensor
with
pytest
.
raises
((
RuntimeError
,
ValueError
)):
renderer
.
load_prompt_embeds
(
mixed_batch
)
# Pytest fixtures
@
pytest
.
fixture
def
model_config
():
"""Mock ModelConfig for testing."""
from
vllm.config
import
ModelConfig
return
ModelConfig
(
model
=
"facebook/opt-125m"
,
tokenizer
=
"facebook/opt-125m"
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
dtype
=
"float32"
,
seed
=
0
,
enable_prompt_embeds
=
True
,
# Required for prompt embeds tests
)
tests/multimodal/test_sparse_tensor_validation_unit.py
0 → 100644
View file @
84e23d10
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit tests for sparse tensor validation.
Simple, fast unit tests that can run without server fixtures.
Run with: pytest tests/multimodal/test_sparse_tensor_validation_unit.py -v
"""
import
io
import
pytest
import
torch
class
TestSparseTensorValidationContextManager
:
"""Test that torch.sparse.check_sparse_tensor_invariants() works as expected."""
def
test_valid_sparse_tensor_passes
(
self
):
"""Valid sparse tensors should pass validation."""
indices
=
torch
.
tensor
([[
0
,
1
],
[
0
,
1
]])
values
=
torch
.
tensor
([
1.0
,
2.0
])
shape
=
(
2
,
2
)
with
torch
.
sparse
.
check_sparse_tensor_invariants
():
tensor
=
torch
.
sparse_coo_tensor
(
indices
,
values
,
shape
)
dense
=
tensor
.
to_dense
()
assert
dense
.
shape
==
shape
def
test_out_of_bounds_indices_rejected
(
self
):
"""Sparse tensors with out-of-bounds indices should be rejected."""
indices
=
torch
.
tensor
([[
5
],
[
5
]])
# Out of bounds for 2x2
values
=
torch
.
tensor
([
1.0
])
shape
=
(
2
,
2
)
with
pytest
.
raises
(
RuntimeError
)
as
exc_info
:
# noqa: SIM117
with
torch
.
sparse
.
check_sparse_tensor_invariants
():
tensor
=
torch
.
sparse_coo_tensor
(
indices
,
values
,
shape
)
tensor
.
to_dense
()
assert
(
"index"
in
str
(
exc_info
.
value
).
lower
()
or
"bound"
in
str
(
exc_info
.
value
).
lower
()
)
def
test_negative_indices_rejected
(
self
):
"""Sparse tensors with negative indices should be rejected."""
indices
=
torch
.
tensor
([[
-
1
],
[
0
]])
values
=
torch
.
tensor
([
1.0
])
shape
=
(
2
,
2
)
with
pytest
.
raises
(
RuntimeError
):
# noqa: SIM117
with
torch
.
sparse
.
check_sparse_tensor_invariants
():
tensor
=
torch
.
sparse_coo_tensor
(
indices
,
values
,
shape
)
tensor
.
to_dense
()
def
test_without_context_manager_allows_invalid
(
self
):
"""
WITHOUT validation, invalid tensors may not immediately error.
This demonstrates the vulnerability: PyTorch 2.8.0+ doesn't validate
by default, which can lead to memory corruption.
"""
indices
=
torch
.
tensor
([[
100
],
[
100
]])
# Way out of bounds
values
=
torch
.
tensor
([
1.0
])
shape
=
(
2
,
2
)
# Without validation context, this might create an invalid tensor
# (actual behavior depends on PyTorch version)
tensor
=
torch
.
sparse_coo_tensor
(
indices
,
values
,
shape
)
# The tensor object is created, but it's invalid
assert
tensor
.
is_sparse
class
TestTorchLoadWithValidation
:
"""Test torch.load() with sparse tensor validation."""
def
test_load_valid_sparse_tensor_with_validation
(
self
):
"""Valid sparse tensors should load successfully with validation."""
# Create and save a valid sparse tensor
indices
=
torch
.
tensor
([[
0
,
1
],
[
0
,
1
]])
values
=
torch
.
tensor
([
1.0
,
2.0
])
tensor
=
torch
.
sparse_coo_tensor
(
indices
,
values
,
(
2
,
2
))
buffer
=
io
.
BytesIO
()
torch
.
save
(
tensor
,
buffer
)
buffer
.
seek
(
0
)
# Load with validation
with
torch
.
sparse
.
check_sparse_tensor_invariants
():
loaded
=
torch
.
load
(
buffer
,
weights_only
=
True
)
dense
=
loaded
.
to_dense
()
assert
dense
.
shape
==
(
2
,
2
)
def
test_load_invalid_sparse_tensor_rejected
(
self
):
"""Invalid sparse tensors should be caught when loaded with validation."""
# Create an invalid sparse tensor (out of bounds)
indices
=
torch
.
tensor
([[
10
],
[
10
]])
values
=
torch
.
tensor
([
1.0
])
tensor
=
torch
.
sparse_coo_tensor
(
indices
,
values
,
(
2
,
2
))
buffer
=
io
.
BytesIO
()
torch
.
save
(
tensor
,
buffer
)
buffer
.
seek
(
0
)
# Load with validation - should fail on to_dense()
with
pytest
.
raises
(
RuntimeError
):
# noqa: SIM117
with
torch
.
sparse
.
check_sparse_tensor_invariants
():
loaded
=
torch
.
load
(
buffer
,
weights_only
=
True
)
loaded
.
to_dense
()
def
test_load_dense_tensor_unaffected
(
self
):
"""Dense tensors should work normally with the validation context."""
# Create and save a dense tensor
tensor
=
torch
.
randn
(
10
,
20
)
buffer
=
io
.
BytesIO
()
torch
.
save
(
tensor
,
buffer
)
buffer
.
seek
(
0
)
# Load with validation (should have no effect on dense tensors)
with
torch
.
sparse
.
check_sparse_tensor_invariants
():
loaded
=
torch
.
load
(
buffer
,
weights_only
=
True
)
assert
loaded
.
shape
==
(
10
,
20
)
assert
not
loaded
.
is_sparse
if
__name__
==
"__main__"
:
# Allow running directly for quick testing
pytest
.
main
([
__file__
,
"-v"
,
"--tb=short"
])
vllm/entrypoints/renderer.py
View file @
84e23d10
...
...
@@ -167,6 +167,9 @@ class BaseRenderer(ABC):
)
def
_load_and_validate_embed
(
embed
:
bytes
)
->
EmbedsPrompt
:
# Enable sparse tensor integrity checks to prevent out-of-bounds
# writes from maliciously crafted tensors
with
torch
.
sparse
.
check_sparse_tensor_invariants
():
tensor
=
torch
.
load
(
io
.
BytesIO
(
pybase64
.
b64decode
(
embed
,
validate
=
True
)),
weights_only
=
True
,
...
...
vllm/multimodal/audio.py
View file @
84e23d10
...
...
@@ -127,13 +127,21 @@ class AudioEmbeddingMediaIO(MediaIO[torch.Tensor]):
def
load_bytes
(
self
,
data
:
bytes
)
->
torch
.
Tensor
:
buffer
=
BytesIO
(
data
)
return
torch
.
load
(
buffer
,
weights_only
=
True
)
# Enable sparse tensor integrity checks to prevent out-of-bounds
# writes from maliciously crafted tensors
with
torch
.
sparse
.
check_sparse_tensor_invariants
():
tensor
=
torch
.
load
(
buffer
,
weights_only
=
True
)
return
tensor
.
to_dense
()
def
load_base64
(
self
,
media_type
:
str
,
data
:
str
)
->
torch
.
Tensor
:
return
self
.
load_bytes
(
pybase64
.
b64decode
(
data
,
validate
=
True
))
def
load_file
(
self
,
filepath
:
Path
)
->
torch
.
Tensor
:
return
torch
.
load
(
filepath
,
weights_only
=
True
)
# Enable sparse tensor integrity checks to prevent out-of-bounds
# writes from maliciously crafted tensors
with
torch
.
sparse
.
check_sparse_tensor_invariants
():
tensor
=
torch
.
load
(
filepath
,
weights_only
=
True
)
return
tensor
.
to_dense
()
def
encode_base64
(
self
,
media
:
torch
.
Tensor
)
->
str
:
return
tensor2base64
(
media
)
vllm/multimodal/image.py
View file @
84e23d10
...
...
@@ -122,13 +122,21 @@ class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]):
def
load_bytes
(
self
,
data
:
bytes
)
->
torch
.
Tensor
:
buffer
=
BytesIO
(
data
)
return
torch
.
load
(
buffer
,
weights_only
=
True
)
# Enable sparse tensor integrity checks to prevent out-of-bounds
# writes from maliciously crafted tensors
with
torch
.
sparse
.
check_sparse_tensor_invariants
():
tensor
=
torch
.
load
(
buffer
,
weights_only
=
True
)
return
tensor
.
to_dense
()
def
load_base64
(
self
,
media_type
:
str
,
data
:
str
)
->
torch
.
Tensor
:
return
self
.
load_bytes
(
pybase64
.
b64decode
(
data
,
validate
=
True
))
def
load_file
(
self
,
filepath
:
Path
)
->
torch
.
Tensor
:
return
torch
.
load
(
filepath
,
weights_only
=
True
)
# Enable sparse tensor integrity checks to prevent out-of-bounds
# writes from maliciously crafted tensors
with
torch
.
sparse
.
check_sparse_tensor_invariants
():
tensor
=
torch
.
load
(
filepath
,
weights_only
=
True
)
return
tensor
.
to_dense
()
def
encode_base64
(
self
,
media
:
torch
.
Tensor
)
->
str
:
return
pybase64
.
b64encode
(
media
.
numpy
()).
decode
(
"utf-8"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment