Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ab81379e
Unverified
Commit
ab81379e
authored
Oct 16, 2025
by
Nick Hill
Committed by
GitHub
Oct 16, 2025
Browse files
[Perf] Exploit out-of-band buffers in shm_broadcast (#26961)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
4ffd6e89
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
54 additions
and
16 deletions
+54
-16
vllm/distributed/device_communicators/shm_broadcast.py
vllm/distributed/device_communicators/shm_broadcast.py
+54
-16
No files found.
vllm/distributed/device_communicators/shm_broadcast.py
View file @
ab81379e
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
functools
import
pickle
import
pickle
import
time
import
time
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
multiprocessing
import
shared_memory
from
multiprocessing
import
shared_memory
from
pickle
import
PickleBuffer
from
threading
import
Event
from
threading
import
Event
from
typing
import
Any
from
typing
import
TYPE_CHECKING
,
Any
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
torch
import
torch
...
@@ -33,8 +34,18 @@ from vllm.utils import (
...
@@ -33,8 +34,18 @@ from vllm.utils import (
is_valid_ipv6_address
,
is_valid_ipv6_address
,
)
)
if
TYPE_CHECKING
:
from
_typeshed
import
SizedBuffer
VLLM_RINGBUFFER_WARNING_INTERVAL
=
envs
.
VLLM_RINGBUFFER_WARNING_INTERVAL
VLLM_RINGBUFFER_WARNING_INTERVAL
=
envs
.
VLLM_RINGBUFFER_WARNING_INTERVAL
from_bytes_big
=
functools
.
partial
(
int
.
from_bytes
,
byteorder
=
"big"
)
def
to_bytes_big
(
value
:
int
,
size
:
int
)
->
bytes
:
return
value
.
to_bytes
(
size
,
byteorder
=
"big"
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -225,7 +236,7 @@ class MessageQueue:
...
@@ -225,7 +236,7 @@ class MessageQueue:
n_reader
,
# number of all readers
n_reader
,
# number of all readers
n_local_reader
,
# number of local readers through shared memory
n_local_reader
,
# number of local readers through shared memory
local_reader_ranks
:
list
[
int
]
|
None
=
None
,
local_reader_ranks
:
list
[
int
]
|
None
=
None
,
max_chunk_bytes
:
int
=
1024
*
1024
*
10
,
max_chunk_bytes
:
int
=
1024
*
1024
*
24
,
# 24MiB
max_chunks
:
int
=
10
,
max_chunks
:
int
=
10
,
connect_ip
:
str
|
None
=
None
,
connect_ip
:
str
|
None
=
None
,
):
):
...
@@ -505,18 +516,41 @@ class MessageQueue:
...
@@ -505,18 +516,41 @@ class MessageQueue:
def
enqueue
(
self
,
obj
,
timeout
:
float
|
None
=
None
):
def
enqueue
(
self
,
obj
,
timeout
:
float
|
None
=
None
):
"""Write to message queue with optional timeout (in seconds)"""
"""Write to message queue with optional timeout (in seconds)"""
assert
self
.
_is_writer
,
"Only writers can enqueue"
assert
self
.
_is_writer
,
"Only writers can enqueue"
serialized_obj
=
pickle
.
dumps
(
obj
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
all_buffers
:
list
[
SizedBuffer
]
=
[
b
""
]
total_bytes
=
6
# 2 bytes for oob buffer count, 4 for main buffer size
def
oob_callback
(
buf
:
PickleBuffer
)
->
bool
:
raw_buf
=
buf
.
raw
()
if
len
(
raw_buf
)
<
1024
*
1024
:
# In-line buffers smaller than 1MiB.
return
True
all_buffers
.
append
(
raw_buf
)
nonlocal
total_bytes
total_bytes
+=
len
(
raw_buf
)
+
4
return
False
all_buffers
[
0
]
=
pickle
.
dumps
(
obj
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
,
buffer_callback
=
oob_callback
)
if
self
.
n_local_reader
>
0
:
if
self
.
n_local_reader
>
0
:
if
len
(
serialized_obj
)
>=
self
.
buffer
.
max_chunk_bytes
:
if
total_bytes
+
len
(
all_buffers
[
0
]
)
>=
self
.
buffer
.
max_chunk_bytes
:
with
self
.
acquire_write
(
timeout
)
as
buf
:
with
self
.
acquire_write
(
timeout
)
as
buf
:
buf
[
0
]
=
1
# overflow
buf
[
0
]
=
1
# overflow
self
.
local_socket
.
send
(
serialized_obj
)
self
.
local_socket
.
send
_multipart
(
all_buffers
,
copy
=
False
)
else
:
else
:
with
self
.
acquire_write
(
timeout
)
as
buf
:
with
self
.
acquire_write
(
timeout
)
as
buf
:
buf
[
0
]
=
0
# not overflow
buf
[
0
]
=
0
# not overflow
buf
[
1
:
len
(
serialized_obj
)
+
1
]
=
serialized_obj
offset
=
3
buf
[
1
:
offset
]
=
to_bytes_big
(
len
(
all_buffers
),
2
)
# oob buf count
for
buffer
in
all_buffers
:
buf_len
=
len
(
buffer
)
# prepend each buffer with 4 bytes containing its size.
buf_offset
=
offset
+
4
buf
[
offset
:
buf_offset
]
=
to_bytes_big
(
buf_len
,
4
)
buf
[
buf_offset
:
(
offset
:
=
buf_offset
+
buf_len
)]
=
buffer
if
self
.
n_remote_reader
>
0
:
if
self
.
n_remote_reader
>
0
:
self
.
remote_socket
.
send
(
serialized_obj
)
self
.
remote_socket
.
send
_multipart
(
all_buffers
,
copy
=
False
)
def
dequeue
(
def
dequeue
(
self
,
self
,
...
@@ -529,10 +563,15 @@ class MessageQueue:
...
@@ -529,10 +563,15 @@ class MessageQueue:
with
self
.
acquire_read
(
timeout
,
cancel
,
indefinite
)
as
buf
:
with
self
.
acquire_read
(
timeout
,
cancel
,
indefinite
)
as
buf
:
overflow
=
buf
[
0
]
==
1
overflow
=
buf
[
0
]
==
1
if
not
overflow
:
if
not
overflow
:
# no need to know the size of serialized object
offset
=
3
# pickle format contains the size information internally
buf_count
=
from_bytes_big
(
buf
[
1
:
offset
])
# see https://docs.python.org/3/library/pickle.html
all_buffers
=
[]
obj
=
pickle
.
loads
(
buf
[
1
:])
for
i
in
range
(
buf_count
):
buf_offset
=
offset
+
4
buf_len
=
from_bytes_big
(
buf
[
offset
:
buf_offset
])
offset
=
buf_offset
+
buf_len
all_buffers
.
append
(
buf
[
buf_offset
:
offset
])
obj
=
pickle
.
loads
(
all_buffers
[
0
],
buffers
=
all_buffers
[
1
:])
if
overflow
:
if
overflow
:
obj
=
MessageQueue
.
recv
(
self
.
local_socket
,
timeout
)
obj
=
MessageQueue
.
recv
(
self
.
local_socket
,
timeout
)
elif
self
.
_is_remote_reader
:
elif
self
.
_is_remote_reader
:
...
@@ -546,14 +585,13 @@ class MessageQueue:
...
@@ -546,14 +585,13 @@ class MessageQueue:
timeout_ms
=
None
if
timeout
is
None
else
int
(
timeout
*
1000
)
timeout_ms
=
None
if
timeout
is
None
else
int
(
timeout
*
1000
)
if
not
socket
.
poll
(
timeout
=
timeout_ms
):
if
not
socket
.
poll
(
timeout
=
timeout_ms
):
raise
TimeoutError
raise
TimeoutError
recv
=
socket
.
recv
(
copy
=
False
)
recv
,
*
recv_oob
=
socket
.
recv
_multipart
(
copy
=
False
)
return
pickle
.
loads
(
recv
.
buffer
)
return
pickle
.
loads
(
recv
,
buffer
s
=
recv_oob
)
def
broadcast_object
(
self
,
obj
=
None
):
def
broadcast_object
(
self
,
obj
=
None
):
if
self
.
_is_writer
:
if
self
.
_is_writer
:
self
.
enqueue
(
obj
)
self
.
enqueue
(
obj
)
return
obj
return
obj
else
:
return
self
.
dequeue
()
return
self
.
dequeue
()
@
staticmethod
@
staticmethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment