Unverified Commit caa6d607 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Feature] Replacing thread_wrapped_func with minimal mp.Process wrapper (#2905)

* standardizing thread_wrapped_func

* lints

* Update __init__.py
parent a90296aa
...@@ -15,39 +15,6 @@ from ogb.nodeproppred import DglNodePropPredDataset ...@@ -15,39 +15,6 @@ from ogb.nodeproppred import DglNodePropPredDataset
from functools import partial, reduce, wraps from functools import partial, reduce, wraps
import torch.multiprocessing as mp
from _thread import start_new_thread
import traceback
def thread_wrapped_func(func):
"""
Wraps a process entry point to make it work with OpenMP.
"""
@wraps(func)
def decorated_function(*args, **kwargs):
queue = mp.Queue()
def _queue_result():
exception, trace, res = None, None, None
try:
res = func(*args, **kwargs)
except Exception as e:
exception = e
trace = traceback.format_exc()
queue.put((res, exception, trace))
start_new_thread(_queue_result, ())
result, exception, trace = queue.get()
if exception is None:
return result
else:
assert isinstance(exception, Exception)
raise exception.__class__(trace)
return decorated_function
def _download(url, path, filename): def _download(url, path, filename):
fn = os.path.join(path, filename) fn = os.path.join(path, filename)
...@@ -520,7 +487,7 @@ def benchmark(track_type, timeout=60): ...@@ -520,7 +487,7 @@ def benchmark(track_type, timeout=60):
if not filter.check(func): if not filter.check(func):
# skip if not enabled # skip if not enabled
func.benchmark_name = "skip_" + func.__name__ func.benchmark_name = "skip_" + func.__name__
return thread_wrapped_func(func) return func
return _wrapper return _wrapper
##################################### #####################################
......
.. _apimultiprocessing:
dgl.multiprocessing
===================
This is a minimal wrapper of Python's native :mod:`multiprocessing` module.
It modifies the :class:`multiprocessing.Process` class to make forking
work with OpenMP in the DGL core library.
The API usage is exactly the same as the native module, so DGL does not provide
additional documentation.
In addition, if your backend is PyTorch, this module will also be compatible with
:mod:`torch.multiprocessing` module.
...@@ -44,6 +44,7 @@ Welcome to Deep Graph Library Tutorials and Documentation ...@@ -44,6 +44,7 @@ Welcome to Deep Graph Library Tutorials and Documentation
api/python/dgl.ops api/python/dgl.ops
api/python/dgl.optim api/python/dgl.optim
api/python/dgl.sampling api/python/dgl.sampling
api/python/dgl.multiprocessing
api/python/udf api/python/udf
.. toctree:: .. toctree::
......
...@@ -9,13 +9,13 @@ import numpy as np ...@@ -9,13 +9,13 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
from tqdm.auto import tqdm from tqdm.auto import tqdm
from numpy import random from numpy import random
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
import dgl import dgl
import dgl.function as fn import dgl.function as fn
import dgl.multiprocessing as mp
from utils import * from utils import *
...@@ -481,10 +481,7 @@ def train_model(network_data): ...@@ -481,10 +481,7 @@ def train_model(network_data):
else: else:
procs = [] procs = []
for proc_id in range(n_gpus): for proc_id in range(n_gpus):
p = mp.Process( p = mp.Process(target=run, args=(proc_id, n_gpus, args, devices, data))
target=thread_wrapped_func(run),
args=(proc_id, n_gpus, args, devices, data),
)
p.start() p.start()
procs.append(p) procs.append(p)
for p in procs: for p in procs:
......
...@@ -12,39 +12,6 @@ import time ...@@ -12,39 +12,6 @@ import time
import multiprocessing import multiprocessing
from functools import partial, reduce, wraps from functools import partial, reduce, wraps
import torch.multiprocessing as mp
from _thread import start_new_thread
import traceback
def thread_wrapped_func(func):
"""
Wraps a process entry point to make it work with OpenMP.
"""
@wraps(func)
def decorated_function(*args, **kwargs):
queue = mp.Queue()
def _queue_result():
exception, trace, res = None, None, None
try:
res = func(*args, **kwargs)
except Exception as e:
exception = e
trace = traceback.format_exc()
queue.put((res, exception, trace))
start_new_thread(_queue_result, ())
result, exception, trace = queue.get()
if exception is None:
return result
else:
assert isinstance(exception, Exception)
raise exception.__class__(trace)
return decorated_function
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
...@@ -14,16 +14,14 @@ import numpy as np ...@@ -14,16 +14,14 @@ import numpy as np
import tqdm import tqdm
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.multiprocessing as mp
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.multiprocessing import Queue
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
from _thread import start_new_thread
from functools import wraps
from data import MovieLens from data import MovieLens
from model import GCMCLayer, DenseBiDecoder, BiDecoder from model import GCMCLayer, DenseBiDecoder, BiDecoder
from utils import get_activation, get_optimizer, torch_total_param_num, torch_net_info, MetricLogger, to_etype_name from utils import get_activation, get_optimizer, torch_total_param_num, torch_net_info, MetricLogger, to_etype_name
import dgl import dgl
import dgl.multiprocessing as mp
from dgl.multiprocessing import Queue
class Net(nn.Module): class Net(nn.Module):
def __init__(self, args, dev_id): def __init__(self, args, dev_id):
...@@ -136,33 +134,6 @@ def evaluate(args, dev_id, net, dataset, dataloader, segment='valid'): ...@@ -136,33 +134,6 @@ def evaluate(args, dev_id, net, dataset, dataloader, segment='valid'):
rmse = np.sqrt(rmse) rmse = np.sqrt(rmse)
return rmse return rmse
# According to https://github.com/pytorch/pytorch/issues/17199, this decorator
# is necessary to make fork() and openmp work together.
def thread_wrapped_func(func):
"""
Wraps a process entry point to make it work with OpenMP.
"""
@wraps(func)
def decorated_function(*args, **kwargs):
queue = Queue()
def _queue_result():
exception, trace, res = None, None, None
try:
res = func(*args, **kwargs)
except Exception as e:
exception = e
trace = traceback.format_exc()
queue.put((res, exception, trace))
start_new_thread(_queue_result, ())
result, exception, trace = queue.get()
if exception is None:
return result
else:
assert isinstance(exception, Exception)
raise exception.__class__(trace)
return decorated_function
def config(): def config():
parser = argparse.ArgumentParser(description='GCMC') parser = argparse.ArgumentParser(description='GCMC')
parser.add_argument('--seed', default=123, type=int) parser.add_argument('--seed', default=123, type=int)
...@@ -409,7 +380,7 @@ if __name__ == '__main__': ...@@ -409,7 +380,7 @@ if __name__ == '__main__':
dataset.train_dec_graph.create_formats_() dataset.train_dec_graph.create_formats_()
procs = [] procs = []
for proc_id in range(n_gpus): for proc_id in range(n_gpus):
p = mp.Process(target=thread_wrapped_func(run), args=(proc_id, n_gpus, args, devices, dataset)) p = mp.Process(target=run, args=(proc_id, n_gpus, args, devices, dataset))
p.start() p.start()
procs.append(p) procs.append(p)
for p in procs: for p in procs:
......
...@@ -4,14 +4,12 @@ import torch as th ...@@ -4,14 +4,12 @@ import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import torch.multiprocessing as mp import dgl.multiprocessing as mp
import dgl.function as fn import dgl.function as fn
import dgl.nn.pytorch as dglnn import dgl.nn.pytorch as dglnn
import time import time
import argparse import argparse
import tqdm import tqdm
from _thread import start_new_thread
from functools import wraps
from dgl.data import RedditDataset from dgl.data import RedditDataset
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
......
...@@ -4,7 +4,7 @@ import torch as th ...@@ -4,7 +4,7 @@ import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import torch.multiprocessing as mp import dgl.multiprocessing as mp
import dgl.function as fn import dgl.function as fn
import dgl.nn.pytorch as dglnn import dgl.nn.pytorch as dglnn
import time import time
...@@ -12,8 +12,6 @@ import argparse ...@@ -12,8 +12,6 @@ import argparse
import tqdm import tqdm
import traceback import traceback
import math import math
from _thread import start_new_thread
from functools import wraps
from dgl.data import RedditDataset from dgl.data import RedditDataset
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
...@@ -148,34 +146,6 @@ class NeighborSampler(object): ...@@ -148,34 +146,6 @@ class NeighborSampler(object):
hist_blocks.insert(0, hist_block) hist_blocks.insert(0, hist_block)
return blocks, hist_blocks return blocks, hist_blocks
# According to https://github.com/pytorch/pytorch/issues/17199, this decorator
# is necessary to make fork() and openmp work together.
#
# TODO: confirm if this is necessary for MXNet and Tensorflow. If so, we need
# to standardize worker process creation since our operators are implemented with
# OpenMP.
def thread_wrapped_func(func):
@wraps(func)
def decorated_function(*args, **kwargs):
queue = mp.Queue()
def _queue_result():
exception, trace, res = None, None, None
try:
res = func(*args, **kwargs)
except Exception as e:
exception = e
trace = traceback.format_exc()
queue.put((res, exception, trace))
start_new_thread(_queue_result, ())
result, exception, trace = queue.get()
if exception is None:
return result
else:
assert isinstance(exception, Exception)
raise exception.__class__(trace)
return decorated_function
def compute_acc(pred, labels): def compute_acc(pred, labels):
""" """
Compute the accuracy of prediction given the labels. Compute the accuracy of prediction given the labels.
...@@ -245,7 +215,6 @@ def update_history(g, blocks): ...@@ -245,7 +215,6 @@ def update_history(g, blocks):
h_new = block.dstdata['h_new'].cpu() h_new = block.dstdata['h_new'].cpu()
g.ndata[hist_col][ids] = h_new g.ndata[hist_col][ids] = h_new
@thread_wrapped_func
def run(proc_id, n_gpus, args, devices, data): def run(proc_id, n_gpus, args, devices, data):
dropout = 0.2 dropout = 0.2
......
...@@ -4,7 +4,7 @@ import torch as th ...@@ -4,7 +4,7 @@ import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import torch.multiprocessing as mp import dgl.multiprocessing as mp
import dgl.nn.pytorch as dglnn import dgl.nn.pytorch as dglnn
import time import time
import math import math
...@@ -13,7 +13,6 @@ from torch.nn.parallel import DistributedDataParallel ...@@ -13,7 +13,6 @@ from torch.nn.parallel import DistributedDataParallel
import tqdm import tqdm
from model import SAGE from model import SAGE
from utils import thread_wrapped_func
from load_graph import load_reddit, inductive_split from load_graph import load_reddit, inductive_split
def compute_acc(pred, labels): def compute_acc(pred, labels):
...@@ -217,8 +216,7 @@ if __name__ == '__main__': ...@@ -217,8 +216,7 @@ if __name__ == '__main__':
else: else:
procs = [] procs = []
for proc_id in range(n_gpus): for proc_id in range(n_gpus):
p = mp.Process(target=thread_wrapped_func(run), p = mp.Process(target=run, args=(proc_id, n_gpus, args, devices, data))
args=(proc_id, n_gpus, args, devices, data))
p.start() p.start()
procs.append(p) procs.append(p)
for p in procs: for p in procs:
......
...@@ -4,7 +4,7 @@ import torch as th ...@@ -4,7 +4,7 @@ import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import torch.multiprocessing as mp import dgl.multiprocessing as mp
import dgl.function as fn import dgl.function as fn
import dgl.nn.pytorch as dglnn import dgl.nn.pytorch as dglnn
import time import time
...@@ -13,7 +13,6 @@ from dgl.data import RedditDataset ...@@ -13,7 +13,6 @@ from dgl.data import RedditDataset
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
import tqdm import tqdm
from utils import thread_wrapped_func
from model import SAGE, compute_acc_unsupervised as compute_acc from model import SAGE, compute_acc_unsupervised as compute_acc
from negative_sampler import NegativeSampler from negative_sampler import NegativeSampler
...@@ -191,8 +190,7 @@ def main(args, devices): ...@@ -191,8 +190,7 @@ def main(args, devices):
else: else:
procs = [] procs = []
for proc_id in range(n_gpus): for proc_id in range(n_gpus):
p = mp.Process(target=thread_wrapped_func(run), p = mp.Process(target=run, args=(proc_id, n_gpus, args, devices, data))
args=(proc_id, n_gpus, args, devices, data))
p.start() p.start()
procs.append(p) procs.append(p)
for p in procs: for p in procs:
......
...@@ -10,8 +10,6 @@ import dgl.function as fn ...@@ -10,8 +10,6 @@ import dgl.function as fn
import dgl.nn.pytorch as dglnn import dgl.nn.pytorch as dglnn
import time import time
import argparse import argparse
from _thread import start_new_thread
from functools import wraps
from dgl.data import RedditDataset from dgl.data import RedditDataset
import tqdm import tqdm
import traceback import traceback
......
import torch import torch
import argparse import argparse
import dgl import dgl
import torch.multiprocessing as mp import dgl.multiprocessing as mp
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import os import os
import random import random
...@@ -10,7 +10,7 @@ import numpy as np ...@@ -10,7 +10,7 @@ import numpy as np
from reading_data import DeepwalkDataset from reading_data import DeepwalkDataset
from model import SkipGramModel from model import SkipGramModel
from utils import thread_wrapped_func, shuffle_walks, sum_up_params from utils import shuffle_walks, sum_up_params
class DeepwalkTrainer: class DeepwalkTrainer:
def __init__(self, args): def __init__(self, args):
...@@ -110,7 +110,6 @@ class DeepwalkTrainer: ...@@ -110,7 +110,6 @@ class DeepwalkTrainer:
else: else:
self.emb_model.save_embedding(self.dataset, self.args.output_emb_file) self.emb_model.save_embedding(self.dataset, self.args.output_emb_file)
@thread_wrapped_func
def fast_train_sp(self, rank, gpu_id): def fast_train_sp(self, rank, gpu_id):
""" a subprocess for fast_train_mp """ """ a subprocess for fast_train_mp """
if self.args.mix: if self.args.mix:
......
...@@ -4,10 +4,9 @@ import torch.nn.functional as F ...@@ -4,10 +4,9 @@ import torch.nn.functional as F
from torch.nn import init from torch.nn import init
import random import random
import numpy as np import numpy as np
import torch.multiprocessing as mp import dgl.multiprocessing as mp
from torch.multiprocessing import Queue from dgl.multiprocessing import Queue
from utils import thread_wrapped_func
def init_emb2pos_index(walk_length, window_size, batch_size): def init_emb2pos_index(walk_length, window_size, batch_size):
''' select embedding of positive nodes from a batch of node embeddings ''' select embedding of positive nodes from a batch of node embeddings
...@@ -110,7 +109,6 @@ def adam(grad, state_sum, nodes, lr, device, only_gpu): ...@@ -110,7 +109,6 @@ def adam(grad, state_sum, nodes, lr, device, only_gpu):
return grad return grad
@thread_wrapped_func
def async_update(num_threads, model, queue): def async_update(num_threads, model, queue):
""" asynchronous embedding update """ """ asynchronous embedding update """
torch.set_num_threads(num_threads) torch.set_num_threads(num_threads)
......
import torch import torch
from functools import wraps
from _thread import start_new_thread
import torch.multiprocessing as mp
def thread_wrapped_func(func):
"""Wrapped func for torch.multiprocessing.Process.
With this wrapper we can use OMP threads in subprocesses
otherwise, OMP_NUM_THREADS=1 is mandatory.
How to use:
@thread_wrapped_func
def func_to_wrap(args ...):
"""
@wraps(func)
def decorated_function(*args, **kwargs):
queue = mp.Queue()
def _queue_result():
exception, trace, res = None, None, None
try:
res = func(*args, **kwargs)
except Exception as e:
exception = e
trace = traceback.format_exc()
queue.put((res, exception, trace))
start_new_thread(_queue_result, ())
result, exception, trace = queue.get()
if exception is None:
return result
else:
assert isinstance(exception, Exception)
raise exception.__class__(trace)
return decorated_function
def shuffle_walks(walks): def shuffle_walks(walks):
seeds = torch.randperm(walks.size()[0]) seeds = torch.randperm(walks.size()[0])
......
import torch import torch
import argparse import argparse
import dgl import dgl
import torch.multiprocessing as mp import dgl.multiprocessing as mp
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import os import os
import random import random
...@@ -10,7 +10,7 @@ import numpy as np ...@@ -10,7 +10,7 @@ import numpy as np
from reading_data import LineDataset from reading_data import LineDataset
from model import SkipGramModel from model import SkipGramModel
from utils import thread_wrapped_func, sum_up_params, check_args from utils import sum_up_params, check_args
class LineTrainer: class LineTrainer:
def __init__(self, args): def __init__(self, args):
...@@ -102,7 +102,6 @@ class LineTrainer: ...@@ -102,7 +102,6 @@ class LineTrainer:
else: else:
self.emb_model.save_embedding(self.dataset, self.args.output_emb_file) self.emb_model.save_embedding(self.dataset, self.args.output_emb_file)
@thread_wrapped_func
def fast_train_sp(self, rank, gpu_id): def fast_train_sp(self, rank, gpu_id):
""" a subprocess for fast_train_mp """ """ a subprocess for fast_train_mp """
if self.args.mix: if self.args.mix:
......
...@@ -4,10 +4,8 @@ import torch.nn.functional as F ...@@ -4,10 +4,8 @@ import torch.nn.functional as F
from torch.nn import init from torch.nn import init
import random import random
import numpy as np import numpy as np
import torch.multiprocessing as mp import dgl.multiprocessing as mp
from torch.multiprocessing import Queue from dgl.multiprocessing import Queue
from utils import thread_wrapped_func
def init_emb2neg_index(negative, batch_size): def init_emb2neg_index(negative, batch_size):
'''select embedding of negative nodes from a batch of node embeddings '''select embedding of negative nodes from a batch of node embeddings
...@@ -44,7 +42,6 @@ def adam(grad, state_sum, nodes, lr, device, only_gpu): ...@@ -44,7 +42,6 @@ def adam(grad, state_sum, nodes, lr, device, only_gpu):
return grad return grad
@thread_wrapped_func
def async_update(num_threads, model, queue): def async_update(num_threads, model, queue):
""" Asynchronous embedding update for entity embeddings. """ Asynchronous embedding update for entity embeddings.
""" """
......
import torch import torch
from functools import wraps
from _thread import start_new_thread
import torch.multiprocessing as mp
def thread_wrapped_func(func):
"""Wrapped func for torch.multiprocessing.Process.
With this wrapper we can use OMP threads in subprocesses
otherwise, OMP_NUM_THREADS=1 is mandatory.
How to use:
@thread_wrapped_func
def func_to_wrap(args ...):
"""
@wraps(func)
def decorated_function(*args, **kwargs):
queue = mp.Queue()
def _queue_result():
exception, trace, res = None, None, None
try:
res = func(*args, **kwargs)
except Exception as e:
exception = e
trace = traceback.format_exc()
queue.put((res, exception, trace))
start_new_thread(_queue_result, ())
result, exception, trace = queue.get()
if exception is None:
return result
else:
assert isinstance(exception, Exception)
raise exception.__class__(trace)
return decorated_function
def check_args(args): def check_args(args):
flag = sum([args.only_1st, args.only_2nd]) flag = sum([args.only_1st, args.only_2nd])
......
...@@ -11,11 +11,10 @@ import dgl.nn as dglnn ...@@ -11,11 +11,10 @@ import dgl.nn as dglnn
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import argparse import argparse
import torch.multiprocessing as mp import dgl.multiprocessing as mp
import sys import sys
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
from collections import OrderedDict from collections import OrderedDict
from utils import thread_wrapped_func
class RGAT(nn.Module): class RGAT(nn.Module):
...@@ -285,7 +284,7 @@ if __name__ == '__main__': ...@@ -285,7 +284,7 @@ if __name__ == '__main__':
procs = [] procs = []
for proc_id in range(n_gpus): for proc_id in range(n_gpus):
p = mp.Process(target=thread_wrapped_func(train), args=(proc_id, n_gpus, args, dataset, g, feats, paper_offset)) p = mp.Process(target=train, args=(proc_id, n_gpus, args, dataset, g, feats, paper_offset))
p.start() p.start()
procs.append(p) procs.append(p)
......
#### Miscellaneous functions
# According to https://github.com/pytorch/pytorch/issues/17199, this decorator
# is necessary to make fork() and openmp work together.
#
# TODO: confirm if this is necessary for MXNet and Tensorflow. If so, we need
# to standardize worker process creation since our operators are implemented with
# OpenMP.
import torch.multiprocessing as mp
from _thread import start_new_thread
from functools import wraps
import traceback
def thread_wrapped_func(func):
"""
Wraps a process entry point to make it work with OpenMP.
"""
@wraps(func)
def decorated_function(*args, **kwargs):
queue = mp.Queue()
def _queue_result():
exception, trace, res = None, None, None
try:
res = func(*args, **kwargs)
except Exception as e:
exception = e
trace = traceback.format_exc()
queue.put((res, exception, trace))
start_new_thread(_queue_result, ())
result, exception, trace = queue.get()
if exception is None:
return result
else:
assert isinstance(exception, Exception)
raise exception.__class__(trace)
return decorated_function
\ No newline at end of file
...@@ -12,8 +12,8 @@ import time ...@@ -12,8 +12,8 @@ import time
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.multiprocessing as mp import dgl.multiprocessing as mp
from torch.multiprocessing import Queue from dgl.multiprocessing import Queue
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import dgl import dgl
...@@ -23,7 +23,6 @@ from functools import partial ...@@ -23,7 +23,6 @@ from functools import partial
from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
from model import RelGraphEmbedLayer from model import RelGraphEmbedLayer
from dgl.nn import RelGraphConv from dgl.nn import RelGraphConv
from utils import thread_wrapped_func
import tqdm import tqdm
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
...@@ -195,7 +194,6 @@ def evaluate(model, embed_layer, eval_loader, node_feats): ...@@ -195,7 +194,6 @@ def evaluate(model, embed_layer, eval_loader, node_feats):
return eval_logits, eval_seeds return eval_logits, eval_seeds
@thread_wrapped_func
def run(proc_id, n_gpus, n_cpus, args, devices, dataset, split, queue=None): def run(proc_id, n_gpus, n_cpus, args, devices, dataset, split, queue=None):
dev_id = devices[proc_id] if devices[proc_id] != 'cpu' else -1 dev_id = devices[proc_id] if devices[proc_id] != 'cpu' else -1
g, node_feats, num_of_ntype, num_classes, num_rels, target_idx, \ g, node_feats, num_of_ntype, num_classes, num_rels, target_idx, \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment