Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0d4ea3fb
Unverified
Commit
0d4ea3fb
authored
Nov 12, 2024
by
youkaichao
Committed by
GitHub
Nov 12, 2024
Browse files
[core][distributed] use tcp store directly (#10275)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
112fa0bb
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
25 deletions
+29
-25
tests/distributed/test_utils.py
tests/distributed/test_utils.py
+16
-10
vllm/distributed/utils.py
vllm/distributed/utils.py
+13
-15
No files found.
tests/distributed/test_utils.py
View file @
0d4ea3fb
...
...
@@ -43,12 +43,15 @@ def test_cuda_device_count_stateless():
def
cpu_worker
(
rank
,
WORLD_SIZE
,
port1
,
port2
):
pg1
=
StatelessProcessGroup
.
create
(
init_method
=
f
"tcp://127.0.0.1:
{
port1
}
"
,
pg1
=
StatelessProcessGroup
.
create
(
host
=
"127.0.0.1"
,
port
=
port1
,
rank
=
rank
,
world_size
=
WORLD_SIZE
)
if
rank
<=
2
:
pg2
=
StatelessProcessGroup
.
create
(
init_method
=
f
"tcp://127.0.0.1:
{
port2
}
"
,
rank
=
rank
,
world_size
=
3
)
pg2
=
StatelessProcessGroup
.
create
(
host
=
"127.0.0.1"
,
port
=
port2
,
rank
=
rank
,
world_size
=
3
)
data
=
torch
.
tensor
([
rank
])
data
=
pg1
.
broadcast_obj
(
data
,
src
=
2
)
assert
data
.
item
()
==
2
...
...
@@ -62,14 +65,17 @@ def cpu_worker(rank, WORLD_SIZE, port1, port2):
def
gpu_worker
(
rank
,
WORLD_SIZE
,
port1
,
port2
):
torch
.
cuda
.
set_device
(
rank
)
pg1
=
StatelessProcessGroup
.
create
(
init_method
=
f
"tcp://127.0.0.1:
{
port1
}
"
,
pg1
=
StatelessProcessGroup
.
create
(
host
=
"127.0.0.1"
,
port
=
port1
,
rank
=
rank
,
world_size
=
WORLD_SIZE
)
pynccl1
=
PyNcclCommunicator
(
pg1
,
device
=
rank
)
pynccl1
.
disabled
=
False
if
rank
<=
2
:
pg2
=
StatelessProcessGroup
.
create
(
init_method
=
f
"tcp://127.0.0.1:
{
port2
}
"
,
rank
=
rank
,
world_size
=
3
)
pg2
=
StatelessProcessGroup
.
create
(
host
=
"127.0.0.1"
,
port
=
port2
,
rank
=
rank
,
world_size
=
3
)
pynccl2
=
PyNcclCommunicator
(
pg2
,
device
=
rank
)
pynccl2
.
disabled
=
False
data
=
torch
.
tensor
([
rank
]).
cuda
()
...
...
@@ -89,7 +95,8 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2):
def
broadcast_worker
(
rank
,
WORLD_SIZE
,
port1
,
port2
):
pg1
=
StatelessProcessGroup
.
create
(
init_method
=
f
"tcp://127.0.0.1:
{
port1
}
"
,
pg1
=
StatelessProcessGroup
.
create
(
host
=
"127.0.0.1"
,
port
=
port1
,
rank
=
rank
,
world_size
=
WORLD_SIZE
)
if
rank
==
2
:
...
...
@@ -101,7 +108,8 @@ def broadcast_worker(rank, WORLD_SIZE, port1, port2):
def
allgather_worker
(
rank
,
WORLD_SIZE
,
port1
,
port2
):
pg1
=
StatelessProcessGroup
.
create
(
init_method
=
f
"tcp://127.0.0.1:
{
port1
}
"
,
pg1
=
StatelessProcessGroup
.
create
(
host
=
"127.0.0.1"
,
port
=
port1
,
rank
=
rank
,
world_size
=
WORLD_SIZE
)
data
=
pg1
.
all_gather_obj
(
rank
)
...
...
@@ -109,8 +117,6 @@ def allgather_worker(rank, WORLD_SIZE, port1, port2):
pg1
.
barrier
()
# TODO: investigate why this test is flaky. It hangs during initialization.
@
pytest
.
mark
.
skip
(
"Skip the test because it is flaky."
)
@
multi_gpu_test
(
num_gpus
=
4
)
@
pytest
.
mark
.
parametrize
(
"worker"
,
[
cpu_worker
,
gpu_worker
,
broadcast_worker
,
allgather_worker
])
...
...
vllm/distributed/utils.py
View file @
0d4ea3fb
...
...
@@ -9,7 +9,7 @@ from collections import deque
from
typing
import
Any
,
Deque
,
Dict
,
Optional
,
Sequence
,
Tuple
import
torch
from
torch.distributed
.rendezvous
import
rendezvous
from
torch.distributed
import
TCPStore
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
...
...
@@ -97,7 +97,6 @@ class StatelessProcessGroup:
group. Only use it to communicate metadata between processes.
For data-plane communication, create NCCL-related objects.
"""
prefix
:
str
rank
:
int
world_size
:
int
store
:
torch
.
_C
.
_distributed_c10d
.
Store
...
...
@@ -127,7 +126,7 @@ class StatelessProcessGroup:
def
send_obj
(
self
,
obj
:
Any
,
dst
:
int
):
"""Send an object to a destination rank."""
self
.
expire_data
()
key
=
f
"
{
self
.
prefix
}
/
send_to/
{
dst
}
/
{
self
.
send_dst_counter
[
dst
]
}
"
key
=
f
"send_to/
{
dst
}
/
{
self
.
send_dst_counter
[
dst
]
}
"
self
.
store
.
set
(
key
,
pickle
.
dumps
(
obj
))
self
.
send_dst_counter
[
dst
]
+=
1
self
.
entries
.
append
((
key
,
time
.
time
()))
...
...
@@ -147,8 +146,7 @@ class StatelessProcessGroup:
"""Receive an object from a source rank."""
obj
=
pickle
.
loads
(
self
.
store
.
get
(
f
"
{
self
.
prefix
}
/send_to/
{
self
.
rank
}
/
{
self
.
recv_src_counter
[
src
]
}
"
))
f
"send_to/
{
self
.
rank
}
/
{
self
.
recv_src_counter
[
src
]
}
"
))
self
.
recv_src_counter
[
src
]
+=
1
return
obj
...
...
@@ -159,14 +157,14 @@ class StatelessProcessGroup:
"""
if
self
.
rank
==
src
:
self
.
expire_data
()
key
=
(
f
"
{
self
.
prefix
}
/
broadcast_from/
{
src
}
/"
key
=
(
f
"broadcast_from/
{
src
}
/"
f
"
{
self
.
broadcast_send_counter
}
"
)
self
.
store
.
set
(
key
,
pickle
.
dumps
(
obj
))
self
.
broadcast_send_counter
+=
1
self
.
entries
.
append
((
key
,
time
.
time
()))
return
obj
else
:
key
=
(
f
"
{
self
.
prefix
}
/
broadcast_from/
{
src
}
/"
key
=
(
f
"broadcast_from/
{
src
}
/"
f
"
{
self
.
broadcast_recv_src_counter
[
src
]
}
"
)
recv_obj
=
pickle
.
loads
(
self
.
store
.
get
(
key
))
self
.
broadcast_recv_src_counter
[
src
]
+=
1
...
...
@@ -194,7 +192,8 @@ class StatelessProcessGroup:
@
staticmethod
def
create
(
init_method
:
str
,
host
:
str
,
port
:
int
,
rank
:
int
,
world_size
:
int
,
data_expiration_seconds
:
int
=
3600
,
...
...
@@ -214,15 +213,14 @@ class StatelessProcessGroup:
can call `StatelessProcessGroup.create` to form a group, and then process A, B,
C, and D can call `StatelessProcessGroup.create` to form another group.
"""
# noqa
from
torch._C._distributed_c10d
import
_DEFAULT_PG_TIMEOUT
timeout
=
_DEFAULT_PG_TIMEOUT
store
,
rank
,
world_size
=
next
(
rendezvous
(
init_method
,
rank
,
world_size
,
timeout
=
timeout
))
store
.
set_timeout
(
timeout
)
store
=
TCPStore
(
host_name
=
host
,
port
=
port
,
world_size
=
world_size
,
is_master
=
(
rank
==
0
),
)
return
StatelessProcessGroup
(
prefix
=
init_method
,
rank
=
rank
,
world_size
=
world_size
,
store
=
store
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment