Unverified Commit 20d0699d authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

[Fix] Fix comm test (#1691)

parent 686f5e32
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
Run `pytest tests/distributed/test_comm_ops.py --forked`. Run `pytest tests/distributed/test_comm_ops.py --forked`.
""" """
from multiprocessing import Process from multiprocessing import Process, set_start_method
import pytest import pytest
import torch import torch
...@@ -70,6 +70,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int, ...@@ -70,6 +70,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
@pytest.mark.parametrize("test_target", @pytest.mark.parametrize("test_target",
[all_reduce_test_worker, all_gather_test_worker]) [all_reduce_test_worker, all_gather_test_worker])
def test_multi_process_tensor_parallel(tensor_parallel_size, test_target): def test_multi_process_tensor_parallel(tensor_parallel_size, test_target):
set_start_method("spawn", force=True)
distributed_init_port = get_open_port() distributed_init_port = get_open_port()
processes = [] processes = []
for rank in range(tensor_parallel_size): for rank in range(tensor_parallel_size):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment