Unverified Commit caa6d607 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Feature] Replacing thread_wrapped_func with minimal mp.Process wrapper (#2905)

* standardizing thread_wrapped_func

* lints

* Update __init__.py
parent a90296aa
......@@ -15,39 +15,6 @@ from ogb.nodeproppred import DglNodePropPredDataset
from functools import partial, reduce, wraps
import torch.multiprocessing as mp
from _thread import start_new_thread
import traceback
def thread_wrapped_func(func):
"""
Wraps a process entry point to make it work with OpenMP.
"""
@wraps(func)
def decorated_function(*args, **kwargs):
queue = mp.Queue()
def _queue_result():
exception, trace, res = None, None, None
try:
res = func(*args, **kwargs)
except Exception as e:
exception = e
trace = traceback.format_exc()
queue.put((res, exception, trace))
start_new_thread(_queue_result, ())
result, exception, trace = queue.get()
if exception is None:
return result
else:
assert isinstance(exception, Exception)
raise exception.__class__(trace)
return decorated_function
def _download(url, path, filename):
fn = os.path.join(path, filename)
......@@ -520,7 +487,7 @@ def benchmark(track_type, timeout=60):
if not filter.check(func):
# skip if not enabled
func.benchmark_name = "skip_" + func.__name__
return thread_wrapped_func(func)
return func
return _wrapper
#####################################
......
.. _apimultiprocessing:
dgl.multiprocessing
===================
This is a minimal wrapper of Python's native :mod:`multiprocessing` module.
It modifies the :class:`multiprocessing.Process` class to make forking
work with OpenMP in the DGL core library.
The API usage is exactly the same as the native module, so DGL does not provide
additional documentation.
In addition, if your backend is PyTorch, this module will also be compatible with
:mod:`torch.multiprocessing` module.
......@@ -44,6 +44,7 @@ Welcome to Deep Graph Library Tutorials and Documentation
api/python/dgl.ops
api/python/dgl.optim
api/python/dgl.sampling
api/python/dgl.multiprocessing
api/python/udf
.. toctree::
......
......@@ -9,13 +9,13 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel
from tqdm.auto import tqdm
from numpy import random
from torch.nn.parameter import Parameter
import dgl
import dgl.function as fn
import dgl.multiprocessing as mp
from utils import *
......@@ -481,10 +481,7 @@ def train_model(network_data):
else:
procs = []
for proc_id in range(n_gpus):
p = mp.Process(
target=thread_wrapped_func(run),
args=(proc_id, n_gpus, args, devices, data),
)
p = mp.Process(target=run, args=(proc_id, n_gpus, args, devices, data))
p.start()
procs.append(p)
for p in procs:
......
......@@ -12,39 +12,6 @@ import time
import multiprocessing
from functools import partial, reduce, wraps
import torch.multiprocessing as mp
from _thread import start_new_thread
import traceback
def thread_wrapped_func(func):
"""
Wraps a process entry point to make it work with OpenMP.
"""
@wraps(func)
def decorated_function(*args, **kwargs):
queue = mp.Queue()
def _queue_result():
exception, trace, res = None, None, None
try:
res = func(*args, **kwargs)
except Exception as e:
exception = e
trace = traceback.format_exc()
queue.put((res, exception, trace))
start_new_thread(_queue_result, ())
result, exception, trace = queue.get()
if exception is None:
return result
else:
assert isinstance(exception, Exception)
raise exception.__class__(trace)
return decorated_function
def parse_args():
parser = argparse.ArgumentParser()
......
......@@ -14,16 +14,14 @@ import numpy as np
import tqdm
import torch as th
import torch.nn as nn
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
from torch.multiprocessing import Queue
from torch.nn.parallel import DistributedDataParallel
from _thread import start_new_thread
from functools import wraps
from data import MovieLens
from model import GCMCLayer, DenseBiDecoder, BiDecoder
from utils import get_activation, get_optimizer, torch_total_param_num, torch_net_info, MetricLogger, to_etype_name
import dgl
import dgl.multiprocessing as mp
from dgl.multiprocessing import Queue
class Net(nn.Module):
def __init__(self, args, dev_id):
......@@ -136,33 +134,6 @@ def evaluate(args, dev_id, net, dataset, dataloader, segment='valid'):
rmse = np.sqrt(rmse)
return rmse
# According to https://github.com/pytorch/pytorch/issues/17199, this decorator
# is necessary to make fork() and openmp work together.
def thread_wrapped_func(func):
"""
Wraps a process entry point to make it work with OpenMP.
"""
@wraps(func)
def decorated_function(*args, **kwargs):
queue = Queue()
def _queue_result():
exception, trace, res = None, None, None
try:
res = func(*args, **kwargs)
except Exception as e:
exception = e
trace = traceback.format_exc()
queue.put((res, exception, trace))
start_new_thread(_queue_result, ())
result, exception, trace = queue.get()
if exception is None:
return result
else:
assert isinstance(exception, Exception)
raise exception.__class__(trace)
return decorated_function
def config():
parser = argparse.ArgumentParser(description='GCMC')
parser.add_argument('--seed', default=123, type=int)
......@@ -409,7 +380,7 @@ if __name__ == '__main__':
dataset.train_dec_graph.create_formats_()
procs = []
for proc_id in range(n_gpus):
p = mp.Process(target=thread_wrapped_func(run), args=(proc_id, n_gpus, args, devices, dataset))
p = mp.Process(target=run, args=(proc_id, n_gpus, args, devices, dataset))
p.start()
procs.append(p)
for p in procs:
......
......@@ -4,14 +4,12 @@ import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
import dgl.multiprocessing as mp
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import time
import argparse
import tqdm
from _thread import start_new_thread
from functools import wraps
from dgl.data import RedditDataset
from torch.utils.data import DataLoader
from torch.nn.parallel import DistributedDataParallel
......
......@@ -4,7 +4,7 @@ import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
import dgl.multiprocessing as mp
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import time
......@@ -12,8 +12,6 @@ import argparse
import tqdm
import traceback
import math
from _thread import start_new_thread
from functools import wraps
from dgl.data import RedditDataset
from torch.utils.data import DataLoader
from torch.nn.parallel import DistributedDataParallel
......@@ -148,34 +146,6 @@ class NeighborSampler(object):
hist_blocks.insert(0, hist_block)
return blocks, hist_blocks
# According to https://github.com/pytorch/pytorch/issues/17199, this decorator
# is necessary to make fork() and openmp work together.
#
# TODO: confirm if this is necessary for MXNet and Tensorflow. If so, we need
# to standardize worker process creation since our operators are implemented with
# OpenMP.
def thread_wrapped_func(func):
@wraps(func)
def decorated_function(*args, **kwargs):
queue = mp.Queue()
def _queue_result():
exception, trace, res = None, None, None
try:
res = func(*args, **kwargs)
except Exception as e:
exception = e
trace = traceback.format_exc()
queue.put((res, exception, trace))
start_new_thread(_queue_result, ())
result, exception, trace = queue.get()
if exception is None:
return result
else:
assert isinstance(exception, Exception)
raise exception.__class__(trace)
return decorated_function
def compute_acc(pred, labels):
"""
Compute the accuracy of prediction given the labels.
......@@ -245,7 +215,6 @@ def update_history(g, blocks):
h_new = block.dstdata['h_new'].cpu()
g.ndata[hist_col][ids] = h_new
@thread_wrapped_func
def run(proc_id, n_gpus, args, devices, data):
dropout = 0.2
......
......@@ -4,7 +4,7 @@ import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
import dgl.multiprocessing as mp
import dgl.nn.pytorch as dglnn
import time
import math
......@@ -13,7 +13,6 @@ from torch.nn.parallel import DistributedDataParallel
import tqdm
from model import SAGE
from utils import thread_wrapped_func
from load_graph import load_reddit, inductive_split
def compute_acc(pred, labels):
......@@ -217,8 +216,7 @@ if __name__ == '__main__':
else:
procs = []
for proc_id in range(n_gpus):
p = mp.Process(target=thread_wrapped_func(run),
args=(proc_id, n_gpus, args, devices, data))
p = mp.Process(target=run, args=(proc_id, n_gpus, args, devices, data))
p.start()
procs.append(p)
for p in procs:
......
......@@ -4,7 +4,7 @@ import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
import dgl.multiprocessing as mp
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import time
......@@ -13,7 +13,6 @@ from dgl.data import RedditDataset
from torch.nn.parallel import DistributedDataParallel
import tqdm
from utils import thread_wrapped_func
from model import SAGE, compute_acc_unsupervised as compute_acc
from negative_sampler import NegativeSampler
......@@ -191,8 +190,7 @@ def main(args, devices):
else:
procs = []
for proc_id in range(n_gpus):
p = mp.Process(target=thread_wrapped_func(run),
args=(proc_id, n_gpus, args, devices, data))
p = mp.Process(target=run, args=(proc_id, n_gpus, args, devices, data))
p.start()
procs.append(p)
for p in procs:
......
......@@ -10,8 +10,6 @@ import dgl.function as fn
import dgl.nn.pytorch as dglnn
import time
import argparse
from _thread import start_new_thread
from functools import wraps
from dgl.data import RedditDataset
import tqdm
import traceback
......
import torch
import argparse
import dgl
import torch.multiprocessing as mp
import dgl.multiprocessing as mp
from torch.utils.data import DataLoader
import os
import random
......@@ -10,7 +10,7 @@ import numpy as np
from reading_data import DeepwalkDataset
from model import SkipGramModel
from utils import thread_wrapped_func, shuffle_walks, sum_up_params
from utils import shuffle_walks, sum_up_params
class DeepwalkTrainer:
def __init__(self, args):
......@@ -110,7 +110,6 @@ class DeepwalkTrainer:
else:
self.emb_model.save_embedding(self.dataset, self.args.output_emb_file)
@thread_wrapped_func
def fast_train_sp(self, rank, gpu_id):
""" a subprocess for fast_train_mp """
if self.args.mix:
......
......@@ -4,10 +4,9 @@ import torch.nn.functional as F
from torch.nn import init
import random
import numpy as np
import torch.multiprocessing as mp
from torch.multiprocessing import Queue
import dgl.multiprocessing as mp
from dgl.multiprocessing import Queue
from utils import thread_wrapped_func
def init_emb2pos_index(walk_length, window_size, batch_size):
''' select embedding of positive nodes from a batch of node embeddings
......@@ -110,7 +109,6 @@ def adam(grad, state_sum, nodes, lr, device, only_gpu):
return grad
@thread_wrapped_func
def async_update(num_threads, model, queue):
""" asynchronous embedding update """
torch.set_num_threads(num_threads)
......
import torch
from functools import wraps
from _thread import start_new_thread
import torch.multiprocessing as mp
def thread_wrapped_func(func):
"""Wrapped func for torch.multiprocessing.Process.
With this wrapper we can use OMP threads in subprocesses
otherwise, OMP_NUM_THREADS=1 is mandatory.
How to use:
@thread_wrapped_func
def func_to_wrap(args ...):
"""
@wraps(func)
def decorated_function(*args, **kwargs):
queue = mp.Queue()
def _queue_result():
exception, trace, res = None, None, None
try:
res = func(*args, **kwargs)
except Exception as e:
exception = e
trace = traceback.format_exc()
queue.put((res, exception, trace))
start_new_thread(_queue_result, ())
result, exception, trace = queue.get()
if exception is None:
return result
else:
assert isinstance(exception, Exception)
raise exception.__class__(trace)
return decorated_function
def shuffle_walks(walks):
seeds = torch.randperm(walks.size()[0])
......
import torch
import argparse
import dgl
import torch.multiprocessing as mp
import dgl.multiprocessing as mp
from torch.utils.data import DataLoader
import os
import random
......@@ -10,7 +10,7 @@ import numpy as np
from reading_data import LineDataset
from model import SkipGramModel
from utils import thread_wrapped_func, sum_up_params, check_args
from utils import sum_up_params, check_args
class LineTrainer:
def __init__(self, args):
......@@ -102,7 +102,6 @@ class LineTrainer:
else:
self.emb_model.save_embedding(self.dataset, self.args.output_emb_file)
@thread_wrapped_func
def fast_train_sp(self, rank, gpu_id):
""" a subprocess for fast_train_mp """
if self.args.mix:
......
......@@ -4,10 +4,8 @@ import torch.nn.functional as F
from torch.nn import init
import random
import numpy as np
import torch.multiprocessing as mp
from torch.multiprocessing import Queue
from utils import thread_wrapped_func
import dgl.multiprocessing as mp
from dgl.multiprocessing import Queue
def init_emb2neg_index(negative, batch_size):
'''select embedding of negative nodes from a batch of node embeddings
......@@ -44,7 +42,6 @@ def adam(grad, state_sum, nodes, lr, device, only_gpu):
return grad
@thread_wrapped_func
def async_update(num_threads, model, queue):
""" Asynchronous embedding update for entity embeddings.
"""
......
import torch
from functools import wraps
from _thread import start_new_thread
import torch.multiprocessing as mp
def thread_wrapped_func(func):
"""Wrapped func for torch.multiprocessing.Process.
With this wrapper we can use OMP threads in subprocesses
otherwise, OMP_NUM_THREADS=1 is mandatory.
How to use:
@thread_wrapped_func
def func_to_wrap(args ...):
"""
@wraps(func)
def decorated_function(*args, **kwargs):
queue = mp.Queue()
def _queue_result():
exception, trace, res = None, None, None
try:
res = func(*args, **kwargs)
except Exception as e:
exception = e
trace = traceback.format_exc()
queue.put((res, exception, trace))
start_new_thread(_queue_result, ())
result, exception, trace = queue.get()
if exception is None:
return result
else:
assert isinstance(exception, Exception)
raise exception.__class__(trace)
return decorated_function
def check_args(args):
flag = sum([args.only_1st, args.only_2nd])
......
......@@ -11,11 +11,10 @@ import dgl.nn as dglnn
import torch.nn as nn
import torch.nn.functional as F
import argparse
import torch.multiprocessing as mp
import dgl.multiprocessing as mp
import sys
from torch.nn.parallel import DistributedDataParallel
from collections import OrderedDict
from utils import thread_wrapped_func
class RGAT(nn.Module):
......@@ -285,7 +284,7 @@ if __name__ == '__main__':
procs = []
for proc_id in range(n_gpus):
p = mp.Process(target=thread_wrapped_func(train), args=(proc_id, n_gpus, args, dataset, g, feats, paper_offset))
p = mp.Process(target=train, args=(proc_id, n_gpus, args, dataset, g, feats, paper_offset))
p.start()
procs.append(p)
......
#### Miscellaneous functions
# According to https://github.com/pytorch/pytorch/issues/17199, this decorator
# is necessary to make fork() and openmp work together.
#
# TODO: confirm if this is necessary for MXNet and Tensorflow. If so, we need
# to standardize worker process creation since our operators are implemented with
# OpenMP.
import torch.multiprocessing as mp
from _thread import start_new_thread
from functools import wraps
import traceback
def thread_wrapped_func(func):
"""
Wraps a process entry point to make it work with OpenMP.
"""
@wraps(func)
def decorated_function(*args, **kwargs):
queue = mp.Queue()
def _queue_result():
exception, trace, res = None, None, None
try:
res = func(*args, **kwargs)
except Exception as e:
exception = e
trace = traceback.format_exc()
queue.put((res, exception, trace))
start_new_thread(_queue_result, ())
result, exception, trace = queue.get()
if exception is None:
return result
else:
assert isinstance(exception, Exception)
raise exception.__class__(trace)
return decorated_function
\ No newline at end of file
......@@ -12,8 +12,8 @@ import time
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.multiprocessing as mp
from torch.multiprocessing import Queue
import dgl.multiprocessing as mp
from dgl.multiprocessing import Queue
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
import dgl
......@@ -23,7 +23,6 @@ from functools import partial
from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
from model import RelGraphEmbedLayer
from dgl.nn import RelGraphConv
from utils import thread_wrapped_func
import tqdm
from ogb.nodeproppred import DglNodePropPredDataset
......@@ -195,7 +194,6 @@ def evaluate(model, embed_layer, eval_loader, node_feats):
return eval_logits, eval_seeds
@thread_wrapped_func
def run(proc_id, n_gpus, n_cpus, args, devices, dataset, split, queue=None):
dev_id = devices[proc_id] if devices[proc_id] != 'cpu' else -1
g, node_feats, num_of_ntype, num_classes, num_rels, target_idx, \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment