Unverified Commit 55c21c88 authored by Micah Williamson's avatar Micah Williamson Committed by GitHub
Browse files

[ROCm][CI] Fix "Cannot re-initialize CUDA in forked subprocess" in test_pynccl.py (#29119)


Signed-off-by: default avatarMicah Williamson <micah.williamson@amd.com>
parent 3999442f
...@@ -40,5 +40,8 @@ mteb[bm25s]>=1.38.11, <2 ...@@ -40,5 +40,8 @@ mteb[bm25s]>=1.38.11, <2
# Required for eval tests # Required for eval tests
lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d
# Required for multiprocessed tests that use spawn method
multiprocess==0.70.16
# Plugins test # Plugins test
terratorch @ git+https://github.com/IBM/terratorch.git@07184fcf91a1324f831ff521dd238d97fe350e3e terratorch @ git+https://github.com/IBM/terratorch.git@07184fcf91a1324f831ff521dd238d97fe350e3e
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import multiprocessing
import os import os
import multiprocess as mp
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
...@@ -20,10 +20,12 @@ from vllm.distributed.parallel_state import ( ...@@ -20,10 +20,12 @@ from vllm.distributed.parallel_state import (
) )
from vllm.utils.system_utils import update_environment_variables from vllm.utils.system_utils import update_environment_variables
mp.set_start_method("spawn", force=True)
def distributed_run(fn, world_size): def distributed_run(fn, world_size):
number_of_processes = world_size number_of_processes = world_size
processes: list[multiprocessing.Process] = [] processes: list[mp.Process] = []
for i in range(number_of_processes): for i in range(number_of_processes):
env: dict[str, str] = {} env: dict[str, str] = {}
env["RANK"] = str(i) env["RANK"] = str(i)
...@@ -32,7 +34,7 @@ def distributed_run(fn, world_size): ...@@ -32,7 +34,7 @@ def distributed_run(fn, world_size):
env["LOCAL_WORLD_SIZE"] = str(number_of_processes) env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
env["MASTER_ADDR"] = "localhost" env["MASTER_ADDR"] = "localhost"
env["MASTER_PORT"] = "12345" env["MASTER_PORT"] = "12345"
p = multiprocessing.Process(target=fn, args=(env,)) p = mp.Process(target=fn, args=(env,))
processes.append(p) processes.append(p)
p.start() p.start()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment