Unverified Commit 6e58f5f1 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Auto format with manual fix of 3 files and add pylint: disable=...


[Misc] Auto format with manual fix of 3 files and add pylint: disable= too-many-lines for functional.py. (#5330)

* blabal

* 2more

---------
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent f17c0d62
"""DGL PyTorch DataLoaders""" """DGL PyTorch DataLoaders"""
from collections.abc import Mapping, Sequence import atexit
from queue import Queue, Empty, Full import inspect
import itertools import itertools
import threading
import math import math
import inspect
import re
import atexit
import os import os
import re
import threading
from collections.abc import Mapping, Sequence
from contextlib import contextmanager from contextlib import contextmanager
import psutil from queue import Empty, Full, Queue
import numpy as np import numpy as np
import psutil
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from ..base import NID, EID, dgl_warning, DGLError
from ..batch import batch as batch_graphs
from .._ffi.base import is_tensor_adaptor_enabled
from ..heterograph import DGLGraph
from ..utils import (
recursive_apply, ExceptionWrapper, recursive_apply_pair, set_num_threads, get_num_threads,
get_numa_nodes_cores, dtype_of, version)
from ..frame import LazyFeature
from ..storages import wrap_storage
from .. import backend as F from .. import backend as F
from .._ffi.base import is_tensor_adaptor_enabled
from ..base import dgl_warning, DGLError, EID, NID
from ..batch import batch as batch_graphs
from ..distributed import DistGraph from ..distributed import DistGraph
from ..frame import LazyFeature
from ..heterograph import DGLGraph
from ..multiprocessing import call_once_and_share from ..multiprocessing import call_once_and_share
from ..storages import wrap_storage
from ..utils import (
dtype_of,
ExceptionWrapper,
get_num_threads,
get_numa_nodes_cores,
recursive_apply,
recursive_apply_pair,
set_num_threads,
version,
)
PYTORCH_VER = version.parse(torch.__version__) PYTORCH_VER = version.parse(torch.__version__)
PYTHON_EXIT_STATUS = False PYTHON_EXIT_STATUS = False
def _set_python_exit_flag(): def _set_python_exit_flag():
global PYTHON_EXIT_STATUS global PYTHON_EXIT_STATUS
PYTHON_EXIT_STATUS = True PYTHON_EXIT_STATUS = True
atexit.register(_set_python_exit_flag) atexit.register(_set_python_exit_flag)
prefetcher_timeout = int(os.environ.get('DGL_PREFETCHER_TIMEOUT', '30')) prefetcher_timeout = int(os.environ.get("DGL_PREFETCHER_TIMEOUT", "30"))
class _TensorizedDatasetIter(object): class _TensorizedDatasetIter(object):
def __init__(self, dataset, batch_size, drop_last, mapping_keys, shuffle): def __init__(self, dataset, batch_size, drop_last, mapping_keys, shuffle):
...@@ -60,7 +73,7 @@ class _TensorizedDatasetIter(object): ...@@ -60,7 +73,7 @@ class _TensorizedDatasetIter(object):
if self.drop_last: if self.drop_last:
raise StopIteration raise StopIteration
end_idx = num_items end_idx = num_items
batch = self.dataset[self.index:end_idx] batch = self.dataset[self.index : end_idx]
self.index += self.batch_size self.index += self.batch_size
return batch return batch
...@@ -79,28 +92,34 @@ class _TensorizedDatasetIter(object): ...@@ -79,28 +92,34 @@ class _TensorizedDatasetIter(object):
else: else:
if not self.shuffle: if not self.shuffle:
dgl_warning( dgl_warning(
'The current output_nodes are out of order even if set shuffle ' "The current output_nodes are out of order even if set shuffle "
'to False in Dataloader, the reason is that the current version ' "to False in Dataloader, the reason is that the current version "
'of torch dose not support stable sort. ' "of torch dose not support stable sort. "
'Please update torch to 1.10.0 or higher to fix it.') "Please update torch to 1.10.0 or higher to fix it."
)
type_ids_sortidx = torch.argsort(type_ids) type_ids_sortidx = torch.argsort(type_ids)
type_ids = type_ids[type_ids_sortidx] type_ids = type_ids[type_ids_sortidx]
indices = indices[type_ids_sortidx] indices = indices[type_ids_sortidx]
type_id_uniq, type_id_count = torch.unique_consecutive(type_ids, return_counts=True) type_id_uniq, type_id_count = torch.unique_consecutive(
type_ids, return_counts=True
)
type_id_uniq = type_id_uniq.tolist() type_id_uniq = type_id_uniq.tolist()
type_id_offset = type_id_count.cumsum(0).tolist() type_id_offset = type_id_count.cumsum(0).tolist()
type_id_offset.insert(0, 0) type_id_offset.insert(0, 0)
id_dict = { id_dict = {
self.mapping_keys[type_id_uniq[i]]: self.mapping_keys[type_id_uniq[i]]: indices[
indices[type_id_offset[i]:type_id_offset[i+1]].clone() type_id_offset[i] : type_id_offset[i + 1]
for i in range(len(type_id_uniq))} ].clone()
for i in range(len(type_id_uniq))
}
return id_dict return id_dict
def _get_id_tensor_from_mapping(indices, device, keys): def _get_id_tensor_from_mapping(indices, device, keys):
dtype = dtype_of(indices) dtype = dtype_of(indices)
id_tensor = torch.empty( id_tensor = torch.empty(
sum(v.shape[0] for v in indices.values()), 2, dtype=dtype, device=device) sum(v.shape[0] for v in indices.values()), 2, dtype=dtype, device=device
)
offset = 0 offset = 0
for i, k in enumerate(keys): for i, k in enumerate(keys):
...@@ -108,8 +127,8 @@ def _get_id_tensor_from_mapping(indices, device, keys): ...@@ -108,8 +127,8 @@ def _get_id_tensor_from_mapping(indices, device, keys):
continue continue
index = indices[k] index = indices[k]
length = index.shape[0] length = index.shape[0]
id_tensor[offset:offset+length, 0] = i id_tensor[offset : offset + length, 0] = i
id_tensor[offset:offset+length, 1] = index id_tensor[offset : offset + length, 1] = index
offset += length offset += length
return id_tensor return id_tensor
...@@ -118,10 +137,14 @@ def _divide_by_worker(dataset, batch_size, drop_last): ...@@ -118,10 +137,14 @@ def _divide_by_worker(dataset, batch_size, drop_last):
num_samples = dataset.shape[0] num_samples = dataset.shape[0]
worker_info = torch.utils.data.get_worker_info() worker_info = torch.utils.data.get_worker_info()
if worker_info: if worker_info:
num_batches = (num_samples + (0 if drop_last else batch_size - 1)) // batch_size num_batches = (
num_samples + (0 if drop_last else batch_size - 1)
) // batch_size
num_batches_per_worker = num_batches // worker_info.num_workers num_batches_per_worker = num_batches // worker_info.num_workers
left_over = num_batches % worker_info.num_workers left_over = num_batches % worker_info.num_workers
start = (num_batches_per_worker * worker_info.id) + min(left_over, worker_info.id) start = (num_batches_per_worker * worker_info.id) + min(
left_over, worker_info.id
)
end = start + num_batches_per_worker + (worker_info.id < left_over) end = start + num_batches_per_worker + (worker_info.id < left_over)
start *= batch_size start *= batch_size
end = min(end * batch_size, num_samples) end = min(end * batch_size, num_samples)
...@@ -133,12 +156,16 @@ class TensorizedDataset(torch.utils.data.IterableDataset): ...@@ -133,12 +156,16 @@ class TensorizedDataset(torch.utils.data.IterableDataset):
"""Custom Dataset wrapper that returns a minibatch as tensors or dicts of tensors. """Custom Dataset wrapper that returns a minibatch as tensors or dicts of tensors.
When the dataset is on the GPU, this significantly reduces the overhead. When the dataset is on the GPU, this significantly reduces the overhead.
""" """
def __init__(self, indices, batch_size, drop_last, shuffle, use_shared_memory):
def __init__(
self, indices, batch_size, drop_last, shuffle, use_shared_memory
):
if isinstance(indices, Mapping): if isinstance(indices, Mapping):
self._mapping_keys = list(indices.keys()) self._mapping_keys = list(indices.keys())
self._device = next(iter(indices.values())).device self._device = next(iter(indices.values())).device
self._id_tensor = _get_id_tensor_from_mapping( self._id_tensor = _get_id_tensor_from_mapping(
indices, self._device, self._mapping_keys) indices, self._device, self._mapping_keys
)
else: else:
self._id_tensor = indices self._id_tensor = indices
self._device = indices.device self._device = indices.device
...@@ -146,7 +173,9 @@ class TensorizedDataset(torch.utils.data.IterableDataset): ...@@ -146,7 +173,9 @@ class TensorizedDataset(torch.utils.data.IterableDataset):
# Use a shared memory array to permute indices for shuffling. This is to make sure that # Use a shared memory array to permute indices for shuffling. This is to make sure that
# the worker processes can see it when persistent_workers=True, where self._indices # the worker processes can see it when persistent_workers=True, where self._indices
# would not be duplicated every epoch. # would not be duplicated every epoch.
self._indices = torch.arange(self._id_tensor.shape[0], dtype=torch.int64) self._indices = torch.arange(
self._id_tensor.shape[0], dtype=torch.int64
)
if use_shared_memory: if use_shared_memory:
self._indices.share_memory_() self._indices.share_memory_()
self.batch_size = batch_size self.batch_size = batch_size
...@@ -158,14 +187,24 @@ class TensorizedDataset(torch.utils.data.IterableDataset): ...@@ -158,14 +187,24 @@ class TensorizedDataset(torch.utils.data.IterableDataset):
np.random.shuffle(self._indices.numpy()) np.random.shuffle(self._indices.numpy())
def __iter__(self): def __iter__(self):
indices = _divide_by_worker(self._indices, self.batch_size, self.drop_last) indices = _divide_by_worker(
self._indices, self.batch_size, self.drop_last
)
id_tensor = self._id_tensor[indices] id_tensor = self._id_tensor[indices]
return _TensorizedDatasetIter( return _TensorizedDatasetIter(
id_tensor, self.batch_size, self.drop_last, self._mapping_keys, self._shuffle) id_tensor,
self.batch_size,
self.drop_last,
self._mapping_keys,
self._shuffle,
)
def __len__(self): def __len__(self):
num_samples = self._id_tensor.shape[0] num_samples = self._id_tensor.shape[0]
return (num_samples + (0 if self.drop_last else (self.batch_size - 1))) // self.batch_size return (
num_samples + (0 if self.drop_last else (self.batch_size - 1))
) // self.batch_size
class DDPTensorizedDataset(torch.utils.data.IterableDataset): class DDPTensorizedDataset(torch.utils.data.IterableDataset):
"""Custom Dataset wrapper that returns a minibatch as tensors or dicts of tensors. """Custom Dataset wrapper that returns a minibatch as tensors or dicts of tensors.
...@@ -174,6 +213,7 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset): ...@@ -174,6 +213,7 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset):
This class additionally saves the index tensor in shared memory and therefore This class additionally saves the index tensor in shared memory and therefore
avoids duplicating the same index tensor during shuffling. avoids duplicating the same index tensor during shuffling.
""" """
def __init__(self, indices, batch_size, drop_last, ddp_seed, shuffle): def __init__(self, indices, batch_size, drop_last, ddp_seed, shuffle):
if isinstance(indices, Mapping): if isinstance(indices, Mapping):
self._mapping_keys = list(indices.keys()) self._mapping_keys = list(indices.keys())
...@@ -191,7 +231,9 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset): ...@@ -191,7 +231,9 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset):
self._shuffle = shuffle self._shuffle = shuffle
if self.drop_last and len_indices % self.num_replicas != 0: if self.drop_last and len_indices % self.num_replicas != 0:
self.num_samples = math.ceil((len_indices - self.num_replicas) / self.num_replicas) self.num_samples = math.ceil(
(len_indices - self.num_replicas) / self.num_replicas
)
else: else:
self.num_samples = math.ceil(len_indices / self.num_replicas) self.num_samples = math.ceil(len_indices / self.num_replicas)
self.total_size = self.num_samples * self.num_replicas self.total_size = self.num_samples * self.num_replicas
...@@ -199,20 +241,27 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset): ...@@ -199,20 +241,27 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset):
# of indices since we will need to pad it after shuffling to make it evenly # of indices since we will need to pad it after shuffling to make it evenly
# divisible before every epoch. If drop_last is False, we create an array # divisible before every epoch. If drop_last is False, we create an array
# with the same size as the indices so we can trim it later. # with the same size as the indices so we can trim it later.
self.shared_mem_size = self.total_size if not self.drop_last else len_indices self.shared_mem_size = (
self.total_size if not self.drop_last else len_indices
)
self.num_indices = len_indices self.num_indices = len_indices
if isinstance(indices, Mapping): if isinstance(indices, Mapping):
self._device = next(iter(indices.values())).device self._device = next(iter(indices.values())).device
self._id_tensor = call_once_and_share( self._id_tensor = call_once_and_share(
lambda: _get_id_tensor_from_mapping(indices, self._device, self._mapping_keys), lambda: _get_id_tensor_from_mapping(
(self.num_indices, 2), dtype_of(indices)) indices, self._device, self._mapping_keys
),
(self.num_indices, 2),
dtype_of(indices),
)
else: else:
self._id_tensor = indices self._id_tensor = indices
self._device = self._id_tensor.device self._device = self._id_tensor.device
self._indices = call_once_and_share( self._indices = call_once_and_share(
self._create_shared_indices, (self.shared_mem_size,), torch.int64) self._create_shared_indices, (self.shared_mem_size,), torch.int64
)
def _create_shared_indices(self): def _create_shared_indices(self):
indices = torch.empty(self.shared_mem_size, dtype=torch.int64) indices = torch.empty(self.shared_mem_size, dtype=torch.int64)
...@@ -225,27 +274,38 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset): ...@@ -225,27 +274,38 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset):
"""Shuffles the dataset.""" """Shuffles the dataset."""
# Only rank 0 does the actual shuffling. The other ranks wait for it. # Only rank 0 does the actual shuffling. The other ranks wait for it.
if self.rank == 0: if self.rank == 0:
np.random.shuffle(self._indices[:self.num_indices].numpy()) np.random.shuffle(self._indices[: self.num_indices].numpy())
if not self.drop_last: if not self.drop_last:
# pad extra # pad extra
self._indices[self.num_indices:] = \ self._indices[self.num_indices :] = self._indices[
self._indices[:self.total_size - self.num_indices] : self.total_size - self.num_indices
]
dist.barrier() dist.barrier()
def __iter__(self): def __iter__(self):
start = self.num_samples * self.rank start = self.num_samples * self.rank
end = self.num_samples * (self.rank + 1) end = self.num_samples * (self.rank + 1)
indices = _divide_by_worker(self._indices[start:end], self.batch_size, self.drop_last) indices = _divide_by_worker(
self._indices[start:end], self.batch_size, self.drop_last
)
id_tensor = self._id_tensor[indices] id_tensor = self._id_tensor[indices]
return _TensorizedDatasetIter( return _TensorizedDatasetIter(
id_tensor, self.batch_size, self.drop_last, self._mapping_keys, self._shuffle) id_tensor,
self.batch_size,
self.drop_last,
self._mapping_keys,
self._shuffle,
)
def __len__(self): def __len__(self):
return (self.num_samples + (0 if self.drop_last else (self.batch_size - 1))) // \ return (
self.batch_size self.num_samples + (0 if self.drop_last else (self.batch_size - 1))
) // self.batch_size
def _prefetch_update_feats(feats, frames, types, get_storage_func, id_name, device, pin_prefetcher): def _prefetch_update_feats(
feats, frames, types, get_storage_func, id_name, device, pin_prefetcher
):
for tid, frame in enumerate(frames): for tid, frame in enumerate(frames):
type_ = types[tid] type_ = types[tid]
default_id = frame.get(id_name, None) default_id = frame.get(id_name, None)
...@@ -255,16 +315,19 @@ def _prefetch_update_feats(feats, frames, types, get_storage_func, id_name, devi ...@@ -255,16 +315,19 @@ def _prefetch_update_feats(feats, frames, types, get_storage_func, id_name, devi
parent_key = column.name or key parent_key = column.name or key
if column.id_ is None and default_id is None: if column.id_ is None and default_id is None:
raise DGLError( raise DGLError(
'Found a LazyFeature with no ID specified, ' "Found a LazyFeature with no ID specified, "
'and the graph does not have dgl.NID or dgl.EID columns') "and the graph does not have dgl.NID or dgl.EID columns"
)
feats[tid, key] = get_storage_func(parent_key, type_).fetch( feats[tid, key] = get_storage_func(parent_key, type_).fetch(
column.id_ or default_id, device, pin_prefetcher) column.id_ or default_id, device, pin_prefetcher
)
# This class exists to avoid recursion into the feature dictionary returned by the # This class exists to avoid recursion into the feature dictionary returned by the
# prefetcher when calling recursive_apply(). # prefetcher when calling recursive_apply().
class _PrefetchedGraphFeatures(object): class _PrefetchedGraphFeatures(object):
__slots__ = ['node_feats', 'edge_feats'] __slots__ = ["node_feats", "edge_feats"]
def __init__(self, node_feats, edge_feats): def __init__(self, node_feats, edge_feats):
self.node_feats = node_feats self.node_feats = node_feats
self.edge_feats = edge_feats self.edge_feats = edge_feats
...@@ -273,11 +336,23 @@ class _PrefetchedGraphFeatures(object): ...@@ -273,11 +336,23 @@ class _PrefetchedGraphFeatures(object):
def _prefetch_for_subgraph(subg, dataloader): def _prefetch_for_subgraph(subg, dataloader):
node_feats, edge_feats = {}, {} node_feats, edge_feats = {}, {}
_prefetch_update_feats( _prefetch_update_feats(
node_feats, subg._node_frames, subg.ntypes, dataloader.graph.get_node_storage, node_feats,
NID, dataloader.device, dataloader.pin_prefetcher) subg._node_frames,
subg.ntypes,
dataloader.graph.get_node_storage,
NID,
dataloader.device,
dataloader.pin_prefetcher,
)
_prefetch_update_feats( _prefetch_update_feats(
edge_feats, subg._edge_frames, subg.canonical_etypes, dataloader.graph.get_edge_storage, edge_feats,
EID, dataloader.device, dataloader.pin_prefetcher) subg._edge_frames,
subg.canonical_etypes,
dataloader.graph.get_edge_storage,
EID,
dataloader.device,
dataloader.pin_prefetcher,
)
return _PrefetchedGraphFeatures(node_feats, edge_feats) return _PrefetchedGraphFeatures(node_feats, edge_feats)
...@@ -286,13 +361,14 @@ def _prefetch_for(item, dataloader): ...@@ -286,13 +361,14 @@ def _prefetch_for(item, dataloader):
return _prefetch_for_subgraph(item, dataloader) return _prefetch_for_subgraph(item, dataloader)
elif isinstance(item, LazyFeature): elif isinstance(item, LazyFeature):
return dataloader.other_storages[item.name].fetch( return dataloader.other_storages[item.name].fetch(
item.id_, dataloader.device, dataloader.pin_prefetcher) item.id_, dataloader.device, dataloader.pin_prefetcher
)
else: else:
return None return None
def _await_or_return(x): def _await_or_return(x):
if hasattr(x, 'wait'): if hasattr(x, "wait"):
return x.wait() return x.wait()
elif isinstance(x, _PrefetchedGraphFeatures): elif isinstance(x, _PrefetchedGraphFeatures):
node_feats = recursive_apply(x.node_feats, _await_or_return) node_feats = recursive_apply(x.node_feats, _await_or_return)
...@@ -301,10 +377,11 @@ def _await_or_return(x): ...@@ -301,10 +377,11 @@ def _await_or_return(x):
else: else:
return x return x
def _record_stream(x, stream): def _record_stream(x, stream):
if stream is None: if stream is None:
return x return x
if hasattr(x, 'record_stream'): if hasattr(x, "record_stream"):
x.record_stream(stream) x.record_stream(stream)
return x return x
elif isinstance(x, _PrefetchedGraphFeatures): elif isinstance(x, _PrefetchedGraphFeatures):
...@@ -314,6 +391,7 @@ def _record_stream(x, stream): ...@@ -314,6 +391,7 @@ def _record_stream(x, stream):
else: else:
return x return x
def _prefetch(batch, dataloader, stream): def _prefetch(batch, dataloader, stream):
# feats has the same nested structure of batch, except that # feats has the same nested structure of batch, except that
# (1) each subgraph is replaced with a pair of node features and edge features, both # (1) each subgraph is replaced with a pair of node features and edge features, both
...@@ -335,7 +413,9 @@ def _prefetch(batch, dataloader, stream): ...@@ -335,7 +413,9 @@ def _prefetch(batch, dataloader, stream):
feats = recursive_apply(feats, _await_or_return) feats = recursive_apply(feats, _await_or_return)
feats = recursive_apply(feats, _record_stream, current_stream) feats = recursive_apply(feats, _record_stream, current_stream)
# transfer input nodes/seed nodes/subgraphs # transfer input nodes/seed nodes/subgraphs
batch = recursive_apply(batch, lambda x: x.to(dataloader.device, non_blocking=True)) batch = recursive_apply(
batch, lambda x: x.to(dataloader.device, non_blocking=True)
)
batch = recursive_apply(batch, _record_stream, current_stream) batch = recursive_apply(batch, _record_stream, current_stream)
stream_event = stream.record_event() if stream is not None else None stream_event = stream.record_event() if stream is not None else None
return batch, feats, stream_event return batch, feats, stream_event
...@@ -356,6 +436,7 @@ def _assign_for(item, feat): ...@@ -356,6 +436,7 @@ def _assign_for(item, feat):
else: else:
return item return item
def _put_if_event_not_set(queue, result, event): def _put_if_event_not_set(queue, result, event):
while not event.is_set(): while not event.is_set():
try: try:
...@@ -364,8 +445,10 @@ def _put_if_event_not_set(queue, result, event): ...@@ -364,8 +445,10 @@ def _put_if_event_not_set(queue, result, event):
except Full: except Full:
continue continue
def _prefetcher_entry( def _prefetcher_entry(
dataloader_it, dataloader, queue, num_threads, stream, done_event): dataloader_it, dataloader, queue, num_threads, stream, done_event
):
# PyTorch will set the number of threads to 1 which slows down pin_memory() calls # PyTorch will set the number of threads to 1 which slows down pin_memory() calls
# in main process if a prefetching thread is created. # in main process if a prefetching thread is created.
if num_threads is not None: if num_threads is not None:
...@@ -377,13 +460,20 @@ def _prefetcher_entry( ...@@ -377,13 +460,20 @@ def _prefetcher_entry(
batch = next(dataloader_it) batch = next(dataloader_it)
except StopIteration: except StopIteration:
break break
batch = recursive_apply(batch, restore_parent_storage_columns, dataloader.graph) batch = recursive_apply(
batch, restore_parent_storage_columns, dataloader.graph
)
batch, feats, stream_event = _prefetch(batch, dataloader, stream) batch, feats, stream_event = _prefetch(batch, dataloader, stream)
_put_if_event_not_set(queue, (batch, feats, stream_event, None), done_event) _put_if_event_not_set(
queue, (batch, feats, stream_event, None), done_event
)
_put_if_event_not_set(queue, (None, None, None, None), done_event) _put_if_event_not_set(queue, (None, None, None, None), done_event)
except: # pylint: disable=bare-except except: # pylint: disable=bare-except
_put_if_event_not_set( _put_if_event_not_set(
queue, (None, None, None, ExceptionWrapper(where='in prefetcher')), done_event) queue,
(None, None, None, ExceptionWrapper(where="in prefetcher")),
done_event,
)
# DGLGraphs have the semantics of lazy feature slicing with subgraphs. Such behavior depends # DGLGraphs have the semantics of lazy feature slicing with subgraphs. Such behavior depends
...@@ -400,10 +490,11 @@ def remove_parent_storage_columns(item, g): ...@@ -400,10 +490,11 @@ def remove_parent_storage_columns(item, g):
return item return item
for subframe, frame in zip( for subframe, frame in zip(
itertools.chain(item._node_frames, item._edge_frames), itertools.chain(item._node_frames, item._edge_frames),
itertools.chain(g._node_frames, g._edge_frames)): itertools.chain(g._node_frames, g._edge_frames),
):
for key in list(subframe.keys()): for key in list(subframe.keys()):
subcol = subframe._columns[key] # directly get the column object subcol = subframe._columns[key] # directly get the column object
if isinstance(subcol, LazyFeature): if isinstance(subcol, LazyFeature):
continue continue
col = frame._columns.get(key, None) col = frame._columns.get(key, None)
...@@ -422,8 +513,9 @@ def restore_parent_storage_columns(item, g): ...@@ -422,8 +513,9 @@ def restore_parent_storage_columns(item, g):
return item return item
for subframe, frame in zip( for subframe, frame in zip(
itertools.chain(item._node_frames, item._edge_frames), itertools.chain(item._node_frames, item._edge_frames),
itertools.chain(g._node_frames, g._edge_frames)): itertools.chain(g._node_frames, g._edge_frames),
):
for key in subframe.keys(): for key in subframe.keys():
subcol = subframe._columns[key] subcol = subframe._columns[key]
if isinstance(subcol, LazyFeature): if isinstance(subcol, LazyFeature):
...@@ -446,7 +538,7 @@ class _PrefetchingIter(object): ...@@ -446,7 +538,7 @@ class _PrefetchingIter(object):
self.use_thread = dataloader.use_prefetch_thread self.use_thread = dataloader.use_prefetch_thread
self.use_alternate_streams = dataloader.use_alternate_streams self.use_alternate_streams = dataloader.use_alternate_streams
self.device = self.dataloader.device self.device = self.dataloader.device
if self.use_alternate_streams and self.device.type == 'cuda': if self.use_alternate_streams and self.device.type == "cuda":
self.stream = torch.cuda.Stream(device=self.device) self.stream = torch.cuda.Stream(device=self.device)
else: else:
self.stream = None self.stream = None
...@@ -455,9 +547,16 @@ class _PrefetchingIter(object): ...@@ -455,9 +547,16 @@ class _PrefetchingIter(object):
self._done_event = threading.Event() self._done_event = threading.Event()
thread = threading.Thread( thread = threading.Thread(
target=_prefetcher_entry, target=_prefetcher_entry,
args=(dataloader_it, dataloader, self.queue, num_threads, args=(
self.stream, self._done_event), dataloader_it,
daemon=True) dataloader,
self.queue,
num_threads,
self.stream,
self._done_event,
),
daemon=True,
)
thread.start() thread.start()
self.thread = thread self.thread = thread
...@@ -477,12 +576,12 @@ class _PrefetchingIter(object): ...@@ -477,12 +576,12 @@ class _PrefetchingIter(object):
self._done_event.set() self._done_event.set()
try: try:
self.queue.get_nowait() # In case the thread is blocking on put(). self.queue.get_nowait() # In case the thread is blocking on put().
except: # pylint: disable=bare-except except: # pylint: disable=bare-except
pass pass
self.thread.join() self.thread.join()
except: # pylint: disable=bare-except except: # pylint: disable=bare-except
pass pass
def __del__(self): def __del__(self):
...@@ -491,16 +590,23 @@ class _PrefetchingIter(object): ...@@ -491,16 +590,23 @@ class _PrefetchingIter(object):
def _next_non_threaded(self): def _next_non_threaded(self):
batch = next(self.dataloader_it) batch = next(self.dataloader_it)
batch = recursive_apply(batch, restore_parent_storage_columns, self.dataloader.graph) batch = recursive_apply(
batch, feats, stream_event = _prefetch(batch, self.dataloader, self.stream) batch, restore_parent_storage_columns, self.dataloader.graph
)
batch, feats, stream_event = _prefetch(
batch, self.dataloader, self.stream
)
return batch, feats, stream_event return batch, feats, stream_event
def _next_threaded(self): def _next_threaded(self):
try: try:
batch, feats, stream_event, exception = self.queue.get(timeout=prefetcher_timeout) batch, feats, stream_event, exception = self.queue.get(
timeout=prefetcher_timeout
)
except Empty: except Empty:
raise RuntimeError( raise RuntimeError(
f'Prefetcher thread timed out at {prefetcher_timeout} seconds.') f"Prefetcher thread timed out at {prefetcher_timeout} seconds."
)
if batch is None: if batch is None:
self.thread.join() self.thread.join()
if exception is None: if exception is None:
...@@ -509,8 +615,11 @@ class _PrefetchingIter(object): ...@@ -509,8 +615,11 @@ class _PrefetchingIter(object):
return batch, feats, stream_event return batch, feats, stream_event
def __next__(self): def __next__(self):
batch, feats, stream_event = \ batch, feats, stream_event = (
self._next_non_threaded() if not self.use_thread else self._next_threaded() self._next_non_threaded()
if not self.use_thread
else self._next_threaded()
)
batch = recursive_apply_pair(batch, feats, _assign_for) batch = recursive_apply_pair(batch, feats, _assign_for)
if stream_event is not None: if stream_event is not None:
stream_event.wait() stream_event.wait()
...@@ -522,6 +631,7 @@ class CollateWrapper(object): ...@@ -522,6 +631,7 @@ class CollateWrapper(object):
"""Wraps a collate function with :func:`remove_parent_storage_columns` for serializing """Wraps a collate function with :func:`remove_parent_storage_columns` for serializing
from PyTorch DataLoader workers. from PyTorch DataLoader workers.
""" """
def __init__(self, sample_func, g, use_uva, device): def __init__(self, sample_func, g, use_uva, device):
self.sample_func = sample_func self.sample_func = sample_func
self.g = g self.g = g
...@@ -529,8 +639,8 @@ class CollateWrapper(object): ...@@ -529,8 +639,8 @@ class CollateWrapper(object):
self.device = device self.device = device
def __call__(self, items): def __call__(self, items):
graph_device = getattr(self.g, 'device', None) graph_device = getattr(self.g, "device", None)
if self.use_uva or (graph_device != torch.device('cpu')): if self.use_uva or (graph_device != torch.device("cpu")):
# Only copy the indices to the given device if in UVA mode or the graph # Only copy the indices to the given device if in UVA mode or the graph
# is not on CPU. # is not on CPU.
items = recursive_apply(items, lambda x: x.to(self.device)) items = recursive_apply(items, lambda x: x.to(self.device))
...@@ -542,6 +652,7 @@ class WorkerInitWrapper(object): ...@@ -542,6 +652,7 @@ class WorkerInitWrapper(object):
"""Wraps the :attr:`worker_init_fn` argument of the DataLoader to set the number of DGL """Wraps the :attr:`worker_init_fn` argument of the DataLoader to set the number of DGL
OMP threads to 1 for PyTorch DataLoader workers. OMP threads to 1 for PyTorch DataLoader workers.
""" """
def __init__(self, func): def __init__(self, func):
self.func = func self.func = func
...@@ -551,25 +662,37 @@ class WorkerInitWrapper(object): ...@@ -551,25 +662,37 @@ class WorkerInitWrapper(object):
self.func(worker_id) self.func(worker_id)
def create_tensorized_dataset(indices, batch_size, drop_last, use_ddp, ddp_seed, def create_tensorized_dataset(
shuffle, use_shared_memory): indices,
batch_size,
drop_last,
use_ddp,
ddp_seed,
shuffle,
use_shared_memory,
):
"""Converts a given indices tensor to a TensorizedDataset, an IterableDataset """Converts a given indices tensor to a TensorizedDataset, an IterableDataset
that returns views of the original tensor, to reduce overhead from having that returns views of the original tensor, to reduce overhead from having
a list of scalar tensors in default PyTorch DataLoader implementation. a list of scalar tensors in default PyTorch DataLoader implementation.
""" """
if use_ddp: if use_ddp:
# DDP always uses shared memory # DDP always uses shared memory
return DDPTensorizedDataset(indices, batch_size, drop_last, ddp_seed, shuffle) return DDPTensorizedDataset(
indices, batch_size, drop_last, ddp_seed, shuffle
)
else: else:
return TensorizedDataset(indices, batch_size, drop_last, shuffle, use_shared_memory) return TensorizedDataset(
indices, batch_size, drop_last, shuffle, use_shared_memory
)
def _get_device(device): def _get_device(device):
device = torch.device(device) device = torch.device(device)
if device.type == 'cuda' and device.index is None: if device.type == "cuda" and device.index is None:
device = torch.device('cuda', torch.cuda.current_device()) device = torch.device("cuda", torch.cuda.current_device())
return device return device
class DataLoader(torch.utils.data.DataLoader): class DataLoader(torch.utils.data.DataLoader):
"""Sampled graph data loader. Wrap a :class:`~dgl.DGLGraph` and a """Sampled graph data loader. Wrap a :class:`~dgl.DGLGraph` and a
:class:`~dgl.dataloading.Sampler` into an iterable over mini-batches of samples. :class:`~dgl.dataloading.Sampler` into an iterable over mini-batches of samples.
...@@ -695,10 +818,24 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -695,10 +818,24 @@ class DataLoader(torch.utils.data.DataLoader):
- Otherwise, both the sampling and subgraph construction will take place on the CPU. - Otherwise, both the sampling and subgraph construction will take place on the CPU.
""" """
def __init__(self, graph, indices, graph_sampler, device=None, use_ddp=False,
ddp_seed=0, batch_size=1, drop_last=False, shuffle=False, def __init__(
use_prefetch_thread=None, use_alternate_streams=None, self,
pin_prefetcher=None, use_uva=False, **kwargs): graph,
indices,
graph_sampler,
device=None,
use_ddp=False,
ddp_seed=0,
batch_size=1,
drop_last=False,
shuffle=False,
use_prefetch_thread=None,
use_alternate_streams=None,
pin_prefetcher=None,
use_uva=False,
**kwargs,
):
# (BarclayII) PyTorch Lightning sometimes will recreate a DataLoader from an existing # (BarclayII) PyTorch Lightning sometimes will recreate a DataLoader from an existing
# DataLoader with modifications to the original arguments. The arguments are retrieved # DataLoader with modifications to the original arguments. The arguments are retrieved
# from the attributes with the same name, and because we change certain arguments # from the attributes with the same name, and because we change certain arguments
...@@ -710,8 +847,8 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -710,8 +847,8 @@ class DataLoader(torch.utils.data.DataLoader):
# is indeed in kwargs and it's already a CollateWrapper object, we can assume that # is indeed in kwargs and it's already a CollateWrapper object, we can assume that
# the arguments come from a previously created DGL DataLoader, and directly initialize # the arguments come from a previously created DGL DataLoader, and directly initialize
# the new DataLoader from kwargs without any changes. # the new DataLoader from kwargs without any changes.
if isinstance(kwargs.get('collate_fn', None), CollateWrapper): if isinstance(kwargs.get("collate_fn", None), CollateWrapper):
assert batch_size is None # must be None assert batch_size is None # must be None
# restore attributes # restore attributes
self.graph = graph self.graph = graph
self.indices = indices self.indices = indices
...@@ -725,14 +862,15 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -725,14 +862,15 @@ class DataLoader(torch.utils.data.DataLoader):
self.use_alternate_streams = use_alternate_streams self.use_alternate_streams = use_alternate_streams
self.pin_prefetcher = pin_prefetcher self.pin_prefetcher = pin_prefetcher
self.use_uva = use_uva self.use_uva = use_uva
kwargs['batch_size'] = None kwargs["batch_size"] = None
super().__init__(**kwargs) super().__init__(**kwargs)
return return
if isinstance(graph, DistGraph): if isinstance(graph, DistGraph):
raise TypeError( raise TypeError(
'Please use dgl.dataloading.DistNodeDataLoader or ' "Please use dgl.dataloading.DistNodeDataLoader or "
'dgl.datalaoding.DistEdgeDataLoader for DistGraphs.') "dgl.datalaoding.DistEdgeDataLoader for DistGraphs."
)
# (BarclayII) I hoped that pin_prefetcher can be merged into PyTorch's native # (BarclayII) I hoped that pin_prefetcher can be merged into PyTorch's native
# pin_memory argument. But our neighbor samplers and subgraph samplers # pin_memory argument. But our neighbor samplers and subgraph samplers
# return indices, which could be CUDA tensors (e.g. during UVA sampling) # return indices, which could be CUDA tensors (e.g. during UVA sampling)
...@@ -743,26 +881,34 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -743,26 +881,34 @@ class DataLoader(torch.utils.data.DataLoader):
# to pinning prefetched features and disable pin memory for sampler's returns # to pinning prefetched features and disable pin memory for sampler's returns
# no matter what, but I doubt if it's reasonable. # no matter what, but I doubt if it's reasonable.
self.graph = graph self.graph = graph
self.indices = indices # For PyTorch-Lightning self.indices = indices # For PyTorch-Lightning
num_workers = kwargs.get('num_workers', 0) num_workers = kwargs.get("num_workers", 0)
indices_device = None indices_device = None
try: try:
if isinstance(indices, Mapping): if isinstance(indices, Mapping):
indices = {k: (torch.tensor(v) if not torch.is_tensor(v) else v) indices = {
for k, v in indices.items()} k: (torch.tensor(v) if not torch.is_tensor(v) else v)
for k, v in indices.items()
}
indices_device = next(iter(indices.values())).device indices_device = next(iter(indices.values())).device
else: else:
indices = torch.tensor(indices) if not torch.is_tensor(indices) else indices indices = (
torch.tensor(indices)
if not torch.is_tensor(indices)
else indices
)
indices_device = indices.device indices_device = indices.device
except: # pylint: disable=bare-except except: # pylint: disable=bare-except
# ignore when it fails to convert to torch Tensors. # ignore when it fails to convert to torch Tensors.
pass pass
if indices_device is None: if indices_device is None:
if not hasattr(indices, 'device'): if not hasattr(indices, "device"):
raise AttributeError('Custom indices dataset requires a \"device\" \ raise AttributeError(
attribute indicating where the indices is.') 'Custom indices dataset requires a "device" \
attribute indicating where the indices is.'
)
indices_device = indices.device indices_device = indices.device
if device is None: if device is None:
...@@ -776,10 +922,14 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -776,10 +922,14 @@ class DataLoader(torch.utils.data.DataLoader):
if isinstance(self.graph, DGLGraph): if isinstance(self.graph, DGLGraph):
# Check graph and indices device as well as num_workers # Check graph and indices device as well as num_workers
if use_uva: if use_uva:
if self.graph.device.type != 'cpu': if self.graph.device.type != "cpu":
raise ValueError('Graph must be on CPU if UVA sampling is enabled.') raise ValueError(
"Graph must be on CPU if UVA sampling is enabled."
)
if num_workers > 0: if num_workers > 0:
raise ValueError('num_workers must be 0 if UVA sampling is enabled.') raise ValueError(
"num_workers must be 0 if UVA sampling is enabled."
)
# Create all the formats and pin the features - custom GraphStorages # Create all the formats and pin the features - custom GraphStorages
# will need to do that themselves. # will need to do that themselves.
...@@ -788,16 +938,23 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -788,16 +938,23 @@ class DataLoader(torch.utils.data.DataLoader):
else: else:
if self.graph.device != indices_device: if self.graph.device != indices_device:
raise ValueError( raise ValueError(
'Expect graph and indices to be on the same device when use_uva=False. ') "Expect graph and indices to be on the same device when use_uva=False. "
if self.graph.device.type == 'cuda' and num_workers > 0: )
raise ValueError('num_workers must be 0 if graph and indices are on CUDA.') if self.graph.device.type == "cuda" and num_workers > 0:
if self.graph.device.type == 'cpu' and num_workers > 0: raise ValueError(
"num_workers must be 0 if graph and indices are on CUDA."
)
if self.graph.device.type == "cpu" and num_workers > 0:
# Instantiate all the formats if the number of workers is greater than 0. # Instantiate all the formats if the number of workers is greater than 0.
self.graph.create_formats_() self.graph.create_formats_()
# Check pin_prefetcher and use_prefetch_thread - should be only effective # Check pin_prefetcher and use_prefetch_thread - should be only effective
# if performing CPU sampling but output device is CUDA # if performing CPU sampling but output device is CUDA
if self.device.type == 'cuda' and self.graph.device.type == 'cpu' and not use_uva: if (
self.device.type == "cuda"
and self.graph.device.type == "cpu"
and not use_uva
):
if pin_prefetcher is None: if pin_prefetcher is None:
pin_prefetcher = True pin_prefetcher = True
if use_prefetch_thread is None: if use_prefetch_thread is None:
...@@ -805,15 +962,17 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -805,15 +962,17 @@ class DataLoader(torch.utils.data.DataLoader):
else: else:
if pin_prefetcher is True: if pin_prefetcher is True:
raise ValueError( raise ValueError(
'pin_prefetcher=True is only effective when device=cuda and ' "pin_prefetcher=True is only effective when device=cuda and "
'sampling is performed on CPU.') "sampling is performed on CPU."
)
if pin_prefetcher is None: if pin_prefetcher is None:
pin_prefetcher = False pin_prefetcher = False
if use_prefetch_thread is True: if use_prefetch_thread is True:
raise ValueError( raise ValueError(
'use_prefetch_thread=True is only effective when device=cuda and ' "use_prefetch_thread=True is only effective when device=cuda and "
'sampling is performed on CPU.') "sampling is performed on CPU."
)
if use_prefetch_thread is None: if use_prefetch_thread is None:
use_prefetch_thread = False use_prefetch_thread = False
...@@ -826,16 +985,25 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -826,16 +985,25 @@ class DataLoader(torch.utils.data.DataLoader):
and is_tensor_adaptor_enabled() and is_tensor_adaptor_enabled()
) )
elif use_alternate_streams and not is_tensor_adaptor_enabled(): elif use_alternate_streams and not is_tensor_adaptor_enabled():
dgl_warning("use_alternate_streams is turned off because " dgl_warning(
"TensorAdaptor is not available.") "use_alternate_streams is turned off because "
"TensorAdaptor is not available."
)
use_alternate_streams = False use_alternate_streams = False
if (torch.is_tensor(indices) or ( if torch.is_tensor(indices) or (
isinstance(indices, Mapping) and isinstance(indices, Mapping)
all(torch.is_tensor(v) for v in indices.values()))): and all(torch.is_tensor(v) for v in indices.values())
):
self.dataset = create_tensorized_dataset( self.dataset = create_tensorized_dataset(
indices, batch_size, drop_last, use_ddp, ddp_seed, shuffle, indices,
kwargs.get('persistent_workers', False)) batch_size,
drop_last,
use_ddp,
ddp_seed,
shuffle,
kwargs.get("persistent_workers", False),
)
else: else:
self.dataset = indices self.dataset = indices
...@@ -850,35 +1018,43 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -850,35 +1018,43 @@ class DataLoader(torch.utils.data.DataLoader):
self.use_prefetch_thread = use_prefetch_thread self.use_prefetch_thread = use_prefetch_thread
self.cpu_affinity_enabled = False self.cpu_affinity_enabled = False
worker_init_fn = WorkerInitWrapper(kwargs.get('worker_init_fn', None)) worker_init_fn = WorkerInitWrapper(kwargs.get("worker_init_fn", None))
self.other_storages = {} self.other_storages = {}
super().__init__( super().__init__(
self.dataset, self.dataset,
collate_fn=CollateWrapper( collate_fn=CollateWrapper(
self.graph_sampler.sample, graph, self.use_uva, self.device), self.graph_sampler.sample, graph, self.use_uva, self.device
),
batch_size=None, batch_size=None,
pin_memory=self.pin_prefetcher, pin_memory=self.pin_prefetcher,
worker_init_fn=worker_init_fn, worker_init_fn=worker_init_fn,
**kwargs) **kwargs,
)
def __iter__(self): def __iter__(self):
if self.device.type == 'cpu' and not self.cpu_affinity_enabled: if self.device.type == "cpu" and not self.cpu_affinity_enabled:
link = 'https://docs.dgl.ai/tutorials/cpu/cpu_best_practises.html' link = "https://docs.dgl.ai/tutorials/cpu/cpu_best_practises.html"
dgl_warning(f'Dataloader CPU affinity opt is not enabled, consider switching it on ' dgl_warning(
f'(see enable_cpu_affinity() or CPU best practices for DGL [{link}])') f"Dataloader CPU affinity opt is not enabled, consider switching it on "
f"(see enable_cpu_affinity() or CPU best practices for DGL [{link}])"
)
if self.shuffle: if self.shuffle:
self.dataset.shuffle() self.dataset.shuffle()
# When using multiprocessing PyTorch sometimes set the number of PyTorch threads to 1 # When using multiprocessing PyTorch sometimes set the number of PyTorch threads to 1
# when spawning new Python threads. This drastically slows down pinning features. # when spawning new Python threads. This drastically slows down pinning features.
num_threads = torch.get_num_threads() if self.num_workers > 0 else None num_threads = torch.get_num_threads() if self.num_workers > 0 else None
return _PrefetchingIter(self, super().__iter__(), num_threads=num_threads) return _PrefetchingIter(
self, super().__iter__(), num_threads=num_threads
)
@contextmanager @contextmanager
def enable_cpu_affinity(self, loader_cores=None, compute_cores=None, verbose=True): def enable_cpu_affinity(
""" Helper method for enabling cpu affinity for compute threads and dataloader workers self, loader_cores=None, compute_cores=None, verbose=True
):
"""Helper method for enabling cpu affinity for compute threads and dataloader workers
Only for CPU devices Only for CPU devices
Uses only NUMA node 0 by default for multi-node systems Uses only NUMA node 0 by default for multi-node systems
...@@ -900,16 +1076,21 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -900,16 +1076,21 @@ class DataLoader(torch.utils.data.DataLoader):
with dataloader.enable_cpu_affinity(): with dataloader.enable_cpu_affinity():
<training loop> <training loop>
""" """
if self.device.type == 'cpu': if self.device.type == "cpu":
if not self.num_workers > 0: if not self.num_workers > 0:
raise Exception('ERROR: affinity should be used with at least one DL worker') raise Exception(
"ERROR: affinity should be used with at least one DL worker"
)
if loader_cores and len(loader_cores) != self.num_workers: if loader_cores and len(loader_cores) != self.num_workers:
raise Exception('ERROR: cpu_affinity incorrect ' raise Exception(
'number of loader_cores={} for num_workers={}' "ERROR: cpu_affinity incorrect "
.format(loader_cores, self.num_workers)) "number of loader_cores={} for num_workers={}".format(
loader_cores, self.num_workers
)
)
# False positive E0203 (access-member-before-definition) linter warning # False positive E0203 (access-member-before-definition) linter warning
worker_init_fn_old = self.worker_init_fn # pylint: disable=E0203 worker_init_fn_old = self.worker_init_fn # pylint: disable=E0203
affinity_old = psutil.Process().cpu_affinity() affinity_old = psutil.Process().cpu_affinity()
nthreads_old = get_num_threads() nthreads_old = get_num_threads()
...@@ -920,8 +1101,11 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -920,8 +1101,11 @@ class DataLoader(torch.utils.data.DataLoader):
try: try:
psutil.Process().cpu_affinity([loader_cores[worker_id]]) psutil.Process().cpu_affinity([loader_cores[worker_id]])
except: except:
raise Exception('ERROR: cannot use affinity id={} cpu={}' raise Exception(
.format(worker_id, loader_cores)) "ERROR: cannot use affinity id={} cpu={}".format(
worker_id, loader_cores
)
)
worker_init_fn_old(worker_id) worker_init_fn_old(worker_id)
...@@ -931,13 +1115,15 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -931,13 +1115,15 @@ class DataLoader(torch.utils.data.DataLoader):
# take one thread per each node 0 core # take one thread per each node 0 core
node0_cores = [cpus[0] for core_id, cpus in numa_info[0]] node0_cores = [cpus[0] for core_id, cpus in numa_info[0]]
else: else:
node0_cores = list(range(psutil.cpu_count(logical = False))) node0_cores = list(range(psutil.cpu_count(logical=False)))
if len(node0_cores) <= self.num_workers: if len(node0_cores) <= self.num_workers:
raise Exception('ERROR: more workers than available cores') raise Exception("ERROR: more workers than available cores")
loader_cores = loader_cores or node0_cores[0:self.num_workers] loader_cores = loader_cores or node0_cores[0 : self.num_workers]
compute_cores = [cpu for cpu in node0_cores if cpu not in loader_cores] compute_cores = [
cpu for cpu in node0_cores if cpu not in loader_cores
]
try: try:
psutil.Process().cpu_affinity(compute_cores) psutil.Process().cpu_affinity(compute_cores)
...@@ -946,8 +1132,11 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -946,8 +1132,11 @@ class DataLoader(torch.utils.data.DataLoader):
self.cpu_affinity_enabled = True self.cpu_affinity_enabled = True
if verbose: if verbose:
print('{} DL workers are assigned to cpus {}, main process will use cpus {}' print(
.format(self.num_workers, loader_cores, compute_cores)) f"{self.num_workers} DL workers are assigned to cpus "
f"{loader_cores}, main process will use cpus "
f"{compute_cores}"
)
yield yield
finally: finally:
...@@ -970,16 +1159,18 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -970,16 +1159,18 @@ class DataLoader(torch.utils.data.DataLoader):
# GraphDataLoader loads a set of graphs so it's not relevant to the above. They are currently # GraphDataLoader loads a set of graphs so it's not relevant to the above. They are currently
# copied from the old DataLoader implementation. # copied from the old DataLoader implementation.
def _create_dist_sampler(dataset, dataloader_kwargs, ddp_seed): def _create_dist_sampler(dataset, dataloader_kwargs, ddp_seed):
# Note: will change the content of dataloader_kwargs # Note: will change the content of dataloader_kwargs
dist_sampler_kwargs = {'shuffle': dataloader_kwargs.get('shuffle', False)} dist_sampler_kwargs = {"shuffle": dataloader_kwargs.get("shuffle", False)}
dataloader_kwargs['shuffle'] = False dataloader_kwargs["shuffle"] = False
dist_sampler_kwargs['seed'] = ddp_seed dist_sampler_kwargs["seed"] = ddp_seed
dist_sampler_kwargs['drop_last'] = dataloader_kwargs.get('drop_last', False) dist_sampler_kwargs["drop_last"] = dataloader_kwargs.get("drop_last", False)
dataloader_kwargs['drop_last'] = False dataloader_kwargs["drop_last"] = False
return DistributedSampler(dataset, **dist_sampler_kwargs) return DistributedSampler(dataset, **dist_sampler_kwargs)
class GraphCollator(object): class GraphCollator(object):
"""Given a set of graphs as well as their graph-level data, the collate function will batch the """Given a set of graphs as well as their graph-level data, the collate function will batch the
graphs into a batched graph, and stack the tensors into a single bigger tensor. If the graphs into a batched graph, and stack the tensors into a single bigger tensor. If the
...@@ -998,13 +1189,15 @@ class GraphCollator(object): ...@@ -998,13 +1189,15 @@ class GraphCollator(object):
>>> for batched_graph, labels in dataloader: >>> for batched_graph, labels in dataloader:
... train_on(batched_graph, labels) ... train_on(batched_graph, labels)
""" """
def __init__(self): def __init__(self):
self.graph_collate_err_msg_format = ( self.graph_collate_err_msg_format = (
"graph_collate: batch must contain DGLGraph, tensors, numpy arrays, " "graph_collate: batch must contain DGLGraph, tensors, numpy arrays, "
"numbers, dicts or lists; found {}") "numbers, dicts or lists; found {}"
self.np_str_obj_array_pattern = re.compile(r'[SaUO]') )
self.np_str_obj_array_pattern = re.compile(r"[SaUO]")
#This implementation is based on torch.utils.data._utils.collate.default_collate # This implementation is based on torch.utils.data._utils.collate.default_collate
def collate(self, items): def collate(self, items):
"""This function is similar to ``torch.utils.data._utils.collate.default_collate``. """This function is similar to ``torch.utils.data._utils.collate.default_collate``.
It combines the sampled graphs and corresponding graph-level data It combines the sampled graphs and corresponding graph-level data
...@@ -1028,12 +1221,23 @@ class GraphCollator(object): ...@@ -1028,12 +1221,23 @@ class GraphCollator(object):
return batched_graphs return batched_graphs
elif F.is_tensor(elem): elif F.is_tensor(elem):
return F.stack(items, 0) return F.stack(items, 0)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ elif (
and elem_type.__name__ != 'string_': elem_type.__module__ == "numpy"
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': and elem_type.__name__ != "str_"
and elem_type.__name__ != "string_"
):
if (
elem_type.__name__ == "ndarray"
or elem_type.__name__ == "memmap"
):
# array of string classes and object # array of string classes and object
if self.np_str_obj_array_pattern.search(elem.dtype.str) is not None: if (
raise TypeError(self.graph_collate_err_msg_format.format(elem.dtype)) self.np_str_obj_array_pattern.search(elem.dtype.str)
is not None
):
raise TypeError(
self.graph_collate_err_msg_format.format(elem.dtype)
)
return self.collate([F.tensor(b) for b in items]) return self.collate([F.tensor(b) for b in items])
elif elem.shape == (): # scalars elif elem.shape == (): # scalars
...@@ -1046,19 +1250,24 @@ class GraphCollator(object): ...@@ -1046,19 +1250,24 @@ class GraphCollator(object):
return items return items
elif isinstance(elem, Mapping): elif isinstance(elem, Mapping):
return {key: self.collate([d[key] for d in items]) for key in elem} return {key: self.collate([d[key] for d in items]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
return elem_type(*(self.collate(samples) for samples in zip(*items))) return elem_type(
*(self.collate(samples) for samples in zip(*items))
)
elif isinstance(elem, Sequence): elif isinstance(elem, Sequence):
# check to make sure that the elements in batch have consistent size # check to make sure that the elements in batch have consistent size
item_iter = iter(items) item_iter = iter(items)
elem_size = len(next(item_iter)) elem_size = len(next(item_iter))
if not all(len(elem) == elem_size for elem in item_iter): if not all(len(elem) == elem_size for elem in item_iter):
raise RuntimeError('each element in list of batch should be of equal size') raise RuntimeError(
"each element in list of batch should be of equal size"
)
transposed = zip(*items) transposed = zip(*items)
return [self.collate(samples) for samples in transposed] return [self.collate(samples) for samples in transposed]
raise TypeError(self.graph_collate_err_msg_format.format(elem_type)) raise TypeError(self.graph_collate_err_msg_format.format(elem_type))
class GraphDataLoader(torch.utils.data.DataLoader): class GraphDataLoader(torch.utils.data.DataLoader):
"""Batched graph data loader. """Batched graph data loader.
...@@ -1113,9 +1322,12 @@ class GraphDataLoader(torch.utils.data.DataLoader): ...@@ -1113,9 +1322,12 @@ class GraphDataLoader(torch.utils.data.DataLoader):
... for batched_graph, labels in dataloader: ... for batched_graph, labels in dataloader:
... train_on(batched_graph, labels) ... train_on(batched_graph, labels)
""" """
collator_arglist = inspect.getfullargspec(GraphCollator).args collator_arglist = inspect.getfullargspec(GraphCollator).args
def __init__(self, dataset, collate_fn=None, use_ddp=False, ddp_seed=0, **kwargs): def __init__(
self, dataset, collate_fn=None, use_ddp=False, ddp_seed=0, **kwargs
):
collator_kwargs = {} collator_kwargs = {}
dataloader_kwargs = {} dataloader_kwargs = {}
for k, v in kwargs.items(): for k, v in kwargs.items():
...@@ -1126,13 +1338,17 @@ class GraphDataLoader(torch.utils.data.DataLoader): ...@@ -1126,13 +1338,17 @@ class GraphDataLoader(torch.utils.data.DataLoader):
self.use_ddp = use_ddp self.use_ddp = use_ddp
if use_ddp: if use_ddp:
self.dist_sampler = _create_dist_sampler(dataset, dataloader_kwargs, ddp_seed) self.dist_sampler = _create_dist_sampler(
dataloader_kwargs['sampler'] = self.dist_sampler dataset, dataloader_kwargs, ddp_seed
)
dataloader_kwargs["sampler"] = self.dist_sampler
if collate_fn is None and kwargs.get('batch_size', 1) is not None: if collate_fn is None and kwargs.get("batch_size", 1) is not None:
collate_fn = GraphCollator(**collator_kwargs).collate collate_fn = GraphCollator(**collator_kwargs).collate
super().__init__(dataset=dataset, collate_fn=collate_fn, **dataloader_kwargs) super().__init__(
dataset=dataset, collate_fn=collate_fn, **dataloader_kwargs
)
def set_epoch(self, epoch): def set_epoch(self, epoch):
"""Sets the epoch number for the underlying sampler which ensures all replicas """Sets the epoch number for the underlying sampler which ensures all replicas
...@@ -1150,4 +1366,4 @@ class GraphDataLoader(torch.utils.data.DataLoader): ...@@ -1150,4 +1366,4 @@ class GraphDataLoader(torch.utils.data.DataLoader):
if self.use_ddp: if self.use_ddp:
self.dist_sampler.set_epoch(epoch) self.dist_sampler.set_epoch(epoch)
else: else:
raise DGLError('set_epoch is only available when use_ddp is True.') raise DGLError("set_epoch is only available when use_ddp is True.")
...@@ -14,10 +14,12 @@ ...@@ -14,10 +14,12 @@
# limitations under the License. # limitations under the License.
# #
"""Functional interface for transform""" """Functional interface for transform"""
# pylint: disable= too-many-lines
from collections.abc import Iterable, Mapping
from collections import defaultdict
import copy import copy
from collections import defaultdict
from collections.abc import Iterable, Mapping
import numpy as np import numpy as np
import scipy.sparse as sparse import scipy.sparse as sparse
import scipy.sparse.linalg import scipy.sparse.linalg
...@@ -27,62 +29,70 @@ try: ...@@ -27,62 +29,70 @@ try:
except ImportError: except ImportError:
pass pass
from .. import (
backend as F,
batch,
convert,
function,
ndarray as nd,
subgraph,
utils,
)
from .._ffi.function import _init_api from .._ffi.function import _init_api
from ..base import dgl_warning, DGLError, NID, EID from ..base import dgl_warning, DGLError, EID, NID
from .. import convert
from ..heterograph import DGLGraph, DGLBlock
from ..heterograph_index import create_metagraph_index, create_heterograph_from_relations
from ..frame import Frame from ..frame import Frame
from .. import ndarray as nd from ..heterograph import DGLBlock, DGLGraph
from .. import backend as F from ..heterograph_index import (
from .. import utils, batch create_heterograph_from_relations,
from ..partition import metis_partition_assignment create_metagraph_index,
from ..partition import partition_graph_with_halo )
from ..partition import metis_partition from ..partition import (
from .. import subgraph metis_partition,
from .. import function metis_partition_assignment,
partition_graph_with_halo,
)
from ..sampling.neighbor import sample_neighbors from ..sampling.neighbor import sample_neighbors
__all__ = [ __all__ = [
'line_graph', "line_graph",
'khop_adj', "khop_adj",
'khop_graph', "khop_graph",
'reverse', "reverse",
'to_bidirected', "to_bidirected",
'add_reverse_edges', "add_reverse_edges",
'laplacian_lambda_max', "laplacian_lambda_max",
'knn_graph', "knn_graph",
'segmented_knn_graph', "segmented_knn_graph",
'add_edges', "add_edges",
'add_nodes', "add_nodes",
'remove_edges', "remove_edges",
'remove_nodes', "remove_nodes",
'add_self_loop', "add_self_loop",
'remove_self_loop', "remove_self_loop",
'metapath_reachable_graph', "metapath_reachable_graph",
'compact_graphs', "compact_graphs",
'to_block', "to_block",
'to_simple', "to_simple",
'to_simple_graph', "to_simple_graph",
'sort_csr_by_tag', "sort_csr_by_tag",
'sort_csc_by_tag', "sort_csc_by_tag",
'metis_partition_assignment', "metis_partition_assignment",
'partition_graph_with_halo', "partition_graph_with_halo",
'metis_partition', "metis_partition",
'adj_product_graph', "adj_product_graph",
'adj_sum_graph', "adj_sum_graph",
'reorder_graph', "reorder_graph",
'norm_by_dst', "norm_by_dst",
'radius_graph', "radius_graph",
'random_walk_pe', "random_walk_pe",
'laplacian_pe', "laplacian_pe",
'to_half', "to_half",
'to_float', "to_float",
'to_double', "to_double",
'double_radius_node_labeling', "double_radius_node_labeling",
'shortest_dist', "shortest_dist",
'svd_pe' "svd_pe",
] ]
def pairwise_squared_distance(x): def pairwise_squared_distance(x):
...@@ -94,9 +104,11 @@ def pairwise_squared_distance(x): ...@@ -94,9 +104,11 @@ def pairwise_squared_distance(x):
# assuming that __matmul__ is always implemented (true for PyTorch, MXNet and Chainer) # assuming that __matmul__ is always implemented (true for PyTorch, MXNet and Chainer)
return x2s + F.swapaxes(x2s, -1, -2) - 2 * x @ F.swapaxes(x, -1, -2) return x2s + F.swapaxes(x2s, -1, -2) - 2 * x @ F.swapaxes(x, -1, -2)
#pylint: disable=invalid-name
def knn_graph(x, k, algorithm='bruteforce-blas', dist='euclidean', # pylint: disable=invalid-name
exclude_self=False): def knn_graph(
x, k, algorithm="bruteforce-blas", dist="euclidean", exclude_self=False
):
r"""Construct a graph from a set of points according to k-nearest-neighbor (KNN) r"""Construct a graph from a set of points according to k-nearest-neighbor (KNN)
and return. and return.
...@@ -223,7 +235,7 @@ def knn_graph(x, k, algorithm='bruteforce-blas', dist='euclidean', ...@@ -223,7 +235,7 @@ def knn_graph(x, k, algorithm='bruteforce-blas', dist='euclidean',
d = F.ndim(x) d = F.ndim(x)
x_seg = x_size[0] * [x_size[1]] if d == 3 else [x_size[0]] x_seg = x_size[0] * [x_size[1]] if d == 3 else [x_size[0]]
if algorithm == 'bruteforce-blas': if algorithm == "bruteforce-blas":
result = _knn_graph_blas(x, k, dist=dist) result = _knn_graph_blas(x, k, dist=dist)
else: else:
if d == 3: if d == 3:
...@@ -238,7 +250,7 @@ def knn_graph(x, k, algorithm='bruteforce-blas', dist='euclidean', ...@@ -238,7 +250,7 @@ def knn_graph(x, k, algorithm='bruteforce-blas', dist='euclidean',
result.set_batch_num_nodes(num_nodes) result.set_batch_num_nodes(num_nodes)
# if any segment is too small for k, all algorithms reduce k for all segments # if any segment is too small for k, all algorithms reduce k for all segments
clamped_k = min(k, np.min(x_seg)) clamped_k = min(k, np.min(x_seg))
result.set_batch_num_edges(clamped_k*num_nodes) result.set_batch_num_edges(clamped_k * num_nodes)
if exclude_self: if exclude_self:
# remove_self_loop will update batch_num_edges as needed # remove_self_loop will update batch_num_edges as needed
...@@ -250,18 +262,21 @@ def knn_graph(x, k, algorithm='bruteforce-blas', dist='euclidean', ...@@ -250,18 +262,21 @@ def knn_graph(x, k, algorithm='bruteforce-blas', dist='euclidean',
# same degree as each other, so we can check that condition easily. # same degree as each other, so we can check that condition easily.
# The -1 is for the self edge removal. # The -1 is for the self edge removal.
clamped_k = min(k, np.min(x_seg)) - 1 clamped_k = min(k, np.min(x_seg)) - 1
if result.num_edges() != clamped_k*result.num_nodes(): if result.num_edges() != clamped_k * result.num_nodes():
# edges on any nodes with too high degree should all be length zero, # edges on any nodes with too high degree should all be length zero,
# so pick an arbitrary one to remove from each such node # so pick an arbitrary one to remove from each such node
degrees = result.in_degrees() degrees = result.in_degrees()
node_indices = F.nonzero_1d(degrees > clamped_k) node_indices = F.nonzero_1d(degrees > clamped_k)
edges_to_remove_graph = sample_neighbors(result, node_indices, 1, edge_dir='in') edges_to_remove_graph = sample_neighbors(
result, node_indices, 1, edge_dir="in"
)
edge_ids = edges_to_remove_graph.edata[EID] edge_ids = edges_to_remove_graph.edata[EID]
result = remove_edges(result, edge_ids) result = remove_edges(result, edge_ids)
return result return result
def _knn_graph_blas(x, k, dist='euclidean'):
def _knn_graph_blas(x, k, dist="euclidean"):
r"""Construct a graph from a set of points according to k-nearest-neighbor (KNN). r"""Construct a graph from a set of points according to k-nearest-neighbor (KNN).
This function first compute the distance matrix using BLAS matrix multiplication This function first compute the distance matrix using BLAS matrix multiplication
...@@ -291,13 +306,15 @@ def _knn_graph_blas(x, k, dist='euclidean'): ...@@ -291,13 +306,15 @@ def _knn_graph_blas(x, k, dist='euclidean'):
n_samples, n_points, _ = F.shape(x) n_samples, n_points, _ = F.shape(x)
if k > n_points: if k > n_points:
dgl_warning("'k' should be less than or equal to the number of points in 'x'" \ dgl_warning(
"expect k <= {0}, got k = {1}, use k = {0}".format(n_points, k)) "'k' should be less than or equal to the number of points in 'x'"
"expect k <= {0}, got k = {1}, use k = {0}".format(n_points, k)
)
k = n_points k = n_points
# if use cosine distance, normalize input points first # if use cosine distance, normalize input points first
# thus we can use euclidean distance to find knn equivalently. # thus we can use euclidean distance to find knn equivalently.
if dist == 'cosine': if dist == "cosine":
l2_norm = lambda v: F.sqrt(F.sum(v * v, dim=2, keepdims=True)) l2_norm = lambda v: F.sqrt(F.sum(v * v, dim=2, keepdims=True))
x = x / (l2_norm(x) + 1e-5) x = x / (l2_norm(x) + 1e-5)
...@@ -313,9 +330,16 @@ def _knn_graph_blas(x, k, dist='euclidean'): ...@@ -313,9 +330,16 @@ def _knn_graph_blas(x, k, dist='euclidean'):
dst = F.unsqueeze(dst, 0) + offset dst = F.unsqueeze(dst, 0) + offset
return convert.graph((F.reshape(src, (-1,)), F.reshape(dst, (-1,)))) return convert.graph((F.reshape(src, (-1,)), F.reshape(dst, (-1,))))
#pylint: disable=invalid-name
def segmented_knn_graph(x, k, segs, algorithm='bruteforce-blas', dist='euclidean', # pylint: disable=invalid-name
exclude_self=False): def segmented_knn_graph(
x,
k,
segs,
algorithm="bruteforce-blas",
dist="euclidean",
exclude_self=False,
):
r"""Construct multiple graphs from multiple sets of points according to r"""Construct multiple graphs from multiple sets of points according to
k-nearest-neighbor (KNN) and return. k-nearest-neighbor (KNN) and return.
...@@ -424,7 +448,7 @@ def segmented_knn_graph(x, k, segs, algorithm='bruteforce-blas', dist='euclidean ...@@ -424,7 +448,7 @@ def segmented_knn_graph(x, k, segs, algorithm='bruteforce-blas', dist='euclidean
if F.shape(x)[0] == 0: if F.shape(x)[0] == 0:
raise DGLError("Find empty point set") raise DGLError("Find empty point set")
if algorithm == 'bruteforce-blas': if algorithm == "bruteforce-blas":
result = _segmented_knn_graph_blas(x, k, segs, dist=dist) result = _segmented_knn_graph_blas(x, k, segs, dist=dist)
else: else:
out = knn(k, x, segs, algorithm=algorithm, dist=dist) out = knn(k, x, segs, algorithm=algorithm, dist=dist)
...@@ -435,7 +459,7 @@ def segmented_knn_graph(x, k, segs, algorithm='bruteforce-blas', dist='euclidean ...@@ -435,7 +459,7 @@ def segmented_knn_graph(x, k, segs, algorithm='bruteforce-blas', dist='euclidean
result.set_batch_num_nodes(num_nodes) result.set_batch_num_nodes(num_nodes)
# if any segment is too small for k, all algorithms reduce k for all segments # if any segment is too small for k, all algorithms reduce k for all segments
clamped_k = min(k, np.min(segs)) clamped_k = min(k, np.min(segs))
result.set_batch_num_edges(clamped_k*num_nodes) result.set_batch_num_edges(clamped_k * num_nodes)
if exclude_self: if exclude_self:
# remove_self_loop will update batch_num_edges as needed # remove_self_loop will update batch_num_edges as needed
...@@ -447,18 +471,21 @@ def segmented_knn_graph(x, k, segs, algorithm='bruteforce-blas', dist='euclidean ...@@ -447,18 +471,21 @@ def segmented_knn_graph(x, k, segs, algorithm='bruteforce-blas', dist='euclidean
# same degree as each other, so we can check that condition easily. # same degree as each other, so we can check that condition easily.
# The -1 is for the self edge removal. # The -1 is for the self edge removal.
clamped_k = min(k, np.min(segs)) - 1 clamped_k = min(k, np.min(segs)) - 1
if result.num_edges() != clamped_k*result.num_nodes(): if result.num_edges() != clamped_k * result.num_nodes():
# edges on any nodes with too high degree should all be length zero, # edges on any nodes with too high degree should all be length zero,
# so pick an arbitrary one to remove from each such node # so pick an arbitrary one to remove from each such node
degrees = result.in_degrees() degrees = result.in_degrees()
node_indices = F.nonzero_1d(degrees > clamped_k) node_indices = F.nonzero_1d(degrees > clamped_k)
edges_to_remove_graph = sample_neighbors(result, node_indices, 1, edge_dir='in') edges_to_remove_graph = sample_neighbors(
result, node_indices, 1, edge_dir="in"
)
edge_ids = edges_to_remove_graph.edata[EID] edge_ids = edges_to_remove_graph.edata[EID]
result = remove_edges(result, edge_ids) result = remove_edges(result, edge_ids)
return result return result
def _segmented_knn_graph_blas(x, k, segs, dist='euclidean'):
def _segmented_knn_graph_blas(x, k, segs, dist="euclidean"):
r"""Construct multiple graphs from multiple sets of points according to r"""Construct multiple graphs from multiple sets of points according to
k-nearest-neighbor (KNN). k-nearest-neighbor (KNN).
...@@ -484,7 +511,7 @@ def _segmented_knn_graph_blas(x, k, segs, dist='euclidean'): ...@@ -484,7 +511,7 @@ def _segmented_knn_graph_blas(x, k, segs, dist='euclidean'):
""" """
# if use cosine distance, normalize input points first # if use cosine distance, normalize input points first
# thus we can use euclidean distance to find knn equivalently. # thus we can use euclidean distance to find knn equivalently.
if dist == 'cosine': if dist == "cosine":
l2_norm = lambda v: F.sqrt(F.sum(v * v, dim=1, keepdims=True)) l2_norm = lambda v: F.sqrt(F.sum(v * v, dim=1, keepdims=True))
x = x / (l2_norm(x) + 1e-5) x = x / (l2_norm(x) + 1e-5)
...@@ -492,22 +519,34 @@ def _segmented_knn_graph_blas(x, k, segs, dist='euclidean'): ...@@ -492,22 +519,34 @@ def _segmented_knn_graph_blas(x, k, segs, dist='euclidean'):
offset = np.insert(np.cumsum(segs), 0, 0) offset = np.insert(np.cumsum(segs), 0, 0)
min_seg_size = np.min(segs) min_seg_size = np.min(segs)
if k > min_seg_size: if k > min_seg_size:
dgl_warning("'k' should be less than or equal to the number of points in 'x'" \ dgl_warning(
"expect k <= {0}, got k = {1}, use k = {0}".format(min_seg_size, k)) "'k' should be less than or equal to the number of points in 'x'"
"expect k <= {0}, got k = {1}, use k = {0}".format(min_seg_size, k)
)
k = min_seg_size k = min_seg_size
h_list = F.split(x, segs, 0) h_list = F.split(x, segs, 0)
src = [ src = [
F.argtopk(pairwise_squared_distance(h_g), k, 1, descending=False) + F.argtopk(pairwise_squared_distance(h_g), k, 1, descending=False)
int(offset[i]) + int(offset[i])
for i, h_g in enumerate(h_list)] for i, h_g in enumerate(h_list)
]
src = F.cat(src, 0) src = F.cat(src, 0)
ctx = F.context(x) ctx = F.context(x)
dst = F.repeat(F.arange(0, n_total_points, ctx=ctx), k, dim=0) dst = F.repeat(F.arange(0, n_total_points, ctx=ctx), k, dim=0)
return convert.graph((F.reshape(src, (-1,)), F.reshape(dst, (-1,)))) return convert.graph((F.reshape(src, (-1,)), F.reshape(dst, (-1,))))
def _nndescent_knn_graph(x, k, segs, num_iters=None, max_candidates=None,
delta=0.001, sample_rate=0.5, dist='euclidean'): def _nndescent_knn_graph(
x,
k,
segs,
num_iters=None,
max_candidates=None,
delta=0.001,
sample_rate=0.5,
dist="euclidean",
):
r"""Construct multiple graphs from multiple sets of points according to r"""Construct multiple graphs from multiple sets of points according to
**approximate** k-nearest-neighbor using NN-descent algorithm from paper **approximate** k-nearest-neighbor using NN-descent algorithm from paper
`Efficient k-nearest neighbor graph construction for generic similarity `Efficient k-nearest neighbor graph construction for generic similarity
...@@ -567,14 +606,16 @@ def _nndescent_knn_graph(x, k, segs, num_iters=None, max_candidates=None, ...@@ -567,14 +606,16 @@ def _nndescent_knn_graph(x, k, segs, num_iters=None, max_candidates=None,
# if use cosine distance, normalize input points first # if use cosine distance, normalize input points first
# thus we can use euclidean distance to find knn equivalently. # thus we can use euclidean distance to find knn equivalently.
if dist == 'cosine': if dist == "cosine":
l2_norm = lambda v: F.sqrt(F.sum(v * v, dim=1, keepdims=True)) l2_norm = lambda v: F.sqrt(F.sum(v * v, dim=1, keepdims=True))
x = x / (l2_norm(x) + 1e-5) x = x / (l2_norm(x) + 1e-5)
# k must less than or equal to min(segs) # k must less than or equal to min(segs)
if k > F.min(segs, dim=0): if k > F.min(segs, dim=0):
raise DGLError("'k' must be less than or equal to the number of points in 'x'" raise DGLError(
"expect 'k' <= {}, got 'k' = {}".format(F.min(segs, dim=0), k)) "'k' must be less than or equal to the number of points in 'x'"
"expect 'k' <= {}, got 'k' = {}".format(F.min(segs, dim=0), k)
)
if delta < 0 or delta > 1: if delta < 0 or delta > 1:
raise DGLError("'delta' must in [0, 1], got 'delta' = {}".format(delta)) raise DGLError("'delta' must in [0, 1], got 'delta' = {}".format(delta))
...@@ -583,12 +624,21 @@ def _nndescent_knn_graph(x, k, segs, num_iters=None, max_candidates=None, ...@@ -583,12 +624,21 @@ def _nndescent_knn_graph(x, k, segs, num_iters=None, max_candidates=None,
out = F.zeros((2, num_points * k), F.dtype(segs), F.context(segs)) out = F.zeros((2, num_points * k), F.dtype(segs), F.context(segs))
# points, offsets, out, k, num_iters, max_candidates, delta # points, offsets, out, k, num_iters, max_candidates, delta
_CAPI_DGLNNDescent(F.to_dgl_nd(x), F.to_dgl_nd(offset), _CAPI_DGLNNDescent(
F.zerocopy_to_dgl_ndarray_for_write(out), F.to_dgl_nd(x),
k, num_iters, max_candidates, delta) F.to_dgl_nd(offset),
F.zerocopy_to_dgl_ndarray_for_write(out),
k,
num_iters,
max_candidates,
delta,
)
return out return out
def knn(k, x, x_segs, y=None, y_segs=None, algorithm='bruteforce', dist='euclidean'):
def knn(
k, x, x_segs, y=None, y_segs=None, algorithm="bruteforce", dist="euclidean"
):
r"""For each element in each segment in :attr:`y`, find :attr:`k` nearest r"""For each element in each segment in :attr:`y`, find :attr:`k` nearest
points in the same segment in :attr:`x`. If :attr:`y` is None, perform a self-query points in the same segment in :attr:`x`. If :attr:`y` is None, perform a self-query
over :attr:`x`. over :attr:`x`.
...@@ -658,7 +708,9 @@ def knn(k, x, x_segs, y=None, y_segs=None, algorithm='bruteforce', dist='euclide ...@@ -658,7 +708,9 @@ def knn(k, x, x_segs, y=None, y_segs=None, algorithm='bruteforce', dist='euclide
# TODO(lygztq) add support for querying different point sets using nn-descent. # TODO(lygztq) add support for querying different point sets using nn-descent.
if algorithm == "nn-descent": if algorithm == "nn-descent":
if y is not None or y_segs is not None: if y is not None or y_segs is not None:
raise DGLError("Currently 'nn-descent' only supports self-query cases.") raise DGLError(
"Currently 'nn-descent' only supports self-query cases."
)
return _nndescent_knn_graph(x, k, x_segs, dist=dist) return _nndescent_knn_graph(x, k, x_segs, dist=dist)
# self query # self query
...@@ -677,8 +729,12 @@ def knn(k, x, x_segs, y=None, y_segs=None, algorithm='bruteforce', dist='euclide ...@@ -677,8 +729,12 @@ def knn(k, x, x_segs, y=None, y_segs=None, algorithm='bruteforce', dist='euclide
# k shoule be less than or equal to min(x_segs) # k shoule be less than or equal to min(x_segs)
min_num_points = F.min(x_segs, dim=0) min_num_points = F.min(x_segs, dim=0)
if k > min_num_points: if k > min_num_points:
dgl_warning("'k' should be less than or equal to the number of points in 'x'" \ dgl_warning(
"expect k <= {0}, got k = {1}, use k = {0}".format(min_num_points, k)) "'k' should be less than or equal to the number of points in 'x'"
"expect k <= {0}, got k = {1}, use k = {0}".format(
min_num_points, k
)
)
k = F.as_scalar(min_num_points) k = F.as_scalar(min_num_points)
# invalid k # invalid k
...@@ -690,31 +746,43 @@ def knn(k, x, x_segs, y=None, y_segs=None, algorithm='bruteforce', dist='euclide ...@@ -690,31 +746,43 @@ def knn(k, x, x_segs, y=None, y_segs=None, algorithm='bruteforce', dist='euclide
raise DGLError("Find empty point set") raise DGLError("Find empty point set")
dist = dist.lower() dist = dist.lower()
dist_metric_list = ['euclidean', 'cosine'] dist_metric_list = ["euclidean", "cosine"]
if dist not in dist_metric_list: if dist not in dist_metric_list:
raise DGLError('Only {} are supported for distance' raise DGLError(
'computation, got {}'.format(dist_metric_list, dist)) "Only {} are supported for distance"
"computation, got {}".format(dist_metric_list, dist)
)
x_offset = F.zeros((F.shape(x_segs)[0] + 1,), F.dtype(x_segs), F.context(x_segs)) x_offset = F.zeros(
(F.shape(x_segs)[0] + 1,), F.dtype(x_segs), F.context(x_segs)
)
x_offset[1:] = F.cumsum(x_segs, dim=0) x_offset[1:] = F.cumsum(x_segs, dim=0)
y_offset = F.zeros((F.shape(y_segs)[0] + 1,), F.dtype(y_segs), F.context(y_segs)) y_offset = F.zeros(
(F.shape(y_segs)[0] + 1,), F.dtype(y_segs), F.context(y_segs)
)
y_offset[1:] = F.cumsum(y_segs, dim=0) y_offset[1:] = F.cumsum(y_segs, dim=0)
out = F.zeros((2, F.shape(y)[0] * k), F.dtype(x_segs), F.context(x_segs)) out = F.zeros((2, F.shape(y)[0] * k), F.dtype(x_segs), F.context(x_segs))
# if use cosine distance, normalize input points first # if use cosine distance, normalize input points first
# thus we can use euclidean distance to find knn equivalently. # thus we can use euclidean distance to find knn equivalently.
if dist == 'cosine': if dist == "cosine":
l2_norm = lambda v: F.sqrt(F.sum(v * v, dim=1, keepdims=True)) l2_norm = lambda v: F.sqrt(F.sum(v * v, dim=1, keepdims=True))
x = x / (l2_norm(x) + 1e-5) x = x / (l2_norm(x) + 1e-5)
y = y / (l2_norm(y) + 1e-5) y = y / (l2_norm(y) + 1e-5)
_CAPI_DGLKNN(F.to_dgl_nd(x), F.to_dgl_nd(x_offset), _CAPI_DGLKNN(
F.to_dgl_nd(y), F.to_dgl_nd(y_offset), F.to_dgl_nd(x),
k, F.zerocopy_to_dgl_ndarray_for_write(out), F.to_dgl_nd(x_offset),
algorithm) F.to_dgl_nd(y),
F.to_dgl_nd(y_offset),
k,
F.zerocopy_to_dgl_ndarray_for_write(out),
algorithm,
)
return out return out
def to_bidirected(g, copy_ndata=False, readonly=None): def to_bidirected(g, copy_ndata=False, readonly=None):
r"""Convert the graph to a bi-directional simple graph and return. r"""Convert the graph to a bi-directional simple graph and return.
...@@ -785,21 +853,34 @@ def to_bidirected(g, copy_ndata=False, readonly=None): ...@@ -785,21 +853,34 @@ def to_bidirected(g, copy_ndata=False, readonly=None):
(tensor([1, 1, 2]), tensor([1, 2, 1])) (tensor([1, 1, 2]), tensor([1, 2, 1]))
""" """
if readonly is not None: if readonly is not None:
dgl_warning("Parameter readonly is deprecated" \ dgl_warning(
"There will be no difference between readonly and non-readonly DGLGraph") "Parameter readonly is deprecated"
"There will be no difference between readonly and non-readonly DGLGraph"
)
for c_etype in g.canonical_etypes: for c_etype in g.canonical_etypes:
if c_etype[0] != c_etype[2]: if c_etype[0] != c_etype[2]:
assert False, "to_bidirected is not well defined for " \ assert False, (
"unidirectional bipartite graphs" \ "to_bidirected is not well defined for "
"unidirectional bipartite graphs"
", but {} is unidirectional bipartite".format(c_etype) ", but {} is unidirectional bipartite".format(c_etype)
)
g = add_reverse_edges(g, copy_ndata=copy_ndata, copy_edata=False) g = add_reverse_edges(g, copy_ndata=copy_ndata, copy_edata=False)
g = to_simple(g, return_counts=None, copy_ndata=copy_ndata, copy_edata=False) g = to_simple(
g, return_counts=None, copy_ndata=copy_ndata, copy_edata=False
)
return g return g
def add_reverse_edges(g, readonly=None, copy_ndata=True,
copy_edata=False, ignore_bipartite=False, exclude_self=True): def add_reverse_edges(
g,
readonly=None,
copy_ndata=True,
copy_edata=False,
ignore_bipartite=False,
exclude_self=True,
):
r"""Add a reversed edge for each edge in the input graph and return a new graph. r"""Add a reversed edge for each edge in the input graph and return a new graph.
For a graph with edges :math:`(i_1, j_1), \cdots, (i_n, j_n)`, this For a graph with edges :math:`(i_1, j_1), \cdots, (i_n, j_n)`, this
...@@ -897,8 +978,10 @@ def add_reverse_edges(g, readonly=None, copy_ndata=True, ...@@ -897,8 +978,10 @@ def add_reverse_edges(g, readonly=None, copy_ndata=True,
th.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3, 4]) th.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3, 4])
""" """
if readonly is not None: if readonly is not None:
dgl_warning("Parameter readonly is deprecated" \ dgl_warning(
"There will be no difference between readonly and non-readonly DGLGraph") "Parameter readonly is deprecated"
"There will be no difference between readonly and non-readonly DGLGraph"
)
# get node cnt for each ntype # get node cnt for each ntype
num_nodes_dict = {} num_nodes_dict = {}
...@@ -911,7 +994,7 @@ def add_reverse_edges(g, readonly=None, copy_ndata=True, ...@@ -911,7 +994,7 @@ def add_reverse_edges(g, readonly=None, copy_ndata=True,
rev_eids = {} rev_eids = {}
def add_for_etype(etype): def add_for_etype(etype):
u, v = g.edges(form='uv', order='eid', etype=etype) u, v = g.edges(form="uv", order="eid", etype=etype)
rev_u, rev_v = v, u rev_u, rev_v = v, u
eid = F.copy_to(F.arange(0, g.num_edges(etype)), g.device) eid = F.copy_to(F.arange(0, g.num_edges(etype)), g.device)
if exclude_self: if exclude_self:
...@@ -929,16 +1012,18 @@ def add_reverse_edges(g, readonly=None, copy_ndata=True, ...@@ -929,16 +1012,18 @@ def add_reverse_edges(g, readonly=None, copy_ndata=True,
if ignore_bipartite is False: if ignore_bipartite is False:
for c_etype in canonical_etypes: for c_etype in canonical_etypes:
if c_etype[0] != c_etype[2]: if c_etype[0] != c_etype[2]:
assert False, "add_reverse_edges is not well defined for " \ assert False, (
"unidirectional bipartite graphs" \ "add_reverse_edges is not well defined for "
"unidirectional bipartite graphs"
", but {} is unidirectional bipartite".format(c_etype) ", but {} is unidirectional bipartite".format(c_etype)
)
add_for_etype(c_etype) add_for_etype(c_etype)
new_g = convert.heterograph(subgs, num_nodes_dict=num_nodes_dict) new_g = convert.heterograph(subgs, num_nodes_dict=num_nodes_dict)
else: else:
for c_etype in canonical_etypes: for c_etype in canonical_etypes:
if c_etype[0] != c_etype[2]: if c_etype[0] != c_etype[2]:
u, v = g.edges(form='uv', order='eid', etype=c_etype) u, v = g.edges(form="uv", order="eid", etype=c_etype)
subgs[c_etype] = (u, v) subgs[c_etype] = (u, v)
else: else:
add_for_etype(c_etype) add_for_etype(c_etype)
...@@ -955,7 +1040,11 @@ def add_reverse_edges(g, readonly=None, copy_ndata=True, ...@@ -955,7 +1040,11 @@ def add_reverse_edges(g, readonly=None, copy_ndata=True,
eids = [] eids = []
for c_etype in canonical_etypes: for c_etype in canonical_etypes:
if c_etype[0] != c_etype[2]: if c_etype[0] != c_etype[2]:
eids.append(F.copy_to(F.arange(0, g.number_of_edges(c_etype)), new_g.device)) eids.append(
F.copy_to(
F.arange(0, g.number_of_edges(c_etype)), new_g.device
)
)
else: else:
eids.append(rev_eids[c_etype]) eids.append(rev_eids[c_etype])
...@@ -964,6 +1053,7 @@ def add_reverse_edges(g, readonly=None, copy_ndata=True, ...@@ -964,6 +1053,7 @@ def add_reverse_edges(g, readonly=None, copy_ndata=True,
return new_g return new_g
def line_graph(g, backtracking=True, shared=False): def line_graph(g, backtracking=True, shared=False):
"""Return the line graph of this graph. """Return the line graph of this graph.
...@@ -1025,11 +1115,12 @@ def line_graph(g, backtracking=True, shared=False): ...@@ -1025,11 +1115,12 @@ def line_graph(g, backtracking=True, shared=False):
>>> lg.edges() >>> lg.edges()
(tensor([0, 1, 2, 4]), tensor([4, 0, 3, 1])) (tensor([0, 1, 2, 4]), tensor([4, 0, 3, 1]))
""" """
assert g.is_homogeneous, \ assert g.is_homogeneous, "only homogeneous graph is supported"
'only homogeneous graph is supported'
dev = g.device dev = g.device
lg = DGLGraph(_CAPI_DGLHeteroLineGraph(g._graph.copy_to(nd.cpu()), backtracking)) lg = DGLGraph(
_CAPI_DGLHeteroLineGraph(g._graph.copy_to(nd.cpu()), backtracking)
)
lg = lg.to(dev) lg = lg.to(dev)
if shared: if shared:
new_frames = utils.extract_edge_subframes(g, None) new_frames = utils.extract_edge_subframes(g, None)
...@@ -1037,8 +1128,10 @@ def line_graph(g, backtracking=True, shared=False): ...@@ -1037,8 +1128,10 @@ def line_graph(g, backtracking=True, shared=False):
return lg return lg
DGLGraph.line_graph = utils.alias_func(line_graph) DGLGraph.line_graph = utils.alias_func(line_graph)
def khop_adj(g, k): def khop_adj(g, k):
"""Return the matrix of :math:`A^k` where :math:`A` is the adjacency matrix of the graph """Return the matrix of :math:`A^k` where :math:`A` is the adjacency matrix of the graph
:math:`g`. :math:`g`.
...@@ -1074,11 +1167,11 @@ def khop_adj(g, k): ...@@ -1074,11 +1167,11 @@ def khop_adj(g, k):
[1., 3., 3., 1., 0.], [1., 3., 3., 1., 0.],
[0., 1., 3., 3., 1.]]) [0., 1., 3., 3., 1.]])
""" """
assert g.is_homogeneous, \ assert g.is_homogeneous, "only homogeneous graph is supported"
'only homogeneous graph is supported' adj_k = g.adj(transpose=True, scipy_fmt=g.formats()["created"][0]) ** k
adj_k = g.adj(transpose=True, scipy_fmt=g.formats()['created'][0]) ** k
return F.tensor(adj_k.todense().astype(np.float32)) return F.tensor(adj_k.todense().astype(np.float32))
def khop_graph(g, k, copy_ndata=True): def khop_graph(g, k, copy_ndata=True):
"""Return the graph whose edges connect the :attr:`k`-hop neighbors of the original graph. """Return the graph whose edges connect the :attr:`k`-hop neighbors of the original graph.
...@@ -1143,17 +1236,18 @@ def khop_graph(g, k, copy_ndata=True): ...@@ -1143,17 +1236,18 @@ def khop_graph(g, k, copy_ndata=True):
ndata_schemes={} ndata_schemes={}
edata_schemes={}) edata_schemes={})
""" """
assert g.is_homogeneous, \ assert g.is_homogeneous, "only homogeneous graph is supported"
'only homogeneous graph is supported'
n = g.number_of_nodes() n = g.number_of_nodes()
adj_k = g.adj(transpose=False, scipy_fmt=g.formats()['created'][0]) ** k adj_k = g.adj(transpose=False, scipy_fmt=g.formats()["created"][0]) ** k
adj_k = adj_k.tocoo() adj_k = adj_k.tocoo()
multiplicity = adj_k.data multiplicity = adj_k.data
row = np.repeat(adj_k.row, multiplicity) row = np.repeat(adj_k.row, multiplicity)
col = np.repeat(adj_k.col, multiplicity) col = np.repeat(adj_k.col, multiplicity)
# TODO(zihao): we should support creating multi-graph from scipy sparse matrix # TODO(zihao): we should support creating multi-graph from scipy sparse matrix
# in the future. # in the future.
new_g = convert.graph((row, col), num_nodes=n, idtype=g.idtype, device=g.device) new_g = convert.graph(
(row, col), num_nodes=n, idtype=g.idtype, device=g.device
)
# handle ndata # handle ndata
if copy_ndata: if copy_ndata:
...@@ -1162,7 +1256,10 @@ def khop_graph(g, k, copy_ndata=True): ...@@ -1162,7 +1256,10 @@ def khop_graph(g, k, copy_ndata=True):
return new_g return new_g
def reverse(g, copy_ndata=True, copy_edata=False, *, share_ndata=None, share_edata=None):
def reverse(
g, copy_ndata=True, copy_edata=False, *, share_ndata=None, share_edata=None
):
r"""Return a new graph with every edges being the reverse ones in the input graph. r"""Return a new graph with every edges being the reverse ones in the input graph.
The reverse (also called converse, transpose) of a graph with edges The reverse (also called converse, transpose) of a graph with edges
...@@ -1261,15 +1358,15 @@ def reverse(g, copy_ndata=True, copy_edata=False, *, share_ndata=None, share_eda ...@@ -1261,15 +1358,15 @@ def reverse(g, copy_ndata=True, copy_edata=False, *, share_ndata=None, share_eda
{} {}
""" """
if share_ndata is not None: if share_ndata is not None:
dgl_warning('share_ndata argument has been renamed to copy_ndata.') dgl_warning("share_ndata argument has been renamed to copy_ndata.")
copy_ndata = share_ndata copy_ndata = share_ndata
if share_edata is not None: if share_edata is not None:
dgl_warning('share_edata argument has been renamed to copy_edata.') dgl_warning("share_edata argument has been renamed to copy_edata.")
copy_edata = share_edata copy_edata = share_edata
if g.is_block: if g.is_block:
# TODO(0.5 release, xiangsx) need to handle BLOCK # TODO(0.5 release, xiangsx) need to handle BLOCK
# currently reversing a block results in undefined behavior # currently reversing a block results in undefined behavior
raise DGLError('Reversing a block graph is not supported.') raise DGLError("Reversing a block graph is not supported.")
gidx = g._graph.reverse() gidx = g._graph.reverse()
new_g = DGLGraph(gidx, g.ntypes, g.etypes) new_g = DGLGraph(gidx, g.ntypes, g.etypes)
...@@ -1284,12 +1381,15 @@ def reverse(g, copy_ndata=True, copy_edata=False, *, share_ndata=None, share_eda ...@@ -1284,12 +1381,15 @@ def reverse(g, copy_ndata=True, copy_edata=False, *, share_ndata=None, share_eda
# for each etype # for each etype
for utype, etype, vtype in g.canonical_etypes: for utype, etype, vtype in g.canonical_etypes:
new_g.edges[vtype, etype, utype].data.update( new_g.edges[vtype, etype, utype].data.update(
g.edges[utype, etype, vtype].data) g.edges[utype, etype, vtype].data
)
return new_g return new_g
DGLGraph.reverse = utils.alias_func(reverse) DGLGraph.reverse = utils.alias_func(reverse)
def to_simple_graph(g): def to_simple_graph(g):
"""Convert the graph to a simple graph with no multi-edge. """Convert the graph to a simple graph with no multi-edge.
...@@ -1313,9 +1413,10 @@ def to_simple_graph(g): ...@@ -1313,9 +1413,10 @@ def to_simple_graph(g):
and :func:`dgl.DGLGraph.set_batch_num_edges` on the transformed graph and :func:`dgl.DGLGraph.set_batch_num_edges` on the transformed graph
to maintain the information. to maintain the information.
""" """
dgl_warning('dgl.to_simple_graph is renamed to dgl.to_simple in v0.5.') dgl_warning("dgl.to_simple_graph is renamed to dgl.to_simple in v0.5.")
return to_simple(g) return to_simple(g)
def laplacian_lambda_max(g): def laplacian_lambda_max(g):
"""Return the largest eigenvalue of the normalized symmetric Laplacian of a graph. """Return the largest eigenvalue of the normalized symmetric Laplacian of a graph.
...@@ -1349,13 +1450,21 @@ def laplacian_lambda_max(g): ...@@ -1349,13 +1450,21 @@ def laplacian_lambda_max(g):
rst = [] rst = []
for g_i in g_arr: for g_i in g_arr:
n = g_i.number_of_nodes() n = g_i.number_of_nodes()
adj = g_i.adj(transpose=True, scipy_fmt=g_i.formats()['created'][0]).astype(float) adj = g_i.adj(
norm = sparse.diags(F.asnumpy(g_i.in_degrees()).clip(1) ** -0.5, dtype=float) transpose=True, scipy_fmt=g_i.formats()["created"][0]
).astype(float)
norm = sparse.diags(
F.asnumpy(g_i.in_degrees()).clip(1) ** -0.5, dtype=float
)
laplacian = sparse.eye(n) - norm * adj * norm laplacian = sparse.eye(n) - norm * adj * norm
rst.append(scipy.sparse.linalg.eigs( rst.append(
laplacian, 1, which='LM', return_eigenvectors=False)[0].real) scipy.sparse.linalg.eigs(
laplacian, 1, which="LM", return_eigenvectors=False
)[0].real
)
return rst return rst
def metapath_reachable_graph(g, metapath): def metapath_reachable_graph(g, metapath):
"""Return a graph where the successors of any node ``u`` are nodes reachable from ``u`` by """Return a graph where the successors of any node ``u`` are nodes reachable from ``u`` by
the given metapath. the given metapath.
...@@ -1405,14 +1514,17 @@ def metapath_reachable_graph(g, metapath): ...@@ -1405,14 +1514,17 @@ def metapath_reachable_graph(g, metapath):
""" """
adj = 1 adj = 1
for etype in metapath: for etype in metapath:
adj = adj * g.adj(etype=etype, scipy_fmt='csr', transpose=False) adj = adj * g.adj(etype=etype, scipy_fmt="csr", transpose=False)
adj = (adj != 0).tocsr() adj = (adj != 0).tocsr()
srctype = g.to_canonical_etype(metapath[0])[0] srctype = g.to_canonical_etype(metapath[0])[0]
dsttype = g.to_canonical_etype(metapath[-1])[2] dsttype = g.to_canonical_etype(metapath[-1])[2]
new_g = convert.heterograph({(srctype, '_E', dsttype): adj.nonzero()}, new_g = convert.heterograph(
{srctype: adj.shape[0], dsttype: adj.shape[1]}, {(srctype, "_E", dsttype): adj.nonzero()},
idtype=g.idtype, device=g.device) {srctype: adj.shape[0], dsttype: adj.shape[1]},
idtype=g.idtype,
device=g.device,
)
# copy srcnode features # copy srcnode features
new_g.nodes[srctype].data.update(g.nodes[srctype].data) new_g.nodes[srctype].data.update(g.nodes[srctype].data)
...@@ -1422,6 +1534,7 @@ def metapath_reachable_graph(g, metapath): ...@@ -1422,6 +1534,7 @@ def metapath_reachable_graph(g, metapath):
return new_g return new_g
def add_nodes(g, num, data=None, ntype=None): def add_nodes(g, num, data=None, ntype=None):
r"""Add the given number of nodes to the graph and return a new graph. r"""Add the given number of nodes to the graph and return a new graph.
...@@ -1515,6 +1628,7 @@ def add_nodes(g, num, data=None, ntype=None): ...@@ -1515,6 +1628,7 @@ def add_nodes(g, num, data=None, ntype=None):
g.add_nodes(num, data=data, ntype=ntype) g.add_nodes(num, data=data, ntype=ntype)
return g return g
def add_edges(g, u, v, data=None, etype=None): def add_edges(g, u, v, data=None, etype=None):
r"""Add the edges to the graph and return a new graph. r"""Add the edges to the graph and return a new graph.
...@@ -1625,6 +1739,7 @@ def add_edges(g, u, v, data=None, etype=None): ...@@ -1625,6 +1739,7 @@ def add_edges(g, u, v, data=None, etype=None):
g.add_edges(u, v, data=data, etype=etype) g.add_edges(u, v, data=data, etype=etype)
return g return g
def remove_edges(g, eids, etype=None, store_ids=False): def remove_edges(g, eids, etype=None, store_ids=False):
r"""Remove the specified edges and return a new graph. r"""Remove the specified edges and return a new graph.
...@@ -1779,7 +1894,8 @@ def remove_nodes(g, nids, ntype=None, store_ids=False): ...@@ -1779,7 +1894,8 @@ def remove_nodes(g, nids, ntype=None, store_ids=False):
g.remove_nodes(nids, ntype=ntype, store_ids=store_ids) g.remove_nodes(nids, ntype=ntype, store_ids=store_ids)
return g return g
def add_self_loop(g, edge_feat_names=None, fill_data=1., etype=None):
def add_self_loop(g, edge_feat_names=None, fill_data=1.0, etype=None):
r"""Add self-loops for each node in the graph and return a new graph. r"""Add self-loops for each node in the graph and return a new graph.
Parameters Parameters
...@@ -1861,38 +1977,54 @@ def add_self_loop(g, edge_feat_names=None, fill_data=1., etype=None): ...@@ -1861,38 +1977,54 @@ def add_self_loop(g, edge_feat_names=None, fill_data=1., etype=None):
""" """
etype = g.to_canonical_etype(etype) etype = g.to_canonical_etype(etype)
data = {} data = {}
reduce_funcs = {'sum': function.sum, reduce_funcs = {
'mean': function.mean, "sum": function.sum,
'max': function.max, "mean": function.mean,
'min': function.min} "max": function.max,
"min": function.min,
}
if edge_feat_names is None: if edge_feat_names is None:
edge_feat_names = g.edges[etype].data.keys() edge_feat_names = g.edges[etype].data.keys()
if etype[0] != etype[2]: if etype[0] != etype[2]:
raise DGLError( raise DGLError(
'add_self_loop does not support unidirectional bipartite graphs: {}.' \ "add_self_loop does not support unidirectional bipartite graphs: {}."
'Please make sure the types of head node and tail node are identical.' \ "Please make sure the types of head node and tail node are identical."
''.format(etype)) "".format(etype)
)
for feat_name in edge_feat_names: for feat_name in edge_feat_names:
if isinstance(fill_data, (int, float)): if isinstance(fill_data, (int, float)):
dtype = g.edges[etype].data[feat_name].dtype dtype = g.edges[etype].data[feat_name].dtype
dshape = g.edges[etype].data[feat_name].shape dshape = g.edges[etype].data[feat_name].shape
tmp_fill_data = F.copy_to(F.astype(F.tensor([fill_data]), dtype), g.device) tmp_fill_data = F.copy_to(
F.astype(F.tensor([fill_data]), dtype), g.device
)
if len(dshape) > 1: if len(dshape) > 1:
data[feat_name] = F.zeros((g.num_nodes(etype[0]), *dshape[1:]), dtype, data[feat_name] = (
g.device) + tmp_fill_data F.zeros(
(g.num_nodes(etype[0]), *dshape[1:]), dtype, g.device
)
+ tmp_fill_data
)
else: else:
data[feat_name] = F.zeros((g.num_nodes(etype[0]),), dtype, g.device) + tmp_fill_data data[feat_name] = (
F.zeros((g.num_nodes(etype[0]),), dtype, g.device)
+ tmp_fill_data
)
elif isinstance(fill_data, str): elif isinstance(fill_data, str):
if fill_data not in reduce_funcs.keys(): if fill_data not in reduce_funcs.keys():
raise DGLError('Unsupported aggregation: {}'.format(fill_data)) raise DGLError("Unsupported aggregation: {}".format(fill_data))
reducer = reduce_funcs[fill_data] reducer = reduce_funcs[fill_data]
with g.local_scope(): with g.local_scope():
g.update_all(function.copy_e(feat_name, "h"), reducer('h', 'h'), etype=etype) g.update_all(
data[feat_name] = g.nodes[etype[0]].data['h'] function.copy_e(feat_name, "h"),
reducer("h", "h"),
etype=etype,
)
data[feat_name] = g.nodes[etype[0]].data["h"]
nodes = g.nodes(etype[0]) nodes = g.nodes(etype[0])
if len(data): if len(data):
...@@ -1901,10 +2033,12 @@ def add_self_loop(g, edge_feat_names=None, fill_data=1., etype=None): ...@@ -1901,10 +2033,12 @@ def add_self_loop(g, edge_feat_names=None, fill_data=1., etype=None):
new_g = add_edges(g, nodes, nodes, etype=etype) new_g = add_edges(g, nodes, nodes, etype=etype)
return new_g return new_g
DGLGraph.add_self_loop = utils.alias_func(add_self_loop) DGLGraph.add_self_loop = utils.alias_func(add_self_loop)
def remove_self_loop(g, etype=None): def remove_self_loop(g, etype=None):
r""" Remove self-loops for each node in the graph and return a new graph. r"""Remove self-loops for each node in the graph and return a new graph.
Parameters Parameters
---------- ----------
...@@ -1968,17 +2102,22 @@ def remove_self_loop(g, etype=None): ...@@ -1968,17 +2102,22 @@ def remove_self_loop(g, etype=None):
etype = g.to_canonical_etype(etype) etype = g.to_canonical_etype(etype)
if etype[0] != etype[2]: if etype[0] != etype[2]:
raise DGLError( raise DGLError(
'remove_self_loop does not support unidirectional bipartite graphs: {}.' \ "remove_self_loop does not support unidirectional bipartite graphs: {}."
'Please make sure the types of head node and tail node are identical.' \ "Please make sure the types of head node and tail node are identical."
''.format(etype)) "".format(etype)
u, v = g.edges(form='uv', order='eid', etype=etype) )
u, v = g.edges(form="uv", order="eid", etype=etype)
self_loop_eids = F.tensor(F.nonzero_1d(u == v), dtype=F.dtype(u)) self_loop_eids = F.tensor(F.nonzero_1d(u == v), dtype=F.dtype(u))
new_g = remove_edges(g, self_loop_eids, etype=etype) new_g = remove_edges(g, self_loop_eids, etype=etype)
return new_g return new_g
DGLGraph.remove_self_loop = utils.alias_func(remove_self_loop) DGLGraph.remove_self_loop = utils.alias_func(remove_self_loop)
def compact_graphs(graphs, always_preserve=None, copy_ndata=True, copy_edata=True):
def compact_graphs(
graphs, always_preserve=None, copy_ndata=True, copy_edata=True
):
"""Given a list of graphs with the same set of nodes, find and eliminate the common """Given a list of graphs with the same set of nodes, find and eliminate the common
isolated nodes across all graphs. isolated nodes across all graphs.
...@@ -2096,7 +2235,7 @@ def compact_graphs(graphs, always_preserve=None, copy_ndata=True, copy_edata=Tru ...@@ -2096,7 +2235,7 @@ def compact_graphs(graphs, always_preserve=None, copy_ndata=True, copy_edata=Tru
if len(graphs) == 0: if len(graphs) == 0:
return [] return []
if graphs[0].is_block: if graphs[0].is_block:
raise DGLError('Compacting a block graph is not allowed.') raise DGLError("Compacting a block graph is not allowed.")
# Ensure the node types are ordered the same. # Ensure the node types are ordered the same.
# TODO(BarclayII): we ideally need to remove this constraint. # TODO(BarclayII): we ideally need to remove this constraint.
...@@ -2104,23 +2243,34 @@ def compact_graphs(graphs, always_preserve=None, copy_ndata=True, copy_edata=Tru ...@@ -2104,23 +2243,34 @@ def compact_graphs(graphs, always_preserve=None, copy_ndata=True, copy_edata=Tru
idtype = graphs[0].idtype idtype = graphs[0].idtype
device = graphs[0].device device = graphs[0].device
for g in graphs: for g in graphs:
assert ntypes == g.ntypes, \ assert ntypes == g.ntypes, (
("All graphs should have the same node types in the same order, got %s and %s" % "All graphs should have the same node types in the same order, got %s and %s"
ntypes, g.ntypes) % ntypes,
assert idtype == g.idtype, "Expect graph data type to be {}, but got {}".format( g.ntypes,
idtype, g.idtype) )
assert device == g.device, "All graphs must be on the same devices." \ assert (
idtype == g.idtype
), "Expect graph data type to be {}, but got {}".format(
idtype, g.idtype
)
assert device == g.device, (
"All graphs must be on the same devices."
"Expect graph device to be {}, but got {}".format(device, g.device) "Expect graph device to be {}, but got {}".format(device, g.device)
)
# Process the dictionary or tensor of "always preserve" nodes # Process the dictionary or tensor of "always preserve" nodes
if always_preserve is None: if always_preserve is None:
always_preserve = {} always_preserve = {}
elif not isinstance(always_preserve, Mapping): elif not isinstance(always_preserve, Mapping):
if len(ntypes) > 1: if len(ntypes) > 1:
raise ValueError("Node type must be given if multiple node types exist.") raise ValueError(
"Node type must be given if multiple node types exist."
)
always_preserve = {ntypes[0]: always_preserve} always_preserve = {ntypes[0]: always_preserve}
always_preserve = utils.prepare_tensor_dict(graphs[0], always_preserve, 'always_preserve') always_preserve = utils.prepare_tensor_dict(
graphs[0], always_preserve, "always_preserve"
)
always_preserve_nd = [] always_preserve_nd = []
for ntype in ntypes: for ntype in ntypes:
nodes = always_preserve.get(ntype, None) nodes = always_preserve.get(ntype, None)
...@@ -2130,12 +2280,14 @@ def compact_graphs(graphs, always_preserve=None, copy_ndata=True, copy_edata=Tru ...@@ -2130,12 +2280,14 @@ def compact_graphs(graphs, always_preserve=None, copy_ndata=True, copy_edata=Tru
# Compact and construct heterographs # Compact and construct heterographs
new_graph_indexes, induced_nodes = _CAPI_DGLCompactGraphs( new_graph_indexes, induced_nodes = _CAPI_DGLCompactGraphs(
[g._graph for g in graphs], always_preserve_nd) [g._graph for g in graphs], always_preserve_nd
)
induced_nodes = [F.from_dgl_nd(nodes) for nodes in induced_nodes] induced_nodes = [F.from_dgl_nd(nodes) for nodes in induced_nodes]
new_graphs = [ new_graphs = [
DGLGraph(new_graph_index, graph.ntypes, graph.etypes) DGLGraph(new_graph_index, graph.ntypes, graph.etypes)
for new_graph_index, graph in zip(new_graph_indexes, graphs)] for new_graph_index, graph in zip(new_graph_indexes, graphs)
]
if copy_ndata: if copy_ndata:
for g, new_g in zip(graphs, new_graphs): for g, new_g in zip(graphs, new_graphs):
...@@ -2151,6 +2303,7 @@ def compact_graphs(graphs, always_preserve=None, copy_ndata=True, copy_edata=Tru ...@@ -2151,6 +2303,7 @@ def compact_graphs(graphs, always_preserve=None, copy_ndata=True, copy_edata=Tru
return new_graphs return new_graphs
def to_block(g, dst_nodes=None, include_dst_in_src=True, src_nodes=None): def to_block(g, dst_nodes=None, include_dst_in_src=True, src_nodes=None):
"""Convert a graph into a bipartite-structured *block* for message passing. """Convert a graph into a bipartite-structured *block* for message passing.
...@@ -2301,23 +2454,29 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True, src_nodes=None): ...@@ -2301,23 +2454,29 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True, src_nodes=None):
for etype in g.canonical_etypes: for etype in g.canonical_etypes:
_, dst = g.edges(etype=etype) _, dst = g.edges(etype=etype)
dst_nodes[etype[2]].append(dst) dst_nodes[etype[2]].append(dst)
dst_nodes = {ntype: F.unique(F.cat(values, 0)) for ntype, values in dst_nodes.items()} dst_nodes = {
ntype: F.unique(F.cat(values, 0))
for ntype, values in dst_nodes.items()
}
elif not isinstance(dst_nodes, Mapping): elif not isinstance(dst_nodes, Mapping):
# dst_nodes is a Tensor, check if the g has only one type. # dst_nodes is a Tensor, check if the g has only one type.
if len(g.ntypes) > 1: if len(g.ntypes) > 1:
raise DGLError( raise DGLError(
'Graph has more than one node type; please specify a dict for dst_nodes.') "Graph has more than one node type; please specify a dict for dst_nodes."
)
dst_nodes = {g.ntypes[0]: dst_nodes} dst_nodes = {g.ntypes[0]: dst_nodes}
dst_node_ids = [ dst_node_ids = [
utils.toindex(dst_nodes.get(ntype, []), g._idtype_str).tousertensor( utils.toindex(dst_nodes.get(ntype, []), g._idtype_str).tousertensor(
ctx=F.to_backend_ctx(g._graph.ctx)) ctx=F.to_backend_ctx(g._graph.ctx)
for ntype in g.ntypes] )
for ntype in g.ntypes
]
dst_node_ids_nd = [F.to_dgl_nd(nodes) for nodes in dst_node_ids] dst_node_ids_nd = [F.to_dgl_nd(nodes) for nodes in dst_node_ids]
for d in dst_node_ids_nd: for d in dst_node_ids_nd:
if g._graph.ctx != d.ctx: if g._graph.ctx != d.ctx:
raise ValueError('g and dst_nodes need to have the same context.') raise ValueError("g and dst_nodes need to have the same context.")
src_node_ids = None src_node_ids = None
src_node_ids_nd = None src_node_ids_nd = None
...@@ -2325,23 +2484,30 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True, src_nodes=None): ...@@ -2325,23 +2484,30 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True, src_nodes=None):
# src_nodes is a Tensor, check if the g has only one type. # src_nodes is a Tensor, check if the g has only one type.
if len(g.ntypes) > 1: if len(g.ntypes) > 1:
raise DGLError( raise DGLError(
'Graph has more than one node type; please specify a dict for src_nodes.') "Graph has more than one node type; please specify a dict for src_nodes."
)
src_nodes = {g.ntypes[0]: src_nodes} src_nodes = {g.ntypes[0]: src_nodes}
src_node_ids = [ src_node_ids = [
F.copy_to(F.tensor(src_nodes.get(ntype, []), dtype=g.idtype), \ F.copy_to(
F.to_backend_ctx(g._graph.ctx)) \ F.tensor(src_nodes.get(ntype, []), dtype=g.idtype),
for ntype in g.ntypes] F.to_backend_ctx(g._graph.ctx),
)
for ntype in g.ntypes
]
src_node_ids_nd = [F.to_dgl_nd(nodes) for nodes in src_node_ids] src_node_ids_nd = [F.to_dgl_nd(nodes) for nodes in src_node_ids]
for d in src_node_ids_nd: for d in src_node_ids_nd:
if g._graph.ctx != d.ctx: if g._graph.ctx != d.ctx:
raise ValueError('g and src_nodes need to have the same context.') raise ValueError(
"g and src_nodes need to have the same context."
)
else: else:
# use an empty list to signal we need to generate it # use an empty list to signal we need to generate it
src_node_ids_nd = [] src_node_ids_nd = []
new_graph_index, src_nodes_ids_nd, induced_edges_nd = _CAPI_DGLToBlock( new_graph_index, src_nodes_ids_nd, induced_edges_nd = _CAPI_DGLToBlock(
g._graph, dst_node_ids_nd, include_dst_in_src, src_node_ids_nd) g._graph, dst_node_ids_nd, include_dst_in_src, src_node_ids_nd
)
# The new graph duplicates the original node types to SRC and DST sets. # The new graph duplicates the original node types to SRC and DST sets.
new_ntypes = (g.ntypes, g.ntypes) new_ntypes = (g.ntypes, g.ntypes)
...@@ -2351,12 +2517,17 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True, src_nodes=None): ...@@ -2351,12 +2517,17 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True, src_nodes=None):
src_node_ids = [F.from_dgl_nd(src) for src in src_nodes_ids_nd] src_node_ids = [F.from_dgl_nd(src) for src in src_nodes_ids_nd]
edge_ids = [F.from_dgl_nd(eid) for eid in induced_edges_nd] edge_ids = [F.from_dgl_nd(eid) for eid in induced_edges_nd]
node_frames = utils.extract_node_subframes_for_block(g, src_node_ids, dst_node_ids) node_frames = utils.extract_node_subframes_for_block(
g, src_node_ids, dst_node_ids
)
edge_frames = utils.extract_edge_subframes(g, edge_ids) edge_frames = utils.extract_edge_subframes(g, edge_ids)
utils.set_new_frames(new_graph, node_frames=node_frames, edge_frames=edge_frames) utils.set_new_frames(
new_graph, node_frames=node_frames, edge_frames=edge_frames
)
return new_graph return new_graph
def _coalesce_edge_frame(g, edge_maps, counts, aggregator): def _coalesce_edge_frame(g, edge_maps, counts, aggregator):
r"""Coalesce edge features of duplicate edges via given aggregator in g. r"""Coalesce edge features of duplicate edges via given aggregator in g.
...@@ -2377,7 +2548,7 @@ def _coalesce_edge_frame(g, edge_maps, counts, aggregator): ...@@ -2377,7 +2548,7 @@ def _coalesce_edge_frame(g, edge_maps, counts, aggregator):
List[Frame] List[Frame]
The frames corresponding to each edge type. The frames corresponding to each edge type.
""" """
if aggregator == 'arbitrary': if aggregator == "arbitrary":
eids = [] eids = []
for i in range(len(g.canonical_etypes)): for i in range(len(g.canonical_etypes)):
feat_idx = F.asnumpy(edge_maps[i]) feat_idx = F.asnumpy(edge_maps[i])
...@@ -2385,7 +2556,7 @@ def _coalesce_edge_frame(g, edge_maps, counts, aggregator): ...@@ -2385,7 +2556,7 @@ def _coalesce_edge_frame(g, edge_maps, counts, aggregator):
eids.append(F.zerocopy_from_numpy(indices)) eids.append(F.zerocopy_from_numpy(indices))
edge_frames = utils.extract_edge_subframes(g, eids) edge_frames = utils.extract_edge_subframes(g, eids)
elif aggregator in ['sum', 'mean']: elif aggregator in ["sum", "mean"]:
edge_frames = [] edge_frames = []
for i in range(len(g.canonical_etypes)): for i in range(len(g.canonical_etypes)):
feat_idx = edge_maps[i] feat_idx = edge_maps[i]
...@@ -2395,25 +2566,32 @@ def _coalesce_edge_frame(g, edge_maps, counts, aggregator): ...@@ -2395,25 +2566,32 @@ def _coalesce_edge_frame(g, edge_maps, counts, aggregator):
for key, col in g._edge_frames[i]._columns.items(): for key, col in g._edge_frames[i]._columns.items():
data = col.data data = col.data
new_data = F.scatter_add(data, feat_idx, _num_rows) new_data = F.scatter_add(data, feat_idx, _num_rows)
if aggregator == 'mean': if aggregator == "mean":
norm = F.astype(counts[i], F.dtype(data)) norm = F.astype(counts[i], F.dtype(data))
norm = F.reshape(norm, (F.shape(norm)[0],) + (1,) * (F.ndim(data) - 1)) norm = F.reshape(
norm, (F.shape(norm)[0],) + (1,) * (F.ndim(data) - 1)
)
new_data /= norm new_data /= norm
_data[key] = new_data _data[key] = new_data
newf = Frame(data=_data, num_rows=_num_rows) newf = Frame(data=_data, num_rows=_num_rows)
edge_frames.append(newf) edge_frames.append(newf)
else: else:
raise DGLError("Aggregator {} not regonized, cannot coalesce edge feature in the " raise DGLError(
"specified way".format(aggregator)) "Aggregator {} not regonized, cannot coalesce edge feature in the "
"specified way".format(aggregator)
)
return edge_frames return edge_frames
def to_simple(g,
return_counts='count', def to_simple(
writeback_mapping=False, g,
copy_ndata=True, return_counts="count",
copy_edata=False, writeback_mapping=False,
aggregator='arbitrary'): copy_ndata=True,
copy_edata=False,
aggregator="arbitrary",
):
r"""Convert a graph to a simple graph without parallel edges and return. r"""Convert a graph to a simple graph without parallel edges and return.
For a heterogeneous graph with multiple edge types, DGL treats edges with the same For a heterogeneous graph with multiple edge types, DGL treats edges with the same
...@@ -2554,9 +2732,9 @@ def to_simple(g, ...@@ -2554,9 +2732,9 @@ def to_simple(g,
{('user', 'wins', 'user'): tensor([1, 2, 1, 1]) {('user', 'wins', 'user'): tensor([1, 2, 1, 1])
('user', 'plays', 'game'): tensor([1, 1, 1])} ('user', 'plays', 'game'): tensor([1, 1, 1])}
""" """
assert g.device == F.cpu(), 'the graph must be on CPU' assert g.device == F.cpu(), "the graph must be on CPU"
if g.is_block: if g.is_block:
raise DGLError('Cannot convert a block graph to a simple graph.') raise DGLError("Cannot convert a block graph to a simple graph.")
simple_graph_index, counts, edge_maps = _CAPI_DGLToSimpleHetero(g._graph) simple_graph_index, counts, edge_maps = _CAPI_DGLToSimpleHetero(g._graph)
simple_graph = DGLGraph(simple_graph_index, g.ntypes, g.etypes) simple_graph = DGLGraph(simple_graph_index, g.ntypes, g.etypes)
counts = [F.from_dgl_nd(count) for count in counts] counts = [F.from_dgl_nd(count) for count in counts]
...@@ -2586,8 +2764,10 @@ def to_simple(g, ...@@ -2586,8 +2764,10 @@ def to_simple(g,
return simple_graph return simple_graph
DGLGraph.to_simple = utils.alias_func(to_simple) DGLGraph.to_simple = utils.alias_func(to_simple)
def _unitgraph_less_than_int32(g): def _unitgraph_less_than_int32(g):
"""Check if a graph with only one edge type has more than 2 ** 31 - 1 """Check if a graph with only one edge type has more than 2 ** 31 - 1
nodes or edges. nodes or edges.
...@@ -2596,7 +2776,8 @@ def _unitgraph_less_than_int32(g): ...@@ -2596,7 +2776,8 @@ def _unitgraph_less_than_int32(g):
num_nodes = max(g.num_nodes(g.ntypes[0]), g.num_nodes(g.ntypes[-1])) num_nodes = max(g.num_nodes(g.ntypes[0]), g.num_nodes(g.ntypes[-1]))
return max(num_nodes, num_edges) <= (1 << 31) - 1 return max(num_nodes, num_edges) <= (1 << 31) - 1
def adj_product_graph(A, B, weight_name, etype='_E'):
def adj_product_graph(A, B, weight_name, etype="_E"):
r"""Create a weighted graph whose adjacency matrix is the product of r"""Create a weighted graph whose adjacency matrix is the product of
the adjacency matrices of the given two graphs. the adjacency matrices of the given two graphs.
...@@ -2718,23 +2899,37 @@ def adj_product_graph(A, B, weight_name, etype='_E'): ...@@ -2718,23 +2899,37 @@ def adj_product_graph(A, B, weight_name, etype='_E'):
ntypes = [srctype] if num_vtypes == 1 else [srctype, dsttype] ntypes = [srctype] if num_vtypes == 1 else [srctype, dsttype]
if A.device != F.cpu(): if A.device != F.cpu():
if not (_unitgraph_less_than_int32(A) and _unitgraph_less_than_int32(B)): if not (
_unitgraph_less_than_int32(A) and _unitgraph_less_than_int32(B)
):
raise ValueError( raise ValueError(
'For GPU graphs the number of nodes and edges must be less than 2 ** 31 - 1.') "For GPU graphs the number of nodes and edges must be less than 2 ** 31 - 1."
)
C_gidx, C_weights = F.csrmm( C_gidx, C_weights = F.csrmm(
A._graph, A.edata[weight_name], B._graph, B.edata[weight_name], num_vtypes) A._graph,
num_nodes_dict = {srctype: A.num_nodes(srctype), dsttype: B.num_nodes(dsttype)} A.edata[weight_name],
C_metagraph, ntypes, etypes, _ = \ B._graph,
create_metagraph_index(ntypes, [(srctype, etype, dsttype)]) B.edata[weight_name],
num_vtypes,
)
num_nodes_dict = {
srctype: A.num_nodes(srctype),
dsttype: B.num_nodes(dsttype),
}
C_metagraph, ntypes, etypes, _ = create_metagraph_index(
ntypes, [(srctype, etype, dsttype)]
)
num_nodes_per_type = [num_nodes_dict[ntype] for ntype in ntypes] num_nodes_per_type = [num_nodes_dict[ntype] for ntype in ntypes]
C_gidx = create_heterograph_from_relations( C_gidx = create_heterograph_from_relations(
C_metagraph, [C_gidx], utils.toindex(num_nodes_per_type)) C_metagraph, [C_gidx], utils.toindex(num_nodes_per_type)
)
C = DGLGraph(C_gidx, ntypes, etypes) C = DGLGraph(C_gidx, ntypes, etypes)
C.edata[weight_name] = C_weights C.edata[weight_name] = C_weights
return C return C
def adj_sum_graph(graphs, weight_name): def adj_sum_graph(graphs, weight_name):
r"""Create a weighted graph whose adjacency matrix is the sum of the r"""Create a weighted graph whose adjacency matrix is the sum of the
adjacency matrices of the given graphs, whose rows represent source nodes adjacency matrices of the given graphs, whose rows represent source nodes
...@@ -2818,15 +3013,20 @@ def adj_sum_graph(graphs, weight_name): ...@@ -2818,15 +3013,20 @@ def adj_sum_graph(graphs, weight_name):
tensor([1., 1., 1., 1., 1., 1.]) tensor([1., 1., 1., 1., 1., 1.])
""" """
if len(graphs) == 0: if len(graphs) == 0:
raise ValueError('The list of graphs must not be empty.') raise ValueError("The list of graphs must not be empty.")
if graphs[0].device != F.cpu(): if graphs[0].device != F.cpu():
if not all(_unitgraph_less_than_int32(A) for A in graphs): if not all(_unitgraph_less_than_int32(A) for A in graphs):
raise ValueError( raise ValueError(
'For GPU graphs the number of nodes and edges must be less than 2 ** 31 - 1.') "For GPU graphs the number of nodes and edges must be less than 2 ** 31 - 1."
)
metagraph = graphs[0]._graph.metagraph metagraph = graphs[0]._graph.metagraph
num_nodes = utils.toindex( num_nodes = utils.toindex(
[graphs[0]._graph.number_of_nodes(i) for i in range(graphs[0]._graph.number_of_ntypes())]) [
graphs[0]._graph.number_of_nodes(i)
for i in range(graphs[0]._graph.number_of_ntypes())
]
)
weights = [A.edata[weight_name] for A in graphs] weights = [A.edata[weight_name] for A in graphs]
gidxs = [A._graph for A in graphs] gidxs = [A._graph for A in graphs]
C_gidx, C_weights = F.csrsum(gidxs, weights) C_gidx, C_weights = F.csrsum(gidxs, weights)
...@@ -2837,7 +3037,7 @@ def adj_sum_graph(graphs, weight_name): ...@@ -2837,7 +3037,7 @@ def adj_sum_graph(graphs, weight_name):
return C return C
def sort_csr_by_tag(g, tag, tag_offset_name='_TAG_OFFSET', tag_type='node'): def sort_csr_by_tag(g, tag, tag_offset_name="_TAG_OFFSET", tag_type="node"):
r"""Return a new graph whose CSR matrix is sorted by the given tag. r"""Return a new graph whose CSR matrix is sorted by the given tag.
Sort the internal CSR matrix of the graph so that the adjacency list of each node Sort the internal CSR matrix of the graph so that the adjacency list of each node
...@@ -2947,20 +3147,25 @@ def sort_csr_by_tag(g, tag, tag_offset_name='_TAG_OFFSET', tag_type='node'): ...@@ -2947,20 +3147,25 @@ def sort_csr_by_tag(g, tag, tag_offset_name='_TAG_OFFSET', tag_type='node'):
""" """
if len(g.etypes) > 1: if len(g.etypes) > 1:
raise DGLError("Only support homograph and bipartite graph") raise DGLError("Only support homograph and bipartite graph")
assert tag_type in ['node', 'edge'], "tag_type should be either 'node' or 'edge'." assert tag_type in [
if tag_type == 'node': "node",
"edge",
], "tag_type should be either 'node' or 'edge'."
if tag_type == "node":
_, dst = g.edges() _, dst = g.edges()
tag = F.gather_row(tag, F.tensor(dst)) tag = F.gather_row(tag, F.tensor(dst))
assert len(tag) == g.num_edges() assert len(tag) == g.num_edges()
num_tags = int(F.asnumpy(F.max(tag, 0))) + 1 num_tags = int(F.asnumpy(F.max(tag, 0))) + 1
tag_arr = F.zerocopy_to_dgl_ndarray(tag) tag_arr = F.zerocopy_to_dgl_ndarray(tag)
new_g = g.clone() new_g = g.clone()
new_g._graph, tag_pos_arr = _CAPI_DGLHeteroSortOutEdges(g._graph, tag_arr, num_tags) new_g._graph, tag_pos_arr = _CAPI_DGLHeteroSortOutEdges(
g._graph, tag_arr, num_tags
)
new_g.srcdata[tag_offset_name] = F.from_dgl_nd(tag_pos_arr) new_g.srcdata[tag_offset_name] = F.from_dgl_nd(tag_pos_arr)
return new_g return new_g
def sort_csc_by_tag(g, tag, tag_offset_name='_TAG_OFFSET', tag_type='node'): def sort_csc_by_tag(g, tag, tag_offset_name="_TAG_OFFSET", tag_type="node"):
r"""Return a new graph whose CSC matrix is sorted by the given tag. r"""Return a new graph whose CSC matrix is sorted by the given tag.
Sort the internal CSC matrix of the graph so that the adjacency list of each node Sort the internal CSC matrix of the graph so that the adjacency list of each node
...@@ -3068,21 +3273,31 @@ def sort_csc_by_tag(g, tag, tag_offset_name='_TAG_OFFSET', tag_type='node'): ...@@ -3068,21 +3273,31 @@ def sort_csc_by_tag(g, tag, tag_offset_name='_TAG_OFFSET', tag_type='node'):
""" """
if len(g.etypes) > 1: if len(g.etypes) > 1:
raise DGLError("Only support homograph and bipartite graph") raise DGLError("Only support homograph and bipartite graph")
assert tag_type in ['node', 'edge'], "tag_type should be either 'node' or 'edge'." assert tag_type in [
if tag_type == 'node': "node",
"edge",
], "tag_type should be either 'node' or 'edge'."
if tag_type == "node":
src, _ = g.edges() src, _ = g.edges()
tag = F.gather_row(tag, F.tensor(src)) tag = F.gather_row(tag, F.tensor(src))
assert len(tag) == g.num_edges() assert len(tag) == g.num_edges()
num_tags = int(F.asnumpy(F.max(tag, 0))) + 1 num_tags = int(F.asnumpy(F.max(tag, 0))) + 1
tag_arr = F.zerocopy_to_dgl_ndarray(tag) tag_arr = F.zerocopy_to_dgl_ndarray(tag)
new_g = g.clone() new_g = g.clone()
new_g._graph, tag_pos_arr = _CAPI_DGLHeteroSortInEdges(g._graph, tag_arr, num_tags) new_g._graph, tag_pos_arr = _CAPI_DGLHeteroSortInEdges(
g._graph, tag_arr, num_tags
)
new_g.dstdata[tag_offset_name] = F.from_dgl_nd(tag_pos_arr) new_g.dstdata[tag_offset_name] = F.from_dgl_nd(tag_pos_arr)
return new_g return new_g
def reorder_graph(g, node_permute_algo=None, edge_permute_algo='src', def reorder_graph(
store_ids=True, permute_config=None): g,
node_permute_algo=None,
edge_permute_algo="src",
store_ids=True,
permute_config=None,
):
r"""Return a new graph with nodes and edges re-ordered/re-labeled r"""Return a new graph with nodes and edges re-ordered/re-labeled
according to the specified permute algorithm. according to the specified permute algorithm.
...@@ -3253,36 +3468,51 @@ def reorder_graph(g, node_permute_algo=None, edge_permute_algo='src', ...@@ -3253,36 +3468,51 @@ def reorder_graph(g, node_permute_algo=None, edge_permute_algo='src',
# sanity checks # sanity checks
if not g.is_homogeneous: if not g.is_homogeneous:
raise DGLError("Only homogeneous graphs are supported.") raise DGLError("Only homogeneous graphs are supported.")
expected_node_algo = ['rcmk', 'metis', 'custom'] expected_node_algo = ["rcmk", "metis", "custom"]
if node_permute_algo is not None and node_permute_algo not in expected_node_algo: if (
raise DGLError("Unexpected node_permute_algo is specified: {}. Expected algos: {}".format( node_permute_algo is not None
node_permute_algo, expected_node_algo)) and node_permute_algo not in expected_node_algo
expected_edge_algo = ['src', 'dst', 'custom'] ):
raise DGLError(
"Unexpected node_permute_algo is specified: {}. Expected algos: {}".format(
node_permute_algo, expected_node_algo
)
)
expected_edge_algo = ["src", "dst", "custom"]
if edge_permute_algo not in expected_edge_algo: if edge_permute_algo not in expected_edge_algo:
raise DGLError("Unexpected edge_permute_algo is specified: {}. Expected algos: {}".format( raise DGLError(
edge_permute_algo, expected_edge_algo)) "Unexpected edge_permute_algo is specified: {}. Expected algos: {}".format(
edge_permute_algo, expected_edge_algo
)
)
g.edata['__orig__'] = F.arange(0, g.num_edges(), g.idtype, g.device) g.edata["__orig__"] = F.arange(0, g.num_edges(), g.idtype, g.device)
# reorder nodes # reorder nodes
if node_permute_algo == 'rcmk': if node_permute_algo == "rcmk":
nodes_perm = rcmk_perm(g) nodes_perm = rcmk_perm(g)
rg = subgraph.node_subgraph(g, nodes_perm, store_ids=False) rg = subgraph.node_subgraph(g, nodes_perm, store_ids=False)
elif node_permute_algo == 'metis': elif node_permute_algo == "metis":
if permute_config is None or 'k' not in permute_config: if permute_config is None or "k" not in permute_config:
raise DGLError( raise DGLError(
"Partition parts 'k' is required for metis. Please specify in permute_config.") "Partition parts 'k' is required for metis. Please specify in permute_config."
nodes_perm = metis_perm(g, permute_config['k']) )
nodes_perm = metis_perm(g, permute_config["k"])
rg = subgraph.node_subgraph(g, nodes_perm, store_ids=False) rg = subgraph.node_subgraph(g, nodes_perm, store_ids=False)
elif node_permute_algo == 'custom': elif node_permute_algo == "custom":
if permute_config is None or 'nodes_perm' not in permute_config: if permute_config is None or "nodes_perm" not in permute_config:
raise DGLError( raise DGLError(
"node_permute_algo is specified as custom, but no 'nodes_perm' is specified in \ "node_permute_algo is specified as custom, but no 'nodes_perm' is specified in \
permute_config.") permute_config."
nodes_perm = permute_config['nodes_perm'] )
nodes_perm = permute_config["nodes_perm"]
if len(nodes_perm) != g.num_nodes(): if len(nodes_perm) != g.num_nodes():
raise DGLError("Length of 'nodes_perm' ({}) does not \ raise DGLError(
match graph num_nodes ({}).".format(len(nodes_perm), g.num_nodes())) "Length of 'nodes_perm' ({}) does not \
match graph num_nodes ({}).".format(
len(nodes_perm), g.num_nodes()
)
)
rg = subgraph.node_subgraph(g, nodes_perm, store_ids=False) rg = subgraph.node_subgraph(g, nodes_perm, store_ids=False)
else: else:
nodes_perm = F.arange(0, g.num_nodes(), g.idtype, g.device) nodes_perm = F.arange(0, g.num_nodes(), g.idtype, g.device)
...@@ -3291,32 +3521,38 @@ def reorder_graph(g, node_permute_algo=None, edge_permute_algo='src', ...@@ -3291,32 +3521,38 @@ def reorder_graph(g, node_permute_algo=None, edge_permute_algo='src',
if store_ids: if store_ids:
rg.ndata[NID] = F.copy_to(F.tensor(nodes_perm, g.idtype), g.device) rg.ndata[NID] = F.copy_to(F.tensor(nodes_perm, g.idtype), g.device)
g.edata.pop('__orig__') g.edata.pop("__orig__")
# reorder edges # reorder edges
if edge_permute_algo == 'src': if edge_permute_algo == "src":
edges_perm = np.argsort(F.asnumpy(rg.edges()[0])) edges_perm = np.argsort(F.asnumpy(rg.edges()[0]))
rg = subgraph.edge_subgraph( rg = subgraph.edge_subgraph(
rg, edges_perm, relabel_nodes=False, store_ids=False) rg, edges_perm, relabel_nodes=False, store_ids=False
elif edge_permute_algo == 'dst': )
elif edge_permute_algo == "dst":
edges_perm = np.argsort(F.asnumpy(rg.edges()[1])) edges_perm = np.argsort(F.asnumpy(rg.edges()[1]))
rg = subgraph.edge_subgraph( rg = subgraph.edge_subgraph(
rg, edges_perm, relabel_nodes=False, store_ids=False) rg, edges_perm, relabel_nodes=False, store_ids=False
elif edge_permute_algo == 'custom': )
if permute_config is None or 'edges_perm' not in permute_config: elif edge_permute_algo == "custom":
if permute_config is None or "edges_perm" not in permute_config:
raise DGLError( raise DGLError(
"edge_permute_algo is specified as custom, but no 'edges_perm' is specified in \ "edge_permute_algo is specified as custom, but no 'edges_perm' is specified in \
permute_config.") permute_config."
edges_perm = permute_config['edges_perm'] )
edges_perm = permute_config["edges_perm"]
# First revert the edge reorder caused by node reorder and then # First revert the edge reorder caused by node reorder and then
# apply user-provided edge permutation # apply user-provided edge permutation
rev_id = F.argsort(rg.edata['__orig__'], 0, False) rev_id = F.argsort(rg.edata["__orig__"], 0, False)
edges_perm = F.astype(F.gather_row(rev_id, F.tensor(edges_perm)), rg.idtype) edges_perm = F.astype(
F.gather_row(rev_id, F.tensor(edges_perm)), rg.idtype
)
rg = subgraph.edge_subgraph( rg = subgraph.edge_subgraph(
rg, edges_perm, relabel_nodes=False, store_ids=False) rg, edges_perm, relabel_nodes=False, store_ids=False
)
if store_ids: if store_ids:
rg.edata[EID] = rg.edata.pop('__orig__') rg.edata[EID] = rg.edata.pop("__orig__")
return rg return rg
...@@ -3342,7 +3578,8 @@ def metis_perm(g, k): ...@@ -3342,7 +3578,8 @@ def metis_perm(g, k):
The nodes permutation. The nodes permutation.
""" """
pids = metis_partition_assignment( pids = metis_partition_assignment(
g if g.device == F.cpu() else g.to(F.cpu()), k) g if g.device == F.cpu() else g.to(F.cpu()), k
)
pids = F.asnumpy(pids) pids = F.asnumpy(pids)
return np.argsort(pids).copy() return np.argsort(pids).copy()
...@@ -3362,7 +3599,7 @@ def rcmk_perm(g): ...@@ -3362,7 +3599,7 @@ def rcmk_perm(g):
iterable[int] iterable[int]
The nodes permutation. The nodes permutation.
""" """
fmat = 'csr' fmat = "csr"
allowed_fmats = sum(g.formats().values(), []) allowed_fmats = sum(g.formats().values(), [])
if fmat not in allowed_fmats: if fmat not in allowed_fmats:
g = g.formats(allowed_fmats + [fmat]) g = g.formats(allowed_fmats + [fmat])
...@@ -3400,16 +3637,23 @@ def norm_by_dst(g, etype=None): ...@@ -3400,16 +3637,23 @@ def norm_by_dst(g, etype=None):
>>> print(dgl.norm_by_dst(g)) >>> print(dgl.norm_by_dst(g))
tensor([0.5000, 0.5000, 1.0000]) tensor([0.5000, 0.5000, 1.0000])
""" """
_, v, _ = g.edges(form='all', etype=etype) _, v, _ = g.edges(form="all", etype=etype)
_, inv_index, count = F.unique(v, return_inverse=True, return_counts=True) _, inv_index, count = F.unique(v, return_inverse=True, return_counts=True)
deg = F.astype(count[inv_index], F.float32) deg = F.astype(count[inv_index], F.float32)
norm = 1. / deg norm = 1.0 / deg
norm = F.replace_inf_with_zero(norm) norm = F.replace_inf_with_zero(norm)
return norm return norm
def radius_graph(x, r, p=2, self_loop=False,
compute_mode='donot_use_mm_for_euclid_dist', get_distances=False): def radius_graph(
x,
r,
p=2,
self_loop=False,
compute_mode="donot_use_mm_for_euclid_dist",
get_distances=False,
):
r"""Construct a graph from a set of points with neighbors within given distance. r"""Construct a graph from a set of points with neighbors within given distance.
The function transforms the coordinates/features of a point set The function transforms the coordinates/features of a point set
...@@ -3521,6 +3765,7 @@ def radius_graph(x, r, p=2, self_loop=False, ...@@ -3521,6 +3765,7 @@ def radius_graph(x, r, p=2, self_loop=False,
return g return g
def random_walk_pe(g, k, eweight_name=None): def random_walk_pe(g, k, eweight_name=None):
r"""Random Walk Positional Encoding, as introduced in r"""Random Walk Positional Encoding, as introduced in
`Graph Neural Networks with Learnable Structural and Positional Representations `Graph Neural Networks with Learnable Structural and Positional Representations
...@@ -3553,28 +3798,29 @@ def random_walk_pe(g, k, eweight_name=None): ...@@ -3553,28 +3798,29 @@ def random_walk_pe(g, k, eweight_name=None):
tensor([[0.0000, 0.5000], tensor([[0.0000, 0.5000],
[0.5000, 0.7500]]) [0.5000, 0.7500]])
""" """
N = g.num_nodes() # number of nodes N = g.num_nodes() # number of nodes
M = g.num_edges() # number of edges M = g.num_edges() # number of edges
A = g.adj(scipy_fmt='csr') # adjacency matrix A = g.adj(scipy_fmt="csr") # adjacency matrix
if eweight_name is not None: if eweight_name is not None:
# add edge weights if required # add edge weights if required
W = sparse.csr_matrix( W = sparse.csr_matrix(
(g.edata[eweight_name].squeeze(), g.find_edges(list(range(M)))), (g.edata[eweight_name].squeeze(), g.find_edges(list(range(M)))),
shape = (N, N) shape=(N, N),
) )
A = A.multiply(W) A = A.multiply(W)
RW = np.array(A / (A.sum(1) + 1e-30)) # 1-step transition probability RW = np.array(A / (A.sum(1) + 1e-30)) # 1-step transition probability
# Iterate for k steps # Iterate for k steps
PE = [F.astype(F.tensor(RW.diagonal()), F.float32)] PE = [F.astype(F.tensor(RW.diagonal()), F.float32)]
RW_power = RW RW_power = RW
for _ in range(k-1): for _ in range(k - 1):
RW_power = RW_power @ RW RW_power = RW_power @ RW
PE.append(F.astype(F.tensor(RW_power.diagonal()), F.float32)) PE.append(F.astype(F.tensor(RW_power.diagonal()), F.float32))
PE = F.stack(PE,dim=-1) PE = F.stack(PE, dim=-1)
return PE return PE
def laplacian_pe(g, k, padding=False, return_eigval=False): def laplacian_pe(g, k, padding=False, return_eigval=False):
r"""Laplacian Positional Encoding, as introduced in r"""Laplacian Positional Encoding, as introduced in
`Benchmarking Graph Neural Networks `Benchmarking Graph Neural Networks
...@@ -3632,38 +3878,45 @@ def laplacian_pe(g, k, padding=False, return_eigval=False): ...@@ -3632,38 +3878,45 @@ def laplacian_pe(g, k, padding=False, return_eigval=False):
# check for the "k < n" constraint # check for the "k < n" constraint
n = g.num_nodes() n = g.num_nodes()
if not padding and n <= k: if not padding and n <= k:
assert "the number of eigenvectors k must be smaller than the number of nodes n, " + \ assert (
f"{k} and {n} detected." "the number of eigenvectors k must be smaller than the number of nodes n, "
+ f"{k} and {n} detected."
)
# get laplacian matrix as I - D^-0.5 * A * D^-0.5 # get laplacian matrix as I - D^-0.5 * A * D^-0.5
A = g.adj(scipy_fmt='csr') # adjacency matrix A = g.adj(scipy_fmt="csr") # adjacency matrix
N = sparse.diags(F.asnumpy(g.in_degrees()).clip(1) ** -0.5, dtype=float) # D^-1/2 N = sparse.diags(
F.asnumpy(g.in_degrees()).clip(1) ** -0.5, dtype=float
) # D^-1/2
L = sparse.eye(g.num_nodes()) - N * A * N L = sparse.eye(g.num_nodes()) - N * A * N
# select eigenvectors with smaller eigenvalues O(n + klogk) # select eigenvectors with smaller eigenvalues O(n + klogk)
EigVal, EigVec = np.linalg.eig(L.toarray()) EigVal, EigVec = np.linalg.eig(L.toarray())
max_freqs = min(n-1, k) max_freqs = min(n - 1, k)
kpartition_indices = np.argpartition(EigVal, max_freqs)[:max_freqs+1] kpartition_indices = np.argpartition(EigVal, max_freqs)[: max_freqs + 1]
topk_eigvals = EigVal[kpartition_indices] topk_eigvals = EigVal[kpartition_indices]
topk_indices = kpartition_indices[topk_eigvals.argsort()][1:] topk_indices = kpartition_indices[topk_eigvals.argsort()][1:]
topk_EigVec = EigVec[:, topk_indices] topk_EigVec = EigVec[:, topk_indices]
eigvals = F.tensor(EigVal[topk_indices], dtype=F.float32) eigvals = F.tensor(EigVal[topk_indices], dtype=F.float32)
# get random flip signs # get random flip signs
rand_sign = 2 * (np.random.rand(max_freqs) > 0.5) - 1. rand_sign = 2 * (np.random.rand(max_freqs) > 0.5) - 1.0
PE = F.astype(F.tensor(rand_sign * topk_EigVec), F.float32) PE = F.astype(F.tensor(rand_sign * topk_EigVec), F.float32)
# add paddings # add paddings
if n <= k: if n <= k:
temp_EigVec = F.zeros([n, k-n+1], dtype=F.float32, ctx=F.context(PE)) temp_EigVec = F.zeros(
[n, k - n + 1], dtype=F.float32, ctx=F.context(PE)
)
PE = F.cat([PE, temp_EigVec], dim=1) PE = F.cat([PE, temp_EigVec], dim=1)
temp_EigVal = F.tensor(np.full(k-n+1, np.nan), F.float32) temp_EigVal = F.tensor(np.full(k - n + 1, np.nan), F.float32)
eigvals = F.cat([eigvals, temp_EigVal], dim=0) eigvals = F.cat([eigvals, temp_EigVal], dim=0)
if return_eigval: if return_eigval:
return PE, eigvals return PE, eigvals
return PE return PE
def to_half(g): def to_half(g):
r"""Cast this graph to use float16 (half-precision) for any r"""Cast this graph to use float16 (half-precision) for any
floating-point edge and node feature data. floating-point edge and node feature data.
...@@ -3681,6 +3934,7 @@ def to_half(g): ...@@ -3681,6 +3934,7 @@ def to_half(g):
ret._node_frames = [frame.half() for frame in ret._node_frames] ret._node_frames = [frame.half() for frame in ret._node_frames]
return ret return ret
def to_float(g): def to_float(g):
r"""Cast this graph to use float32 (single-precision) for any r"""Cast this graph to use float32 (single-precision) for any
floating-point edge and node feature data. floating-point edge and node feature data.
...@@ -3698,6 +3952,7 @@ def to_float(g): ...@@ -3698,6 +3952,7 @@ def to_float(g):
ret._node_frames = [frame.float() for frame in ret._node_frames] ret._node_frames = [frame.float() for frame in ret._node_frames]
return ret return ret
def to_double(g): def to_double(g):
r"""Cast this graph to use float64 (double-precision) for any r"""Cast this graph to use float64 (double-precision) for any
floating-point edge and node feature data. floating-point edge and node feature data.
...@@ -3715,6 +3970,7 @@ def to_double(g): ...@@ -3715,6 +3970,7 @@ def to_double(g):
ret._node_frames = [frame.double() for frame in ret._node_frames] ret._node_frames = [frame.double() for frame in ret._node_frames]
return ret return ret
def double_radius_node_labeling(g, src, dst): def double_radius_node_labeling(g, src, dst):
r"""Double Radius Node Labeling, as introduced in `Link Prediction r"""Double Radius Node Labeling, as introduced in `Link Prediction
Based on Graph Neural Networks <https://arxiv.org/abs/1802.09691>`__. Based on Graph Neural Networks <https://arxiv.org/abs/1802.09691>`__.
...@@ -3754,7 +4010,7 @@ def double_radius_node_labeling(g, src, dst): ...@@ -3754,7 +4010,7 @@ def double_radius_node_labeling(g, src, dst):
>>> dgl.double_radius_node_labeling(g, 0, 1) >>> dgl.double_radius_node_labeling(g, 0, 1)
tensor([1, 1, 3, 2, 3, 7, 0]) tensor([1, 1, 3, 2, 3, 7, 0])
""" """
adj = g.adj(scipy_fmt='csr') adj = g.adj(scipy_fmt="csr")
src, dst = (dst, src) if src > dst else (src, dst) src, dst = (dst, src) if src > dst else (src, dst)
idx = list(range(src)) + list(range(src + 1, adj.shape[0])) idx = list(range(src)) + list(range(src + 1, adj.shape[0]))
...@@ -3764,22 +4020,27 @@ def double_radius_node_labeling(g, src, dst): ...@@ -3764,22 +4020,27 @@ def double_radius_node_labeling(g, src, dst):
adj_wo_dst = adj[idx, :][:, idx] adj_wo_dst = adj[idx, :][:, idx]
# distance to the source node # distance to the source node
ds = sparse.csgraph.shortest_path(adj_wo_dst, directed=False, unweighted=True, indices=src) ds = sparse.csgraph.shortest_path(
adj_wo_dst, directed=False, unweighted=True, indices=src
)
ds = np.insert(ds, dst, 0, axis=0) ds = np.insert(ds, dst, 0, axis=0)
# distance to the destination node # distance to the destination node
dt = sparse.csgraph.shortest_path(adj_wo_src, directed=False, unweighted=True, indices=dst-1) dt = sparse.csgraph.shortest_path(
adj_wo_src, directed=False, unweighted=True, indices=dst - 1
)
dt = np.insert(dt, src, 0, axis=0) dt = np.insert(dt, src, 0, axis=0)
d = ds + dt d = ds + dt
# suppress invalid value (nan) warnings # suppress invalid value (nan) warnings
with np.errstate(invalid='ignore'): with np.errstate(invalid="ignore"):
z = 1 + np.stack([ds, dt]).min(axis=0) + d//2 * (d//2 + d%2 - 1) z = 1 + np.stack([ds, dt]).min(axis=0) + d // 2 * (d // 2 + d % 2 - 1)
z[src] = 1 z[src] = 1
z[dst] = 1 z[dst] = 1
z[np.isnan(z)] = 0 # unreachable nodes z[np.isnan(z)] = 0 # unreachable nodes
return F.tensor(z, F.int64) return F.tensor(z, F.int64)
def shortest_dist(g, root=None, return_paths=False): def shortest_dist(g, root=None, return_paths=False):
r"""Compute shortest distance and paths on the given graph. r"""Compute shortest distance and paths on the given graph.
...@@ -3859,13 +4120,18 @@ def shortest_dist(g, root=None, return_paths=False): ...@@ -3859,13 +4120,18 @@ def shortest_dist(g, root=None, return_paths=False):
""" """
if root is None: if root is None:
dist, pred = sparse.csgraph.shortest_path( dist, pred = sparse.csgraph.shortest_path(
g.adj(scipy_fmt='csr'), return_predecessors=True, unweighted=True, g.adj(scipy_fmt="csr"),
directed=True return_predecessors=True,
unweighted=True,
directed=True,
) )
else: else:
dist, pred = sparse.csgraph.dijkstra( dist, pred = sparse.csgraph.dijkstra(
g.adj(scipy_fmt='csr'), directed=True, indices=root, g.adj(scipy_fmt="csr"),
return_predecessors=True, unweighted=True, directed=True,
indices=root,
return_predecessors=True,
unweighted=True,
) )
dist[np.isinf(dist)] = -1 dist[np.isinf(dist)] = -1
...@@ -3901,7 +4167,7 @@ def shortest_dist(g, root=None, return_paths=False): ...@@ -3901,7 +4167,7 @@ def shortest_dist(g, root=None, return_paths=False):
u.extend(nodes[:-1]) u.extend(nodes[:-1])
v.extend(nodes[1:]) v.extend(nodes[1:])
if nodes: if nodes:
masks_i[j, :len(nodes) - 1] = True masks_i[j, : len(nodes) - 1] = True
masks.append(masks_i) masks.append(masks_i)
masks = np.stack(masks, axis=0) masks = np.stack(masks, axis=0)
...@@ -3911,8 +4177,9 @@ def shortest_dist(g, root=None, return_paths=False): ...@@ -3911,8 +4177,9 @@ def shortest_dist(g, root=None, return_paths=False):
if root is not None: if root is not None:
paths = paths[0] paths = paths[0]
return F.copy_to(F.tensor(dist, dtype=F.int64), g.device), \ return F.copy_to(F.tensor(dist, dtype=F.int64), g.device), F.copy_to(
F.copy_to(F.tensor(paths, dtype=F.int64), g.device) F.tensor(paths, dtype=F.int64), g.device
)
def svd_pe(g, k, padding=False, random_flip=True): def svd_pe(g, k, padding=False, random_flip=True):
...@@ -3962,8 +4229,7 @@ def svd_pe(g, k, padding=False, random_flip=True): ...@@ -3962,8 +4229,7 @@ def svd_pe(g, k, padding=False, random_flip=True):
if not padding and n < k: if not padding and n < k:
raise ValueError( raise ValueError(
"The number of singular values k must be no greater than the " "The number of singular values k must be no greater than the "
"number of nodes n, but " + "number of nodes n, but " + f"got {k} and {n} respectively."
f"got {k} and {n} respectively."
) )
a = g.adj(ctx=g.device, scipy_fmt="coo").toarray() a = g.adj(ctx=g.device, scipy_fmt="coo").toarray()
u, d, vh = scipy.linalg.svd(a) u, d, vh = scipy.linalg.svd(a)
......
...@@ -18,12 +18,9 @@ ...@@ -18,12 +18,9 @@
from scipy.linalg import expm from scipy.linalg import expm
from .. import convert from .. import backend as F, convert, function as fn, utils
from .. import backend as F
from .. import function as fn
from ..base import DGLError from ..base import DGLError
from . import functional from . import functional
from .. import utils
try: try:
import torch import torch
...@@ -32,32 +29,33 @@ except ImportError: ...@@ -32,32 +29,33 @@ except ImportError:
pass pass
__all__ = [ __all__ = [
'BaseTransform', "BaseTransform",
'RowFeatNormalizer', "RowFeatNormalizer",
'FeatMask', "FeatMask",
'RandomWalkPE', "RandomWalkPE",
'LaplacianPE', "LaplacianPE",
'AddSelfLoop', "AddSelfLoop",
'RemoveSelfLoop', "RemoveSelfLoop",
'AddReverse', "AddReverse",
'ToSimple', "ToSimple",
'LineGraph', "LineGraph",
'KHopGraph', "KHopGraph",
'AddMetaPaths', "AddMetaPaths",
'Compose', "Compose",
'GCNNorm', "GCNNorm",
'PPR', "PPR",
'HeatKernel', "HeatKernel",
'GDC', "GDC",
'NodeShuffle', "NodeShuffle",
'DropNode', "DropNode",
'DropEdge', "DropEdge",
'AddEdge', "AddEdge",
'SIGNDiffusion', "SIGNDiffusion",
'ToLevi', "ToLevi",
'SVDPE' "SVDPE",
] ]
def update_graph_structure(g, data_dict, copy_edata=True): def update_graph_structure(g, data_dict, copy_edata=True):
r"""Update the structure of a graph. r"""Update the structure of a graph.
...@@ -82,8 +80,9 @@ def update_graph_structure(g, data_dict, copy_edata=True): ...@@ -82,8 +80,9 @@ def update_graph_structure(g, data_dict, copy_edata=True):
for ntype in g.ntypes: for ntype in g.ntypes:
num_nodes_dict[ntype] = g.num_nodes(ntype) num_nodes_dict[ntype] = g.num_nodes(ntype)
new_g = convert.heterograph(data_dict, num_nodes_dict=num_nodes_dict, new_g = convert.heterograph(
idtype=idtype, device=device) data_dict, num_nodes_dict=num_nodes_dict, idtype=idtype, device=device
)
# Copy features # Copy features
for ntype in g.ntypes: for ntype in g.ntypes:
...@@ -97,13 +96,16 @@ def update_graph_structure(g, data_dict, copy_edata=True): ...@@ -97,13 +96,16 @@ def update_graph_structure(g, data_dict, copy_edata=True):
return new_g return new_g
class BaseTransform: class BaseTransform:
r"""An abstract class for writing transforms.""" r"""An abstract class for writing transforms."""
def __call__(self, g): def __call__(self, g):
raise NotImplementedError raise NotImplementedError
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '()' return self.__class__.__name__ + "()"
class RowFeatNormalizer(BaseTransform): class RowFeatNormalizer(BaseTransform):
r""" r"""
...@@ -172,9 +174,16 @@ class RowFeatNormalizer(BaseTransform): ...@@ -172,9 +174,16 @@ class RowFeatNormalizer(BaseTransform):
... g.edata['w'][('player', 'plays', 'game')].sum(1)) ... g.edata['w'][('player', 'plays', 'game')].sum(1))
tensor([1., 1.]) tensor([1., 1.]) tensor([1., 1.]) tensor([1., 1.])
""" """
def __init__(self, subtract_min=False, node_feat_names=None, edge_feat_names=None):
self.node_feat_names = [] if node_feat_names is None else node_feat_names def __init__(
self.edge_feat_names = [] if edge_feat_names is None else edge_feat_names self, subtract_min=False, node_feat_names=None, edge_feat_names=None
):
self.node_feat_names = (
[] if node_feat_names is None else node_feat_names
)
self.edge_feat_names = (
[] if edge_feat_names is None else edge_feat_names
)
self.subtract_min = subtract_min self.subtract_min = subtract_min
def row_normalize(self, feat): def row_normalize(self, feat):
...@@ -196,28 +205,35 @@ class RowFeatNormalizer(BaseTransform): ...@@ -196,28 +205,35 @@ class RowFeatNormalizer(BaseTransform):
""" """
if self.subtract_min: if self.subtract_min:
feat = feat - feat.min() feat = feat - feat.min()
feat.div_(feat.sum(dim=-1, keepdim=True).clamp_(min=1.)) feat.div_(feat.sum(dim=-1, keepdim=True).clamp_(min=1.0))
return feat return feat
def __call__(self, g): def __call__(self, g):
for node_feat_name in self.node_feat_names: for node_feat_name in self.node_feat_names:
if isinstance(g.ndata[node_feat_name], torch.Tensor): if isinstance(g.ndata[node_feat_name], torch.Tensor):
g.ndata[node_feat_name] = self.row_normalize(g.ndata[node_feat_name]) g.ndata[node_feat_name] = self.row_normalize(
g.ndata[node_feat_name]
)
else: else:
for ntype in g.ndata[node_feat_name].keys(): for ntype in g.ndata[node_feat_name].keys():
g.nodes[ntype].data[node_feat_name] = \ g.nodes[ntype].data[node_feat_name] = self.row_normalize(
self.row_normalize(g.nodes[ntype].data[node_feat_name]) g.nodes[ntype].data[node_feat_name]
)
for edge_feat_name in self.edge_feat_names: for edge_feat_name in self.edge_feat_names:
if isinstance(g.edata[edge_feat_name], torch.Tensor): if isinstance(g.edata[edge_feat_name], torch.Tensor):
g.edata[edge_feat_name] = self.row_normalize(g.edata[edge_feat_name]) g.edata[edge_feat_name] = self.row_normalize(
g.edata[edge_feat_name]
)
else: else:
for etype in g.edata[edge_feat_name].keys(): for etype in g.edata[edge_feat_name].keys():
g.edges[etype].data[edge_feat_name] = \ g.edges[etype].data[edge_feat_name] = self.row_normalize(
self.row_normalize(g.edges[etype].data[edge_feat_name]) g.edges[etype].data[edge_feat_name]
)
return g return g
class FeatMask(BaseTransform): class FeatMask(BaseTransform):
r"""Randomly mask columns of the node and edge feature tensors, as described in `Graph r"""Randomly mask columns of the node and edge feature tensors, as described in `Graph
Contrastive Learning with Augmentations <https://arxiv.org/abs/2010.13902>`__. Contrastive Learning with Augmentations <https://arxiv.org/abs/2010.13902>`__.
...@@ -290,10 +306,15 @@ class FeatMask(BaseTransform): ...@@ -290,10 +306,15 @@ class FeatMask(BaseTransform):
tensor([[0., 1., 0., 1., 0.], tensor([[0., 1., 0., 1., 0.],
[0., 1., 0., 1., 0.]]) [0., 1., 0., 1., 0.]])
""" """
def __init__(self, p=0.5, node_feat_names=None, edge_feat_names=None): def __init__(self, p=0.5, node_feat_names=None, edge_feat_names=None):
self.p = p self.p = p
self.node_feat_names = [] if node_feat_names is None else node_feat_names self.node_feat_names = (
self.edge_feat_names = [] if edge_feat_names is None else edge_feat_names [] if node_feat_names is None else node_feat_names
)
self.edge_feat_names = (
[] if edge_feat_names is None else edge_feat_names
)
self.dist = Bernoulli(p) self.dist = Bernoulli(p)
def __call__(self, g): def __call__(self, g):
...@@ -303,27 +324,56 @@ class FeatMask(BaseTransform): ...@@ -303,27 +324,56 @@ class FeatMask(BaseTransform):
for node_feat_name in self.node_feat_names: for node_feat_name in self.node_feat_names:
if isinstance(g.ndata[node_feat_name], torch.Tensor): if isinstance(g.ndata[node_feat_name], torch.Tensor):
feat_mask = self.dist.sample(torch.Size([g.ndata[node_feat_name].shape[-1], ])) feat_mask = self.dist.sample(
torch.Size(
[
g.ndata[node_feat_name].shape[-1],
]
)
)
g.ndata[node_feat_name][:, feat_mask.bool().to(g.device)] = 0 g.ndata[node_feat_name][:, feat_mask.bool().to(g.device)] = 0
else: else:
for ntype in g.ndata[node_feat_name].keys(): for ntype in g.ndata[node_feat_name].keys():
mask_shape = g.ndata[node_feat_name][ntype].shape[-1] mask_shape = g.ndata[node_feat_name][ntype].shape[-1]
feat_mask = self.dist.sample(torch.Size([mask_shape, ])) feat_mask = self.dist.sample(
g.ndata[node_feat_name][ntype][:, feat_mask.bool().to(g.device)] = 0 torch.Size(
[
mask_shape,
]
)
)
g.ndata[node_feat_name][ntype][
:, feat_mask.bool().to(g.device)
] = 0
for edge_feat_name in self.edge_feat_names: for edge_feat_name in self.edge_feat_names:
if isinstance(g.edata[edge_feat_name], torch.Tensor): if isinstance(g.edata[edge_feat_name], torch.Tensor):
feat_mask = self.dist.sample(torch.Size([g.edata[edge_feat_name].shape[-1], ])) feat_mask = self.dist.sample(
torch.Size(
[
g.edata[edge_feat_name].shape[-1],
]
)
)
g.edata[edge_feat_name][:, feat_mask.bool().to(g.device)] = 0 g.edata[edge_feat_name][:, feat_mask.bool().to(g.device)] = 0
else: else:
for etype in g.edata[edge_feat_name].keys(): for etype in g.edata[edge_feat_name].keys():
mask_shape = g.edata[edge_feat_name][etype].shape[-1] mask_shape = g.edata[edge_feat_name][etype].shape[-1]
feat_mask = self.dist.sample(torch.Size([mask_shape, ])) feat_mask = self.dist.sample(
g.edata[edge_feat_name][etype][:, feat_mask.bool().to(g.device)] = 0 torch.Size(
[
mask_shape,
]
)
)
g.edata[edge_feat_name][etype][
:, feat_mask.bool().to(g.device)
] = 0
return g return g
class RandomWalkPE(BaseTransform): class RandomWalkPE(BaseTransform):
r"""Random Walk Positional Encoding, as introduced in r"""Random Walk Positional Encoding, as introduced in
`Graph Neural Networks with Learnable Structural and Positional Representations `Graph Neural Networks with Learnable Structural and Positional Representations
...@@ -354,17 +404,21 @@ class RandomWalkPE(BaseTransform): ...@@ -354,17 +404,21 @@ class RandomWalkPE(BaseTransform):
tensor([[0.0000, 0.5000], tensor([[0.0000, 0.5000],
[0.5000, 0.7500]]) [0.5000, 0.7500]])
""" """
def __init__(self, k, feat_name='PE', eweight_name=None):
def __init__(self, k, feat_name="PE", eweight_name=None):
self.k = k self.k = k
self.feat_name = feat_name self.feat_name = feat_name
self.eweight_name = eweight_name self.eweight_name = eweight_name
def __call__(self, g): def __call__(self, g):
PE = functional.random_walk_pe(g, k=self.k, eweight_name=self.eweight_name) PE = functional.random_walk_pe(
g, k=self.k, eweight_name=self.eweight_name
)
g.ndata[self.feat_name] = F.copy_to(PE, g.device) g.ndata[self.feat_name] = F.copy_to(PE, g.device)
return g return g
class LaplacianPE(BaseTransform): class LaplacianPE(BaseTransform):
r"""Laplacian Positional Encoding, as introduced in r"""Laplacian Positional Encoding, as introduced in
`Benchmarking Graph Neural Networks `Benchmarking Graph Neural Networks
...@@ -425,7 +479,8 @@ class LaplacianPE(BaseTransform): ...@@ -425,7 +479,8 @@ class LaplacianPE(BaseTransform):
[-0.5117, 0.4508, -0.3938, 0.6295, 0.0000], [-0.5117, 0.4508, -0.3938, 0.6295, 0.0000],
[ 0.1954, 0.5612, 0.0278, -0.5454, 0.0000]]) [ 0.1954, 0.5612, 0.0278, -0.5454, 0.0000]])
""" """
def __init__(self, k, feat_name='PE', eigval_name=None, padding=False):
def __init__(self, k, feat_name="PE", eigval_name=None, padding=False):
self.k = k self.k = k
self.feat_name = feat_name self.feat_name = feat_name
self.eigval_name = eigval_name self.eigval_name = eigval_name
...@@ -433,9 +488,10 @@ class LaplacianPE(BaseTransform): ...@@ -433,9 +488,10 @@ class LaplacianPE(BaseTransform):
def __call__(self, g): def __call__(self, g):
if self.eigval_name: if self.eigval_name:
PE, eigval = functional.laplacian_pe(g, k=self.k, padding=self.padding, PE, eigval = functional.laplacian_pe(
return_eigval=True) g, k=self.k, padding=self.padding, return_eigval=True
eigval = F.repeat(F.reshape(eigval, [1,-1]), g.num_nodes(), dim=0) )
eigval = F.repeat(F.reshape(eigval, [1, -1]), g.num_nodes(), dim=0)
g.ndata[self.eigval_name] = F.copy_to(eigval, g.device) g.ndata[self.eigval_name] = F.copy_to(eigval, g.device)
else: else:
PE = functional.laplacian_pe(g, k=self.k, padding=self.padding) PE = functional.laplacian_pe(g, k=self.k, padding=self.padding)
...@@ -443,83 +499,90 @@ class LaplacianPE(BaseTransform): ...@@ -443,83 +499,90 @@ class LaplacianPE(BaseTransform):
return g return g
class AddSelfLoop(BaseTransform): class AddSelfLoop(BaseTransform):
r"""Add self-loops for each node in the graph and return a new graph. r"""Add self-loops for each node in the graph and return a new graph.
For heterogeneous graphs, self-loops are added only for edge types with same For heterogeneous graphs, self-loops are added only for edge types with same
source and destination node types. source and destination node types.
Parameters Parameters
---------- ----------
allow_duplicate : bool, optional allow_duplicate : bool, optional
If False, it will first remove self-loops to prevent duplicate self-loops. If False, it will first remove self-loops to prevent duplicate self-loops.
new_etypes : bool, optional new_etypes : bool, optional
If True, it will add an edge type 'self' per node type, which holds self-loops. If True, it will add an edge type 'self' per node type, which holds self-loops.
edge_feat_names : list[str], optional edge_feat_names : list[str], optional
The names of the self-loop features to apply `fill_data`. If None, it will apply `fill_data` The names of the self-loop features to apply `fill_data`. If None, it
to all self-loop features. Default: None. will apply `fill_data` to all self-loop features. Default: None.
fill_data : int, float or str, optional fill_data : int, float or str, optional
The value to fill the self-loop features. Default: 1. The value to fill the self-loop features. Default: 1.
* If ``fill_data`` is ``int`` or ``float``, self-loop features will be directly given by * If ``fill_data`` is ``int`` or ``float``, self-loop features will be directly given by
``fill_data``. ``fill_data``.
* if ``fill_data`` is ``str``, self-loop features will be generated by aggregating the * if ``fill_data`` is ``str``, self-loop features will be generated by aggregating the
features of the incoming edges of the corresponding nodes. The supported aggregation are: features of the incoming edges of the corresponding nodes. The supported aggregation are:
``'mean'``, ``'sum'``, ``'max'``, ``'min'``. ``'mean'``, ``'sum'``, ``'max'``, ``'min'``.
Example Example
------- -------
>>> import dgl >>> import dgl
>>> from dgl import AddSelfLoop >>> from dgl import AddSelfLoop
Case1: Add self-loops for a homogeneous graph Case1: Add self-loops for a homogeneous graph
>>> transform = AddSelfLoop(fill_data='sum') >>> transform = AddSelfLoop(fill_data='sum')
>>> g = dgl.graph(([0, 0, 2], [2, 1, 0])) >>> g = dgl.graph(([0, 0, 2], [2, 1, 0]))
>>> g.edata['he'] = torch.arange(3).float().reshape(-1, 1) >>> g.edata['he'] = torch.arange(3).float().reshape(-1, 1)
>>> new_g = transform(g) >>> new_g = transform(g)
>>> print(new_g.edges()) >>> print(new_g.edges())
(tensor([1, 0, 1, 2]), tensor([2, 0, 1, 2])) (tensor([1, 0, 1, 2]), tensor([2, 0, 1, 2]))
>>> print(new_g.edata('he')) >>> print(new_g.edata('he'))
tensor([[0.], tensor([[0.],
[1.], [1.],
[2.], [2.],
[2.], [2.],
[1.], [1.],
[0.]]) [0.]])
Case2: Add self-loops for a heterogeneous graph Case2: Add self-loops for a heterogeneous graph
>>> transform = AddSelfLoop(fill_data='sum') >>> transform = AddSelfLoop(fill_data='sum')
>>> g = dgl.heterograph({ >>> g = dgl.heterograph({
... ('user', 'follows', 'user'): (torch.tensor([1, 2]), ... ('user', 'follows', 'user'): (torch.tensor([1, 2]),
... torch.tensor([0, 1])), ... torch.tensor([0, 1])),
... ('user', 'plays', 'game'): (torch.tensor([0, 1]), ... ('user', 'plays', 'game'): (torch.tensor([0, 1]),
... torch.tensor([0, 1]))}) ... torch.tensor([0, 1]))})
>>> g.edata['feat'] = {('user', 'follows', 'user'): torch.randn(2, 5), >>> g.edata['feat'] = {('user', 'follows', 'user'): torch.randn(2, 5),
... ('user', 'plays', 'game'): torch.randn(2, 5)} ... ('user', 'plays', 'game'): torch.randn(2, 5)}
>>> g.edata['feat1'] = {('user', 'follows', 'user'): torch.randn(2, 15), >>> g.edata['feat1'] = {('user', 'follows', 'user'): torch.randn(2, 15),
... ('user', 'plays', 'game'): torch.randn(2, 15)} ... ('user', 'plays', 'game'): torch.randn(2, 15)}
>>> new_g = transform(g) >>> new_g = transform(g)
>>> print(new_g.edges(etype='plays')) >>> print(new_g.edges(etype='plays'))
(tensor([0, 1]), tensor([0, 1])) (tensor([0, 1]), tensor([0, 1]))
>>> print(new_g.edges(etype='follows')) >>> print(new_g.edges(etype='follows'))
(tensor([1, 2]), tensor([0, 1])) (tensor([1, 2]), tensor([0, 1]))
>>> print(new_g.edata['feat'][('user', 'follows', 'user')].shape) >>> print(new_g.edata['feat'][('user', 'follows', 'user')].shape)
torch.Size([5, 5]) torch.Size([5, 5])
Case3: Add self-etypes for a heterogeneous graph Case3: Add self-etypes for a heterogeneous graph
>>> transform = AddSelfLoop(new_etypes=True) >>> transform = AddSelfLoop(new_etypes=True)
>>> new_g = transform(g) >>> new_g = transform(g)
>>> print(new_g.edges(etype='follows')) >>> print(new_g.edges(etype='follows'))
(tensor([1, 0, 1, 2]), tensor([2, 0, 1, 2])) (tensor([1, 0, 1, 2]), tensor([2, 0, 1, 2]))
>>> print(new_g.edges(etype=('game', 'self', 'game'))) >>> print(new_g.edges(etype=('game', 'self', 'game')))
(tensor([0, 1]), tensor([0, 1])) (tensor([0, 1]), tensor([0, 1]))
""" """
def __init__(self, allow_duplicate=False, new_etypes=False, edge_feat_names=None, fill_data=1.): def __init__(
self,
allow_duplicate=False,
new_etypes=False,
edge_feat_names=None,
fill_data=1.0,
):
self.allow_duplicate = allow_duplicate self.allow_duplicate = allow_duplicate
self.new_etypes = new_etypes self.new_etypes = new_etypes
self.edge_feat_names = edge_feat_names self.edge_feat_names = edge_feat_names
...@@ -550,8 +613,12 @@ class AddSelfLoop(BaseTransform): ...@@ -550,8 +613,12 @@ class AddSelfLoop(BaseTransform):
if not self.allow_duplicate: if not self.allow_duplicate:
g = functional.remove_self_loop(g, etype=c_etype) g = functional.remove_self_loop(g, etype=c_etype)
return functional.add_self_loop(g, edge_feat_names=self.edge_feat_names, return functional.add_self_loop(
fill_data=self.fill_data, etype=c_etype) g,
edge_feat_names=self.edge_feat_names,
fill_data=self.fill_data,
etype=c_etype,
)
def __call__(self, g): def __call__(self, g):
for c_etype in g.canonical_etypes: for c_etype in g.canonical_etypes:
...@@ -565,7 +632,7 @@ class AddSelfLoop(BaseTransform): ...@@ -565,7 +632,7 @@ class AddSelfLoop(BaseTransform):
# Add self etypes # Add self etypes
for ntype in g.ntypes: for ntype in g.ntypes:
nids = F.arange(0, g.num_nodes(ntype), idtype, device) nids = F.arange(0, g.num_nodes(ntype), idtype, device)
data_dict[(ntype, 'self', ntype)] = (nids, nids) data_dict[(ntype, "self", ntype)] = (nids, nids)
# Copy edges # Copy edges
for c_etype in g.canonical_etypes: for c_etype in g.canonical_etypes:
...@@ -575,6 +642,7 @@ class AddSelfLoop(BaseTransform): ...@@ -575,6 +642,7 @@ class AddSelfLoop(BaseTransform):
return g return g
class RemoveSelfLoop(BaseTransform): class RemoveSelfLoop(BaseTransform):
r"""Remove self-loops for each node in the graph and return a new graph. r"""Remove self-loops for each node in the graph and return a new graph.
...@@ -607,6 +675,7 @@ class RemoveSelfLoop(BaseTransform): ...@@ -607,6 +675,7 @@ class RemoveSelfLoop(BaseTransform):
>>> print(new_g.edges(etype='follows')) >>> print(new_g.edges(etype='follows'))
(tensor([1]), tensor([2])) (tensor([1]), tensor([2]))
""" """
def transform_etype(self, c_etype, g): def transform_etype(self, c_etype, g):
r"""Transform the graph corresponding to a canonical edge type. r"""Transform the graph corresponding to a canonical edge type.
...@@ -632,6 +701,7 @@ class RemoveSelfLoop(BaseTransform): ...@@ -632,6 +701,7 @@ class RemoveSelfLoop(BaseTransform):
g = self.transform_etype(c_etype, g) g = self.transform_etype(c_etype, g)
return g return g
class AddReverse(BaseTransform): class AddReverse(BaseTransform):
r"""Add a reverse edge :math:`(i,j)` for each edge :math:`(j,i)` in the input graph and r"""Add a reverse edge :math:`(i,j)` for each edge :math:`(j,i)` in the input graph and
return a new graph. return a new graph.
...@@ -692,6 +762,7 @@ class AddReverse(BaseTransform): ...@@ -692,6 +762,7 @@ class AddReverse(BaseTransform):
>>> print(new_g.edges(etype='follows')) >>> print(new_g.edges(etype='follows'))
(tensor([1, 2, 2, 2]), tensor([2, 2, 1, 2])) (tensor([1, 2, 2, 2]), tensor([2, 2, 1, 2]))
""" """
def __init__(self, copy_edata=False, sym_new_etype=False): def __init__(self, copy_edata=False, sym_new_etype=False):
self.copy_edata = copy_edata self.copy_edata = copy_edata
self.sym_new_etype = sym_new_etype self.sym_new_etype = sym_new_etype
...@@ -729,10 +800,12 @@ class AddReverse(BaseTransform): ...@@ -729,10 +800,12 @@ class AddReverse(BaseTransform):
""" """
utype, etype, vtype = c_etype utype, etype, vtype = c_etype
src, dst = g.edges(etype=c_etype) src, dst = g.edges(etype=c_etype)
data_dict.update({ data_dict.update(
c_etype: (src, dst), {
(vtype, 'rev_{}'.format(etype), utype): (dst, src) c_etype: (src, dst),
}) (vtype, "rev_{}".format(etype), utype): (dst, src),
}
)
def transform_etype(self, c_etype, g, data_dict): def transform_etype(self, c_etype, g, data_dict):
r"""Transform the graph corresponding to a canonical edge type. r"""Transform the graph corresponding to a canonical edge type.
...@@ -762,19 +835,27 @@ class AddReverse(BaseTransform): ...@@ -762,19 +835,27 @@ class AddReverse(BaseTransform):
for c_etype in g.canonical_etypes: for c_etype in g.canonical_etypes:
utype, etype, vtype = c_etype utype, etype, vtype = c_etype
if utype != vtype or self.sym_new_etype: if utype != vtype or self.sym_new_etype:
rev_c_etype = (vtype, 'rev_{}'.format(etype), utype) rev_c_etype = (vtype, "rev_{}".format(etype), utype)
for key, feat in g.edges[c_etype].data.items(): for key, feat in g.edges[c_etype].data.items():
new_g.edges[c_etype].data[key] = feat new_g.edges[c_etype].data[key] = feat
if self.copy_edata: if self.copy_edata:
new_g.edges[rev_c_etype].data[key] = feat new_g.edges[rev_c_etype].data[key] = feat
else: else:
for key, feat in g.edges[c_etype].data.items(): for key, feat in g.edges[c_etype].data.items():
new_feat = feat if self.copy_edata else F.zeros( new_feat = (
F.shape(feat), F.dtype(feat), F.context(feat)) feat
new_g.edges[c_etype].data[key] = F.cat([feat, new_feat], dim=0) if self.copy_edata
else F.zeros(
F.shape(feat), F.dtype(feat), F.context(feat)
)
)
new_g.edges[c_etype].data[key] = F.cat(
[feat, new_feat], dim=0
)
return new_g return new_g
class ToSimple(BaseTransform): class ToSimple(BaseTransform):
r"""Convert a graph to a simple graph without parallel edges and return a new graph. r"""Convert a graph to a simple graph without parallel edges and return a new graph.
...@@ -823,15 +904,19 @@ class ToSimple(BaseTransform): ...@@ -823,15 +904,19 @@ class ToSimple(BaseTransform):
>>> print(sg.edges(etype='plays')) >>> print(sg.edges(etype='plays'))
(tensor([0, 1]), tensor([1, 1])) (tensor([0, 1]), tensor([1, 1]))
""" """
def __init__(self, return_counts='count', aggregator='arbitrary'):
def __init__(self, return_counts="count", aggregator="arbitrary"):
self.return_counts = return_counts self.return_counts = return_counts
self.aggregator = aggregator self.aggregator = aggregator
def __call__(self, g): def __call__(self, g):
return functional.to_simple(g, return functional.to_simple(
return_counts=self.return_counts, g,
copy_edata=True, return_counts=self.return_counts,
aggregator=self.aggregator) copy_edata=True,
aggregator=self.aggregator,
)
class LineGraph(BaseTransform): class LineGraph(BaseTransform):
r"""Return the line graph of the input graph. r"""Return the line graph of the input graph.
...@@ -880,11 +965,15 @@ class LineGraph(BaseTransform): ...@@ -880,11 +965,15 @@ class LineGraph(BaseTransform):
>>> print(new_g.edges()) >>> print(new_g.edges())
(tensor([0]), tensor([2])) (tensor([0]), tensor([2]))
""" """
def __init__(self, backtracking=True): def __init__(self, backtracking=True):
self.backtracking = backtracking self.backtracking = backtracking
def __call__(self, g): def __call__(self, g):
return functional.line_graph(g, backtracking=self.backtracking, shared=True) return functional.line_graph(
g, backtracking=self.backtracking, shared=True
)
class KHopGraph(BaseTransform): class KHopGraph(BaseTransform):
r"""Return the graph whose edges connect the :math:`k`-hop neighbors of the original graph. r"""Return the graph whose edges connect the :math:`k`-hop neighbors of the original graph.
...@@ -908,12 +997,14 @@ class KHopGraph(BaseTransform): ...@@ -908,12 +997,14 @@ class KHopGraph(BaseTransform):
>>> print(new_g.edges()) >>> print(new_g.edges())
(tensor([0]), tensor([2])) (tensor([0]), tensor([2]))
""" """
def __init__(self, k): def __init__(self, k):
self.k = k self.k = k
def __call__(self, g): def __call__(self, g):
return functional.khop_graph(g, self.k) return functional.khop_graph(g, self.k)
class AddMetaPaths(BaseTransform): class AddMetaPaths(BaseTransform):
r"""Add new edges to an input graph based on given metapaths, as described in r"""Add new edges to an input graph based on given metapaths, as described in
`Heterogeneous Graph Attention Network <https://arxiv.org/abs/1903.07293>`__. `Heterogeneous Graph Attention Network <https://arxiv.org/abs/1903.07293>`__.
...@@ -959,6 +1050,7 @@ class AddMetaPaths(BaseTransform): ...@@ -959,6 +1050,7 @@ class AddMetaPaths(BaseTransform):
>>> print(new_g.edges(etype=('person', 'rejected', 'venue'))) >>> print(new_g.edges(etype=('person', 'rejected', 'venue')))
(tensor([0, 1]), tensor([1, 1])) (tensor([0, 1]), tensor([1, 1]))
""" """
def __init__(self, metapaths, keep_orig_edges=True): def __init__(self, metapaths, keep_orig_edges=True):
self.metapaths = metapaths self.metapaths = metapaths
self.keep_orig_edges = keep_orig_edges self.keep_orig_edges = keep_orig_edges
...@@ -981,6 +1073,7 @@ class AddMetaPaths(BaseTransform): ...@@ -981,6 +1073,7 @@ class AddMetaPaths(BaseTransform):
return new_g return new_g
class Compose(BaseTransform): class Compose(BaseTransform):
r"""Create a transform composed of multiple transforms in sequence. r"""Create a transform composed of multiple transforms in sequence.
...@@ -1002,6 +1095,7 @@ class Compose(BaseTransform): ...@@ -1002,6 +1095,7 @@ class Compose(BaseTransform):
>>> print(new_g.edges()) >>> print(new_g.edges())
(tensor([0, 1]), tensor([1, 0])) (tensor([0, 1]), tensor([1, 0]))
""" """
def __init__(self, transforms): def __init__(self, transforms):
self.transforms = transforms self.transforms = transforms
...@@ -1011,8 +1105,9 @@ class Compose(BaseTransform): ...@@ -1011,8 +1105,9 @@ class Compose(BaseTransform):
return g return g
def __repr__(self): def __repr__(self):
args = [' ' + str(transform) for transform in self.transforms] args = [" " + str(transform) for transform in self.transforms]
return self.__class__.__name__ + '([\n' + ',\n'.join(args) + '\n])' return self.__class__.__name__ + "([\n" + ",\n".join(args) + "\n])"
class GCNNorm(BaseTransform): class GCNNorm(BaseTransform):
r"""Apply symmetric adjacency normalization to an input graph and save the result edge r"""Apply symmetric adjacency normalization to an input graph and save the result edge
...@@ -1049,7 +1144,8 @@ class GCNNorm(BaseTransform): ...@@ -1049,7 +1144,8 @@ class GCNNorm(BaseTransform):
>>> print(g.edata['w']) >>> print(g.edata['w'])
tensor([0.3333, 0.6667, 0.0000]) tensor([0.3333, 0.6667, 0.0000])
""" """
def __init__(self, eweight_name='w'):
def __init__(self, eweight_name="w"):
self.eweight_name = eweight_name self.eweight_name = eweight_name
def calc_etype(self, c_etype, g): def calc_etype(self, c_etype, g):
...@@ -1062,18 +1158,30 @@ class GCNNorm(BaseTransform): ...@@ -1062,18 +1158,30 @@ class GCNNorm(BaseTransform):
ntype = c_etype[0] ntype = c_etype[0]
with g.local_scope(): with g.local_scope():
if self.eweight_name in g.edges[c_etype].data: if self.eweight_name in g.edges[c_etype].data:
g.update_all(fn.copy_e(self.eweight_name, 'm'), fn.sum('m', 'deg'), etype=c_etype) g.update_all(
deg_inv_sqrt = 1. / F.sqrt(g.nodes[ntype].data['deg']) fn.copy_e(self.eweight_name, "m"),
g.nodes[ntype].data['w'] = F.replace_inf_with_zero(deg_inv_sqrt) fn.sum("m", "deg"),
g.apply_edges(lambda edge: {'w': edge.src['w'] * edge.data[self.eweight_name] * etype=c_etype,
edge.dst['w']}, )
etype=c_etype) deg_inv_sqrt = 1.0 / F.sqrt(g.nodes[ntype].data["deg"])
g.nodes[ntype].data["w"] = F.replace_inf_with_zero(deg_inv_sqrt)
g.apply_edges(
lambda edge: {
"w": edge.src["w"]
* edge.data[self.eweight_name]
* edge.dst["w"]
},
etype=c_etype,
)
else: else:
deg = g.in_degrees(etype=c_etype) deg = g.in_degrees(etype=c_etype)
deg_inv_sqrt = 1. / F.sqrt(F.astype(deg, F.float32)) deg_inv_sqrt = 1.0 / F.sqrt(F.astype(deg, F.float32))
g.nodes[ntype].data['w'] = F.replace_inf_with_zero(deg_inv_sqrt) g.nodes[ntype].data["w"] = F.replace_inf_with_zero(deg_inv_sqrt)
g.apply_edges(lambda edges: {'w': edges.src['w'] * edges.dst['w']}, etype=c_etype) g.apply_edges(
return g.edges[c_etype].data['w'] lambda edges: {"w": edges.src["w"] * edges.dst["w"]},
etype=c_etype,
)
return g.edges[c_etype].data["w"]
def __call__(self, g): def __call__(self, g):
result = dict() result = dict()
...@@ -1086,6 +1194,7 @@ class GCNNorm(BaseTransform): ...@@ -1086,6 +1194,7 @@ class GCNNorm(BaseTransform):
g.edges[c_etype].data[self.eweight_name] = eweight g.edges[c_etype].data[self.eweight_name] = eweight
return g return g
class PPR(BaseTransform): class PPR(BaseTransform):
r"""Apply personalized PageRank (PPR) to an input graph for diffusion, as introduced in r"""Apply personalized PageRank (PPR) to an input graph for diffusion, as introduced in
`The pagerank citation ranking: Bringing order to the web `The pagerank citation ranking: Bringing order to the web
...@@ -1128,19 +1237,19 @@ class PPR(BaseTransform): ...@@ -1128,19 +1237,19 @@ class PPR(BaseTransform):
tensor([0.1500, 0.1500, 0.1500, 0.0255, 0.0163, 0.1500, 0.0638, 0.0383, 0.1500, tensor([0.1500, 0.1500, 0.1500, 0.0255, 0.0163, 0.1500, 0.0638, 0.0383, 0.1500,
0.0510, 0.0217, 0.1500]) 0.0510, 0.0217, 0.1500])
""" """
def __init__(self, alpha=0.15, eweight_name='w', eps=None, avg_degree=5):
def __init__(self, alpha=0.15, eweight_name="w", eps=None, avg_degree=5):
self.alpha = alpha self.alpha = alpha
self.eweight_name = eweight_name self.eweight_name = eweight_name
self.eps = eps self.eps = eps
self.avg_degree = avg_degree self.avg_degree = avg_degree
def get_eps(self, num_nodes, mat): def get_eps(self, num_nodes, mat):
r"""Get the threshold for graph sparsification. r"""Get the threshold for graph sparsification."""
"""
if self.eps is None: if self.eps is None:
# Infer from self.avg_degree # Infer from self.avg_degree
if self.avg_degree > num_nodes: if self.avg_degree > num_nodes:
return float('-inf') return float("-inf")
sorted_weights = torch.sort(mat.flatten(), descending=True).values sorted_weights = torch.sort(mat.flatten(), descending=True).values
return sorted_weights[self.avg_degree * num_nodes - 1] return sorted_weights[self.avg_degree * num_nodes - 1]
else: else:
...@@ -1150,8 +1259,9 @@ class PPR(BaseTransform): ...@@ -1150,8 +1259,9 @@ class PPR(BaseTransform):
# Step1: PPR diffusion # Step1: PPR diffusion
# (α - 1) A # (α - 1) A
device = g.device device = g.device
eweight = (self.alpha - 1) * g.edata.get(self.eweight_name, F.ones( eweight = (self.alpha - 1) * g.edata.get(
(g.num_edges(),), F.float32, device)) self.eweight_name, F.ones((g.num_edges(),), F.float32, device)
)
num_nodes = g.num_nodes() num_nodes = g.num_nodes()
mat = F.zeros((num_nodes, num_nodes), F.float32, device) mat = F.zeros((num_nodes, num_nodes), F.float32, device)
src, dst = g.edges() src, dst = g.edges()
...@@ -1173,6 +1283,7 @@ class PPR(BaseTransform): ...@@ -1173,6 +1283,7 @@ class PPR(BaseTransform):
return new_g return new_g
def is_bidirected(g): def is_bidirected(g):
"""Return whether the graph is a bidirected graph. """Return whether the graph is a bidirected graph.
...@@ -1194,6 +1305,7 @@ def is_bidirected(g): ...@@ -1194,6 +1305,7 @@ def is_bidirected(g):
return F.allclose(src1, dst2) and F.allclose(src2, dst1) return F.allclose(src1, dst2) and F.allclose(src2, dst1)
# pylint: disable=C0103 # pylint: disable=C0103
class HeatKernel(BaseTransform): class HeatKernel(BaseTransform):
r"""Apply heat kernel to an input graph for diffusion, as introduced in r"""Apply heat kernel to an input graph for diffusion, as introduced in
...@@ -1237,19 +1349,19 @@ class HeatKernel(BaseTransform): ...@@ -1237,19 +1349,19 @@ class HeatKernel(BaseTransform):
tensor([0.1353, 0.1353, 0.1353, 0.0541, 0.0406, 0.1353, 0.1353, 0.0812, 0.1353, tensor([0.1353, 0.1353, 0.1353, 0.0541, 0.0406, 0.1353, 0.1353, 0.0812, 0.1353,
0.1083, 0.0541, 0.1353]) 0.1083, 0.0541, 0.1353])
""" """
def __init__(self, t=2., eweight_name='w', eps=None, avg_degree=5):
def __init__(self, t=2.0, eweight_name="w", eps=None, avg_degree=5):
self.t = t self.t = t
self.eweight_name = eweight_name self.eweight_name = eweight_name
self.eps = eps self.eps = eps
self.avg_degree = avg_degree self.avg_degree = avg_degree
def get_eps(self, num_nodes, mat): def get_eps(self, num_nodes, mat):
r"""Get the threshold for graph sparsification. r"""Get the threshold for graph sparsification."""
"""
if self.eps is None: if self.eps is None:
# Infer from self.avg_degree # Infer from self.avg_degree
if self.avg_degree > num_nodes: if self.avg_degree > num_nodes:
return float('-inf') return float("-inf")
sorted_weights = torch.sort(mat.flatten(), descending=True).values sorted_weights = torch.sort(mat.flatten(), descending=True).values
return sorted_weights[self.avg_degree * num_nodes - 1] return sorted_weights[self.avg_degree * num_nodes - 1]
else: else:
...@@ -1259,8 +1371,9 @@ class HeatKernel(BaseTransform): ...@@ -1259,8 +1371,9 @@ class HeatKernel(BaseTransform):
# Step1: heat kernel diffusion # Step1: heat kernel diffusion
# t A # t A
device = g.device device = g.device
eweight = self.t * g.edata.get(self.eweight_name, F.ones( eweight = self.t * g.edata.get(
(g.num_edges(),), F.float32, device)) self.eweight_name, F.ones((g.num_edges(),), F.float32, device)
)
num_nodes = g.num_nodes() num_nodes = g.num_nodes()
mat = F.zeros((num_nodes, num_nodes), F.float32, device) mat = F.zeros((num_nodes, num_nodes), F.float32, device)
src, dst = g.edges() src, dst = g.edges()
...@@ -1271,7 +1384,7 @@ class HeatKernel(BaseTransform): ...@@ -1271,7 +1384,7 @@ class HeatKernel(BaseTransform):
mat[nids, nids] = mat[nids, nids] - self.t mat[nids, nids] = mat[nids, nids] - self.t
if is_bidirected(g): if is_bidirected(g):
e, V = torch.linalg.eigh(mat, UPLO='U') e, V = torch.linalg.eigh(mat, UPLO="U")
diff_mat = V @ torch.diag(e.exp()) @ V.t() diff_mat = V @ torch.diag(e.exp()) @ V.t()
else: else:
diff_mat_np = expm(mat.cpu().numpy()) diff_mat_np = expm(mat.cpu().numpy())
...@@ -1287,6 +1400,7 @@ class HeatKernel(BaseTransform): ...@@ -1287,6 +1400,7 @@ class HeatKernel(BaseTransform):
return new_g return new_g
class GDC(BaseTransform): class GDC(BaseTransform):
r"""Apply graph diffusion convolution (GDC) to an input graph, as introduced in r"""Apply graph diffusion convolution (GDC) to an input graph, as introduced in
`Diffusion Improves Graph Learning <https://www.in.tum.de/daml/gdc/>`__. `Diffusion Improves Graph Learning <https://www.in.tum.de/daml/gdc/>`__.
...@@ -1328,7 +1442,8 @@ class GDC(BaseTransform): ...@@ -1328,7 +1442,8 @@ class GDC(BaseTransform):
tensor([0.3000, 0.3000, 0.0200, 0.3000, 0.0400, 0.3000, 0.1000, 0.0600, 0.3000, tensor([0.3000, 0.3000, 0.0200, 0.3000, 0.0400, 0.3000, 0.1000, 0.0600, 0.3000,
0.0800, 0.0200, 0.3000]) 0.0800, 0.0200, 0.3000])
""" """
def __init__(self, coefs, eweight_name='w', eps=None, avg_degree=5):
def __init__(self, coefs, eweight_name="w", eps=None, avg_degree=5):
self.coefs = coefs self.coefs = coefs
self.eweight_name = eweight_name self.eweight_name = eweight_name
self.eps = eps self.eps = eps
...@@ -1339,7 +1454,7 @@ class GDC(BaseTransform): ...@@ -1339,7 +1454,7 @@ class GDC(BaseTransform):
if self.eps is None: if self.eps is None:
# Infer from self.avg_degree # Infer from self.avg_degree
if self.avg_degree > num_nodes: if self.avg_degree > num_nodes:
return float('-inf') return float("-inf")
sorted_weights = torch.sort(mat.flatten(), descending=True).values sorted_weights = torch.sort(mat.flatten(), descending=True).values
return sorted_weights[self.avg_degree * num_nodes - 1] return sorted_weights[self.avg_degree * num_nodes - 1]
else: else:
...@@ -1349,8 +1464,9 @@ class GDC(BaseTransform): ...@@ -1349,8 +1464,9 @@ class GDC(BaseTransform):
# Step1: diffusion # Step1: diffusion
# A # A
device = g.device device = g.device
eweight = g.edata.get(self.eweight_name, F.ones( eweight = g.edata.get(
(g.num_edges(),), F.float32, device)) self.eweight_name, F.ones((g.num_edges(),), F.float32, device)
)
num_nodes = g.num_nodes() num_nodes = g.num_nodes()
adj = F.zeros((num_nodes, num_nodes), F.float32, device) adj = F.zeros((num_nodes, num_nodes), F.float32, device)
src, dst = g.edges() src, dst = g.edges()
...@@ -1375,6 +1491,7 @@ class GDC(BaseTransform): ...@@ -1375,6 +1491,7 @@ class GDC(BaseTransform):
return new_g return new_g
class NodeShuffle(BaseTransform): class NodeShuffle(BaseTransform):
r"""Randomly shuffle the nodes. r"""Randomly shuffle the nodes.
...@@ -1399,6 +1516,7 @@ class NodeShuffle(BaseTransform): ...@@ -1399,6 +1516,7 @@ class NodeShuffle(BaseTransform):
[ 9., 10.], [ 9., 10.],
[ 7., 8.]]) [ 7., 8.]])
""" """
def __call__(self, g): def __call__(self, g):
g = g.clone() g = g.clone()
for ntype in g.ntypes: for ntype in g.ntypes:
...@@ -1408,6 +1526,7 @@ class NodeShuffle(BaseTransform): ...@@ -1408,6 +1526,7 @@ class NodeShuffle(BaseTransform):
g.nodes[ntype].data[key] = feat[perm] g.nodes[ntype].data[key] = feat[perm]
return g return g
# pylint: disable=C0103 # pylint: disable=C0103
class DropNode(BaseTransform): class DropNode(BaseTransform):
r"""Randomly drop nodes, as described in r"""Randomly drop nodes, as described in
...@@ -1439,6 +1558,7 @@ class DropNode(BaseTransform): ...@@ -1439,6 +1558,7 @@ class DropNode(BaseTransform):
>>> print(new_g.edata['h']) >>> print(new_g.edata['h'])
tensor([0, 6, 14, 5, 17, 3, 11]) tensor([0, 6, 14, 5, 17, 3, 11])
""" """
def __init__(self, p=0.5): def __init__(self, p=0.5):
self.p = p self.p = p
self.dist = Bernoulli(p) self.dist = Bernoulli(p)
...@@ -1456,6 +1576,7 @@ class DropNode(BaseTransform): ...@@ -1456,6 +1576,7 @@ class DropNode(BaseTransform):
g.remove_nodes(nids_to_remove, ntype=ntype) g.remove_nodes(nids_to_remove, ntype=ntype)
return g return g
# pylint: disable=C0103 # pylint: disable=C0103
class DropEdge(BaseTransform): class DropEdge(BaseTransform):
r"""Randomly drop edges, as described in r"""Randomly drop edges, as described in
...@@ -1486,6 +1607,7 @@ class DropEdge(BaseTransform): ...@@ -1486,6 +1607,7 @@ class DropEdge(BaseTransform):
>>> print(new_g.edata['h']) >>> print(new_g.edata['h'])
tensor([0, 1, 3, 7, 8, 10, 11, 12, 13, 15, 18, 19]) tensor([0, 1, 3, 7, 8, 10, 11, 12, 13, 15, 18, 19])
""" """
def __init__(self, p=0.5): def __init__(self, p=0.5):
self.p = p self.p = p
self.dist = Bernoulli(p) self.dist = Bernoulli(p)
...@@ -1499,10 +1621,13 @@ class DropEdge(BaseTransform): ...@@ -1499,10 +1621,13 @@ class DropEdge(BaseTransform):
for c_etype in g.canonical_etypes: for c_etype in g.canonical_etypes:
samples = self.dist.sample(torch.Size([g.num_edges(c_etype)])) samples = self.dist.sample(torch.Size([g.num_edges(c_etype)]))
eids_to_remove = g.edges(form='eid', etype=c_etype)[samples.bool().to(g.device)] eids_to_remove = g.edges(form="eid", etype=c_etype)[
samples.bool().to(g.device)
]
g.remove_edges(eids_to_remove, etype=c_etype) g.remove_edges(eids_to_remove, etype=c_etype)
return g return g
class AddEdge(BaseTransform): class AddEdge(BaseTransform):
r"""Randomly add edges, as described in `Graph Contrastive Learning with Augmentations r"""Randomly add edges, as described in `Graph Contrastive Learning with Augmentations
<https://arxiv.org/abs/2010.13902>`__. <https://arxiv.org/abs/2010.13902>`__.
...@@ -1524,12 +1649,13 @@ class AddEdge(BaseTransform): ...@@ -1524,12 +1649,13 @@ class AddEdge(BaseTransform):
>>> print(new_g.num_edges()) >>> print(new_g.num_edges())
24 24
""" """
def __init__(self, ratio=0.2): def __init__(self, ratio=0.2):
self.ratio = ratio self.ratio = ratio
def __call__(self, g): def __call__(self, g):
# Fast path # Fast path
if self.ratio == 0.: if self.ratio == 0.0:
return g return g
device = g.device device = g.device
...@@ -1538,11 +1664,24 @@ class AddEdge(BaseTransform): ...@@ -1538,11 +1664,24 @@ class AddEdge(BaseTransform):
for c_etype in g.canonical_etypes: for c_etype in g.canonical_etypes:
utype, _, vtype = c_etype utype, _, vtype = c_etype
num_edges_to_add = int(g.num_edges(c_etype) * self.ratio) num_edges_to_add = int(g.num_edges(c_etype) * self.ratio)
src = F.randint([num_edges_to_add], idtype, device, low=0, high=g.num_nodes(utype)) src = F.randint(
dst = F.randint([num_edges_to_add], idtype, device, low=0, high=g.num_nodes(vtype)) [num_edges_to_add],
idtype,
device,
low=0,
high=g.num_nodes(utype),
)
dst = F.randint(
[num_edges_to_add],
idtype,
device,
low=0,
high=g.num_nodes(vtype),
)
g.add_edges(src, dst, etype=c_etype) g.add_edges(src, dst, etype=c_etype)
return g return g
class SIGNDiffusion(BaseTransform): class SIGNDiffusion(BaseTransform):
r"""The diffusion operator from `SIGN: Scalable Inception Graph Neural Networks r"""The diffusion operator from `SIGN: Scalable Inception Graph Neural Networks
<https://arxiv.org/abs/2004.11198>`__ <https://arxiv.org/abs/2004.11198>`__
...@@ -1609,13 +1748,16 @@ class SIGNDiffusion(BaseTransform): ...@@ -1609,13 +1748,16 @@ class SIGNDiffusion(BaseTransform):
'out_feat_2': Scheme(shape=(10,), dtype=torch.float32)} 'out_feat_2': Scheme(shape=(10,), dtype=torch.float32)}
edata_schemes={'w': Scheme(shape=(), dtype=torch.float32)}) edata_schemes={'w': Scheme(shape=(), dtype=torch.float32)})
""" """
def __init__(self,
k, def __init__(
in_feat_name='feat', self,
out_feat_name='out_feat', k,
eweight_name=None, in_feat_name="feat",
diffuse_op='raw', out_feat_name="out_feat",
alpha=0.2): eweight_name=None,
diffuse_op="raw",
alpha=0.2,
):
self.k = k self.k = k
self.in_feat_name = in_feat_name self.in_feat_name = in_feat_name
self.out_feat_name = out_feat_name self.out_feat_name = out_feat_name
...@@ -1623,23 +1765,27 @@ class SIGNDiffusion(BaseTransform): ...@@ -1623,23 +1765,27 @@ class SIGNDiffusion(BaseTransform):
self.diffuse_op = diffuse_op self.diffuse_op = diffuse_op
self.alpha = alpha self.alpha = alpha
if diffuse_op == 'raw': if diffuse_op == "raw":
self.diffuse = self.raw self.diffuse = self.raw
elif diffuse_op == 'rw': elif diffuse_op == "rw":
self.diffuse = self.rw self.diffuse = self.rw
elif diffuse_op == 'gcn': elif diffuse_op == "gcn":
self.diffuse = self.gcn self.diffuse = self.gcn
elif diffuse_op == 'ppr': elif diffuse_op == "ppr":
self.diffuse = self.ppr self.diffuse = self.ppr
else: else:
raise DGLError("Expect diffuse_op to be from ['raw', 'rw', 'gcn', 'ppr'], \ raise DGLError(
got {}".format(diffuse_op)) "Expect diffuse_op to be from ['raw', 'rw', 'gcn', 'ppr'], \
got {}".format(
diffuse_op
)
)
def __call__(self, g): def __call__(self, g):
feat_list = self.diffuse(g) feat_list = self.diffuse(g)
for i in range(1, self.k + 1): for i in range(1, self.k + 1):
g.ndata[self.out_feat_name + '_' + str(i)] = feat_list[i - 1] g.ndata[self.out_feat_name + "_" + str(i)] = feat_list[i - 1]
return g return g
def raw(self, g): def raw(self, g):
...@@ -1650,11 +1796,13 @@ class SIGNDiffusion(BaseTransform): ...@@ -1650,11 +1796,13 @@ class SIGNDiffusion(BaseTransform):
feat_list = [] feat_list = []
with g.local_scope(): with g.local_scope():
if use_eweight: if use_eweight:
message_func = fn.u_mul_e(self.in_feat_name, self.eweight_name, 'm') message_func = fn.u_mul_e(
self.in_feat_name, self.eweight_name, "m"
)
else: else:
message_func = fn.copy_u(self.in_feat_name, 'm') message_func = fn.copy_u(self.in_feat_name, "m")
for _ in range(self.k): for _ in range(self.k):
g.update_all(message_func, fn.sum('m', self.in_feat_name)) g.update_all(message_func, fn.sum("m", self.in_feat_name))
feat_list.append(g.ndata[self.in_feat_name]) feat_list.append(g.ndata[self.in_feat_name])
return feat_list return feat_list
...@@ -1665,28 +1813,32 @@ class SIGNDiffusion(BaseTransform): ...@@ -1665,28 +1813,32 @@ class SIGNDiffusion(BaseTransform):
feat_list = [] feat_list = []
with g.local_scope(): with g.local_scope():
g.ndata['h'] = g.ndata[self.in_feat_name] g.ndata["h"] = g.ndata[self.in_feat_name]
if use_eweight: if use_eweight:
message_func = fn.u_mul_e('h', self.eweight_name, 'm') message_func = fn.u_mul_e("h", self.eweight_name, "m")
reduce_func = fn.sum('m', 'h') reduce_func = fn.sum("m", "h")
# Compute the diagonal entries of D from the weighted A # Compute the diagonal entries of D from the weighted A
g.update_all(fn.copy_e(self.eweight_name, 'm'), fn.sum('m', 'z')) g.update_all(
fn.copy_e(self.eweight_name, "m"), fn.sum("m", "z")
)
else: else:
message_func = fn.copy_u('h', 'm') message_func = fn.copy_u("h", "m")
reduce_func = fn.mean('m', 'h') reduce_func = fn.mean("m", "h")
for _ in range(self.k): for _ in range(self.k):
g.update_all(message_func, reduce_func) g.update_all(message_func, reduce_func)
if use_eweight: if use_eweight:
g.ndata['h'] = g.ndata['h'] / F.reshape(g.ndata['z'], (g.num_nodes(), 1)) g.ndata["h"] = g.ndata["h"] / F.reshape(
feat_list.append(g.ndata['h']) g.ndata["z"], (g.num_nodes(), 1)
)
feat_list.append(g.ndata["h"])
return feat_list return feat_list
def gcn(self, g): def gcn(self, g):
feat_list = [] feat_list = []
with g.local_scope(): with g.local_scope():
if self.eweight_name is None: if self.eweight_name is None:
eweight_name = 'w' eweight_name = "w"
if eweight_name in g.edata: if eweight_name in g.edata:
g.edata.pop(eweight_name) g.edata.pop(eweight_name)
else: else:
...@@ -1696,8 +1848,10 @@ class SIGNDiffusion(BaseTransform): ...@@ -1696,8 +1848,10 @@ class SIGNDiffusion(BaseTransform):
transform(g) transform(g)
for _ in range(self.k): for _ in range(self.k):
g.update_all(fn.u_mul_e(self.in_feat_name, eweight_name, 'm'), g.update_all(
fn.sum('m', self.in_feat_name)) fn.u_mul_e(self.in_feat_name, eweight_name, "m"),
fn.sum("m", self.in_feat_name),
)
feat_list.append(g.ndata[self.in_feat_name]) feat_list.append(g.ndata[self.in_feat_name])
return feat_list return feat_list
...@@ -1705,7 +1859,7 @@ class SIGNDiffusion(BaseTransform): ...@@ -1705,7 +1859,7 @@ class SIGNDiffusion(BaseTransform):
feat_list = [] feat_list = []
with g.local_scope(): with g.local_scope():
if self.eweight_name is None: if self.eweight_name is None:
eweight_name = 'w' eweight_name = "w"
if eweight_name in g.edata: if eweight_name in g.edata:
g.edata.pop(eweight_name) g.edata.pop(eweight_name)
else: else:
...@@ -1715,13 +1869,17 @@ class SIGNDiffusion(BaseTransform): ...@@ -1715,13 +1869,17 @@ class SIGNDiffusion(BaseTransform):
in_feat = g.ndata[self.in_feat_name] in_feat = g.ndata[self.in_feat_name]
for _ in range(self.k): for _ in range(self.k):
g.update_all(fn.u_mul_e(self.in_feat_name, eweight_name, 'm'), g.update_all(
fn.sum('m', self.in_feat_name)) fn.u_mul_e(self.in_feat_name, eweight_name, "m"),
g.ndata[self.in_feat_name] = (1 - self.alpha) * g.ndata[self.in_feat_name] +\ fn.sum("m", self.in_feat_name),
self.alpha * in_feat )
g.ndata[self.in_feat_name] = (1 - self.alpha) * g.ndata[
self.in_feat_name
] + self.alpha * in_feat
feat_list.append(g.ndata[self.in_feat_name]) feat_list.append(g.ndata[self.in_feat_name])
return feat_list return feat_list
class ToLevi(BaseTransform): class ToLevi(BaseTransform):
r"""This function transforms the original graph to its heterogeneous Levi graph, r"""This function transforms the original graph to its heterogeneous Levi graph,
by converting edges to intermediate nodes, only support homogeneous directed graph. by converting edges to intermediate nodes, only support homogeneous directed graph.
...@@ -1777,8 +1935,10 @@ class ToLevi(BaseTransform): ...@@ -1777,8 +1935,10 @@ class ToLevi(BaseTransform):
edge_list = g.edges() edge_list = g.edges()
n2e = edge_list[0], F.arange(0, g.num_edges(), idtype, device) n2e = edge_list[0], F.arange(0, g.num_edges(), idtype, device)
e2n = F.arange(0, g.num_edges(), idtype, device), edge_list[1] e2n = F.arange(0, g.num_edges(), idtype, device), edge_list[1]
graph_data = {('node', 'n2e', 'edge'): n2e, graph_data = {
('edge', 'e2n', 'node'): e2n} ("node", "n2e", "edge"): n2e,
("edge", "e2n", "node"): e2n,
}
levi_g = convert.heterograph(graph_data, idtype=idtype, device=device) levi_g = convert.heterograph(graph_data, idtype=idtype, device=device)
# Copy ndata and edata # Copy ndata and edata
...@@ -1786,7 +1946,7 @@ class ToLevi(BaseTransform): ...@@ -1786,7 +1946,7 @@ class ToLevi(BaseTransform):
# ('edge' < 'node'), edge_frames should be in front of node_frames. # ('edge' < 'node'), edge_frames should be in front of node_frames.
node_frames = utils.extract_node_subframes(g, nodes_or_device=device) node_frames = utils.extract_node_subframes(g, nodes_or_device=device)
edge_frames = utils.extract_edge_subframes(g, edges_or_device=device) edge_frames = utils.extract_edge_subframes(g, edges_or_device=device)
utils.set_new_frames(levi_g, node_frames=edge_frames+node_frames) utils.set_new_frames(levi_g, node_frames=edge_frames + node_frames)
return levi_g return levi_g
...@@ -1833,6 +1993,7 @@ class SVDPE(BaseTransform): ...@@ -1833,6 +1993,7 @@ class SVDPE(BaseTransform):
[-6.3246e-01, -7.6512e-01, -6.3246e-01, 7.6512e-01], [-6.3246e-01, -7.6512e-01, -6.3246e-01, 7.6512e-01],
[ 6.3246e-01, -4.7287e-01, 6.3246e-01, 4.7287e-01]]) [ 6.3246e-01, -4.7287e-01, 6.3246e-01, 4.7287e-01]])
""" """
def __init__(self, k, feat_name="svd_pe", padding=False, random_flip=True): def __init__(self, k, feat_name="svd_pe", padding=False, random_flip=True):
self.k = k self.k = k
self.feat_name = feat_name self.feat_name = feat_name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment