Unverified Commit 3f138eba authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bugfix] Bug fixes in new dataloader (#3727)



* fixes

* fix

* more fixes

* update

* oops

* lint?

* temporarily revert - will fix in another PR

* more fixes

* skipping mxnet test

* address comments

* fix DDP

* fix edge dataloader exclusion problems

* stupid bug

* fix

* use_uvm option

* fix

* fixes

* fixes

* fixes

* fixes

* add evaluation for cluster gcn and ddp

* stupid bug again

* fixes

* move sanity checks to only support DGLGraphs

* pytorch lightning compatibility fixes

* remove

* poke

* more fixes

* fix

* fix

* disable test

* docstrings

* why is it getting a memory leak?

* fix

* update

* updates and temporarily disable forkingpickler

* update

* fix?

* fix?

* oops

* oops

* fix

* lint

* huh

* uh

* update

* fix

* made it memory efficient

* refine exclude interface

* fix tutorial

* fix tutorial

* fix graph duplication in CPU dataloader workers

* lint

* lint

* Revert "lint"

This reverts commit 805484dd553695111b5fb37f2125214a6b7276e9.

* Revert "lint"

This reverts commit 0bce411b2b415c2ab770343949404498436dc8b2.

* Revert "fix graph duplication in CPU dataloader workers"

This reverts commit 9e3a8cf34c175d3093c773f6bb023b155f2bd27f.
Co-authored-by: default avatarxiny <xiny@nvidia.com>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 7b9afbfa
This diff is collapsed.
...@@ -56,6 +56,12 @@ class LazyFeature(object): ...@@ -56,6 +56,12 @@ class LazyFeature(object):
"""No-op. For compatibility of :meth:`Frame.__repr__` method.""" """No-op. For compatibility of :meth:`Frame.__repr__` method."""
return self return self
def pin_memory_(self):
"""No-op. For compatibility of :meth:`Frame.pin_memory_` method."""
def unpin_memory_(self):
"""No-op. For compatibility of :meth:`Frame.unpin_memory_` method."""
class Scheme(namedtuple('Scheme', ['shape', 'dtype'])): class Scheme(namedtuple('Scheme', ['shape', 'dtype'])):
"""The column scheme. """The column scheme.
...@@ -142,6 +148,7 @@ class Column(TensorStorage): ...@@ -142,6 +148,7 @@ class Column(TensorStorage):
self.scheme = scheme if scheme else infer_scheme(storage) self.scheme = scheme if scheme else infer_scheme(storage)
self.index = index self.index = index
self.device = device self.device = device
self.pinned = False
def __len__(self): def __len__(self):
"""The number of features (number of rows) in this column.""" """The number of features (number of rows) in this column."""
...@@ -183,6 +190,7 @@ class Column(TensorStorage): ...@@ -183,6 +190,7 @@ class Column(TensorStorage):
"""Update the column data.""" """Update the column data."""
self.index = None self.index = None
self.storage = val self.storage = val
self.pinned = False
def to(self, device, **kwargs): # pylint: disable=invalid-name def to(self, device, **kwargs): # pylint: disable=invalid-name
""" Return a new column with columns copy to the targeted device (cpu/gpu). """ Return a new column with columns copy to the targeted device (cpu/gpu).
...@@ -330,6 +338,10 @@ class Column(TensorStorage): ...@@ -330,6 +338,10 @@ class Column(TensorStorage):
def __copy__(self): def __copy__(self):
return self.clone() return self.clone()
def fetch(self, indices, device, pin_memory=False):
_ = self.data # materialize in case of lazy slicing & data transfer
return super().fetch(indices, device, pin_memory=False)
class Frame(MutableMapping): class Frame(MutableMapping):
"""The columnar storage for node/edge features. """The columnar storage for node/edge features.
...@@ -702,3 +714,15 @@ class Frame(MutableMapping): ...@@ -702,3 +714,15 @@ class Frame(MutableMapping):
def __repr__(self): def __repr__(self):
return repr(dict(self)) return repr(dict(self))
def pin_memory_(self):
"""Registers the data of every column into pinned memory, materializing them if
necessary."""
for column in self._columns.values():
column.pin_memory_()
def unpin_memory_(self):
"""Unregisters the data of every column from pinned memory, materializing them
if necessary."""
for column in self._columns.values():
column.unpin_memory_()
...@@ -5474,7 +5474,7 @@ class DGLHeteroGraph(object): ...@@ -5474,7 +5474,7 @@ class DGLHeteroGraph(object):
Materialization of new sparse formats for pinned graphs is not allowed. Materialization of new sparse formats for pinned graphs is not allowed.
To avoid implicit formats materialization during training, To avoid implicit formats materialization during training,
you should create all the needed formats before pinnning. you should create all the needed formats before pinning.
But cloning and materialization is fine. See the examples below. But cloning and materialization is fine. See the examples below.
Returns Returns
...@@ -5530,6 +5530,7 @@ class DGLHeteroGraph(object): ...@@ -5530,6 +5530,7 @@ class DGLHeteroGraph(object):
if F.device_type(self.device) != 'cpu': if F.device_type(self.device) != 'cpu':
raise DGLError("The graph structure must be on CPU to be pinned.") raise DGLError("The graph structure must be on CPU to be pinned.")
self._graph.pin_memory_() self._graph.pin_memory_()
return self return self
def unpin_memory_(self): def unpin_memory_(self):
...@@ -5546,6 +5547,7 @@ class DGLHeteroGraph(object): ...@@ -5546,6 +5547,7 @@ class DGLHeteroGraph(object):
if not self._graph.is_pinned(): if not self._graph.is_pinned():
return self return self
self._graph.unpin_memory_() self._graph.unpin_memory_()
return self return self
def is_pinned(self): def is_pinned(self):
......
"""Module for heterogeneous graph index class definition.""" """Module for heterogeneous graph index class definition."""
from __future__ import absolute_import from __future__ import absolute_import
import sys
import itertools import itertools
import numpy as np import numpy as np
import scipy import scipy
...@@ -1365,4 +1366,27 @@ class HeteroPickleStates(ObjectBase): ...@@ -1365,4 +1366,27 @@ class HeteroPickleStates(ObjectBase):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_CAPI_DGLCreateHeteroPickleStatesOld, metagraph, num_nodes_per_type, adjs) _CAPI_DGLCreateHeteroPickleStatesOld, metagraph, num_nodes_per_type, adjs)
def _forking_rebuild(pk_state):
meta, arrays = pk_state
arrays = [F.to_dgl_nd(arr) for arr in arrays]
states = _CAPI_DGLCreateHeteroPickleStates(meta, arrays)
return _CAPI_DGLHeteroForkingUnpickle(states)
def _forking_reduce(graph_index):
states = _CAPI_DGLHeteroForkingPickle(graph_index)
arrays = [F.from_dgl_nd(arr) for arr in states.arrays]
# Similar to what being mentioned in HeteroGraphIndex.__getstate__, we need to save
# the tensors as an attribute of the original graph index object. Otherwise
# PyTorch will throw weird errors like bad value(s) in fds_to_keep or unable to
# resize file.
graph_index._forking_pk_state = (states.meta, arrays)
return _forking_rebuild, (graph_index._forking_pk_state,)
if not (F.get_preferred_backend() == 'mxnet' and sys.version_info.minor <= 6):
# Python 3.6 MXNet crashes with the following statement; remove until we no longer support
# 3.6 (which is EOL anyway).
from multiprocessing.reduction import ForkingPickler
ForkingPickler.register(HeteroGraphIndex, _forking_reduce)
_init_api("dgl.heterograph_index") _init_api("dgl.heterograph_index")
...@@ -222,6 +222,7 @@ class SparseMatrix(ObjectBase): ...@@ -222,6 +222,7 @@ class SparseMatrix(ObjectBase):
_set_class_ndarray(NDArray) _set_class_ndarray(NDArray)
_init_api("dgl.ndarray") _init_api("dgl.ndarray")
_init_api("dgl.ndarray.uvm", __name__)
# An array representing null (no value) that can be safely converted to # An array representing null (no value) that can be safely converted to
# other backend tensors. # other backend tensors.
......
...@@ -3,7 +3,8 @@ from .. import backend as F ...@@ -3,7 +3,8 @@ from .. import backend as F
from .base import * from .base import *
from .numpy import * from .numpy import *
# Defines the name TensorStorage
if F.get_preferred_backend() == 'pytorch': if F.get_preferred_backend() == 'pytorch':
from .pytorch_tensor import * from .pytorch_tensor import PyTorchTensorStorage as TensorStorage
else: else:
from .tensor import * from .tensor import BaseTensorStorage as TensorStorage
"""Feature storages for PyTorch tensors.""" """Feature storages for PyTorch tensors."""
import torch import torch
from .base import FeatureStorage, register_storage_wrapper from .base import register_storage_wrapper
from .tensor import BaseTensorStorage
from ..utils import gather_pinned_tensor_rows
def _fetch_cpu(indices, tensor, feature_shape, device, pin_memory): def _fetch_cpu(indices, tensor, feature_shape, device, pin_memory):
result = torch.empty( result = torch.empty(
...@@ -15,18 +17,26 @@ def _fetch_cuda(indices, tensor, device): ...@@ -15,18 +17,26 @@ def _fetch_cuda(indices, tensor, device):
return torch.index_select(tensor, 0, indices).to(device) return torch.index_select(tensor, 0, indices).to(device)
@register_storage_wrapper(torch.Tensor) @register_storage_wrapper(torch.Tensor)
class TensorStorage(FeatureStorage): class PyTorchTensorStorage(BaseTensorStorage):
"""Feature storages for slicing a PyTorch tensor.""" """Feature storages for slicing a PyTorch tensor."""
def __init__(self, tensor):
self.storage = tensor
self.feature_shape = tensor.shape[1:]
self.is_cuda = (tensor.device.type == 'cuda')
def fetch(self, indices, device, pin_memory=False): def fetch(self, indices, device, pin_memory=False):
device = torch.device(device) device = torch.device(device)
if not self.is_cuda: storage_device_type = self.storage.device.type
indices_device_type = indices.device.type
if storage_device_type != 'cuda':
if indices_device_type == 'cuda':
if self.storage.is_pinned():
return gather_pinned_tensor_rows(self.storage, indices)
else:
raise ValueError(
f'Got indices on device {indices.device} whereas the feature tensor '
f'is on {self.storage.device}. Please either (1) move the graph '
f'to GPU with to() method, or (2) pin the graph with '
f'pin_memory_() method.')
# CPU to CPU or CUDA - use pin_memory and async transfer if possible # CPU to CPU or CUDA - use pin_memory and async transfer if possible
return _fetch_cpu(indices, self.storage, self.feature_shape, device, pin_memory) else:
return _fetch_cpu(indices, self.storage, self.storage.shape[1:], device,
pin_memory)
else: else:
# CUDA to CUDA or CPU # CUDA to CUDA or CPU
return _fetch_cuda(indices, self.storage, device) return _fetch_cuda(indices, self.storage, device)
"""Feature storages for tensors across different frameworks.""" """Feature storages for tensors across different frameworks."""
from .base import FeatureStorage from .base import FeatureStorage
from .. import backend as F from .. import backend as F
from ..utils import recursive_apply_pair
def _fetch(indices, tensor, device): class BaseTensorStorage(FeatureStorage):
return F.copy_to(F.gather_row(tensor, indices), device)
class TensorStorage(FeatureStorage):
"""FeatureStorage that synchronously slices features from a tensor and transfers """FeatureStorage that synchronously slices features from a tensor and transfers
it to the given device. it to the given device.
""" """
...@@ -14,4 +10,4 @@ class TensorStorage(FeatureStorage): ...@@ -14,4 +10,4 @@ class TensorStorage(FeatureStorage):
self.storage = tensor self.storage = tensor
def fetch(self, indices, device, pin_memory=False): # pylint: disable=unused-argument def fetch(self, indices, device, pin_memory=False): # pylint: disable=unused-argument
return recursive_apply_pair(indices, self.storage, _fetch, device) return F.copy_to(F.gather_row(tensor, indices), device)
...@@ -5,3 +5,4 @@ from .checks import * ...@@ -5,3 +5,4 @@ from .checks import *
from .shared_mem import * from .shared_mem import *
from .filter import * from .filter import *
from .exception import * from .exception import *
from .pin_memory import *
...@@ -937,4 +937,8 @@ def recursive_apply_pair(data1, data2, fn, *args, **kwargs): ...@@ -937,4 +937,8 @@ def recursive_apply_pair(data1, data2, fn, *args, **kwargs):
else: else:
return fn(data1, data2, *args, **kwargs) return fn(data1, data2, *args, **kwargs)
def context_of(data):
"""Return the device of the data which can be either a tensor or a dict of tensors."""
return F.context(next(iter(data.values())) if isinstance(data, Mapping) else data)
_init_api("dgl.utils.internal") _init_api("dgl.utils.internal")
"""Utility functions related to pinned memory tensors."""
from .. import backend as F
from .._ffi.function import _init_api
def pin_memory_inplace(tensor):
"""Register the tensor into pinned memory in-place (i.e. without copying)."""
F.to_dgl_nd(tensor).pin_memory_()
def unpin_memory_inplace(tensor):
"""Unregister the tensor from pinned memory in-place (i.e. without copying)."""
F.to_dgl_nd(tensor).unpin_memory_()
def gather_pinned_tensor_rows(tensor, rows):
"""Directly gather rows from a CPU tensor given an indices array on CUDA devices,
and returns the result on the same CUDA device without copying.
Parameters
----------
tensor : Tensor
The tensor. Must be in pinned memory.
rows : Tensor
The rows to gather. Must be a CUDA tensor.
Returns
-------
Tensor
The result with the same device as :attr:`rows`.
"""
return F.from_dgl_nd(_CAPI_DGLIndexSelectCPUFromGPU(F.to_dgl_nd(tensor), F.to_dgl_nd(rows)))
_init_api("dgl.ndarray.uvm", __name__)
...@@ -27,7 +27,7 @@ NDArray IndexSelect(NDArray array, IdArray index) { ...@@ -27,7 +27,7 @@ NDArray IndexSelect(NDArray array, IdArray index) {
shape.emplace_back(array->shape[d]); shape.emplace_back(array->shape[d]);
} }
// use index->ctx for kDLCPUPinned array // use index->ctx for pinned array
NDArray ret = NDArray::Empty(shape, array->dtype, index->ctx); NDArray ret = NDArray::Empty(shape, array->dtype, index->ctx);
if (len == 0) if (len == 0)
return ret; return ret;
......
...@@ -24,7 +24,7 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) { ...@@ -24,7 +24,7 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
int64_t num_feat = 1; int64_t num_feat = 1;
std::vector<int64_t> shape{len}; std::vector<int64_t> shape{len};
CHECK_EQ(array->ctx.device_type, kDLCPUPinned); CHECK(array.IsPinned());
CHECK_EQ(index->ctx.device_type, kDLGPU); CHECK_EQ(index->ctx.device_type, kDLGPU);
for (int d = 1; d < array->ndim; ++d) { for (int d = 1; d < array->ndim; ++d) {
...@@ -72,6 +72,8 @@ template NDArray IndexSelectCPUFromGPU<int64_t, int32_t>(NDArray, IdArray); ...@@ -72,6 +72,8 @@ template NDArray IndexSelectCPUFromGPU<int64_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<int64_t, int64_t>(NDArray, IdArray); template NDArray IndexSelectCPUFromGPU<int64_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<float, int32_t>(NDArray, IdArray); template NDArray IndexSelectCPUFromGPU<float, int32_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<float, int64_t>(NDArray, IdArray); template NDArray IndexSelectCPUFromGPU<float, int64_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<double, int32_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<double, int64_t>(NDArray, IdArray);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -15,7 +15,7 @@ namespace aten { ...@@ -15,7 +15,7 @@ namespace aten {
NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) { NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
CHECK_EQ(array->ctx.device_type, kDLCPUPinned) CHECK(array.IsPinned())
<< "Only the CPUPinned device type input array is supported"; << "Only the CPUPinned device type input array is supported";
CHECK_EQ(index->ctx.device_type, kDLGPU) CHECK_EQ(index->ctx.device_type, kDLGPU)
<< "Only the GPU device type input index is supported"; << "Only the GPU device type input index is supported";
......
...@@ -83,12 +83,6 @@ BcastOff CalcBcastOff(const std::string& op, NDArray lhs, NDArray rhs) { ...@@ -83,12 +83,6 @@ BcastOff CalcBcastOff(const std::string& op, NDArray lhs, NDArray rhs) {
rst.out_len /= rst.reduce_size; // out_len is divied by reduce_size in dot. rst.out_len /= rst.reduce_size; // out_len is divied by reduce_size in dot.
} }
} }
#ifdef DEBUG
LOG(INFO) << "lhs_len: " << rst.lhs_len << " " <<
"rhs_len: " << rst.rhs_len << " " <<
"out_len: " << rst.out_len << " " <<
"reduce_size: " << rst.reduce_size << std::endl;
#endif
return rst; return rst;
} }
......
...@@ -236,7 +236,7 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -236,7 +236,7 @@ class HeteroGraph : public BaseHeteroGraph {
* \brief Pin all relation graphs of the current graph. * \brief Pin all relation graphs of the current graph.
* \note The graph will be pinned inplace. Behavior depends on the current context, * \note The graph will be pinned inplace. Behavior depends on the current context,
* kDLCPU: will be pinned; * kDLCPU: will be pinned;
* kDLCPUPinned: directly return; * IsPinned: directly return;
* kDLGPU: invalid, will throw an error. * kDLGPU: invalid, will throw an error.
* The context check is deferred to pinning the NDArray. * The context check is deferred to pinning the NDArray.
*/ */
...@@ -245,7 +245,7 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -245,7 +245,7 @@ class HeteroGraph : public BaseHeteroGraph {
/*! /*!
* \brief Unpin all relation graphs of the current graph. * \brief Unpin all relation graphs of the current graph.
* \note The graph will be unpinned inplace. Behavior depends on the current context, * \note The graph will be unpinned inplace. Behavior depends on the current context,
* kDLCPUPinned: will be unpinned; * IsPinned: will be unpinned;
* others: directly return. * others: directly return.
* The context check is deferred to unpinning the NDArray. * The context check is deferred to unpinning the NDArray.
*/ */
...@@ -272,6 +272,18 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -272,6 +272,18 @@ class HeteroGraph : public BaseHeteroGraph {
return relation_graphs_; return relation_graphs_;
} }
void SetCOOMatrix(dgl_type_t etype, aten::COOMatrix coo) override {
GetRelationGraph(etype)->SetCOOMatrix(0, coo);
}
void SetCSRMatrix(dgl_type_t etype, aten::CSRMatrix csr) override {
GetRelationGraph(etype)->SetCSRMatrix(0, csr);
}
void SetCSCMatrix(dgl_type_t etype, aten::CSRMatrix csc) override {
GetRelationGraph(etype)->SetCSCMatrix(0, csc);
}
private: private:
// To create empty class // To create empty class
friend class Serializer; friend class Serializer;
......
...@@ -173,13 +173,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDataType") ...@@ -173,13 +173,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDataType")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroContext") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroContext")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
// The Python side only recognizes CPU and GPU device type. *rv = hg->Context();
// Use is_pinned() to checked whether the object is
// on page-locked memory
if (hg->Context().device_type == kDLCPUPinned)
*rv = DLContext{kDLCPU, 0};
else
*rv = hg->Context();
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroIsPinned") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroIsPinned")
......
...@@ -51,6 +51,42 @@ HeteroPickleStates HeteroPickle(HeteroGraphPtr graph) { ...@@ -51,6 +51,42 @@ HeteroPickleStates HeteroPickle(HeteroGraphPtr graph) {
return states; return states;
} }
HeteroPickleStates HeteroForkingPickle(HeteroGraphPtr graph) {
HeteroPickleStates states;
dmlc::MemoryStringStream ofs(&states.meta);
dmlc::Stream *strm = &ofs;
strm->Write(ImmutableGraph::ToImmutable(graph->meta_graph()));
strm->Write(graph->NumVerticesPerType());
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
auto created_formats = graph->GetCreatedFormats();
auto allowed_formats = graph->GetAllowedFormats();
strm->Write(created_formats);
strm->Write(allowed_formats);
if (created_formats & COO_CODE) {
const auto &coo = graph->GetCOOMatrix(etype);
strm->Write(coo.row_sorted);
strm->Write(coo.col_sorted);
states.arrays.push_back(coo.row);
states.arrays.push_back(coo.col);
}
if (created_formats & CSR_CODE) {
const auto &csr = graph->GetCSRMatrix(etype);
strm->Write(csr.sorted);
states.arrays.push_back(csr.indptr);
states.arrays.push_back(csr.indices);
states.arrays.push_back(csr.data);
}
if (created_formats & CSC_CODE) {
const auto &csc = graph->GetCSCMatrix(etype);
strm->Write(csc.sorted);
states.arrays.push_back(csc.indptr);
states.arrays.push_back(csc.indices);
states.arrays.push_back(csc.data);
}
}
return states;
}
HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) { HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) {
char *buf = const_cast<char *>(states.meta.c_str()); // a readonly stream? char *buf = const_cast<char *>(states.meta.c_str()); // a readonly stream?
dmlc::MemoryFixedSizeStream ifs(buf, states.meta.size()); dmlc::MemoryFixedSizeStream ifs(buf, states.meta.size());
...@@ -137,6 +173,76 @@ HeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates& states) { ...@@ -137,6 +173,76 @@ HeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates& states) {
return CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type); return CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type);
} }
HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates &states) {
char *buf = const_cast<char *>(states.meta.c_str()); // a readonly stream?
dmlc::MemoryFixedSizeStream ifs(buf, states.meta.size());
dmlc::Stream *strm = &ifs;
auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
CHECK(strm->Read(&meta_imgraph)) << "Invalid meta graph";
GraphPtr metagraph = meta_imgraph;
std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());
std::vector<int64_t> num_nodes_per_type;
CHECK(strm->Read(&num_nodes_per_type)) << "Invalid num_nodes_per_type";
auto array_itr = states.arrays.begin();
for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {
const auto& pair = metagraph->FindEdge(etype);
const dgl_type_t srctype = pair.first;
const dgl_type_t dsttype = pair.second;
const int64_t num_vtypes = (srctype == dsttype) ? 1 : 2;
int64_t num_src = num_nodes_per_type[srctype];
int64_t num_dst = num_nodes_per_type[dsttype];
dgl_format_code_t created_formats, allowed_formats;
CHECK(strm->Read(&created_formats)) << "Invalid code for created formats";
CHECK(strm->Read(&allowed_formats)) << "Invalid code for allowed formats";
HeteroGraphPtr relgraph = nullptr;
if (created_formats & COO_CODE) {
CHECK_GE(states.arrays.end() - array_itr, 2);
const auto &row = *(array_itr++);
const auto &col = *(array_itr++);
bool rsorted;
bool csorted;
CHECK(strm->Read(&rsorted)) << "Invalid flag 'rsorted'";
CHECK(strm->Read(&csorted)) << "Invalid flag 'csorted'";
auto coo = aten::COOMatrix(num_src, num_dst, row, col, aten::NullArray(), rsorted, csorted);
if (!relgraph)
relgraph = CreateFromCOO(num_vtypes, coo, allowed_formats);
else
relgraph->SetCOOMatrix(0, coo);
}
if (created_formats & CSR_CODE) {
CHECK_GE(states.arrays.end() - array_itr, 3);
const auto &indptr = *(array_itr++);
const auto &indices = *(array_itr++);
const auto &edge_id = *(array_itr++);
bool sorted;
CHECK(strm->Read(&sorted)) << "Invalid flag 'sorted'";
auto csr = aten::CSRMatrix(num_src, num_dst, indptr, indices, edge_id, sorted);
if (!relgraph)
relgraph = CreateFromCSR(num_vtypes, csr, allowed_formats);
else
relgraph->SetCSRMatrix(0, csr);
}
if (created_formats & CSC_CODE) {
CHECK_GE(states.arrays.end() - array_itr, 3);
const auto &indptr = *(array_itr++);
const auto &indices = *(array_itr++);
const auto &edge_id = *(array_itr++);
bool sorted;
CHECK(strm->Read(&sorted)) << "Invalid flag 'sorted'";
auto csc = aten::CSRMatrix(num_dst, num_src, indptr, indices, edge_id, sorted);
if (!relgraph)
relgraph = CreateFromCSC(num_vtypes, csc, allowed_formats);
else
relgraph->SetCSCMatrix(0, csc);
}
relgraphs[etype] = relgraph;
}
return CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type);
}
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetVersion") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetVersion")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroPickleStatesRef st = args[0]; HeteroPickleStatesRef st = args[0];
...@@ -186,6 +292,14 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickle") ...@@ -186,6 +292,14 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickle")
*rv = HeteroPickleStatesRef(st); *rv = HeteroPickleStatesRef(st);
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroForkingPickle")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef ref = args[0];
std::shared_ptr<HeteroPickleStates> st( new HeteroPickleStates );
*st = HeteroForkingPickle(ref.sptr());
*rv = HeteroPickleStatesRef(st);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpickle") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpickle")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroPickleStatesRef ref = args[0]; HeteroPickleStatesRef ref = args[0];
...@@ -203,6 +317,13 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpickle") ...@@ -203,6 +317,13 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpickle")
*rv = HeteroGraphRef(graph); *rv = HeteroGraphRef(graph);
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroForkingUnpickle")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroPickleStatesRef ref = args[0];
HeteroGraphPtr graph = HeteroForkingUnpickle(*ref.sptr());
*rv = HeteroGraphRef(graph);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLCreateHeteroPickleStatesOld") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLCreateHeteroPickleStatesOld")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef metagraph = args[0]; GraphRef metagraph = args[0];
......
...@@ -31,7 +31,7 @@ HeteroSubgraph ExcludeCertainEdges( ...@@ -31,7 +31,7 @@ HeteroSubgraph ExcludeCertainEdges(
sg.induced_edges[etype]->shape[0], sg.induced_edges[etype]->shape[0],
sg.induced_edges[etype]->dtype.bits, sg.induced_edges[etype]->dtype.bits,
sg.induced_edges[etype]->ctx); sg.induced_edges[etype]->ctx);
if (exclude_edges[etype].GetSize() == 0) { if (exclude_edges[etype].GetSize() == 0 || edge_ids.GetSize() == 0) {
remain_edges[etype] = edge_ids; remain_edges[etype] = edge_ids;
remain_induced_edges[etype] = sg.induced_edges[etype]; remain_induced_edges[etype] = sg.induced_edges[etype];
continue; continue;
......
...@@ -4,12 +4,12 @@ ...@@ -4,12 +4,12 @@
* \brief frequency hashmap - used to select top-k frequency edges of each node * \brief frequency hashmap - used to select top-k frequency edges of each node
*/ */
#include <cub/cub.cuh>
#include <algorithm> #include <algorithm>
#include <tuple> #include <tuple>
#include <utility> #include <utility>
#include "../../../runtime/cuda/cuda_common.h" #include "../../../runtime/cuda/cuda_common.h"
#include "../../../array/cuda/atomic.cuh" #include "../../../array/cuda/atomic.cuh"
#include "../../../array/cuda/dgl_cub.cuh"
#include "frequency_hashmap.cuh" #include "frequency_hashmap.cuh"
namespace dgl { namespace dgl {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment