Unverified Commit 3f138eba authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bugfix] Bug fixes in new dataloader (#3727)



* fixes

* fix

* more fixes

* update

* oops

* lint?

* temporarily revert - will fix in another PR

* more fixes

* skipping mxnet test

* address comments

* fix DDP

* fix edge dataloader exclusion problems

* stupid bug

* fix

* use_uvm option

* fix

* fixes

* fixes

* fixes

* fixes

* add evaluation for cluster gcn and ddp

* stupid bug again

* fixes

* move sanity checks to only support DGLGraphs

* pytorch lightning compatibility fixes

* remove

* poke

* more fixes

* fix

* fix

* disable test

* docstrings

* why is it getting a memory leak?

* fix

* update

* updates and temporarily disable forkingpickler

* update

* fix?

* fix?

* oops

* oops

* fix

* lint

* huh

* uh

* update

* fix

* made it memory efficient

* refine exclude interface

* fix tutorial

* fix tutorial

* fix graph duplication in CPU dataloader workers

* lint

* lint

* Revert "lint"

This reverts commit 805484dd553695111b5fb37f2125214a6b7276e9.

* Revert "lint"

This reverts commit 0bce411b2b415c2ab770343949404498436dc8b2.

* Revert "fix graph duplication in CPU dataloader workers"

This reverts commit 9e3a8cf34c175d3093c773f6bb023b155f2bd27f.
Co-authored-by: default avatarxiny <xiny@nvidia.com>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 7b9afbfa
"""DGL PyTorch DataLoaders""" """DGL PyTorch DataLoaders"""
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from queue import Queue from queue import Queue, Empty, Full
import itertools import itertools
import threading import threading
from distutils.version import LooseVersion from distutils.version import LooseVersion
...@@ -8,23 +8,33 @@ import random ...@@ -8,23 +8,33 @@ import random
import math import math
import inspect import inspect
import re import re
import atexit
import os
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from ..base import NID, EID, dgl_warning from ..base import NID, EID
from ..batch import batch as batch_graphs from ..batch import batch as batch_graphs
from ..heterograph import DGLHeteroGraph from ..heterograph import DGLHeteroGraph
from .. import ndarray as nd from .. import ndarray as nd
from ..utils import ( from ..utils import (
recursive_apply, ExceptionWrapper, recursive_apply_pair, set_num_threads, recursive_apply, ExceptionWrapper, recursive_apply_pair, set_num_threads,
create_shared_mem_array, get_shared_mem_array) create_shared_mem_array, get_shared_mem_array, context_of, pin_memory_inplace)
from ..frame import LazyFeature from ..frame import LazyFeature
from ..storages import wrap_storage from ..storages import wrap_storage
from .base import BlockSampler, EdgeBlockSampler from .base import BlockSampler, EdgeBlockSampler
from .. import backend as F from .. import backend as F
PYTHON_EXIT_STATUS = False
def _set_python_exit_flag():
global PYTHON_EXIT_STATUS
PYTHON_EXIT_STATUS = True
atexit.register(_set_python_exit_flag)
prefetcher_timeout = int(os.environ.get('DGL_PREFETCHER_TIMEOUT', '10'))
class _TensorizedDatasetIter(object): class _TensorizedDatasetIter(object):
def __init__(self, dataset, batch_size, drop_last, mapping_keys): def __init__(self, dataset, batch_size, drop_last, mapping_keys):
self.dataset = dataset self.dataset = dataset
...@@ -54,7 +64,8 @@ class _TensorizedDatasetIter(object): ...@@ -54,7 +64,8 @@ class _TensorizedDatasetIter(object):
def __next__(self): def __next__(self):
batch = self._next_indices() batch = self._next_indices()
if self.mapping_keys is None: if self.mapping_keys is None:
return batch # clone() fixes #3755, probably. Not sure why. Need to take a look afterwards.
return batch.clone()
# convert the type-ID pairs to dictionary # convert the type-ID pairs to dictionary
type_ids = batch[:, 0] type_ids = batch[:, 0]
...@@ -67,28 +78,31 @@ class _TensorizedDatasetIter(object): ...@@ -67,28 +78,31 @@ class _TensorizedDatasetIter(object):
type_id_offset = type_id_count.cumsum(0).tolist() type_id_offset = type_id_count.cumsum(0).tolist()
type_id_offset.insert(0, 0) type_id_offset.insert(0, 0)
id_dict = { id_dict = {
self.mapping_keys[type_id_uniq[i]]: indices[type_id_offset[i]:type_id_offset[i+1]] self.mapping_keys[type_id_uniq[i]]:
indices[type_id_offset[i]:type_id_offset[i+1]].clone()
for i in range(len(type_id_uniq))} for i in range(len(type_id_uniq))}
return id_dict return id_dict
def _get_id_tensor_from_mapping(indices, device, keys): def _get_id_tensor_from_mapping(indices, device, keys):
lengths = torch.LongTensor([ lengths = torch.LongTensor([
(indices[k].shape[0] if k in indices else 0) for k in keys], device=device) (indices[k].shape[0] if k in indices else 0) for k in keys]).to(device)
type_ids = torch.arange(len(keys), device=device).repeat_interleave(lengths) type_ids = torch.arange(len(keys), device=device).repeat_interleave(lengths)
all_indices = torch.cat([indices[k] for k in keys if k in indices]) all_indices = torch.cat([indices[k] for k in keys if k in indices])
return torch.stack([type_ids, all_indices], 1) return torch.stack([type_ids, all_indices], 1)
def _divide_by_worker(dataset): def _divide_by_worker(dataset, batch_size, drop_last):
num_samples = dataset.shape[0] num_samples = dataset.shape[0]
worker_info = torch.utils.data.get_worker_info() worker_info = torch.utils.data.get_worker_info()
if worker_info: if worker_info:
chunk_size = num_samples // worker_info.num_workers num_batches = (num_samples + (0 if drop_last else batch_size - 1)) // batch_size
left_over = num_samples % worker_info.num_workers num_batches_per_worker = num_batches // worker_info.num_workers
start = (chunk_size * worker_info.id) + min(left_over, worker_info.id) left_over = num_batches % worker_info.num_workers
end = start + chunk_size + (worker_info.id < left_over) start = (num_batches_per_worker * worker_info.id) + min(left_over, worker_info.id)
assert worker_info.id < worker_info.num_workers - 1 or end == num_samples end = start + num_batches_per_worker + (worker_info.id < left_over)
start *= batch_size
end = min(end * batch_size, num_samples)
dataset = dataset[start:end] dataset = dataset[start:end]
return dataset return dataset
...@@ -98,31 +112,39 @@ class TensorizedDataset(torch.utils.data.IterableDataset): ...@@ -98,31 +112,39 @@ class TensorizedDataset(torch.utils.data.IterableDataset):
When the dataset is on the GPU, this significantly reduces the overhead. When the dataset is on the GPU, this significantly reduces the overhead.
""" """
def __init__(self, indices, batch_size, drop_last): def __init__(self, indices, batch_size, drop_last):
name, _ = _generate_shared_mem_name_id()
if isinstance(indices, Mapping): if isinstance(indices, Mapping):
self._mapping_keys = list(indices.keys()) self._mapping_keys = list(indices.keys())
self._device = next(iter(indices.values())).device self._device = next(iter(indices.values())).device
self._tensor_dataset = _get_id_tensor_from_mapping( self._id_tensor = _get_id_tensor_from_mapping(
indices, self._device, self._mapping_keys) indices, self._device, self._mapping_keys)
else: else:
self._tensor_dataset = indices self._id_tensor = indices
self._device = indices.device self._device = indices.device
self._mapping_keys = None self._mapping_keys = None
# Use a shared memory array to permute indices for shuffling. This is to make sure that
# the worker processes can see it when persistent_workers=True, where self._indices
# would not be duplicated every epoch.
self._indices = create_shared_mem_array(name, (self._id_tensor.shape[0],), torch.int64)
self._indices[:] = torch.arange(self._id_tensor.shape[0])
self.batch_size = batch_size self.batch_size = batch_size
self.drop_last = drop_last self.drop_last = drop_last
self.shared_mem_name = name
self.shared_mem_size = self._indices.shape[0]
def shuffle(self): def shuffle(self):
"""Shuffle the dataset.""" """Shuffle the dataset."""
# TODO: may need an in-place shuffle kernel # TODO: may need an in-place shuffle kernel
perm = torch.randperm(self._tensor_dataset.shape[0], device=self._device) self._indices[:] = self._indices[torch.randperm(self._indices.shape[0])]
self._tensor_dataset[:] = self._tensor_dataset[perm]
def __iter__(self): def __iter__(self):
dataset = _divide_by_worker(self._tensor_dataset) indices = _divide_by_worker(self._indices, self.batch_size, self.drop_last)
id_tensor = self._id_tensor[indices.to(self._device)]
return _TensorizedDatasetIter( return _TensorizedDatasetIter(
dataset, self.batch_size, self.drop_last, self._mapping_keys) id_tensor, self.batch_size, self.drop_last, self._mapping_keys)
def __len__(self): def __len__(self):
num_samples = self._tensor_dataset.shape[0] num_samples = self._id_tensor.shape[0]
return (num_samples + (0 if self.drop_last else (self.batch_size - 1))) // self.batch_size return (num_samples + (0 if self.drop_last else (self.batch_size - 1))) // self.batch_size
def _get_shared_mem_name(id_): def _get_shared_mem_name(id_):
...@@ -168,20 +190,20 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset): ...@@ -168,20 +190,20 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset):
self.shared_mem_size = self.total_size if not self.drop_last else len(indices) self.shared_mem_size = self.total_size if not self.drop_last else len(indices)
self.num_indices = len(indices) self.num_indices = len(indices)
if isinstance(indices, Mapping):
self._device = next(iter(indices.values())).device
self._id_tensor = _get_id_tensor_from_mapping(
indices, self._device, self._mapping_keys)
else:
self._id_tensor = indices
self._device = self._id_tensor.device
if self.rank == 0: if self.rank == 0:
name, id_ = _generate_shared_mem_name_id() name, id_ = _generate_shared_mem_name_id()
if isinstance(indices, Mapping): self._indices = create_shared_mem_array(
device = next(iter(indices.values())).device name, (self.shared_mem_size,), torch.int64)
id_tensor = _get_id_tensor_from_mapping(indices, device, self._mapping_keys) self._indices[:self._id_tensor.shape[0]] = torch.arange(self._id_tensor.shape[0])
self._tensor_dataset = create_shared_mem_array( meta_info = torch.LongTensor([id_, self._indices.shape[0]])
name, (self.shared_mem_size, 2), torch.int64)
self._tensor_dataset[:id_tensor.shape[0], :] = id_tensor
else:
self._tensor_dataset = create_shared_mem_array(
name, (self.shared_mem_size,), torch.int64)
self._tensor_dataset[:len(indices)] = indices
self._device = self._tensor_dataset.device
meta_info = torch.LongTensor([id_, self._tensor_dataset.shape[0]])
else: else:
meta_info = torch.LongTensor([0, 0]) meta_info = torch.LongTensor([0, 0])
...@@ -194,43 +216,41 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset): ...@@ -194,43 +216,41 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset):
if self.rank != 0: if self.rank != 0:
id_, num_samples = meta_info.tolist() id_, num_samples = meta_info.tolist()
name = _get_shared_mem_name(id_) name = _get_shared_mem_name(id_)
if isinstance(indices, Mapping): indices_shared = get_shared_mem_array(name, (num_samples,), torch.int64)
indices_shared = get_shared_mem_array(name, (num_samples, 2), torch.int64) self._indices = indices_shared
else: self.shared_mem_name = name
indices_shared = get_shared_mem_array(name, (num_samples,), torch.int64)
self._tensor_dataset = indices_shared
self._device = indices_shared.device
def shuffle(self): def shuffle(self):
"""Shuffles the dataset.""" """Shuffles the dataset."""
# Only rank 0 does the actual shuffling. The other ranks wait for it. # Only rank 0 does the actual shuffling. The other ranks wait for it.
if self.rank == 0: if self.rank == 0:
self._tensor_dataset[:self.num_indices] = self._tensor_dataset[ self._indices[:self.num_indices] = self._indices[
torch.randperm(self.num_indices, device=self._device)] torch.randperm(self.num_indices, device=self._device)]
if not self.drop_last: if not self.drop_last:
# pad extra # pad extra
self._tensor_dataset[self.num_indices:] = \ self._indices[self.num_indices:] = \
self._tensor_dataset[:self.total_size - self.num_indices] self._indices[:self.total_size - self.num_indices]
dist.barrier() dist.barrier()
def __iter__(self): def __iter__(self):
start = self.num_samples * self.rank start = self.num_samples * self.rank
end = self.num_samples * (self.rank + 1) end = self.num_samples * (self.rank + 1)
dataset = _divide_by_worker(self._tensor_dataset[start:end]) indices = _divide_by_worker(self._indices[start:end], self.batch_size, self.drop_last)
id_tensor = self._id_tensor[indices.to(self._device)]
return _TensorizedDatasetIter( return _TensorizedDatasetIter(
dataset, self.batch_size, self.drop_last, self._mapping_keys) id_tensor, self.batch_size, self.drop_last, self._mapping_keys)
def __len__(self): def __len__(self):
return (self.num_samples + (0 if self.drop_last else (self.batch_size - 1))) // \ return (self.num_samples + (0 if self.drop_last else (self.batch_size - 1))) // \
self.batch_size self.batch_size
def _prefetch_update_feats(feats, frames, types, get_storage_func, id_name, device, pin_memory): def _prefetch_update_feats(feats, frames, types, get_storage_func, id_name, device, pin_prefetcher):
for tid, frame in enumerate(frames): for tid, frame in enumerate(frames):
type_ = types[tid] type_ = types[tid]
default_id = frame.get(id_name, None) default_id = frame.get(id_name, None)
for key in frame.keys(): for key in frame.keys():
column = frame[key] column = frame._columns[key]
if isinstance(column, LazyFeature): if isinstance(column, LazyFeature):
parent_key = column.name or key parent_key = column.name or key
if column.id_ is None and default_id is None: if column.id_ is None and default_id is None:
...@@ -238,7 +258,7 @@ def _prefetch_update_feats(feats, frames, types, get_storage_func, id_name, devi ...@@ -238,7 +258,7 @@ def _prefetch_update_feats(feats, frames, types, get_storage_func, id_name, devi
'Found a LazyFeature with no ID specified, ' 'Found a LazyFeature with no ID specified, '
'and the graph does not have dgl.NID or dgl.EID columns') 'and the graph does not have dgl.NID or dgl.EID columns')
feats[tid, key] = get_storage_func(parent_key, type_).fetch( feats[tid, key] = get_storage_func(parent_key, type_).fetch(
column.id_ or default_id, device, pin_memory) column.id_ or default_id, device, pin_prefetcher)
# This class exists to avoid recursion into the feature dictionary returned by the # This class exists to avoid recursion into the feature dictionary returned by the
...@@ -254,10 +274,10 @@ def _prefetch_for_subgraph(subg, dataloader): ...@@ -254,10 +274,10 @@ def _prefetch_for_subgraph(subg, dataloader):
node_feats, edge_feats = {}, {} node_feats, edge_feats = {}, {}
_prefetch_update_feats( _prefetch_update_feats(
node_feats, subg._node_frames, subg.ntypes, dataloader.graph.get_node_storage, node_feats, subg._node_frames, subg.ntypes, dataloader.graph.get_node_storage,
NID, dataloader.device, dataloader.pin_memory) NID, dataloader.device, dataloader.pin_prefetcher)
_prefetch_update_feats( _prefetch_update_feats(
edge_feats, subg._edge_frames, subg.canonical_etypes, dataloader.graph.get_edge_storage, edge_feats, subg._edge_frames, subg.canonical_etypes, dataloader.graph.get_edge_storage,
EID, dataloader.device, dataloader.pin_memory) EID, dataloader.device, dataloader.pin_prefetcher)
return _PrefetchedGraphFeatures(node_feats, edge_feats) return _PrefetchedGraphFeatures(node_feats, edge_feats)
...@@ -266,7 +286,7 @@ def _prefetch_for(item, dataloader): ...@@ -266,7 +286,7 @@ def _prefetch_for(item, dataloader):
return _prefetch_for_subgraph(item, dataloader) return _prefetch_for_subgraph(item, dataloader)
elif isinstance(item, LazyFeature): elif isinstance(item, LazyFeature):
return dataloader.other_storages[item.name].fetch( return dataloader.other_storages[item.name].fetch(
item.id_, dataloader.device, dataloader.pin_memory) item.id_, dataloader.device, dataloader.pin_prefetcher)
else: else:
return None return None
...@@ -313,8 +333,17 @@ def _assign_for(item, feat): ...@@ -313,8 +333,17 @@ def _assign_for(item, feat):
else: else:
return item return item
def _put_if_event_not_set(queue, result, event):
def _prefetcher_entry(dataloader_it, dataloader, queue, num_threads, use_alternate_streams): while not event.is_set():
try:
queue.put(result, timeout=1.0)
break
except Full:
continue
def _prefetcher_entry(
dataloader_it, dataloader, queue, num_threads, use_alternate_streams,
done_event):
# PyTorch will set the number of threads to 1 which slows down pin_memory() calls # PyTorch will set the number of threads to 1 which slows down pin_memory() calls
# in main process if a prefetching thread is created. # in main process if a prefetching thread is created.
if num_threads is not None: if num_threads is not None:
...@@ -327,20 +356,27 @@ def _prefetcher_entry(dataloader_it, dataloader, queue, num_threads, use_alterna ...@@ -327,20 +356,27 @@ def _prefetcher_entry(dataloader_it, dataloader, queue, num_threads, use_alterna
stream = None stream = None
try: try:
for batch in dataloader_it: while not done_event.is_set():
try:
batch = next(dataloader_it)
except StopIteration:
break
batch = recursive_apply(batch, restore_parent_storage_columns, dataloader.graph) batch = recursive_apply(batch, restore_parent_storage_columns, dataloader.graph)
feats = _prefetch(batch, dataloader, stream) feats = _prefetch(batch, dataloader, stream)
queue.put(( _put_if_event_not_set(queue, (
# batch will be already in pinned memory as per the behavior of # batch will be already in pinned memory as per the behavior of
# PyTorch DataLoader. # PyTorch DataLoader.
recursive_apply(batch, lambda x: x.to(dataloader.device, non_blocking=True)), recursive_apply(
batch, lambda x: x.to(dataloader.device, non_blocking=True)),
feats, feats,
stream.record_event() if stream is not None else None, stream.record_event() if stream is not None else None,
None)) None),
queue.put((None, None, None, None)) done_event)
_put_if_event_not_set(queue, (None, None, None, None), done_event)
except: # pylint: disable=bare-except except: # pylint: disable=bare-except
queue.put((None, None, None, ExceptionWrapper(where='in prefetcher'))) _put_if_event_not_set(
queue, (None, None, None, ExceptionWrapper(where='in prefetcher')), done_event)
# DGLHeteroGraphs have the semantics of lazy feature slicing with subgraphs. Such behavior depends # DGLHeteroGraphs have the semantics of lazy feature slicing with subgraphs. Such behavior depends
...@@ -400,15 +436,18 @@ class _PrefetchingIter(object): ...@@ -400,15 +436,18 @@ class _PrefetchingIter(object):
self.dataloader_it = dataloader_it self.dataloader_it = dataloader_it
self.dataloader = dataloader self.dataloader = dataloader
self.graph_sampler = self.dataloader.graph_sampler self.graph_sampler = self.dataloader.graph_sampler
self.pin_memory = self.dataloader.pin_memory self.pin_prefetcher = self.dataloader.pin_prefetcher
self.num_threads = num_threads self.num_threads = num_threads
self.use_thread = use_thread self.use_thread = use_thread
self.use_alternate_streams = use_alternate_streams self.use_alternate_streams = use_alternate_streams
self._shutting_down = False
if use_thread: if use_thread:
self._done_event = threading.Event()
thread = threading.Thread( thread = threading.Thread(
target=_prefetcher_entry, target=_prefetcher_entry,
args=(dataloader_it, dataloader, self.queue, num_threads, use_alternate_streams), args=(dataloader_it, dataloader, self.queue, num_threads,
use_alternate_streams, self._done_event),
daemon=True) daemon=True)
thread.start() thread.start()
self.thread = thread self.thread = thread
...@@ -416,6 +455,31 @@ class _PrefetchingIter(object): ...@@ -416,6 +455,31 @@ class _PrefetchingIter(object):
def __iter__(self): def __iter__(self):
return self return self
def _shutdown(self):
# Sometimes when Python is exiting complicated operations like
# self.queue.get_nowait() will hang. So we set it to no-op and let Python handle
# the rest since the thread is daemonic.
# PyTorch takes the same solution.
if PYTHON_EXIT_STATUS is True or PYTHON_EXIT_STATUS is None:
return
if not self._shutting_down:
try:
self._shutting_down = True
self._done_event.set()
try:
self.queue.get_nowait() # In case the thread is blocking on put().
except: # pylint: disable=bare-except
pass
self.thread.join()
except: # pylint: disable=bare-except
pass
def __del__(self):
if self.use_thread:
self._shutdown()
def _next_non_threaded(self): def _next_non_threaded(self):
batch = next(self.dataloader_it) batch = next(self.dataloader_it)
batch = recursive_apply(batch, restore_parent_storage_columns, self.dataloader.graph) batch = recursive_apply(batch, restore_parent_storage_columns, self.dataloader.graph)
...@@ -430,7 +494,11 @@ class _PrefetchingIter(object): ...@@ -430,7 +494,11 @@ class _PrefetchingIter(object):
return batch, feats, stream_event return batch, feats, stream_event
def _next_threaded(self): def _next_threaded(self):
batch, feats, stream_event, exception = self.queue.get() try:
batch, feats, stream_event, exception = self.queue.get(timeout=prefetcher_timeout)
except Empty:
raise RuntimeError(
f'Prefetcher thread timed out at {prefetcher_timeout} seconds.')
if batch is None: if batch is None:
self.thread.join() self.thread.join()
if exception is None: if exception is None:
...@@ -485,23 +553,100 @@ def create_tensorized_dataset(indices, batch_size, drop_last, use_ddp, ddp_seed) ...@@ -485,23 +553,100 @@ def create_tensorized_dataset(indices, batch_size, drop_last, use_ddp, ddp_seed)
return TensorizedDataset(indices, batch_size, drop_last) return TensorizedDataset(indices, batch_size, drop_last)
def _get_device(device):
device = torch.device(device)
if device.type == 'cuda' and device.index is None:
device = torch.device('cuda', torch.cuda.current_device())
return device
class DataLoader(torch.utils.data.DataLoader): class DataLoader(torch.utils.data.DataLoader):
"""DataLoader class.""" """DataLoader class."""
def __init__(self, graph, indices, graph_sampler, device='cpu', use_ddp=False, def __init__(self, graph, indices, graph_sampler, device='cpu', use_ddp=False,
ddp_seed=0, batch_size=1, drop_last=False, shuffle=False, ddp_seed=0, batch_size=1, drop_last=False, shuffle=False,
use_prefetch_thread=False, use_alternate_streams=True, **kwargs): use_prefetch_thread=None, use_alternate_streams=None,
pin_prefetcher=None, use_uva=False, **kwargs):
# (BarclayII) I hoped that pin_prefetcher can be merged into PyTorch's native
# pin_memory argument. But our neighbor samplers and subgraph samplers
# return indices, which could be CUDA tensors (e.g. during UVA sampling)
# hence cannot be pinned. PyTorch's native pin memory thread does not ignore
# CUDA tensors when pinning and will crash. To enable pin memory for prefetching
# features and disable pin memory for sampler's return value, I had to use
# a different argument. Of course I could change the meaning of pin_memory
# to pinning prefetched features and disable pin memory for sampler's returns
# no matter what, but I doubt if it's reasonable.
self.graph = graph self.graph = graph
self.indices = indices # For PyTorch-Lightning
num_workers = kwargs.get('num_workers', 0)
try: try:
if isinstance(indices, Mapping): if isinstance(indices, Mapping):
indices = {k: (torch.tensor(v) if not torch.is_tensor(v) else v) indices = {k: (torch.tensor(v) if not torch.is_tensor(v) else v)
for k, v in indices.items()} for k, v in indices.items()}
indices_device = next(iter(indices.values())).device
else: else:
indices = torch.tensor(indices) if not torch.is_tensor(indices) else indices indices = torch.tensor(indices) if not torch.is_tensor(indices) else indices
indices_device = indices.device
except: # pylint: disable=bare-except except: # pylint: disable=bare-except
# ignore when it fails to convert to torch Tensors. # ignore when it fails to convert to torch Tensors.
pass pass
self.device = _get_device(device)
# Sanity check - we only check for DGLGraphs.
if isinstance(self.graph, DGLHeteroGraph):
# Check graph and indices device as well as num_workers
if use_uva:
if self.graph.device.type != 'cpu':
raise ValueError('Graph must be on CPU if UVA sampling is enabled.')
if num_workers > 0:
raise ValueError('num_workers must be 0 if UVA sampling is enabled.')
# Create all the formats and pin the features - custom GraphStorages
# will need to do that themselves.
self.graph.create_formats_()
self.graph.pin_memory_()
for frame in itertools.chain(self.graph._node_frames, self.graph._edge_frames):
for col in frame._columns.values():
pin_memory_inplace(col.data)
indices = recursive_apply(indices, lambda x: x.to(self.device))
else:
if self.graph.device != indices_device:
raise ValueError(
'Expect graph and indices to be on the same device. '
'If you wish to use UVA sampling, please set use_uva=True.')
if self.graph.device.type == 'cuda':
if num_workers > 0:
raise ValueError('num_workers must be 0 if graph and indices are on CUDA.')
# Check pin_prefetcher and use_prefetch_thread - should be only effective
# if performing CPU sampling but output device is CUDA
if not (self.device.type == 'cuda' and self.graph.device.type == 'cpu'):
if pin_prefetcher is True:
raise ValueError(
'pin_prefetcher=True is only effective when device=cuda and '
'sampling is performed on CPU.')
if pin_prefetcher is None:
pin_prefetcher = False
if use_prefetch_thread is True:
raise ValueError(
'use_prefetch_thread=True is only effective when device=cuda and '
'sampling is performed on CPU.')
if pin_prefetcher is None:
pin_prefetcher = False
else:
if pin_prefetcher is None:
pin_prefetcher = True
if use_prefetch_thread is None:
use_prefetch_thread = True
# Check use_alternate_streams
if use_alternate_streams is None:
use_alternate_streams = (
self.device.type == 'cuda' and self.graph.device.type == 'cpu' and
not use_uva)
if (torch.is_tensor(indices) or ( if (torch.is_tensor(indices) or (
isinstance(indices, Mapping) and isinstance(indices, Mapping) and
all(torch.is_tensor(v) for v in indices.values()))): all(torch.is_tensor(v) for v in indices.values()))):
...@@ -511,17 +656,18 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -511,17 +656,18 @@ class DataLoader(torch.utils.data.DataLoader):
self.dataset = indices self.dataset = indices
self.ddp_seed = ddp_seed self.ddp_seed = ddp_seed
self._shuffle_dataset = shuffle self.use_ddp = use_ddp
self.use_uva = use_uva
self.shuffle = shuffle
self.drop_last = drop_last
self.graph_sampler = graph_sampler self.graph_sampler = graph_sampler
self.device = torch.device(device)
self.use_alternate_streams = use_alternate_streams self.use_alternate_streams = use_alternate_streams
if self.device.type == 'cuda' and self.device.index is None: self.pin_prefetcher = pin_prefetcher
self.device = torch.device('cuda', torch.cuda.current_device())
self.use_prefetch_thread = use_prefetch_thread self.use_prefetch_thread = use_prefetch_thread
worker_init_fn = WorkerInitWrapper(kwargs.get('worker_init_fn', None)) worker_init_fn = WorkerInitWrapper(kwargs.get('worker_init_fn', None))
# Instantiate all the formats if the number of workers is greater than 0. # Instantiate all the formats if the number of workers is greater than 0.
if kwargs.get('num_workers', 0) > 0 and hasattr(self.graph, 'create_formats_'): if num_workers > 0 and hasattr(self.graph, 'create_formats_'):
self.graph.create_formats_() self.graph.create_formats_()
self.other_storages = {} self.other_storages = {}
...@@ -534,7 +680,7 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -534,7 +680,7 @@ class DataLoader(torch.utils.data.DataLoader):
**kwargs) **kwargs)
def __iter__(self): def __iter__(self):
if self._shuffle_dataset: if self.shuffle:
self.dataset.shuffle() self.dataset.shuffle()
# When using multiprocessing PyTorch sometimes set the number of PyTorch threads to 1 # When using multiprocessing PyTorch sometimes set the number of PyTorch threads to 1
# when spawning new Python threads. This drastically slows down pinning features. # when spawning new Python threads. This drastically slows down pinning features.
...@@ -551,30 +697,377 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -551,30 +697,377 @@ class DataLoader(torch.utils.data.DataLoader):
# Alias # Alias
class NodeDataLoader(DataLoader): class NodeDataLoader(DataLoader):
"""NodeDataLoader class.""" """PyTorch dataloader for batch-iterating over a set of nodes, generating the list
of message flow graphs (MFGs) as computation dependency of the said minibatch.
Parameters
----------
graph : DGLGraph
The graph.
indices : Tensor or dict[ntype, Tensor]
The node set to compute outputs.
graph_sampler : object
The neighborhood sampler. It could be any object that has a :attr:`sample`
method. The :attr:`sample` methods must take in a graph object and either a tensor
of node indices or a dict of such tensors.
device : device context, optional
The device of the generated MFGs in each iteration, which should be a
PyTorch device object (e.g., ``torch.device``).
By default this value is the same as the device of :attr:`g`.
use_ddp : boolean, optional
If True, tells the DataLoader to split the training set for each
participating process appropriately using
:class:`torch.utils.data.distributed.DistributedSampler`.
Overrides the :attr:`sampler` argument of :class:`torch.utils.data.DataLoader`.
ddp_seed : int, optional
The seed for shuffling the dataset in
:class:`torch.utils.data.distributed.DistributedSampler`.
Only effective when :attr:`use_ddp` is True.
use_uva : bool, optional
Whether to use Unified Virtual Addressing (UVA) to directly sample the graph
and slice the features from CPU into GPU. Setting it to True will pin the
graph and feature tensors into pinned memory.
Default: False.
use_prefetch_thread : bool, optional
(Advanced option)
Spawns a new Python thread to perform feature slicing
asynchronously. Can make things faster at the cost of GPU memory.
Default: True if the graph is on CPU and :attr:`device` is CUDA. False otherwise.
use_alternate_streams : bool, optional
(Advanced option)
Whether to slice and transfers the features to GPU on a non-default stream.
Default: True if the graph is on CPU, :attr:`device` is CUDA, and :attr:`use_uva`
is False. False otherwise.
pin_prefetcher : bool, optional
(Advanced option)
Whether to pin the feature tensors into pinned memory.
Default: True if the graph is on CPU and :attr:`device` is CUDA. False otherwise.
batch_size : int, optional
drop_last : bool, optional
shuffle : bool, optional
kwargs : dict
Arguments being passed to :py:class:`torch.utils.data.DataLoader`.
Examples
--------
To train a 3-layer GNN for node classification on a set of nodes ``train_nid`` on
a homogeneous graph where each node takes messages from all neighbors (assume
the backend is PyTorch):
>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> dataloader = dgl.dataloading.NodeDataLoader(
... g, train_nid, sampler,
... batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, output_nodes, blocks in dataloader:
... train_on(input_nodes, output_nodes, blocks)
**Using with Distributed Data Parallel**
If you are using PyTorch's distributed training (e.g. when using
:mod:`torch.nn.parallel.DistributedDataParallel`), you can train the model by turning
on the `use_ddp` option:
>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> dataloader = dgl.dataloading.NodeDataLoader(
... g, train_nid, sampler, use_ddp=True,
... batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for epoch in range(start_epoch, n_epochs):
... for input_nodes, output_nodes, blocks in dataloader:
... train_on(input_nodes, output_nodes, blocks)
Notes
-----
Please refer to
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`
and :ref:`User Guide Section 6 <guide-minibatch>` for usage.
**Tips for selecting the proper device**
* If the input graph :attr:`g` is on GPU, the output device :attr:`device` must be the same GPU
and :attr:`num_workers` must be zero. In this case, the sampling and subgraph construction
will take place on the GPU. This is the recommended setting when using a single-GPU and
the whole graph fits in GPU memory.
* If the input graph :attr:`g` is on CPU while the output device :attr:`device` is GPU, then
depending on the value of :attr:`use_uva`:
- If :attr:`use_uva` is set to True, the sampling and subgraph construction will happen
on GPU even if the GPU itself cannot hold the entire graph. This is the recommended
setting unless there are operations not supporting UVA. :attr:`num_workers` must be 0
in this case.
- Otherwise, both the sampling and subgraph construction will take place on the CPU.
"""
class EdgeDataLoader(DataLoader): class EdgeDataLoader(DataLoader):
"""EdgeDataLoader class.""" """PyTorch dataloader for batch-iterating over a set of edges, generating the list
of message flow graphs (MFGs) as computation dependency of the said minibatch for
edge classification, edge regression, and link prediction.
For each iteration, the object will yield
* A tensor of input nodes necessary for computing the representation on edges, or
a dictionary of node type names and such tensors.
* A subgraph that contains only the edges in the minibatch and their incident nodes.
Note that the graph has an identical metagraph with the original graph.
* If a negative sampler is given, another graph that contains the "negative edges",
connecting the source and destination nodes yielded from the given negative sampler.
* A list of MFGs necessary for computing the representation of the incident nodes
of the edges in the minibatch.
For more details, please refer to :ref:`guide-minibatch-edge-classification-sampler`
and :ref:`guide-minibatch-link-classification-sampler`.
Parameters
----------
g : DGLGraph
The graph.
indices : Tensor or dict[etype, Tensor]
The edge set in graph :attr:`g` to compute outputs.
graph_sampler : object
The neighborhood sampler. It could be any object that has a :attr:`sample`
method. The :attr:`sample` methods must take in a graph object and either a tensor
of node indices or a dict of such tensors.
device : device context, optional
The device of the generated MFGs and graphs in each iteration, which should be a
PyTorch device object (e.g., ``torch.device``).
By default this value is the same as the device of :attr:`g`.
use_ddp : boolean, optional
If True, tells the DataLoader to split the training set for each
participating process appropriately using
:class:`torch.utils.data.distributed.DistributedSampler`.
Overrides the :attr:`sampler` argument of :class:`torch.utils.data.DataLoader`.
ddp_seed : int, optional
The seed for shuffling the dataset in
:class:`torch.utils.data.distributed.DistributedSampler`.
Only effective when :attr:`use_ddp` is True.
use_prefetch_thread : bool, optional
(Advanced option)
Spawns a new Python thread to perform feature slicing
asynchronously. Can make things faster at the cost of GPU memory.
Default: True if the graph is on CPU and :attr:`device` is CUDA. False otherwise.
use_alternate_streams : bool, optional
(Advanced option)
Whether to slice and transfers the features to GPU on a non-default stream.
Default: True if the graph is on CPU, :attr:`device` is CUDA, and :attr:`use_uva`
is False. False otherwise.
pin_prefetcher : bool, optional
(Advanced option)
Whether to pin the feature tensors into pinned memory.
Default: True if the graph is on CPU and :attr:`device` is CUDA. False otherwise.
exclude : str, optional
Whether and how to exclude dependencies related to the sampled edges in the
minibatch. Possible values are
* None, for not excluding any edges.
* ``self``, for excluding only the edges sampled as seed edges in this minibatch.
* ``reverse_id``, for excluding not only the edges sampled in the minibatch but
also their reverse edges of the same edge type. Requires the argument
:attr:`reverse_eids`.
* ``reverse_types``, for excluding not only the edges sampled in the minibatch
but also their reverse edges of different types but with the same IDs.
Requires the argument :attr:`reverse_etypes`.
* A callable which takes in a tensor or a dictionary of tensors and their
canonical edge types and returns a tensor or dictionary of tensors to
exclude.
reverse_eids : Tensor or dict[etype, Tensor], optional
A tensor of reverse edge ID mapping. The i-th element indicates the ID of
the i-th edge's reverse edge.
If the graph is heterogeneous, this argument requires a dictionary of edge
types and the reverse edge ID mapping tensors.
See the description of the argument with the same name in the docstring of
:class:`~dgl.dataloading.EdgeCollator` for more details.
reverse_etypes : dict[etype, etype], optional
The mapping from the original edge types to their reverse edge types.
See the description of the argument with the same name in the docstring of
:class:`~dgl.dataloading.EdgeCollator` for more details.
negative_sampler : callable, optional
The negative sampler.
See the description of the argument with the same name in the docstring of
:class:`~dgl.dataloading.EdgeCollator` for more details.
use_uva : bool, optional
Whether to use Unified Virtual Addressing (UVA) to directly sample the graph
and slice the features from CPU into GPU. Setting it to True will pin the
graph and feature tensors into pinned memory.
Default: False.
batch_size : int, optional
drop_last : bool, optional
shuffle : bool, optional
kwargs : dict
Arguments being passed to :py:class:`torch.utils.data.DataLoader`.
Examples
--------
The following example shows how to train a 3-layer GNN for edge classification on a
set of edges ``train_eid`` on a homogeneous undirected graph. Each node takes
messages from all neighbors.
Say that you have an array of source node IDs ``src`` and another array of destination
node IDs ``dst``. One can make it bidirectional by adding another set of edges
that connects from ``dst`` to ``src``:
>>> g = dgl.graph((torch.cat([src, dst]), torch.cat([dst, src])))
One can then know that the ID difference of an edge and its reverse edge is ``|E|``,
where ``|E|`` is the length of your source/destination array. The reverse edge
mapping can be obtained by
>>> E = len(src)
>>> reverse_eids = torch.cat([torch.arange(E, 2 * E), torch.arange(0, E)])
Note that the sampled edges as well as their reverse edges are removed from
computation dependencies of the incident nodes. That is, the edge will not
involve in neighbor sampling and message aggregation. This is a common trick
to avoid information leakage.
>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> dataloader = dgl.dataloading.EdgeDataLoader(
... g, train_eid, sampler, exclude='reverse_id',
... reverse_eids=reverse_eids,
... batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, pair_graph, blocks in dataloader:
... train_on(input_nodes, pair_graph, blocks)
To train a 3-layer GNN for link prediction on a set of edges ``train_eid`` on a
homogeneous graph where each node takes messages from all neighbors (assume the
backend is PyTorch), with 5 uniformly chosen negative samples per edge:
>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5)
>>> dataloader = dgl.dataloading.EdgeDataLoader(
... g, train_eid, sampler, exclude='reverse_id',
... reverse_eids=reverse_eids, negative_sampler=neg_sampler,
... batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
... train_on(input_nodse, pair_graph, neg_pair_graph, blocks)
For heterogeneous graphs, the reverse of an edge may have a different edge type
from the original edge. For instance, consider that you have an array of
user-item clicks, representated by a user array ``user`` and an item array ``item``.
You may want to build a heterogeneous graph with a user-click-item relation and an
item-clicked-by-user relation.
>>> g = dgl.heterograph({
... ('user', 'click', 'item'): (user, item),
... ('item', 'clicked-by', 'user'): (item, user)})
To train a 3-layer GNN for edge classification on a set of edges ``train_eid`` with
type ``click``, you can write
>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> dataloader = dgl.dataloading.EdgeDataLoader(
... g, {'click': train_eid}, sampler, exclude='reverse_types',
... reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'},
... batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, pair_graph, blocks in dataloader:
... train_on(input_nodes, pair_graph, blocks)
To train a 3-layer GNN for link prediction on a set of edges ``train_eid`` with type
``click``, you can write
>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5)
>>> dataloader = dgl.dataloading.EdgeDataLoader(
... g, train_eid, sampler, exclude='reverse_types',
... reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'},
... negative_sampler=neg_sampler,
... batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
... train_on(input_nodes, pair_graph, neg_pair_graph, blocks)
**Using with Distributed Data Parallel**
If you are using PyTorch's distributed training (e.g. when using
:mod:`torch.nn.parallel.DistributedDataParallel`), you can train the model by
turning on the :attr:`use_ddp` option:
>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> dataloader = dgl.dataloading.EdgeDataLoader(
... g, train_eid, sampler, use_ddp=True, exclude='reverse_id',
... reverse_eids=reverse_eids,
... batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for epoch in range(start_epoch, n_epochs):
... for input_nodes, pair_graph, blocks in dataloader:
... train_on(input_nodes, pair_graph, blocks)
Notes
-----
Please refer to
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`
and :ref:`User Guide Section 6 <guide-minibatch>` for usage.
**Tips for selecting the proper device**
* If the input graph :attr:`g` is on GPU, the output device :attr:`device` must be the same GPU
and :attr:`num_workers` must be zero. In this case, the sampling and subgraph construction
will take place on the GPU. This is the recommended setting when using a single-GPU and
the whole graph fits in GPU memory.
* If the input graph :attr:`g` is on CPU while the output device :attr:`device` is GPU, then
depending on the value of :attr:`use_uva`:
- If :attr:`use_uva` is set to True, the sampling and subgraph construction will happen
on GPU even if the GPU itself cannot hold the entire graph. This is the recommended
setting unless there are operations not supporting UVA. :attr:`num_workers` must be 0
in this case.
- Otherwise, both the sampling and subgraph construction will take place on the CPU.
"""
def __init__(self, graph, indices, graph_sampler, device='cpu', use_ddp=False, def __init__(self, graph, indices, graph_sampler, device='cpu', use_ddp=False,
ddp_seed=0, batch_size=1, drop_last=False, shuffle=False, ddp_seed=0, batch_size=1, drop_last=False, shuffle=False,
use_prefetch_thread=False, use_alternate_streams=True, use_prefetch_thread=False, use_alternate_streams=True,
pin_prefetcher=False,
exclude=None, reverse_eids=None, reverse_etypes=None, negative_sampler=None, exclude=None, reverse_eids=None, reverse_etypes=None, negative_sampler=None,
g_sampling=None, **kwargs): use_uva=False, **kwargs):
if g_sampling is not None: device = _get_device(device)
dgl_warning(
"g_sampling is deprecated. "
"Please merge g_sampling and the original graph into one graph and use "
"the exclude argument to specify which edges you don't want to sample.")
if isinstance(graph_sampler, BlockSampler): if isinstance(graph_sampler, BlockSampler):
if reverse_eids is not None:
if use_uva:
reverse_eids = recursive_apply(reverse_eids, lambda x: x.to(device))
else:
reverse_eids_device = context_of(reverse_eids)
indices_device = context_of(indices)
if indices_device != reverse_eids_device:
raise ValueError('Expect the same device for indices and reverse_eids')
graph_sampler = EdgeBlockSampler( graph_sampler = EdgeBlockSampler(
graph_sampler, exclude=exclude, reverse_eids=reverse_eids, graph_sampler, exclude=exclude, reverse_eids=reverse_eids,
reverse_etypes=reverse_etypes, negative_sampler=negative_sampler) reverse_etypes=reverse_etypes, negative_sampler=negative_sampler,
prefetch_node_feats=graph_sampler.prefetch_node_feats,
prefetch_labels=graph_sampler.prefetch_labels,
prefetch_edge_feats=graph_sampler.prefetch_edge_feats)
super().__init__( super().__init__(
graph, indices, graph_sampler, device=device, use_ddp=use_ddp, ddp_seed=ddp_seed, graph, indices, graph_sampler, device=device, use_ddp=use_ddp, ddp_seed=ddp_seed,
batch_size=batch_size, drop_last=drop_last, shuffle=shuffle, batch_size=batch_size, drop_last=drop_last, shuffle=shuffle,
use_prefetch_thread=use_prefetch_thread, use_alternate_streams=use_alternate_streams, use_prefetch_thread=use_prefetch_thread, use_alternate_streams=use_alternate_streams,
pin_prefetcher=pin_prefetcher, use_uva=use_uva,
**kwargs) **kwargs)
......
...@@ -56,6 +56,12 @@ class LazyFeature(object): ...@@ -56,6 +56,12 @@ class LazyFeature(object):
"""No-op. For compatibility of :meth:`Frame.__repr__` method.""" """No-op. For compatibility of :meth:`Frame.__repr__` method."""
return self return self
def pin_memory_(self):
"""No-op. For compatibility of :meth:`Frame.pin_memory_` method."""
def unpin_memory_(self):
"""No-op. For compatibility of :meth:`Frame.unpin_memory_` method."""
class Scheme(namedtuple('Scheme', ['shape', 'dtype'])): class Scheme(namedtuple('Scheme', ['shape', 'dtype'])):
"""The column scheme. """The column scheme.
...@@ -142,6 +148,7 @@ class Column(TensorStorage): ...@@ -142,6 +148,7 @@ class Column(TensorStorage):
self.scheme = scheme if scheme else infer_scheme(storage) self.scheme = scheme if scheme else infer_scheme(storage)
self.index = index self.index = index
self.device = device self.device = device
self.pinned = False
def __len__(self): def __len__(self):
"""The number of features (number of rows) in this column.""" """The number of features (number of rows) in this column."""
...@@ -183,6 +190,7 @@ class Column(TensorStorage): ...@@ -183,6 +190,7 @@ class Column(TensorStorage):
"""Update the column data.""" """Update the column data."""
self.index = None self.index = None
self.storage = val self.storage = val
self.pinned = False
def to(self, device, **kwargs): # pylint: disable=invalid-name def to(self, device, **kwargs): # pylint: disable=invalid-name
""" Return a new column with columns copy to the targeted device (cpu/gpu). """ Return a new column with columns copy to the targeted device (cpu/gpu).
...@@ -330,6 +338,10 @@ class Column(TensorStorage): ...@@ -330,6 +338,10 @@ class Column(TensorStorage):
def __copy__(self): def __copy__(self):
return self.clone() return self.clone()
def fetch(self, indices, device, pin_memory=False):
_ = self.data # materialize in case of lazy slicing & data transfer
return super().fetch(indices, device, pin_memory=False)
class Frame(MutableMapping): class Frame(MutableMapping):
"""The columnar storage for node/edge features. """The columnar storage for node/edge features.
...@@ -702,3 +714,15 @@ class Frame(MutableMapping): ...@@ -702,3 +714,15 @@ class Frame(MutableMapping):
def __repr__(self): def __repr__(self):
return repr(dict(self)) return repr(dict(self))
def pin_memory_(self):
"""Registers the data of every column into pinned memory, materializing them if
necessary."""
for column in self._columns.values():
column.pin_memory_()
def unpin_memory_(self):
"""Unregisters the data of every column from pinned memory, materializing them
if necessary."""
for column in self._columns.values():
column.unpin_memory_()
...@@ -5474,7 +5474,7 @@ class DGLHeteroGraph(object): ...@@ -5474,7 +5474,7 @@ class DGLHeteroGraph(object):
Materialization of new sparse formats for pinned graphs is not allowed. Materialization of new sparse formats for pinned graphs is not allowed.
To avoid implicit formats materialization during training, To avoid implicit formats materialization during training,
you should create all the needed formats before pinnning. you should create all the needed formats before pinning.
But cloning and materialization is fine. See the examples below. But cloning and materialization is fine. See the examples below.
Returns Returns
...@@ -5530,6 +5530,7 @@ class DGLHeteroGraph(object): ...@@ -5530,6 +5530,7 @@ class DGLHeteroGraph(object):
if F.device_type(self.device) != 'cpu': if F.device_type(self.device) != 'cpu':
raise DGLError("The graph structure must be on CPU to be pinned.") raise DGLError("The graph structure must be on CPU to be pinned.")
self._graph.pin_memory_() self._graph.pin_memory_()
return self return self
def unpin_memory_(self): def unpin_memory_(self):
...@@ -5546,6 +5547,7 @@ class DGLHeteroGraph(object): ...@@ -5546,6 +5547,7 @@ class DGLHeteroGraph(object):
if not self._graph.is_pinned(): if not self._graph.is_pinned():
return self return self
self._graph.unpin_memory_() self._graph.unpin_memory_()
return self return self
def is_pinned(self): def is_pinned(self):
......
"""Module for heterogeneous graph index class definition.""" """Module for heterogeneous graph index class definition."""
from __future__ import absolute_import from __future__ import absolute_import
import sys
import itertools import itertools
import numpy as np import numpy as np
import scipy import scipy
...@@ -1365,4 +1366,27 @@ class HeteroPickleStates(ObjectBase): ...@@ -1365,4 +1366,27 @@ class HeteroPickleStates(ObjectBase):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_CAPI_DGLCreateHeteroPickleStatesOld, metagraph, num_nodes_per_type, adjs) _CAPI_DGLCreateHeteroPickleStatesOld, metagraph, num_nodes_per_type, adjs)
def _forking_rebuild(pk_state):
meta, arrays = pk_state
arrays = [F.to_dgl_nd(arr) for arr in arrays]
states = _CAPI_DGLCreateHeteroPickleStates(meta, arrays)
return _CAPI_DGLHeteroForkingUnpickle(states)
def _forking_reduce(graph_index):
states = _CAPI_DGLHeteroForkingPickle(graph_index)
arrays = [F.from_dgl_nd(arr) for arr in states.arrays]
# Similar to what being mentioned in HeteroGraphIndex.__getstate__, we need to save
# the tensors as an attribute of the original graph index object. Otherwise
# PyTorch will throw weird errors like bad value(s) in fds_to_keep or unable to
# resize file.
graph_index._forking_pk_state = (states.meta, arrays)
return _forking_rebuild, (graph_index._forking_pk_state,)
if not (F.get_preferred_backend() == 'mxnet' and sys.version_info.minor <= 6):
# Python 3.6 MXNet crashes with the following statement; remove until we no longer support
# 3.6 (which is EOL anyway).
from multiprocessing.reduction import ForkingPickler
ForkingPickler.register(HeteroGraphIndex, _forking_reduce)
_init_api("dgl.heterograph_index") _init_api("dgl.heterograph_index")
...@@ -222,6 +222,7 @@ class SparseMatrix(ObjectBase): ...@@ -222,6 +222,7 @@ class SparseMatrix(ObjectBase):
_set_class_ndarray(NDArray) _set_class_ndarray(NDArray)
_init_api("dgl.ndarray") _init_api("dgl.ndarray")
_init_api("dgl.ndarray.uvm", __name__)
# An array representing null (no value) that can be safely converted to # An array representing null (no value) that can be safely converted to
# other backend tensors. # other backend tensors.
......
...@@ -3,7 +3,8 @@ from .. import backend as F ...@@ -3,7 +3,8 @@ from .. import backend as F
from .base import * from .base import *
from .numpy import * from .numpy import *
# Defines the name TensorStorage
if F.get_preferred_backend() == 'pytorch': if F.get_preferred_backend() == 'pytorch':
from .pytorch_tensor import * from .pytorch_tensor import PyTorchTensorStorage as TensorStorage
else: else:
from .tensor import * from .tensor import BaseTensorStorage as TensorStorage
"""Feature storages for PyTorch tensors.""" """Feature storages for PyTorch tensors."""
import torch import torch
from .base import FeatureStorage, register_storage_wrapper from .base import register_storage_wrapper
from .tensor import BaseTensorStorage
from ..utils import gather_pinned_tensor_rows
def _fetch_cpu(indices, tensor, feature_shape, device, pin_memory): def _fetch_cpu(indices, tensor, feature_shape, device, pin_memory):
result = torch.empty( result = torch.empty(
...@@ -15,18 +17,26 @@ def _fetch_cuda(indices, tensor, device): ...@@ -15,18 +17,26 @@ def _fetch_cuda(indices, tensor, device):
return torch.index_select(tensor, 0, indices).to(device) return torch.index_select(tensor, 0, indices).to(device)
@register_storage_wrapper(torch.Tensor) @register_storage_wrapper(torch.Tensor)
class TensorStorage(FeatureStorage): class PyTorchTensorStorage(BaseTensorStorage):
"""Feature storages for slicing a PyTorch tensor.""" """Feature storages for slicing a PyTorch tensor."""
def __init__(self, tensor):
self.storage = tensor
self.feature_shape = tensor.shape[1:]
self.is_cuda = (tensor.device.type == 'cuda')
def fetch(self, indices, device, pin_memory=False): def fetch(self, indices, device, pin_memory=False):
device = torch.device(device) device = torch.device(device)
if not self.is_cuda: storage_device_type = self.storage.device.type
indices_device_type = indices.device.type
if storage_device_type != 'cuda':
if indices_device_type == 'cuda':
if self.storage.is_pinned():
return gather_pinned_tensor_rows(self.storage, indices)
else:
raise ValueError(
f'Got indices on device {indices.device} whereas the feature tensor '
f'is on {self.storage.device}. Please either (1) move the graph '
f'to GPU with to() method, or (2) pin the graph with '
f'pin_memory_() method.')
# CPU to CPU or CUDA - use pin_memory and async transfer if possible # CPU to CPU or CUDA - use pin_memory and async transfer if possible
return _fetch_cpu(indices, self.storage, self.feature_shape, device, pin_memory) else:
return _fetch_cpu(indices, self.storage, self.storage.shape[1:], device,
pin_memory)
else: else:
# CUDA to CUDA or CPU # CUDA to CUDA or CPU
return _fetch_cuda(indices, self.storage, device) return _fetch_cuda(indices, self.storage, device)
"""Feature storages for tensors across different frameworks.""" """Feature storages for tensors across different frameworks."""
from .base import FeatureStorage from .base import FeatureStorage
from .. import backend as F from .. import backend as F
from ..utils import recursive_apply_pair
def _fetch(indices, tensor, device): class BaseTensorStorage(FeatureStorage):
return F.copy_to(F.gather_row(tensor, indices), device)
class TensorStorage(FeatureStorage):
"""FeatureStorage that synchronously slices features from a tensor and transfers """FeatureStorage that synchronously slices features from a tensor and transfers
it to the given device. it to the given device.
""" """
...@@ -14,4 +10,4 @@ class TensorStorage(FeatureStorage): ...@@ -14,4 +10,4 @@ class TensorStorage(FeatureStorage):
self.storage = tensor self.storage = tensor
def fetch(self, indices, device, pin_memory=False): # pylint: disable=unused-argument def fetch(self, indices, device, pin_memory=False): # pylint: disable=unused-argument
return recursive_apply_pair(indices, self.storage, _fetch, device) return F.copy_to(F.gather_row(tensor, indices), device)
...@@ -5,3 +5,4 @@ from .checks import * ...@@ -5,3 +5,4 @@ from .checks import *
from .shared_mem import * from .shared_mem import *
from .filter import * from .filter import *
from .exception import * from .exception import *
from .pin_memory import *
...@@ -937,4 +937,8 @@ def recursive_apply_pair(data1, data2, fn, *args, **kwargs): ...@@ -937,4 +937,8 @@ def recursive_apply_pair(data1, data2, fn, *args, **kwargs):
else: else:
return fn(data1, data2, *args, **kwargs) return fn(data1, data2, *args, **kwargs)
def context_of(data):
"""Return the device of the data which can be either a tensor or a dict of tensors."""
return F.context(next(iter(data.values())) if isinstance(data, Mapping) else data)
_init_api("dgl.utils.internal") _init_api("dgl.utils.internal")
"""Utility functions related to pinned memory tensors."""
from .. import backend as F
from .._ffi.function import _init_api
def pin_memory_inplace(tensor):
"""Register the tensor into pinned memory in-place (i.e. without copying)."""
F.to_dgl_nd(tensor).pin_memory_()
def unpin_memory_inplace(tensor):
"""Unregister the tensor from pinned memory in-place (i.e. without copying)."""
F.to_dgl_nd(tensor).unpin_memory_()
def gather_pinned_tensor_rows(tensor, rows):
"""Directly gather rows from a CPU tensor given an indices array on CUDA devices,
and returns the result on the same CUDA device without copying.
Parameters
----------
tensor : Tensor
The tensor. Must be in pinned memory.
rows : Tensor
The rows to gather. Must be a CUDA tensor.
Returns
-------
Tensor
The result with the same device as :attr:`rows`.
"""
return F.from_dgl_nd(_CAPI_DGLIndexSelectCPUFromGPU(F.to_dgl_nd(tensor), F.to_dgl_nd(rows)))
_init_api("dgl.ndarray.uvm", __name__)
...@@ -27,7 +27,7 @@ NDArray IndexSelect(NDArray array, IdArray index) { ...@@ -27,7 +27,7 @@ NDArray IndexSelect(NDArray array, IdArray index) {
shape.emplace_back(array->shape[d]); shape.emplace_back(array->shape[d]);
} }
// use index->ctx for kDLCPUPinned array // use index->ctx for pinned array
NDArray ret = NDArray::Empty(shape, array->dtype, index->ctx); NDArray ret = NDArray::Empty(shape, array->dtype, index->ctx);
if (len == 0) if (len == 0)
return ret; return ret;
......
...@@ -24,7 +24,7 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) { ...@@ -24,7 +24,7 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
int64_t num_feat = 1; int64_t num_feat = 1;
std::vector<int64_t> shape{len}; std::vector<int64_t> shape{len};
CHECK_EQ(array->ctx.device_type, kDLCPUPinned); CHECK(array.IsPinned());
CHECK_EQ(index->ctx.device_type, kDLGPU); CHECK_EQ(index->ctx.device_type, kDLGPU);
for (int d = 1; d < array->ndim; ++d) { for (int d = 1; d < array->ndim; ++d) {
...@@ -72,6 +72,8 @@ template NDArray IndexSelectCPUFromGPU<int64_t, int32_t>(NDArray, IdArray); ...@@ -72,6 +72,8 @@ template NDArray IndexSelectCPUFromGPU<int64_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<int64_t, int64_t>(NDArray, IdArray); template NDArray IndexSelectCPUFromGPU<int64_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<float, int32_t>(NDArray, IdArray); template NDArray IndexSelectCPUFromGPU<float, int32_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<float, int64_t>(NDArray, IdArray); template NDArray IndexSelectCPUFromGPU<float, int64_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<double, int32_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<double, int64_t>(NDArray, IdArray);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -15,7 +15,7 @@ namespace aten { ...@@ -15,7 +15,7 @@ namespace aten {
NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) { NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
CHECK_EQ(array->ctx.device_type, kDLCPUPinned) CHECK(array.IsPinned())
<< "Only the CPUPinned device type input array is supported"; << "Only the CPUPinned device type input array is supported";
CHECK_EQ(index->ctx.device_type, kDLGPU) CHECK_EQ(index->ctx.device_type, kDLGPU)
<< "Only the GPU device type input index is supported"; << "Only the GPU device type input index is supported";
......
...@@ -83,12 +83,6 @@ BcastOff CalcBcastOff(const std::string& op, NDArray lhs, NDArray rhs) { ...@@ -83,12 +83,6 @@ BcastOff CalcBcastOff(const std::string& op, NDArray lhs, NDArray rhs) {
rst.out_len /= rst.reduce_size; // out_len is divied by reduce_size in dot. rst.out_len /= rst.reduce_size; // out_len is divied by reduce_size in dot.
} }
} }
#ifdef DEBUG
LOG(INFO) << "lhs_len: " << rst.lhs_len << " " <<
"rhs_len: " << rst.rhs_len << " " <<
"out_len: " << rst.out_len << " " <<
"reduce_size: " << rst.reduce_size << std::endl;
#endif
return rst; return rst;
} }
......
...@@ -236,7 +236,7 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -236,7 +236,7 @@ class HeteroGraph : public BaseHeteroGraph {
* \brief Pin all relation graphs of the current graph. * \brief Pin all relation graphs of the current graph.
* \note The graph will be pinned inplace. Behavior depends on the current context, * \note The graph will be pinned inplace. Behavior depends on the current context,
* kDLCPU: will be pinned; * kDLCPU: will be pinned;
* kDLCPUPinned: directly return; * IsPinned: directly return;
* kDLGPU: invalid, will throw an error. * kDLGPU: invalid, will throw an error.
* The context check is deferred to pinning the NDArray. * The context check is deferred to pinning the NDArray.
*/ */
...@@ -245,7 +245,7 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -245,7 +245,7 @@ class HeteroGraph : public BaseHeteroGraph {
/*! /*!
* \brief Unpin all relation graphs of the current graph. * \brief Unpin all relation graphs of the current graph.
* \note The graph will be unpinned inplace. Behavior depends on the current context, * \note The graph will be unpinned inplace. Behavior depends on the current context,
* kDLCPUPinned: will be unpinned; * IsPinned: will be unpinned;
* others: directly return. * others: directly return.
* The context check is deferred to unpinning the NDArray. * The context check is deferred to unpinning the NDArray.
*/ */
...@@ -272,6 +272,18 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -272,6 +272,18 @@ class HeteroGraph : public BaseHeteroGraph {
return relation_graphs_; return relation_graphs_;
} }
void SetCOOMatrix(dgl_type_t etype, aten::COOMatrix coo) override {
GetRelationGraph(etype)->SetCOOMatrix(0, coo);
}
void SetCSRMatrix(dgl_type_t etype, aten::CSRMatrix csr) override {
GetRelationGraph(etype)->SetCSRMatrix(0, csr);
}
void SetCSCMatrix(dgl_type_t etype, aten::CSRMatrix csc) override {
GetRelationGraph(etype)->SetCSCMatrix(0, csc);
}
private: private:
// To create empty class // To create empty class
friend class Serializer; friend class Serializer;
......
...@@ -173,13 +173,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDataType") ...@@ -173,13 +173,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDataType")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroContext") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroContext")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
// The Python side only recognizes CPU and GPU device type. *rv = hg->Context();
// Use is_pinned() to checked whether the object is
// on page-locked memory
if (hg->Context().device_type == kDLCPUPinned)
*rv = DLContext{kDLCPU, 0};
else
*rv = hg->Context();
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroIsPinned") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroIsPinned")
......
...@@ -51,6 +51,42 @@ HeteroPickleStates HeteroPickle(HeteroGraphPtr graph) { ...@@ -51,6 +51,42 @@ HeteroPickleStates HeteroPickle(HeteroGraphPtr graph) {
return states; return states;
} }
HeteroPickleStates HeteroForkingPickle(HeteroGraphPtr graph) {
HeteroPickleStates states;
dmlc::MemoryStringStream ofs(&states.meta);
dmlc::Stream *strm = &ofs;
strm->Write(ImmutableGraph::ToImmutable(graph->meta_graph()));
strm->Write(graph->NumVerticesPerType());
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
auto created_formats = graph->GetCreatedFormats();
auto allowed_formats = graph->GetAllowedFormats();
strm->Write(created_formats);
strm->Write(allowed_formats);
if (created_formats & COO_CODE) {
const auto &coo = graph->GetCOOMatrix(etype);
strm->Write(coo.row_sorted);
strm->Write(coo.col_sorted);
states.arrays.push_back(coo.row);
states.arrays.push_back(coo.col);
}
if (created_formats & CSR_CODE) {
const auto &csr = graph->GetCSRMatrix(etype);
strm->Write(csr.sorted);
states.arrays.push_back(csr.indptr);
states.arrays.push_back(csr.indices);
states.arrays.push_back(csr.data);
}
if (created_formats & CSC_CODE) {
const auto &csc = graph->GetCSCMatrix(etype);
strm->Write(csc.sorted);
states.arrays.push_back(csc.indptr);
states.arrays.push_back(csc.indices);
states.arrays.push_back(csc.data);
}
}
return states;
}
HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) { HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) {
char *buf = const_cast<char *>(states.meta.c_str()); // a readonly stream? char *buf = const_cast<char *>(states.meta.c_str()); // a readonly stream?
dmlc::MemoryFixedSizeStream ifs(buf, states.meta.size()); dmlc::MemoryFixedSizeStream ifs(buf, states.meta.size());
...@@ -137,6 +173,76 @@ HeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates& states) { ...@@ -137,6 +173,76 @@ HeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates& states) {
return CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type); return CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type);
} }
HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates &states) {
char *buf = const_cast<char *>(states.meta.c_str()); // a readonly stream?
dmlc::MemoryFixedSizeStream ifs(buf, states.meta.size());
dmlc::Stream *strm = &ifs;
auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
CHECK(strm->Read(&meta_imgraph)) << "Invalid meta graph";
GraphPtr metagraph = meta_imgraph;
std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());
std::vector<int64_t> num_nodes_per_type;
CHECK(strm->Read(&num_nodes_per_type)) << "Invalid num_nodes_per_type";
auto array_itr = states.arrays.begin();
for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {
const auto& pair = metagraph->FindEdge(etype);
const dgl_type_t srctype = pair.first;
const dgl_type_t dsttype = pair.second;
const int64_t num_vtypes = (srctype == dsttype) ? 1 : 2;
int64_t num_src = num_nodes_per_type[srctype];
int64_t num_dst = num_nodes_per_type[dsttype];
dgl_format_code_t created_formats, allowed_formats;
CHECK(strm->Read(&created_formats)) << "Invalid code for created formats";
CHECK(strm->Read(&allowed_formats)) << "Invalid code for allowed formats";
HeteroGraphPtr relgraph = nullptr;
if (created_formats & COO_CODE) {
CHECK_GE(states.arrays.end() - array_itr, 2);
const auto &row = *(array_itr++);
const auto &col = *(array_itr++);
bool rsorted;
bool csorted;
CHECK(strm->Read(&rsorted)) << "Invalid flag 'rsorted'";
CHECK(strm->Read(&csorted)) << "Invalid flag 'csorted'";
auto coo = aten::COOMatrix(num_src, num_dst, row, col, aten::NullArray(), rsorted, csorted);
if (!relgraph)
relgraph = CreateFromCOO(num_vtypes, coo, allowed_formats);
else
relgraph->SetCOOMatrix(0, coo);
}
if (created_formats & CSR_CODE) {
CHECK_GE(states.arrays.end() - array_itr, 3);
const auto &indptr = *(array_itr++);
const auto &indices = *(array_itr++);
const auto &edge_id = *(array_itr++);
bool sorted;
CHECK(strm->Read(&sorted)) << "Invalid flag 'sorted'";
auto csr = aten::CSRMatrix(num_src, num_dst, indptr, indices, edge_id, sorted);
if (!relgraph)
relgraph = CreateFromCSR(num_vtypes, csr, allowed_formats);
else
relgraph->SetCSRMatrix(0, csr);
}
if (created_formats & CSC_CODE) {
CHECK_GE(states.arrays.end() - array_itr, 3);
const auto &indptr = *(array_itr++);
const auto &indices = *(array_itr++);
const auto &edge_id = *(array_itr++);
bool sorted;
CHECK(strm->Read(&sorted)) << "Invalid flag 'sorted'";
auto csc = aten::CSRMatrix(num_dst, num_src, indptr, indices, edge_id, sorted);
if (!relgraph)
relgraph = CreateFromCSC(num_vtypes, csc, allowed_formats);
else
relgraph->SetCSCMatrix(0, csc);
}
relgraphs[etype] = relgraph;
}
return CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type);
}
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetVersion") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetVersion")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroPickleStatesRef st = args[0]; HeteroPickleStatesRef st = args[0];
...@@ -186,6 +292,14 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickle") ...@@ -186,6 +292,14 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickle")
*rv = HeteroPickleStatesRef(st); *rv = HeteroPickleStatesRef(st);
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroForkingPickle")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef ref = args[0];
std::shared_ptr<HeteroPickleStates> st( new HeteroPickleStates );
*st = HeteroForkingPickle(ref.sptr());
*rv = HeteroPickleStatesRef(st);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpickle") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpickle")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroPickleStatesRef ref = args[0]; HeteroPickleStatesRef ref = args[0];
...@@ -203,6 +317,13 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpickle") ...@@ -203,6 +317,13 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpickle")
*rv = HeteroGraphRef(graph); *rv = HeteroGraphRef(graph);
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroForkingUnpickle")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroPickleStatesRef ref = args[0];
HeteroGraphPtr graph = HeteroForkingUnpickle(*ref.sptr());
*rv = HeteroGraphRef(graph);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLCreateHeteroPickleStatesOld") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLCreateHeteroPickleStatesOld")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef metagraph = args[0]; GraphRef metagraph = args[0];
......
...@@ -31,7 +31,7 @@ HeteroSubgraph ExcludeCertainEdges( ...@@ -31,7 +31,7 @@ HeteroSubgraph ExcludeCertainEdges(
sg.induced_edges[etype]->shape[0], sg.induced_edges[etype]->shape[0],
sg.induced_edges[etype]->dtype.bits, sg.induced_edges[etype]->dtype.bits,
sg.induced_edges[etype]->ctx); sg.induced_edges[etype]->ctx);
if (exclude_edges[etype].GetSize() == 0) { if (exclude_edges[etype].GetSize() == 0 || edge_ids.GetSize() == 0) {
remain_edges[etype] = edge_ids; remain_edges[etype] = edge_ids;
remain_induced_edges[etype] = sg.induced_edges[etype]; remain_induced_edges[etype] = sg.induced_edges[etype];
continue; continue;
......
...@@ -4,12 +4,12 @@ ...@@ -4,12 +4,12 @@
* \brief frequency hashmap - used to select top-k frequency edges of each node * \brief frequency hashmap - used to select top-k frequency edges of each node
*/ */
#include <cub/cub.cuh>
#include <algorithm> #include <algorithm>
#include <tuple> #include <tuple>
#include <utility> #include <utility>
#include "../../../runtime/cuda/cuda_common.h" #include "../../../runtime/cuda/cuda_common.h"
#include "../../../array/cuda/atomic.cuh" #include "../../../array/cuda/atomic.cuh"
#include "../../../array/cuda/dgl_cub.cuh"
#include "frequency_hashmap.cuh" #include "frequency_hashmap.cuh"
namespace dgl { namespace dgl {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment