Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
740374d4
Unverified
Commit
740374d4
authored
Jul 24, 2024
by
youkaichao
Committed by
GitHub
Jul 24, 2024
Browse files
[core][distributed] fix zmq hang (#6759)
parent
d88c458f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
41 deletions
+23
-41
vllm/connections.py
vllm/connections.py
+2
-2
vllm/distributed/device_communicators/shm_broadcast.py
vllm/distributed/device_communicators/shm_broadcast.py
+21
-39
No files found.
vllm/connections.py
View file @
740374d4
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Mapping
,
Optional
from
typing
import
Mapping
,
MutableMapping
,
Optional
from
urllib.parse
import
urlparse
from
urllib.parse
import
urlparse
import
aiohttp
import
aiohttp
...
@@ -40,7 +40,7 @@ class HTTPConnection:
...
@@ -40,7 +40,7 @@ class HTTPConnection:
raise
ValueError
(
"Invalid HTTP URL: A valid HTTP URL "
raise
ValueError
(
"Invalid HTTP URL: A valid HTTP URL "
"must have scheme 'http' or 'https'."
)
"must have scheme 'http' or 'https'."
)
def
_headers
(
self
,
**
extras
:
str
)
->
Mapping
[
str
,
str
]:
def
_headers
(
self
,
**
extras
:
str
)
->
Mutable
Mapping
[
str
,
str
]:
return
{
"User-Agent"
:
f
"vLLM/
{
VLLM_VERSION
}
"
,
**
extras
}
return
{
"User-Agent"
:
f
"vLLM/
{
VLLM_VERSION
}
"
,
**
extras
}
def
get_response
(
def
get_response
(
...
...
vllm/distributed/device_communicators/shm_broadcast.py
View file @
740374d4
...
@@ -9,7 +9,7 @@ from unittest.mock import patch
...
@@ -9,7 +9,7 @@ from unittest.mock import patch
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
zmq
import
P
UB
,
REP
,
REQ
,
SUB
,
SUBSCRIB
E
,
Context
# type: ignore
from
zmq
import
S
UB
,
SUBSCRIBE
,
XPUB
,
XPUB_VERBOS
E
,
Context
# type: ignore
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -153,9 +153,7 @@ class Handle:
...
@@ -153,9 +153,7 @@ class Handle:
buffer
:
Optional
[
ShmRingBuffer
]
=
None
buffer
:
Optional
[
ShmRingBuffer
]
=
None
local_subscribe_port
:
Optional
[
int
]
=
None
local_subscribe_port
:
Optional
[
int
]
=
None
local_sync_port
:
Optional
[
int
]
=
None
remote_subscribe_port
:
Optional
[
int
]
=
None
remote_subscribe_port
:
Optional
[
int
]
=
None
remote_sync_port
:
Optional
[
int
]
=
None
class
MessageQueue
:
class
MessageQueue
:
...
@@ -189,38 +187,36 @@ class MessageQueue:
...
@@ -189,38 +187,36 @@ class MessageQueue:
self
.
buffer
=
ShmRingBuffer
(
n_local_reader
,
max_chunk_bytes
,
self
.
buffer
=
ShmRingBuffer
(
n_local_reader
,
max_chunk_bytes
,
max_chunks
)
max_chunks
)
self
.
local_socket
=
context
.
socket
(
PUB
)
# XPUB is very similar to PUB,
# except that it can receive subscription messages
# to confirm the number of subscribers
self
.
local_socket
=
context
.
socket
(
XPUB
)
# set the verbose option so that we can receive every subscription
# message. otherwise, we will only receive the first subscription
# see http://api.zeromq.org/3-3:zmq-setsockopt for more details
self
.
local_socket
.
setsockopt
(
XPUB_VERBOSE
,
True
)
local_subscribe_port
=
get_open_port
()
local_subscribe_port
=
get_open_port
()
self
.
local_socket
.
bind
(
f
"tcp://*:
{
local_subscribe_port
}
"
)
self
.
local_socket
.
bind
(
f
"tcp://*:
{
local_subscribe_port
}
"
)
self
.
local_sync_socket
=
context
.
socket
(
REP
)
local_sync_port
=
get_open_port
()
self
.
local_sync_socket
.
bind
(
f
"tcp://*:
{
local_sync_port
}
"
)
self
.
current_idx
=
0
self
.
current_idx
=
0
else
:
else
:
self
.
buffer
=
None
# type: ignore
self
.
buffer
=
None
# type: ignore
local_subscribe_port
=
None
local_subscribe_port
=
None
local_sync_port
=
None
self
.
local_socket
=
None
self
.
local_socket
=
None
self
.
local_sync_socket
=
None
self
.
current_idx
=
-
1
self
.
current_idx
=
-
1
if
n_remote_reader
>
0
:
if
n_remote_reader
>
0
:
# for remote readers, we will:
# for remote readers, we will:
# create a publish-subscribe socket to communicate large data
# create a publish-subscribe socket to communicate large data
self
.
remote_socket
=
context
.
socket
(
PUB
)
self
.
remote_socket
=
context
.
socket
(
XPUB
)
self
.
remote_socket
.
setsockopt
(
XPUB_VERBOSE
,
True
)
remote_subscribe_port
=
get_open_port
()
remote_subscribe_port
=
get_open_port
()
self
.
remote_socket
.
bind
(
f
"tcp://*:
{
remote_subscribe_port
}
"
)
self
.
remote_socket
.
bind
(
f
"tcp://*:
{
remote_subscribe_port
}
"
)
self
.
remote_sync_socket
=
context
.
socket
(
REP
)
remote_sync_port
=
get_open_port
()
self
.
remote_sync_socket
.
bind
(
f
"tcp://*:
{
remote_sync_port
}
"
)
else
:
else
:
remote_subscribe_port
=
None
remote_subscribe_port
=
None
remote_sync_port
=
None
self
.
remote_socket
=
None
self
.
remote_socket
=
None
self
.
remote_sync_socket
=
None
self
.
_is_writer
=
True
self
.
_is_writer
=
True
self
.
_is_local_reader
=
False
self
.
_is_local_reader
=
False
...
@@ -233,9 +229,7 @@ class MessageQueue:
...
@@ -233,9 +229,7 @@ class MessageQueue:
local_reader_ranks
=
local_reader_ranks
,
local_reader_ranks
=
local_reader_ranks
,
buffer
=
self
.
buffer
,
buffer
=
self
.
buffer
,
local_subscribe_port
=
local_subscribe_port
,
local_subscribe_port
=
local_subscribe_port
,
local_sync_port
=
local_sync_port
,
remote_subscribe_port
=
remote_subscribe_port
,
remote_subscribe_port
=
remote_subscribe_port
,
remote_sync_port
=
remote_sync_port
,
)
)
logger
.
info
(
"vLLM message queue communication handle: %s"
,
self
.
handle
)
logger
.
info
(
"vLLM message queue communication handle: %s"
,
self
.
handle
)
...
@@ -264,12 +258,7 @@ class MessageQueue:
...
@@ -264,12 +258,7 @@ class MessageQueue:
self
.
local_socket
.
connect
(
self
.
local_socket
.
connect
(
f
"tcp://
{
handle
.
connect_ip
}
:
{
handle
.
local_subscribe_port
}
"
)
f
"tcp://
{
handle
.
connect_ip
}
:
{
handle
.
local_subscribe_port
}
"
)
self
.
local_sync_socket
=
context
.
socket
(
REQ
)
self
.
local_sync_socket
.
connect
(
f
"tcp://
{
handle
.
connect_ip
}
:
{
handle
.
local_sync_port
}
"
)
self
.
remote_socket
=
None
self
.
remote_socket
=
None
self
.
remote_sync_socket
=
None
else
:
else
:
self
.
buffer
=
None
# type: ignore
self
.
buffer
=
None
# type: ignore
self
.
current_idx
=
-
1
self
.
current_idx
=
-
1
...
@@ -278,17 +267,12 @@ class MessageQueue:
...
@@ -278,17 +267,12 @@ class MessageQueue:
self
.
_is_remote_reader
=
True
self
.
_is_remote_reader
=
True
self
.
local_socket
=
None
self
.
local_socket
=
None
self
.
local_sync_socket
=
None
self
.
remote_socket
=
context
.
socket
(
SUB
)
self
.
remote_socket
=
context
.
socket
(
SUB
)
self
.
remote_socket
.
setsockopt_string
(
SUBSCRIBE
,
""
)
self
.
remote_socket
.
setsockopt_string
(
SUBSCRIBE
,
""
)
self
.
remote_socket
.
connect
(
self
.
remote_socket
.
connect
(
f
"tcp://
{
handle
.
connect_ip
}
:
{
handle
.
remote_subscribe_port
}
"
)
f
"tcp://
{
handle
.
connect_ip
}
:
{
handle
.
remote_subscribe_port
}
"
)
self
.
remote_sync_socket
=
context
.
socket
(
REQ
)
self
.
remote_sync_socket
.
connect
(
f
"tcp://
{
handle
.
connect_ip
}
:
{
handle
.
remote_sync_port
}
"
)
return
self
return
self
def
wait_until_ready
(
self
):
def
wait_until_ready
(
self
):
...
@@ -300,29 +284,27 @@ class MessageQueue:
...
@@ -300,29 +284,27 @@ class MessageQueue:
# local readers
# local readers
for
i
in
range
(
self
.
n_local_reader
):
for
i
in
range
(
self
.
n_local_reader
):
recv
=
self
.
local_sync_socket
.
recv
()
# wait for subscription messages from all local readers
assert
recv
==
b
"READY"
self
.
local_socket
.
recv
()
self
.
local_sync_socket
.
send
(
b
"READY"
)
if
self
.
n_local_reader
>
0
:
if
self
.
n_local_reader
>
0
:
# send a message to all local readers
# to make sure the publish channel is working
self
.
local_socket
.
send
(
b
"READY"
)
self
.
local_socket
.
send
(
b
"READY"
)
# remote readers
# remote readers
for
i
in
range
(
self
.
n_remote_reader
):
for
i
in
range
(
self
.
n_remote_reader
):
recv
=
self
.
remote_sync_socket
.
recv
()
# wait for subscription messages from all remote readers
assert
recv
==
b
"READY"
self
.
remote_socket
.
recv
()
self
.
remote_sync_socket
.
send
(
b
"READY"
)
if
self
.
n_remote_reader
>
0
:
if
self
.
n_remote_reader
>
0
:
# send a message to all remote readers
# to make sure the publish channel is working
self
.
remote_socket
.
send
(
b
"READY"
)
self
.
remote_socket
.
send
(
b
"READY"
)
elif
self
.
_is_local_reader
:
elif
self
.
_is_local_reader
:
self
.
local_sync_socket
.
send
(
b
"READY"
)
# wait for the writer to send a message
recv
=
self
.
local_sync_socket
.
recv
()
assert
recv
==
b
"READY"
recv
=
self
.
local_socket
.
recv
()
recv
=
self
.
local_socket
.
recv
()
assert
recv
==
b
"READY"
assert
recv
==
b
"READY"
elif
self
.
_is_remote_reader
:
elif
self
.
_is_remote_reader
:
self
.
remote_sync_socket
.
send
(
b
"READY"
)
# wait for the writer to send a message
recv
=
self
.
remote_sync_socket
.
recv
()
assert
recv
==
b
"READY"
recv
=
self
.
remote_socket
.
recv
()
recv
=
self
.
remote_socket
.
recv
()
assert
recv
==
b
"READY"
assert
recv
==
b
"READY"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment