Unverified Commit caa6d607 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Feature] Replacing thread_wrapped_func with minimal mp.Process wrapper (#2905)

* standardizing thread_wrapped_func

* lints

* Update __init__.py
parent a90296aa
...@@ -4,10 +4,6 @@ Most code is adapted from authors' implementation of RGCN link prediction: ...@@ -4,10 +4,6 @@ Most code is adapted from authors' implementation of RGCN link prediction:
https://github.com/MichSchli/RelationPrediction https://github.com/MichSchli/RelationPrediction
""" """
import traceback
from _thread import start_new_thread
from functools import wraps
import numpy as np import numpy as np
import torch import torch
from torch.multiprocessing import Queue from torch.multiprocessing import Queue
...@@ -338,37 +334,3 @@ def calc_mrr(embedding, w, train_triplets, valid_triplets, test_triplets, hits=[ ...@@ -338,37 +334,3 @@ def calc_mrr(embedding, w, train_triplets, valid_triplets, test_triplets, hits=[
else: else:
mrr = calc_raw_mrr(embedding, w, test_triplets, hits, eval_bz) mrr = calc_raw_mrr(embedding, w, test_triplets, hits, eval_bz)
return mrr return mrr
#######################################################################
#
# Multithread wrapper
#
#######################################################################
# According to https://github.com/pytorch/pytorch/issues/17199, this decorator
# is necessary to make fork() and openmp work together.
def thread_wrapped_func(func):
"""
Wraps a process entry point to make it work with OpenMP.
"""
@wraps(func)
def decorated_function(*args, **kwargs):
queue = Queue()
def _queue_result():
exception, trace, res = None, None, None
try:
res = func(*args, **kwargs)
except Exception as e:
exception = e
trace = traceback.format_exc()
queue.put((res, exception, trace))
start_new_thread(_queue_result, ())
result, exception, trace = queue.get()
if exception is None:
return result
else:
assert isinstance(exception, Exception)
raise exception.__class__(trace)
return decorated_function
"""Wrapper of the multiprocessing module for multi-GPU training."""
# To avoid duplicating the graph structure for node classification or link prediction
# training we recommend using fork() rather than spawn() for multiple GPU training.
# However, we need to work around https://github.com/pytorch/pytorch/issues/17199 to
# make fork() and openmp work together.
from .. import backend as F
if F.get_preferred_backend() == 'pytorch':
# Wrap around torch.multiprocessing...
from torch.multiprocessing import *
# ... and override the Process initializer
from .pytorch import Process
else:
# Just import multiprocessing module.
from multiprocessing import * # pylint: disable=redefined-builtin
#### Miscellaneous functions """PyTorch multiprocessing wrapper."""
# According to https://github.com/pytorch/pytorch/issues/17199, this decorator
# is necessary to make fork() and openmp work together.
#
# TODO: confirm if this is necessary for MXNet and Tensorflow. If so, we need
# to standardize worker process creation since our operators are implemented with
# OpenMP.
import torch.multiprocessing as mp
from _thread import start_new_thread
from functools import wraps from functools import wraps
import traceback import traceback
from _thread import start_new_thread
import torch.multiprocessing as mp
def thread_wrapped_func(func): def thread_wrapped_func(func):
""" """
...@@ -23,7 +15,7 @@ def thread_wrapped_func(func): ...@@ -23,7 +15,7 @@ def thread_wrapped_func(func):
exception, trace, res = None, None, None exception, trace, res = None, None, None
try: try:
res = func(*args, **kwargs) res = func(*args, **kwargs)
except Exception as e: except Exception as e: # pylint: disable=broad-except
exception = e exception = e
trace = traceback.format_exc() trace = traceback.format_exc()
queue.put((res, exception, trace)) queue.put((res, exception, trace))
...@@ -35,4 +27,11 @@ def thread_wrapped_func(func): ...@@ -35,4 +27,11 @@ def thread_wrapped_func(func):
else: else:
assert isinstance(exception, Exception) assert isinstance(exception, Exception)
raise exception.__class__(trace) raise exception.__class__(trace)
return decorated_function return decorated_function
\ No newline at end of file
# pylint: disable=missing-docstring
class Process(mp.Process):
# pylint: disable=dangerous-default-value
def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, *, daemon=None):
target = thread_wrapped_func(target)
super().__init__(group, target, name, args, kwargs, daemon=daemon)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment