"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "2e209c30cf6f2ba42001d0629dc6b7ce354b9a9d"
Unverified Commit c5e8481c authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[benchmark] fix multi process issue (#5608)

parent ae8cbde5
...@@ -27,14 +27,6 @@ from torch.utils.data import DataLoader ...@@ -27,14 +27,6 @@ from torch.utils.data import DataLoader
from .. import utils from .. import utils
# import sys
# import os
# dir_path = Path(os.path.dirname(__file__))
# sys.path.insert(0, dir_path.parent)
# import utils
class EntityClassify(nn.Module): class EntityClassify(nn.Module):
def __init__( def __init__(
...@@ -475,7 +467,7 @@ def track_time(data, dgl_sparse): ...@@ -475,7 +467,7 @@ def track_time(data, dgl_sparse):
n_gpus = len(devices) n_gpus = len(devices)
n_cpus = mp.cpu_count() n_cpus = mp.cpu_count()
ctx = mp.get_context("fork") ctx = mp.get_context("spawn")
queue = ctx.Queue() queue = ctx.Queue()
procs = [] procs = []
num_train_seeds = train_idx.shape[0] num_train_seeds = train_idx.shape[0]
......
...@@ -53,8 +53,7 @@ def load_subtensor(nfeat, labels, seeds, input_nodes, dev_id): ...@@ -53,8 +53,7 @@ def load_subtensor(nfeat, labels, seeds, input_nodes, dev_id):
# Entry point # Entry point
@utils.thread_wrapped_func
def run(result_queue, proc_id, n_gpus, args, devices, data): def run(result_queue, proc_id, n_gpus, args, devices, data):
dev_id = devices[proc_id] dev_id = devices[proc_id]
timing_records = [] timing_records = []
...@@ -174,11 +173,12 @@ def track_time(data): ...@@ -174,11 +173,12 @@ def track_time(data):
# Pack data # Pack data
data = n_classes, train_g, val_g, test_g data = n_classes, train_g, val_g, test_g
result_queue = mp.Queue() ctx = mp.get_context("spawn")
result_queue = ctx.Queue()
procs = [] procs = []
for proc_id in range(n_gpus): for proc_id in range(n_gpus):
p = mp.Process( p = ctx.Process(
target=utils.thread_wrapped_func(run), target=run,
args=(result_queue, proc_id, n_gpus, args, devices, data), args=(result_queue, proc_id, n_gpus, args, devices, data),
) )
p.start() p.start()
......
...@@ -19,6 +19,8 @@ echo "DGL_BENCH_DEVICE=$DGL_BENCH_DEVICE" ...@@ -19,6 +19,8 @@ echo "DGL_BENCH_DEVICE=$DGL_BENCH_DEVICE"
pushd $ROOT/benchmarks pushd $ROOT/benchmarks
cat asv.conf.json cat asv.conf.json
asv machine --yes asv machine --yes
asv run --launch-method=spawn -e -v # If --launch-method is specified as 'spawn', multigpu tests will crash with
# "No module named 'benchmarks' is found".
asv run -e -v
asv publish asv publish
popd popd
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment