"vscode:/vscode.git/clone" did not exist on "9d35f141e8c931ea4eaa1120a2b5e740b9726ad2"
Unverified Commit e9b624fe authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

Merge branch 'master' into dist_part

parents 8086d1ed a88e7f7e
...@@ -44,6 +44,12 @@ class DeviceAPI { ...@@ -44,6 +44,12 @@ class DeviceAPI {
public: public:
/*! \brief virtual destructor */ /*! \brief virtual destructor */
virtual ~DeviceAPI() {} virtual ~DeviceAPI() {}
/*!
* \brief Check whether the device is available.
*/
virtual bool IsAvailable() {
return true;
}
/*! /*!
* \brief Set the environment device id to ctx * \brief Set the environment device id to ctx
* \param ctx The context to be set. * \param ctx The context to be set.
......
...@@ -69,8 +69,12 @@ class TensorDispatcher { ...@@ -69,8 +69,12 @@ class TensorDispatcher {
/*! /*!
* \brief Allocate an empty tensor. * \brief Allocate an empty tensor.
*
* Used in NDArray::Empty(). * Used in NDArray::Empty().
* \param shape The shape
* \param dtype The data type
* \param ctx The device
* \return An empty NDArray.
*/ */
inline NDArray Empty(std::vector<int64_t> shape, DLDataType dtype, DLContext ctx) const { inline NDArray Empty(std::vector<int64_t> shape, DLDataType dtype, DLContext ctx) const {
auto entry = entrypoints_[Op::kEmpty]; auto entry = entrypoints_[Op::kEmpty];
...@@ -78,6 +82,36 @@ class TensorDispatcher { ...@@ -78,6 +82,36 @@ class TensorDispatcher {
return NDArray::FromDLPack(result); return NDArray::FromDLPack(result);
} }
#ifdef DGL_USE_CUDA
/*!
* \brief Allocate a piece of GPU memory via
* PyTorch's THCCachingAllocator.
* Used in CUDADeviceAPI::AllocWorkspace().
*
* \note THCCachingAllocator specify the device to allocate on
* via cudaGetDevice(). Make sure to call cudaSetDevice()
* before invoking this function.
*
* \param nbytes The size to be allocated.
* \return Pointer to the allocated memory.
*/
inline void* AllocWorkspace(size_t nbytes) {
auto entry = entrypoints_[Op::kRawAlloc];
return FUNCCAST(tensoradapter::RawAlloc, entry)(nbytes);
}
/*!
* \brief Free the GPU memory.
* Used in CUDADeviceAPI::FreeWorkspace().
*
* \param ptr Pointer to the memory to be freed.
*/
inline void FreeWorkspace(void* ptr) {
auto entry = entrypoints_[Op::kRawDelete];
FUNCCAST(tensoradapter::RawDelete, entry)(ptr);
}
#endif // DGL_USE_CUDA
private: private:
/*! \brief ctor */ /*! \brief ctor */
TensorDispatcher() = default; TensorDispatcher() = default;
...@@ -91,19 +125,33 @@ class TensorDispatcher { ...@@ -91,19 +125,33 @@ class TensorDispatcher {
*/ */
static constexpr const char *names_[] = { static constexpr const char *names_[] = {
"TAempty", "TAempty",
#ifdef DGL_USE_CUDA
"RawAlloc",
"RawDelete",
#endif // DGL_USE_CUDA
}; };
/*! \brief Index of each function to the symbol list */ /*! \brief Index of each function to the symbol list */
class Op { class Op {
public: public:
static constexpr int kEmpty = 0; static constexpr int kEmpty = 0;
#ifdef DGL_USE_CUDA
static constexpr int kRawAlloc = 1;
static constexpr int kRawDelete = 2;
#endif // DGL_USE_CUDA
}; };
/*! \brief Number of functions */ /*! \brief Number of functions */
static constexpr int num_entries_ = sizeof(names_) / sizeof(names_[0]); static constexpr int num_entries_ = sizeof(names_) / sizeof(names_[0]);
/*! \brief Entrypoints of each function */ /*! \brief Entrypoints of each function */
void* entrypoints_[num_entries_] = {nullptr}; void* entrypoints_[num_entries_] = {
nullptr,
#ifdef DGL_USE_CUDA
nullptr,
nullptr,
#endif // DGL_USE_CUDA
};
bool available_ = false; bool available_ = false;
#if defined(WIN32) || defined(_WIN32) #if defined(WIN32) || defined(_WIN32)
......
...@@ -127,7 +127,6 @@ class GSpMM(th.autograd.Function): ...@@ -127,7 +127,6 @@ class GSpMM(th.autograd.Function):
@custom_bwd @custom_bwd
def backward(ctx, dZ): def backward(ctx, dZ):
gidx, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last = ctx.backward_cache gidx, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last = ctx.backward_cache
ctx.backward_cache = None
X, Y, argX, argY = ctx.saved_tensors X, Y, argX, argY = ctx.saved_tensors
if op != 'copy_rhs' and ctx.needs_input_grad[3]: if op != 'copy_rhs' and ctx.needs_input_grad[3]:
g_rev = gidx.reverse() g_rev = gidx.reverse()
...@@ -207,7 +206,6 @@ class GSpMM_hetero(th.autograd.Function): ...@@ -207,7 +206,6 @@ class GSpMM_hetero(th.autograd.Function):
@custom_bwd @custom_bwd
def backward(ctx, *dZ): def backward(ctx, *dZ):
gidx, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last, X_len = ctx.backward_cache gidx, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last, X_len = ctx.backward_cache
ctx.backward_cache = None
num_ntypes = gidx.number_of_ntypes() num_ntypes = gidx.number_of_ntypes()
feats = ctx.saved_tensors[:-(4 * num_ntypes)] feats = ctx.saved_tensors[:-(4 * num_ntypes)]
argX = ctx.saved_tensors[-(4 * num_ntypes):-(3 * num_ntypes)] argX = ctx.saved_tensors[-(4 * num_ntypes):-(3 * num_ntypes)]
...@@ -305,7 +303,6 @@ class GSDDMM(th.autograd.Function): ...@@ -305,7 +303,6 @@ class GSDDMM(th.autograd.Function):
@custom_bwd @custom_bwd
def backward(ctx, dZ): def backward(ctx, dZ):
gidx, op, lhs_target, rhs_target, X_shape, Y_shape = ctx.backward_cache gidx, op, lhs_target, rhs_target, X_shape, Y_shape = ctx.backward_cache
ctx.backward_cache = None
X, Y = ctx.saved_tensors X, Y = ctx.saved_tensors
if op != 'copy_rhs' and ctx.needs_input_grad[2]: if op != 'copy_rhs' and ctx.needs_input_grad[2]:
if lhs_target in ['u', 'v']: if lhs_target in ['u', 'v']:
...@@ -373,7 +370,6 @@ class GSDDMM_hetero(th.autograd.Function): ...@@ -373,7 +370,6 @@ class GSDDMM_hetero(th.autograd.Function):
# TODO(Israt): Implement the complete backward operator # TODO(Israt): Implement the complete backward operator
def backward(ctx, *dZ): def backward(ctx, *dZ):
gidx, op, lhs_target, rhs_target, X_shape, Y_shape, X_len = ctx.backward_cache gidx, op, lhs_target, rhs_target, X_shape, Y_shape, X_len = ctx.backward_cache
ctx.backward_cache = None
feats = ctx.saved_tensors feats = ctx.saved_tensors
X, Y = feats[:X_len], feats[X_len:] X, Y = feats[:X_len], feats[X_len:]
if op != 'copy_rhs' and any([x is not None for x in X]): if op != 'copy_rhs' and any([x is not None for x in X]):
...@@ -484,8 +480,6 @@ class EdgeSoftmax(th.autograd.Function): ...@@ -484,8 +480,6 @@ class EdgeSoftmax(th.autograd.Function):
return grad_score.data return grad_score.data
""" """
gidx = ctx.backward_cache gidx = ctx.backward_cache
# See https://github.com/dmlc/dgl/pull/3386
ctx.backward_cache = None
out, = ctx.saved_tensors out, = ctx.saved_tensors
sds = out * grad_out sds = out * grad_out
#Note: Now _edge_softmax_backward op only supports CPU #Note: Now _edge_softmax_backward op only supports CPU
...@@ -554,8 +548,6 @@ class EdgeSoftmax_hetero(th.autograd.Function): ...@@ -554,8 +548,6 @@ class EdgeSoftmax_hetero(th.autograd.Function):
return grad_score.data return grad_score.data
""" """
gidx = ctx.backward_cache gidx = ctx.backward_cache
# See https://github.com/dmlc/dgl/pull/3386
ctx.backward_cache = None
u_len = gidx.number_of_ntypes() u_len = gidx.number_of_ntypes()
e_len = gidx.number_of_etypes() e_len = gidx.number_of_etypes()
lhs = [None] * u_len lhs = [None] * u_len
...@@ -582,8 +574,6 @@ class SegmentReduce(th.autograd.Function): ...@@ -582,8 +574,6 @@ class SegmentReduce(th.autograd.Function):
@custom_bwd @custom_bwd
def backward(ctx, dy): def backward(ctx, dy):
op = ctx.backward_cache op = ctx.backward_cache
# See https://github.com/dmlc/dgl/pull/3386
ctx.backward_cache = None
arg, offsets = ctx.saved_tensors arg, offsets = ctx.saved_tensors
m = offsets[-1].item() m = offsets[-1].item()
if op == 'sum': if op == 'sum':
...@@ -630,7 +620,6 @@ class CSRMM(th.autograd.Function): ...@@ -630,7 +620,6 @@ class CSRMM(th.autograd.Function):
def backward(ctx, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights): def backward(ctx, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights):
# Only the last argument is meaningful. # Only the last argument is meaningful.
gidxA, gidxB, gidxC = ctx.backward_cache gidxA, gidxB, gidxC = ctx.backward_cache
ctx.backward_cache = None
A_weights, B_weights = ctx.saved_tensors A_weights, B_weights = ctx.saved_tensors
dgidxA, dA_weights = csrmm( dgidxA, dA_weights = csrmm(
gidxC, dC_weights, gidxB.reverse(), B_weights, gidxA.number_of_ntypes()) gidxC, dC_weights, gidxB.reverse(), B_weights, gidxA.number_of_ntypes())
...@@ -657,7 +646,6 @@ class CSRSum(th.autograd.Function): ...@@ -657,7 +646,6 @@ class CSRSum(th.autograd.Function):
def backward(ctx, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights): def backward(ctx, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights):
# Only the last argument is meaningful. # Only the last argument is meaningful.
gidxs, gidxC = ctx.backward_cache gidxs, gidxC = ctx.backward_cache
ctx.backward_cache = None
return (None,) + tuple(csrmask(gidxC, dC_weights, gidx) for gidx in gidxs) return (None,) + tuple(csrmask(gidxC, dC_weights, gidx) for gidx in gidxs)
...@@ -670,7 +658,6 @@ class CSRMask(th.autograd.Function): ...@@ -670,7 +658,6 @@ class CSRMask(th.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, dB_weights): def backward(ctx, dB_weights):
gidxA, gidxB = ctx.backward_cache gidxA, gidxB = ctx.backward_cache
ctx.backward_cache = None
return None, csrmask(gidxB, dB_weights, gidxA), None return None, csrmask(gidxB, dB_weights, gidxA), None
......
...@@ -418,8 +418,6 @@ class BinaryReduce(th.autograd.Function): ...@@ -418,8 +418,6 @@ class BinaryReduce(th.autograd.Function):
def backward(ctx, grad_out): def backward(ctx, grad_out):
reducer, binary_op, graph, lhs, rhs, lhs_map, rhs_map, out_map, \ reducer, binary_op, graph, lhs, rhs, lhs_map, rhs_map, out_map, \
feat_shape, degs = ctx.backward_cache feat_shape, degs = ctx.backward_cache
# See https://github.com/dmlc/dgl/pull/3386
ctx.backward_cache = None
lhs_data, rhs_data, out_data = ctx.saved_tensors lhs_data, rhs_data, out_data = ctx.saved_tensors
lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data) lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data)
rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data) rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data)
...@@ -497,8 +495,6 @@ class CopyReduce(th.autograd.Function): ...@@ -497,8 +495,6 @@ class CopyReduce(th.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
reducer, graph, target, in_map, out_map, degs = ctx.backward_cache reducer, graph, target, in_map, out_map, degs = ctx.backward_cache
# See https://github.com/dmlc/dgl/pull/3386
ctx.backward_cache = None
in_data, out_data = ctx.saved_tensors in_data, out_data = ctx.saved_tensors
in_data_nd = zerocopy_to_dgl_ndarray(in_data) in_data_nd = zerocopy_to_dgl_ndarray(in_data)
out_data_nd = zerocopy_to_dgl_ndarray(out_data) out_data_nd = zerocopy_to_dgl_ndarray(out_data)
......
...@@ -29,6 +29,8 @@ __all__ = [ ...@@ -29,6 +29,8 @@ __all__ = [
'from_networkx', 'from_networkx',
'bipartite_from_networkx', 'bipartite_from_networkx',
'to_networkx', 'to_networkx',
'from_cugraph',
'to_cugraph'
] ]
def graph(data, def graph(data,
...@@ -1620,6 +1622,110 @@ def to_networkx(g, node_attrs=None, edge_attrs=None): ...@@ -1620,6 +1622,110 @@ def to_networkx(g, node_attrs=None, edge_attrs=None):
DGLHeteroGraph.to_networkx = to_networkx DGLHeteroGraph.to_networkx = to_networkx
def to_cugraph(g):
"""Convert a DGL graph to a :class:`cugraph.Graph` and return.
Parameters
----------
g : DGLGraph
A homogeneous graph.
Returns
-------
cugraph.Graph
The converted cugraph graph.
Notes
-----
The function only supports GPU graph input.
Examples
--------
The following example uses PyTorch backend.
>>> import dgl
>>> import cugraph
>>> import torch
>>> g = dgl.graph((torch.tensor([1, 2]), torch.tensor([1, 3]))).to('cuda')
>>> cugraph_g = g.to_cugraph()
>>> cugraph_g.edges()
src dst
0 2 3
1 1 1
"""
if g.device.type != 'cuda':
raise DGLError(f"Cannot convert a {g.device.type} graph to cugraph." +
"Call g.to('cuda') first.")
if not g.is_homogeneous:
raise DGLError("dgl.to_cugraph only supports homogeneous graphs.")
try:
import cugraph
import cudf
except ModuleNotFoundError:
raise ModuleNotFoundError("to_cugraph requires cugraph which could not be imported")
edgelist = g.edges()
src_ser = cudf.from_dlpack(F.zerocopy_to_dlpack(edgelist[0]))
dst_ser = cudf.from_dlpack(F.zerocopy_to_dlpack(edgelist[1]))
cudf_data = cudf.DataFrame({'source':src_ser, 'destination':dst_ser})
g_cugraph = cugraph.Graph(directed=True)
g_cugraph.from_cudf_edgelist(cudf_data,
source='source',
destination='destination')
return g_cugraph
DGLHeteroGraph.to_cugraph = to_cugraph
def from_cugraph(cugraph_graph):
"""Create a graph from a :class:`cugraph.Graph` object.
Parameters
----------
cugraph_graph : cugraph.Graph
The cugraph graph object holding the graph structure. Node and edge attributes are
dropped.
If the input graph is undirected, DGL converts it to a directed graph
by :func:`cugraph.Graph.to_directed`.
Returns
-------
DGLGraph
The created graph.
Examples
--------
The following example uses PyTorch backend.
>>> import dgl
>>> import cugraph
>>> import cudf
Create a cugraph graph.
>>> cugraph_g = cugraph.Graph(directed=True)
>>> df = cudf.DataFrame({"source":[0, 1, 2, 3],
"destination":[1, 2, 3, 0]})
>>> cugraph_g.from_cudf_edgelist(df)
Convert it into a DGLGraph
>>> g = dgl.from_cugraph(cugraph_g)
>>> g.edges()
(tensor([1, 2, 3, 0], device='cuda:0'), tensor([2, 3, 0, 1], device='cuda:0'))
"""
if not cugraph_graph.is_directed():
cugraph_graph = cugraph_graph.to_directed()
edges = cugraph_graph.edges()
src_t = F.zerocopy_from_dlpack(edges['src'].to_dlpack())
dst_t = F.zerocopy_from_dlpack(edges['dst'].to_dlpack())
g = graph((src_t,dst_t))
return g
############################################################ ############################################################
# Internal APIs # Internal APIs
############################################################ ############################################################
......
...@@ -439,7 +439,7 @@ class CoraGraphDataset(CitationGraphDataset): ...@@ -439,7 +439,7 @@ class CoraGraphDataset(CitationGraphDataset):
graph structure, node features and labels. graph structure, node features and labels.
- ``ndata['train_mask']`` mask for training node set - ``ndata['train_mask']``: mask for training node set
- ``ndata['val_mask']``: mask for validation node set - ``ndata['val_mask']``: mask for validation node set
- ``ndata['test_mask']``: mask for test node set - ``ndata['test_mask']``: mask for test node set
- ``ndata['feat']``: node feature - ``ndata['feat']``: node feature
...@@ -590,7 +590,7 @@ class CiteseerGraphDataset(CitationGraphDataset): ...@@ -590,7 +590,7 @@ class CiteseerGraphDataset(CitationGraphDataset):
graph structure, node features and labels. graph structure, node features and labels.
- ``ndata['train_mask']`` mask for training node set - ``ndata['train_mask']``: mask for training node set
- ``ndata['val_mask']``: mask for validation node set - ``ndata['val_mask']``: mask for validation node set
- ``ndata['test_mask']``: mask for test node set - ``ndata['test_mask']``: mask for test node set
- ``ndata['feat']``: node feature - ``ndata['feat']``: node feature
...@@ -738,7 +738,7 @@ class PubmedGraphDataset(CitationGraphDataset): ...@@ -738,7 +738,7 @@ class PubmedGraphDataset(CitationGraphDataset):
graph structure, node features and labels. graph structure, node features and labels.
- ``ndata['train_mask']`` mask for training node set - ``ndata['train_mask']``: mask for training node set
- ``ndata['val_mask']``: mask for validation node set - ``ndata['val_mask']``: mask for validation node set
- ``ndata['test_mask']``: mask for test node set - ``ndata['test_mask']``: mask for test node set
- ``ndata['feat']``: node feature - ``ndata['feat']``: node feature
......
...@@ -8,7 +8,6 @@ import traceback ...@@ -8,7 +8,6 @@ import traceback
import abc import abc
from .utils import download, extract_archive, get_download_dir, makedirs from .utils import download, extract_archive, get_download_dir, makedirs
from ..utils import retry_method_with_fix from ..utils import retry_method_with_fix
from .._ffi.base import __version__
class DGLDataset(object): class DGLDataset(object):
r"""The basic DGL dataset for creating graph datasets. r"""The basic DGL dataset for creating graph datasets.
...@@ -238,17 +237,13 @@ class DGLDataset(object): ...@@ -238,17 +237,13 @@ class DGLDataset(object):
def save_dir(self): def save_dir(self):
r"""Directory to save the processed dataset. r"""Directory to save the processed dataset.
""" """
return self._save_dir + "_v{}".format(__version__) return self._save_dir
@property @property
def save_path(self): def save_path(self):
r"""Path to save the processed dataset. r"""Path to save the processed dataset.
""" """
if hasattr(self, '_reorder'): return os.path.join(self._save_dir, self.name)
path = 'reordered' if self._reorder else 'un_reordered'
return os.path.join(self._save_dir, self.name, path)
else:
return os.path.join(self._save_dir, self.name)
@property @property
def verbose(self): def verbose(self):
......
...@@ -50,6 +50,7 @@ class FlickrDataset(DGLBuiltinDataset): ...@@ -50,6 +50,7 @@ class FlickrDataset(DGLBuiltinDataset):
Examples Examples
-------- --------
>>> from dgl.data import FlickrDataset
>>> dataset = FlickrDataset() >>> dataset = FlickrDataset()
>>> dataset.num_classes >>> dataset.num_classes
7 7
...@@ -151,9 +152,9 @@ class FlickrDataset(DGLBuiltinDataset): ...@@ -151,9 +152,9 @@ class FlickrDataset(DGLBuiltinDataset):
- ``ndata['label']``: node label - ``ndata['label']``: node label
- ``ndata['feat']``: node feature - ``ndata['feat']``: node feature
- ``ndata['train_mask']`` mask for training node set - ``ndata['train_mask']``: mask for training node set
- ``ndata['val_mask']``: mask for validation node set - ``ndata['val_mask']``: mask for validation node set
- ``ndata['test_mask']:`` mask for test node set - ``ndata['test_mask']``: mask for test node set
""" """
assert idx == 0, "This dataset has only one graph" assert idx == 0, "This dataset has only one graph"
......
...@@ -17,7 +17,6 @@ from .graph_serialize import save_graphs, load_graphs, load_labels ...@@ -17,7 +17,6 @@ from .graph_serialize import save_graphs, load_graphs, load_labels
from .tensor_serialize import save_tensors, load_tensors from .tensor_serialize import save_tensors, load_tensors
from .. import backend as F from .. import backend as F
from .._ffi.base import __version__
__all__ = ['loadtxt','download', 'check_sha1', 'extract_archive', __all__ = ['loadtxt','download', 'check_sha1', 'extract_archive',
'get_download_dir', 'Subset', 'split_dataset', 'save_graphs', 'get_download_dir', 'Subset', 'split_dataset', 'save_graphs',
...@@ -241,7 +240,7 @@ def get_download_dir(): ...@@ -241,7 +240,7 @@ def get_download_dir():
dirname : str dirname : str
Path to the download directory Path to the download directory
""" """
default_dir = os.path.join(os.path.expanduser('~'), '.dgl_v{}'.format(__version__)) default_dir = os.path.join(os.path.expanduser('~'), '.dgl')
dirname = os.environ.get('DGL_DOWNLOAD_DIR', default_dir) dirname = os.environ.get('DGL_DOWNLOAD_DIR', default_dir)
if not os.path.exists(dirname): if not os.path.exists(dirname):
os.makedirs(dirname) os.makedirs(dirname)
......
...@@ -50,6 +50,7 @@ class WikiCSDataset(DGLBuiltinDataset): ...@@ -50,6 +50,7 @@ class WikiCSDataset(DGLBuiltinDataset):
Examples Examples
-------- --------
>>> from dgl.data import WikiCSDataset
>>> dataset = WikiCSDataset() >>> dataset = WikiCSDataset()
>>> dataset.num_classes >>> dataset.num_classes
10 10
......
...@@ -151,9 +151,9 @@ class YelpDataset(DGLBuiltinDataset): ...@@ -151,9 +151,9 @@ class YelpDataset(DGLBuiltinDataset):
- ``ndata['label']``: node label - ``ndata['label']``: node label
- ``ndata['feat']``: node feature - ``ndata['feat']``: node feature
- ``ndata['train_mask']`` mask for training node set - ``ndata['train_mask']``: mask for training node set
- ``ndata['val_mask']``: mask for validation node set - ``ndata['val_mask']``: mask for validation node set
- ``ndata['test_mask']:`` mask for test node set - ``ndata['test_mask']``: mask for test node set
""" """
assert idx == 0, "This dataset has only one graph" assert idx == 0, "This dataset has only one graph"
......
...@@ -28,7 +28,6 @@ from .base import BlockSampler, as_edge_prediction_sampler ...@@ -28,7 +28,6 @@ from .base import BlockSampler, as_edge_prediction_sampler
from .. import backend as F from .. import backend as F
from ..distributed import DistGraph from ..distributed import DistGraph
from ..multiprocessing import call_once_and_share from ..multiprocessing import call_once_and_share
from ..cuda import stream as dgl_stream
PYTORCH_VER = LooseVersion(torch.__version__) PYTORCH_VER = LooseVersion(torch.__version__)
PYTHON_EXIT_STATUS = False PYTHON_EXIT_STATUS = False
...@@ -158,7 +157,7 @@ class TensorizedDataset(torch.utils.data.IterableDataset): ...@@ -158,7 +157,7 @@ class TensorizedDataset(torch.utils.data.IterableDataset):
def __iter__(self): def __iter__(self):
indices = _divide_by_worker(self._indices, self.batch_size, self.drop_last) indices = _divide_by_worker(self._indices, self.batch_size, self.drop_last)
id_tensor = self._id_tensor[indices.to(self._device)] id_tensor = self._id_tensor[indices]
return _TensorizedDatasetIter( return _TensorizedDatasetIter(
id_tensor, self.batch_size, self.drop_last, self._mapping_keys, self._shuffle) id_tensor, self.batch_size, self.drop_last, self._mapping_keys, self._shuffle)
...@@ -224,12 +223,7 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset): ...@@ -224,12 +223,7 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset):
"""Shuffles the dataset.""" """Shuffles the dataset."""
# Only rank 0 does the actual shuffling. The other ranks wait for it. # Only rank 0 does the actual shuffling. The other ranks wait for it.
if self.rank == 0: if self.rank == 0:
if self._device == torch.device('cpu'): np.random.shuffle(self._indices[:self.num_indices].numpy())
np.random.shuffle(self._indices[:self.num_indices].numpy())
else:
self._indices[:self.num_indices] = self._indices[
torch.randperm(self.num_indices, device=self._device)]
if not self.drop_last: if not self.drop_last:
# pad extra # pad extra
self._indices[self.num_indices:] = \ self._indices[self.num_indices:] = \
...@@ -240,7 +234,7 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset): ...@@ -240,7 +234,7 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset):
start = self.num_samples * self.rank start = self.num_samples * self.rank
end = self.num_samples * (self.rank + 1) end = self.num_samples * (self.rank + 1)
indices = _divide_by_worker(self._indices[start:end], self.batch_size, self.drop_last) indices = _divide_by_worker(self._indices[start:end], self.batch_size, self.drop_last)
id_tensor = self._id_tensor[indices.to(self._device)] id_tensor = self._id_tensor[indices]
return _TensorizedDatasetIter( return _TensorizedDatasetIter(
id_tensor, self.batch_size, self.drop_last, self._mapping_keys, self._shuffle) id_tensor, self.batch_size, self.drop_last, self._mapping_keys, self._shuffle)
...@@ -305,6 +299,18 @@ def _await_or_return(x): ...@@ -305,6 +299,18 @@ def _await_or_return(x):
else: else:
return x return x
def _record_stream(x, stream):
if stream is None:
return x
if isinstance(x, torch.Tensor):
x.record_stream(stream)
return x
elif isinstance(x, _PrefetchedGraphFeatures):
node_feats = recursive_apply(x.node_feats, _record_stream, stream)
edge_feats = recursive_apply(x.edge_feats, _record_stream, stream)
return _PrefetchedGraphFeatures(node_feats, edge_feats)
else:
return x
def _prefetch(batch, dataloader, stream): def _prefetch(batch, dataloader, stream):
# feats has the same nested structure of batch, except that # feats has the same nested structure of batch, except that
...@@ -316,12 +322,21 @@ def _prefetch(batch, dataloader, stream): ...@@ -316,12 +322,21 @@ def _prefetch(batch, dataloader, stream):
# #
# Once the futures are fetched, this function waits for them to complete by # Once the futures are fetched, this function waits for them to complete by
# calling its wait() method. # calling its wait() method.
with torch.cuda.stream(stream), dgl_stream(stream): if stream is not None:
current_stream = torch.cuda.current_stream()
current_stream.wait_stream(stream)
else:
current_stream = None
with torch.cuda.stream(stream):
# fetch node/edge features # fetch node/edge features
feats = recursive_apply(batch, _prefetch_for, dataloader) feats = recursive_apply(batch, _prefetch_for, dataloader)
feats = recursive_apply(feats, _await_or_return) feats = recursive_apply(feats, _await_or_return)
# transfer input nodes/seed nodes/sampled subgraph feats = recursive_apply(feats, _record_stream, current_stream)
# transfer input nodes/seed nodes
# TODO(Xin): sampled subgraph is transferred in the default stream
# because heterograph doesn't support .record_stream() for now
batch = recursive_apply(batch, lambda x: x.to(dataloader.device, non_blocking=True)) batch = recursive_apply(batch, lambda x: x.to(dataloader.device, non_blocking=True))
batch = recursive_apply(batch, _record_stream, current_stream)
stream_event = stream.record_event() if stream is not None else None stream_event = stream.record_event() if stream is not None else None
return batch, feats, stream_event return batch, feats, stream_event
...@@ -941,7 +956,7 @@ class EdgeDataLoader(DataLoader): ...@@ -941,7 +956,7 @@ class EdgeDataLoader(DataLoader):
if use_uva: if use_uva:
device = torch.cuda.current_device() device = torch.cuda.current_device()
else: else:
device = self.graph.device device = graph.device
device = _get_device(device) device = _get_device(device)
if isinstance(graph_sampler, BlockSampler): if isinstance(graph_sampler, BlockSampler):
......
...@@ -173,28 +173,19 @@ class CustomPool: ...@@ -173,28 +173,19 @@ class CustomPool:
self.process_list[i].join() self.process_list[i].join()
def initialize(ip_config, num_servers=1, num_workers=0, def initialize(ip_config, max_queue_size=MAX_QUEUE_SIZE,
max_queue_size=MAX_QUEUE_SIZE, net_type='socket', net_type='socket', num_worker_threads=1):
num_worker_threads=1):
"""Initialize DGL's distributed module """Initialize DGL's distributed module
This function initializes DGL's distributed module. It acts differently in server This function initializes DGL's distributed module. It acts differently in server
or client modes. In the server mode, it runs the server code and never returns. or client modes. In the server mode, it runs the server code and never returns.
In the client mode, it builds connections with servers for communication and In the client mode, it builds connections with servers for communication and
creates worker processes for distributed sampling. `num_workers` specifies creates worker processes for distributed sampling.
the number of sampling worker processes per trainer process.
Users also have to provide the number of server processes on each machine in order
to connect to all the server processes in the cluster of machines correctly.
Parameters Parameters
---------- ----------
ip_config: str ip_config: str
File path of ip_config file File path of ip_config file
num_servers : int
The number of server processes on each machine. This argument is deprecated in DGL 0.7.0.
num_workers: int
Number of worker process on each machine. The worker processes are used
for distributed sampling. This argument is deprecated in DGL 0.7.0.
max_queue_size : int max_queue_size : int
Maximal size (bytes) of client queue buffer (~20 GB on default). Maximal size (bytes) of client queue buffer (~20 GB on default).
...@@ -205,7 +196,7 @@ def initialize(ip_config, num_servers=1, num_workers=0, ...@@ -205,7 +196,7 @@ def initialize(ip_config, num_servers=1, num_workers=0,
Default: ``'socket'`` Default: ``'socket'``
num_worker_threads: int num_worker_threads: int
The number of threads in a worker process. The number of OMP threads in each sampler process.
Note Note
---- ----
...@@ -240,14 +231,8 @@ def initialize(ip_config, num_servers=1, num_workers=0, ...@@ -240,14 +231,8 @@ def initialize(ip_config, num_servers=1, num_workers=0,
serv.start() serv.start()
sys.exit() sys.exit()
else: else:
if os.environ.get('DGL_NUM_SAMPLER') is not None: num_workers = int(os.environ.get('DGL_NUM_SAMPLER', 0))
num_workers = int(os.environ.get('DGL_NUM_SAMPLER')) num_servers = int(os.environ.get('DGL_NUM_SERVER', 1))
else:
num_workers = 0
if os.environ.get('DGL_NUM_SERVER') is not None:
num_servers = int(os.environ.get('DGL_NUM_SERVER'))
else:
num_servers = 1
group_id = int(os.environ.get('DGL_GROUP_ID', 0)) group_id = int(os.environ.get('DGL_GROUP_ID', 0))
rpc.reset() rpc.reset()
global SAMPLER_POOL global SAMPLER_POOL
......
...@@ -9,7 +9,7 @@ import numpy as np ...@@ -9,7 +9,7 @@ import numpy as np
from ..heterograph import DGLHeteroGraph from ..heterograph import DGLHeteroGraph
from ..convert import heterograph as dgl_heterograph from ..convert import heterograph as dgl_heterograph
from ..convert import graph as dgl_graph from ..convert import graph as dgl_graph
from ..transforms import compact_graphs from ..transforms import compact_graphs, sort_csr_by_tag, sort_csc_by_tag
from .. import heterograph_index from .. import heterograph_index
from .. import backend as F from .. import backend as F
from ..base import NID, EID, NTYPE, ETYPE, ALL, is_all from ..base import NID, EID, NTYPE, ETYPE, ALL, is_all
...@@ -336,9 +336,28 @@ class DistGraphServer(KVServer): ...@@ -336,9 +336,28 @@ class DistGraphServer(KVServer):
self.client_g, _, _, self.gpb, graph_name, \ self.client_g, _, _, self.gpb, graph_name, \
ntypes, etypes = load_partition(part_config, self.part_id, load_feats=False) ntypes, etypes = load_partition(part_config, self.part_id, load_feats=False)
print('load ' + graph_name) print('load ' + graph_name)
# formatting dtype
# TODO(Rui) Formatting forcely is not a perfect solution.
# We'd better store all dtypes when mapping to shared memory
# and map back with original dtypes.
for k, dtype in FIELD_DICT.items():
if k in self.client_g.ndata:
self.client_g.ndata[k] = F.astype(
self.client_g.ndata[k], dtype)
if k in self.client_g.edata:
self.client_g.edata[k] = F.astype(
self.client_g.edata[k], dtype)
# Create the graph formats specified the users. # Create the graph formats specified the users.
self.client_g = self.client_g.formats(graph_format) self.client_g = self.client_g.formats(graph_format)
self.client_g.create_formats_() self.client_g.create_formats_()
# Sort underlying matrix beforehand to avoid runtime overhead during sampling.
if len(etypes) > 1:
if 'csr' in graph_format:
self.client_g = sort_csr_by_tag(
self.client_g, tag=self.client_g.edata[ETYPE], tag_type='edge')
if 'csc' in graph_format:
self.client_g = sort_csc_by_tag(
self.client_g, tag=self.client_g.edata[ETYPE], tag_type='edge')
if not disable_shared_mem: if not disable_shared_mem:
self.client_g = _copy_graph_to_shared_mem(self.client_g, graph_name, graph_format) self.client_g = _copy_graph_to_shared_mem(self.client_g, graph_name, graph_format)
...@@ -1113,7 +1132,8 @@ class DistGraph: ...@@ -1113,7 +1132,8 @@ class DistGraph:
gpb = self.get_partition_book() gpb = self.get_partition_book()
if len(gpb.etypes) > 1: if len(gpb.etypes) > 1:
# if etype is a canonical edge type (str, str, str), extract the edge type # if etype is a canonical edge type (str, str, str), extract the edge type
if len(etype) == 3: if isinstance(etype, tuple):
assert len(etype) == 3, 'Invalid canonical etype: {}'.format(etype)
etype = etype[1] etype = etype[1]
edges = gpb.map_to_homo_eid(edges, etype) edges = gpb.map_to_homo_eid(edges, etype)
src, dst = dist_find_edges(self, edges) src, dst = dist_find_edges(self, edges)
...@@ -1160,7 +1180,7 @@ class DistGraph: ...@@ -1160,7 +1180,7 @@ class DistGraph:
if isinstance(edges, dict): if isinstance(edges, dict):
# TODO(zhengda) we need to directly generate subgraph of all relations with # TODO(zhengda) we need to directly generate subgraph of all relations with
# one invocation. # one invocation.
if isinstance(edges, tuple): if isinstance(list(edges.keys())[0], tuple):
subg = {etype: self.find_edges(edges[etype], etype[1]) for etype in edges} subg = {etype: self.find_edges(edges[etype], etype[1]) for etype in edges}
else: else:
subg = {} subg = {}
...@@ -1244,14 +1264,14 @@ class DistGraph: ...@@ -1244,14 +1264,14 @@ class DistGraph:
self._client.barrier() self._client.barrier()
def sample_neighbors(self, seed_nodes, fanout, edge_dir='in', prob=None, def sample_neighbors(self, seed_nodes, fanout, edge_dir='in', prob=None,
exclude_edges=None, replace=False, exclude_edges=None, replace=False, etype_sorted=True,
output_device=None): output_device=None):
# pylint: disable=unused-argument # pylint: disable=unused-argument
"""Sample neighbors from a distributed graph.""" """Sample neighbors from a distributed graph."""
# Currently prob, exclude_edges, output_device, and edge_dir are ignored. # Currently prob, exclude_edges, output_device, and edge_dir are ignored.
if len(self.etypes) > 1: if len(self.etypes) > 1:
frontier = graph_services.sample_etype_neighbors( frontier = graph_services.sample_etype_neighbors(
self, seed_nodes, ETYPE, fanout, replace=replace) self, seed_nodes, ETYPE, fanout, replace=replace, etype_sorted=etype_sorted)
else: else:
frontier = graph_services.sample_neighbors( frontier = graph_services.sample_neighbors(
self, seed_nodes, fanout, replace=replace) self, seed_nodes, fanout, replace=replace)
......
...@@ -164,21 +164,23 @@ class SamplingRequest(Request): ...@@ -164,21 +164,23 @@ class SamplingRequest(Request):
class SamplingRequestEtype(Request): class SamplingRequestEtype(Request):
"""Sampling Request""" """Sampling Request"""
def __init__(self, nodes, etype_field, fan_out, edge_dir='in', prob=None, replace=False): def __init__(self, nodes, etype_field, fan_out, edge_dir='in',
prob=None, replace=False, etype_sorted=True):
self.seed_nodes = nodes self.seed_nodes = nodes
self.edge_dir = edge_dir self.edge_dir = edge_dir
self.prob = prob self.prob = prob
self.replace = replace self.replace = replace
self.fan_out = fan_out self.fan_out = fan_out
self.etype_field = etype_field self.etype_field = etype_field
self.etype_sorted = etype_sorted
def __setstate__(self, state): def __setstate__(self, state):
self.seed_nodes, self.edge_dir, self.prob, self.replace, \ self.seed_nodes, self.edge_dir, self.prob, self.replace, \
self.fan_out, self.etype_field = state self.fan_out, self.etype_field, self.etype_sorted = state
def __getstate__(self): def __getstate__(self):
return self.seed_nodes, self.edge_dir, self.prob, self.replace, \ return self.seed_nodes, self.edge_dir, self.prob, self.replace, \
self.fan_out, self.etype_field self.fan_out, self.etype_field, self.etype_sorted
def process_request(self, server_state): def process_request(self, server_state):
local_g = server_state.graph local_g = server_state.graph
...@@ -190,7 +192,8 @@ class SamplingRequestEtype(Request): ...@@ -190,7 +192,8 @@ class SamplingRequestEtype(Request):
self.fan_out, self.fan_out,
self.edge_dir, self.edge_dir,
self.prob, self.prob,
self.replace) self.replace,
self.etype_sorted)
return SubgraphResponse(global_src, global_dst, global_eids) return SubgraphResponse(global_src, global_dst, global_eids)
class EdgesRequest(Request): class EdgesRequest(Request):
...@@ -418,7 +421,8 @@ def _frontier_to_heterogeneous_graph(g, frontier, gpb): ...@@ -418,7 +421,8 @@ def _frontier_to_heterogeneous_graph(g, frontier, gpb):
hg.edges[etype].data[EID] = edge_ids[etype] hg.edges[etype].data[EID] = edge_ids[etype]
return hg return hg
def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=None, replace=False): def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in',
prob=None, replace=False, etype_sorted=True):
"""Sample from the neighbors of the given nodes from a distributed graph. """Sample from the neighbors of the given nodes from a distributed graph.
For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges
...@@ -471,6 +475,8 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No ...@@ -471,6 +475,8 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No
For sampling without replacement, if fanout > the number of neighbors, all the For sampling without replacement, if fanout > the number of neighbors, all the
neighbors are sampled. If fanout == -1, all neighbors are collected. neighbors are sampled. If fanout == -1, all neighbors are collected.
etype_sorted : bool, optional
Indicates whether etypes are sorted.
Returns Returns
------- -------
...@@ -496,10 +502,11 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No ...@@ -496,10 +502,11 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No
nodes = F.cat(homo_nids, 0) nodes = F.cat(homo_nids, 0)
def issue_remote_req(node_ids): def issue_remote_req(node_ids):
return SamplingRequestEtype(node_ids, etype_field, fanout, edge_dir=edge_dir, return SamplingRequestEtype(node_ids, etype_field, fanout, edge_dir=edge_dir,
prob=prob, replace=replace) prob=prob, replace=replace, etype_sorted=etype_sorted)
def local_access(local_g, partition_book, local_nids): def local_access(local_g, partition_book, local_nids):
return _sample_etype_neighbors(local_g, partition_book, local_nids, return _sample_etype_neighbors(local_g, partition_book, local_nids,
etype_field, fanout, edge_dir, prob, replace) etype_field, fanout, edge_dir, prob, replace,
etype_sorted=etype_sorted)
frontier = _distributed_access(g, nodes, issue_remote_req, local_access) frontier = _distributed_access(g, nodes, issue_remote_req, local_access)
if not gpb.is_homogeneous: if not gpb.is_homogeneous:
return _frontier_to_heterogeneous_graph(g, frontier, gpb) return _frontier_to_heterogeneous_graph(g, frontier, gpb)
......
...@@ -16,11 +16,11 @@ class DistEmbedding: ...@@ -16,11 +16,11 @@ class DistEmbedding:
To support efficient training on a graph with many nodes, the embeddings support sparse To support efficient training on a graph with many nodes, the embeddings support sparse
updates. That is, only the embeddings involved in a mini-batch computation are updated. updates. That is, only the embeddings involved in a mini-batch computation are updated.
Currently, DGL provides only one optimizer: `SparseAdagrad`. DGL will provide more Please refer to `Distributed Optimizers <https://docs.dgl.ai/api/python/dgl.distributed.html#
optimizers in the future. distributed-embedding-optimizer>`__ for available optimizers in DGL.
Distributed embeddings are sharded and stored in a cluster of machines in the same way as Distributed embeddings are sharded and stored in a cluster of machines in the same way as
py:meth:`dgl.distributed.DistTensor`, except that distributed embeddings are trainable. :class:`dgl.distributed.DistTensor`, except that distributed embeddings are trainable.
Because distributed embeddings are sharded Because distributed embeddings are sharded
in the same way as nodes and edges of a distributed graph, it is usually much more in the same way as nodes and edges of a distributed graph, it is usually much more
efficient to access than the sparse embeddings provided by the deep learning frameworks. efficient to access than the sparse embeddings provided by the deep learning frameworks.
......
...@@ -255,7 +255,7 @@ class SparseAdagrad(DistSparseGradOptimizer): ...@@ -255,7 +255,7 @@ class SparseAdagrad(DistSparseGradOptimizer):
update_event.record() update_event.record()
# update emb # update emb
std_values = grad_state.add_(eps).sqrt_() std_values = grad_state.sqrt_().add_(eps)
tmp = clr * grad_values / std_values tmp = clr * grad_values / std_values
tmp_dst = tmp.to(state_dev, non_blocking=True) tmp_dst = tmp.to(state_dev, non_blocking=True)
......
...@@ -45,7 +45,7 @@ class EGATConv(nn.Module): ...@@ -45,7 +45,7 @@ class EGATConv(nn.Module):
num_heads : int num_heads : int
Number of attention heads. Number of attention heads.
bias : bool, optional bias : bool, optional
If True, add bias term to :math: `f_{ij}^{\prime}`. Defaults: ``True``. If True, add bias term to :math:`f_{ij}^{\prime}`. Defaults: ``True``.
Examples Examples
---------- ----------
...@@ -170,16 +170,16 @@ class EGATConv(nn.Module): ...@@ -170,16 +170,16 @@ class EGATConv(nn.Module):
Returns Returns
------- -------
pair of torch.Tensor pair of torch.Tensor
node output features followed by edge output features node output features followed by edge output features.
The node output feature of shape :math:`(N, H, D_{out})` The node output feature is of shape :math:`(N, H, D_{out})`
The edge output feature of shape :math:`(F, H, F_{out})` The edge output feature is of shape :math:`(F, H, F_{out})`
where: where:
:math:`H` is the number of heads, :math:`H` is the number of heads,
:math:`D_{out}` is size of output node feature, :math:`D_{out}` is size of output node feature,
:math:`F_{out}` is size of output edge feature. :math:`F_{out}` is size of output edge feature.
torch.Tensor, optional torch.Tensor, optional
The attention values of shape :math:`(E, H, 1)`. The attention values of shape :math:`(E, H, 1)`.
This is returned only when :attr: `get_attention` is ``True``. This is returned only when :attr:`get_attention` is ``True``.
""" """
with graph.local_scope(): with graph.local_scope():
......
...@@ -151,7 +151,7 @@ class GCN2Conv(nn.Module): ...@@ -151,7 +151,7 @@ class GCN2Conv(nn.Module):
nn.init.normal_(self.weight1) nn.init.normal_(self.weight1)
if not self._project_initial_features: if not self._project_initial_features:
nn.init.normal_(self.weight2) nn.init.normal_(self.weight2)
if self._bias is not None: if self._bias:
nn.init.zeros_(self.bias) nn.init.zeros_(self.bias)
def set_allow_zero_in_degree(self, set_value): def set_allow_zero_in_degree(self, set_value):
...@@ -265,8 +265,8 @@ class GCN2Conv(nn.Module): ...@@ -265,8 +265,8 @@ class GCN2Conv(nn.Module):
feat_0, feat_0, self.weight2, beta=(1 - self.beta), alpha=self.beta feat_0, feat_0, self.weight2, beta=(1 - self.beta), alpha=self.beta
) )
if self._bias is not None: if self._bias:
rst = rst + self._bias rst = rst + self.bias
if self._activation is not None: if self._activation is not None:
rst = self._activation(rst) rst = self._activation(rst)
......
...@@ -49,16 +49,20 @@ class RelGraphConv(nn.Module): ...@@ -49,16 +49,20 @@ class RelGraphConv(nn.Module):
out_feat : int out_feat : int
Output feature size; i.e., the number of dimensions of :math:`h_i^{(l+1)}`. Output feature size; i.e., the number of dimensions of :math:`h_i^{(l+1)}`.
num_rels : int num_rels : int
Number of relations. . Number of relations.
regularizer : str, optional regularizer : str, optional
Which weight regularizer to use "basis" or "bdd": Which weight regularizer to use ("basis", "bdd" or ``None``):
- "basis" is short for basis-decomposition. - "basis" is for basis-decomposition.
- "bdd" is short for block-diagonal-decomposition. - "bdd" is for block-diagonal-decomposition.
- ``None`` applies no regularization.
Default applies no regularization. Default: ``None``.
num_bases : int, optional num_bases : int, optional
Number of bases. Needed when ``regularizer`` is specified. Default: ``None``. Number of bases. It comes into effect when a regularizer is applied.
If ``None``, it uses number of relations (``num_rels``). Default: ``None``.
Note that ``in_feat`` and ``out_feat`` must be divisible by ``num_bases``
when applying "bdd" regularizer.
bias : bool, optional bias : bool, optional
True if bias is added. Default: ``True``. True if bias is added. Default: ``True``.
activation : callable, optional activation : callable, optional
...@@ -67,8 +71,8 @@ class RelGraphConv(nn.Module): ...@@ -67,8 +71,8 @@ class RelGraphConv(nn.Module):
True to include self loop message. Default: ``True``. True to include self loop message. Default: ``True``.
dropout : float, optional dropout : float, optional
Dropout rate. Default: ``0.0`` Dropout rate. Default: ``0.0``
layer_norm: float, optional layer_norm: bool, optional
Add layer norm. Default: ``False`` True to add layer norm. Default: ``False``
Examples Examples
-------- --------
...@@ -102,6 +106,8 @@ class RelGraphConv(nn.Module): ...@@ -102,6 +106,8 @@ class RelGraphConv(nn.Module):
dropout=0.0, dropout=0.0,
layer_norm=False): layer_norm=False):
super().__init__() super().__init__()
if regularizer is not None and num_bases is None:
num_bases = num_rels
self.linear_r = TypedLinear(in_feat, out_feat, num_rels, regularizer, num_bases) self.linear_r = TypedLinear(in_feat, out_feat, num_rels, regularizer, num_bases)
self.bias = bias self.bias = bias
self.activation = activation self.activation = activation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment