Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2a85f930
Unverified
Commit
2a85f930
authored
May 01, 2024
by
youkaichao
Committed by
GitHub
May 02, 2024
Browse files
[Core][Distributed] enable multiple tp group (#4512)
Co-authored-by:
Zhuohan Li
<
zhuohan123@gmail.com
>
parent
cf8cac8c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
43 additions
and
4 deletions
+43
-4
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+8
-3
.buildkite/test-template.j2
.buildkite/test-template.j2
+3
-0
tests/distributed/test_pynccl.py
tests/distributed/test_pynccl.py
+28
-0
vllm/distributed/device_communicators/pynccl.py
vllm/distributed/device_communicators/pynccl.py
+4
-1
No files found.
.buildkite/test-pipeline.yaml
View file @
2a85f930
...
...
@@ -25,19 +25,24 @@ steps:
-
label
:
Distributed Comm Ops Test
command
:
pytest -v -s test_comm_ops.py
working_dir
:
"
/vllm-workspace/tests/distributed"
num_gpus
:
2
# only support 1 or 2 for now.
num_gpus
:
2
-
label
:
Distributed Tests
working_dir
:
"
/vllm-workspace/tests/distributed"
num_gpus
:
2
# only support 1 or 2 for now.
num_gpus
:
2
commands
:
-
pytest -v -s test_pynccl.py
-
pytest -v -s test_pynccl_library.py
-
TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py
-
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py
-
TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py
-
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py
-
label
:
Distributed Tests (Multiple Groups)
working_dir
:
"
/vllm-workspace/tests/distributed"
num_gpus
:
4
commands
:
-
pytest -v -s test_pynccl.py
-
label
:
Engine Test
command
:
pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py
...
...
.buildkite/test-template.j2
View file @
2a85f930
...
...
@@ -45,6 +45,9 @@ steps:
plugins:
- kubernetes:
podSpec:
{% if step.num_gpus %}
priorityClassName: gpu-priority-cls-{{ step.num_gpus }}
{% endif %}
volumes:
- name: dshm
emptyDir:
...
...
tests/distributed/test_pynccl.py
View file @
2a85f930
...
...
@@ -58,6 +58,34 @@ def test_pynccl():
distributed_run
(
worker_fn
,
2
)
@
worker_fn_wrapper
def
multiple_tp_worker_fn
():
device
=
torch
.
device
(
f
"cuda:
{
torch
.
distributed
.
get_rank
()
}
"
)
groups
=
[
torch
.
distributed
.
new_group
(
ranks
=
[
0
,
1
],
backend
=
"gloo"
),
torch
.
distributed
.
new_group
(
ranks
=
[
2
,
3
],
backend
=
"gloo"
)
]
group
=
groups
[
0
]
if
torch
.
distributed
.
get_rank
()
in
[
0
,
1
]
else
groups
[
1
]
comm
=
NCCLCommunicator
(
group
=
group
,
device
=
device
)
tensor
=
torch
.
ones
(
16
,
1024
,
1024
,
dtype
=
torch
.
float32
).
cuda
(
comm
.
rank
)
# two groups can communicate independently
if
torch
.
distributed
.
get_rank
()
in
[
0
,
1
]:
comm
.
all_reduce
(
tensor
)
comm
.
all_reduce
(
tensor
)
result
=
tensor
.
mean
().
cpu
().
item
()
assert
result
==
4
else
:
comm
.
all_reduce
(
tensor
)
result
=
tensor
.
mean
().
cpu
().
item
()
assert
result
==
2
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
4
,
reason
=
"Need at least 2 GPUs to run the test."
)
def
test_pynccl_multiple_tp
():
distributed_run
(
worker_fn
,
4
)
@
worker_fn_wrapper
def
worker_fn_with_cudagraph
():
with
torch
.
no_grad
():
...
...
vllm/distributed/device_communicators/pynccl.py
View file @
2a85f930
...
...
@@ -232,6 +232,7 @@ class NCCLCommunicator:
assert
dist
.
get_backend
(
group
)
!=
dist
.
Backend
.
NCCL
,
(
"NCCLCommunicator should be attached to a non-NCCL group."
)
self
.
group
=
group
# note: this rank is the rank in the group
self
.
rank
=
dist
.
get_rank
(
group
)
self
.
world_size
=
dist
.
get_world_size
(
group
)
if
self
.
rank
==
0
:
...
...
@@ -239,7 +240,9 @@ class NCCLCommunicator:
else
:
self
.
unique_id
=
NcclUniqueId
()
tensor
=
torch
.
ByteTensor
(
list
(
self
.
unique_id
.
internal
))
dist
.
broadcast
(
tensor
,
src
=
0
,
group
=
group
)
ranks
=
dist
.
get_process_group_ranks
(
group
)
# arg `src` in `broadcast` is the global rank
dist
.
broadcast
(
tensor
,
src
=
ranks
[
0
],
group
=
group
)
byte_list
=
tensor
.
tolist
()
for
i
,
byte
in
enumerate
(
byte_list
):
self
.
unique_id
.
internal
[
i
]
=
byte
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment