Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
20d0699d
Unverified
Commit
20d0699d
authored
Nov 16, 2023
by
Zhuohan Li
Committed by
GitHub
Nov 16, 2023
Browse files
[Fix] Fix comm test (#1691)
parent
686f5e32
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
1 deletion
+2
-1
tests/distributed/test_comm_ops.py
tests/distributed/test_comm_ops.py
+2
-1
No files found.
tests/distributed/test_comm_ops.py
View file @
20d0699d
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
Run `pytest tests/distributed/test_comm_ops.py --forked`.
Run `pytest tests/distributed/test_comm_ops.py --forked`.
"""
"""
from
multiprocessing
import
Process
from
multiprocessing
import
Process
,
set_start_method
import
pytest
import
pytest
import
torch
import
torch
...
@@ -70,6 +70,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
...
@@ -70,6 +70,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
@
pytest
.
mark
.
parametrize
(
"test_target"
,
@
pytest
.
mark
.
parametrize
(
"test_target"
,
[
all_reduce_test_worker
,
all_gather_test_worker
])
[
all_reduce_test_worker
,
all_gather_test_worker
])
def
test_multi_process_tensor_parallel
(
tensor_parallel_size
,
test_target
):
def
test_multi_process_tensor_parallel
(
tensor_parallel_size
,
test_target
):
set_start_method
(
"spawn"
,
force
=
True
)
distributed_init_port
=
get_open_port
()
distributed_init_port
=
get_open_port
()
processes
=
[]
processes
=
[]
for
rank
in
range
(
tensor_parallel_size
):
for
rank
in
range
(
tensor_parallel_size
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment