"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "352ca3198cb25e6098f795568547075ff28e3133"
Unverified Commit 62e23bd5 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bugs] Fix distributed example error and import error (#3783)

* fix

* raise an error

* fix docserver crash
parent 7ab10348
...@@ -356,17 +356,13 @@ IdArray VecToIdArray(const std::vector<T>& vec, ...@@ -356,17 +356,13 @@ IdArray VecToIdArray(const std::vector<T>& vec,
} }
/*! /*!
* \brief Get the context of the first non-null array, and check if the non-null arrays' * \brief Get the context of the first array, and check if the non-null arrays'
* contexts are the same. * contexts are the same.
*
* Throws an error if all the arrays are null arrays.
*/ */
inline DLContext GetContextOf(const std::vector<IdArray>& arrays) { inline DLContext GetContextOf(const std::vector<IdArray>& arrays) {
bool first = true; bool first = true;
DLContext result; DLContext result;
for (auto& array : arrays) { for (auto& array : arrays) {
if (IsNullArray(array))
continue;
if (first) { if (first) {
first = false; first = false;
result = array->ctx; result = array->ctx;
...@@ -374,7 +370,6 @@ inline DLContext GetContextOf(const std::vector<IdArray>& arrays) { ...@@ -374,7 +370,6 @@ inline DLContext GetContextOf(const std::vector<IdArray>& arrays) {
CHECK_EQ(array->ctx, result) << "Context of the input arrays are different"; CHECK_EQ(array->ctx, result) << "Context of the input arrays are different";
} }
} }
CHECK(!first) << "All input arrays are empty.";
return result; return result;
} }
......
...@@ -49,7 +49,7 @@ from .random import * ...@@ -49,7 +49,7 @@ from .random import *
from .data.utils import save_graphs, load_graphs from .data.utils import save_graphs, load_graphs
from . import optim from . import optim
from .frame import LazyFeature from .frame import LazyFeature
from .utils import recursive_apply from .utils import apply_each
from ._deprecate.graph import DGLGraph as DGLGraphStale from ._deprecate.graph import DGLGraph as DGLGraphStale
from ._deprecate.nodeflow import * from ._deprecate.nodeflow import *
...@@ -309,12 +309,17 @@ def _sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False, ...@@ -309,12 +309,17 @@ def _sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False,
nodes = {g.ntypes[0] : nodes} nodes = {g.ntypes[0] : nodes}
nodes = utils.prepare_tensor_dict(g, nodes, 'nodes') nodes = utils.prepare_tensor_dict(g, nodes, 'nodes')
if len(nodes) == 0:
raise ValueError(
"Got an empty dictionary in the nodes argument. "
"Please pass in a dictionary with empty tensors as values instead.")
ctx = utils.to_dgl_context(F.context(next(iter(nodes.values()))))
nodes_all_types = [] nodes_all_types = []
for ntype in g.ntypes: for ntype in g.ntypes:
if ntype in nodes: if ntype in nodes:
nodes_all_types.append(F.to_dgl_nd(nodes[ntype])) nodes_all_types.append(F.to_dgl_nd(nodes[ntype]))
else: else:
nodes_all_types.append(nd.array([], ctx=nd.cpu())) nodes_all_types.append(nd.array([], ctx=ctx))
if isinstance(fanout, nd.NDArray): if isinstance(fanout, nd.NDArray):
fanout_array = fanout fanout_array = fanout
...@@ -354,7 +359,7 @@ def _sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False, ...@@ -354,7 +359,7 @@ def _sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False,
if etype in exclude_edges: if etype in exclude_edges:
excluded_edges_all_t.append(F.to_dgl_nd(exclude_edges[etype])) excluded_edges_all_t.append(F.to_dgl_nd(exclude_edges[etype]))
else: else:
excluded_edges_all_t.append(nd.array([], ctx=nd.cpu())) excluded_edges_all_t.append(nd.array([], ctx=ctx))
subgidx = _CAPI_DGLSampleNeighbors(g._graph, nodes_all_types, fanout_array, subgidx = _CAPI_DGLSampleNeighbors(g._graph, nodes_all_types, fanout_array,
edge_dir, prob_arrays, excluded_edges_all_t, replace) edge_dir, prob_arrays, excluded_edges_all_t, replace)
......
...@@ -227,17 +227,18 @@ def main(rank, world_size, dataset, seed=0): ...@@ -227,17 +227,18 @@ def main(rank, world_size, dataset, seed=0):
############################################################################### ###############################################################################
# Finally we load the dataset and launch the processes. # Finally we load the dataset and launch the processes.
# #
# .. code:: python
if __name__ == '__main__': #
import torch.multiprocessing as mp # if __name__ == '__main__':
# import torch.multiprocessing as mp
from dgl.data import GINDataset #
# from dgl.data import GINDataset
num_gpus = 4 #
procs = [] # num_gpus = 4
dataset = GINDataset(name='IMDBBINARY', self_loop=False) # procs = []
mp.spawn(main, args=(num_gpus, dataset), nprocs=num_gpus) # dataset = GINDataset(name='IMDBBINARY', self_loop=False)
# mp.spawn(main, args=(num_gpus, dataset), nprocs=num_gpus)
# Thumbnail credits: DGL # Thumbnail credits: DGL
# sphinx_gallery_thumbnail_path = '_static/blitz_5_graph_classification.png' # sphinx_gallery_thumbnail_path = '_static/blitz_5_graph_classification.png'
...@@ -245,12 +245,13 @@ graph.create_formats_() ...@@ -245,12 +245,13 @@ graph.create_formats_()
# Python’s built-in ``multiprocessing`` except that it handles the # Python’s built-in ``multiprocessing`` except that it handles the
# subtleties between forking and multithreading in Python. # subtleties between forking and multithreading in Python.
# #
# .. code:: python
# Say you have four GPUs. #
if __name__ == '__main__': # # Say you have four GPUs.
num_gpus = 4 # if __name__ == '__main__':
import torch.multiprocessing as mp # num_gpus = 4
mp.spawn(run, args=(list(range(num_gpus)),), nprocs=num_gpus) # import torch.multiprocessing as mp
# mp.spawn(run, args=(list(range(num_gpus)),), nprocs=num_gpus)
# Thumbnail credits: Stanford CS224W Notes # Thumbnail credits: Stanford CS224W Notes
# sphinx_gallery_thumbnail_path = '_static/blitz_1_introduction.png' # sphinx_gallery_thumbnail_path = '_static/blitz_1_introduction.png'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment