"vscode:/vscode.git/clone" did not exist on "599258f97992e5a47db9408e4e3622805ce1adb5"
Unverified Commit 22167f72 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Refactor] Enable new kernel in all message passing APIs (#1953)

* WIP: frame refactor

* new frame

* simple update_all builtin

* move all subgraph routines into the same file

* sddmm & spmm schedule; node & edge udf

* degree bucketing

* some tricky 0deg corner cases

* bug in frame append

* merge test_hetero_basics and test_basics

* some code rearange

* fix test_heterograph

* add mean spmm

* enable all builtin combinations

* pass gpu test

* pass pytorch tests

* wip

* fix some pt debugging codes

* fix bug in mxnet backward

* pass all mxnet utests

* passed tf tests

* docstring

* lint

* lint

* fix broadcasting bugs

* add warning and clamp for mean reducer

* add test for zero-degree mean

* address comments

* lint

* small fix
parent 5d5436ba
...@@ -26,6 +26,7 @@ from .convert import * ...@@ -26,6 +26,7 @@ from .convert import *
from .generators import * from .generators import *
from .heterograph import DGLHeteroGraph from .heterograph import DGLHeteroGraph
from .heterograph import DGLHeteroGraph as DGLGraph # pylint: disable=reimported from .heterograph import DGLHeteroGraph as DGLGraph # pylint: disable=reimported
from .subgraph import *
from .traversal import * from .traversal import *
from .transform import * from .transform import *
from .propagate import * from .propagate import *
......
"""Columnar storage for DGLGraph."""
from __future__ import absolute_import
from collections import namedtuple
from collections.abc import MutableMapping
import numpy as np
from .. import backend as F
from ..base import DGLError, dgl_warning
from ..init import zero_initializer
from .. import utils
class Scheme(namedtuple('Scheme', ['shape', 'dtype'])):
"""The column scheme.
Parameters
----------
shape : tuple of int
The feature shape.
dtype : backend-specific type object
The feature data type.
"""
# Pickling torch dtypes could be problemetic; this is a workaround.
# I also have to create data_type_dict and reverse_data_type_dict
# attribute just for this bug.
# I raised an issue in PyTorch bug tracker:
# https://github.com/pytorch/pytorch/issues/14057
def __reduce__(self):
state = (self.shape, F.reverse_data_type_dict[self.dtype])
return self._reconstruct_scheme, state
@classmethod
def _reconstruct_scheme(cls, shape, dtype_str):
dtype = F.data_type_dict[dtype_str]
return cls(shape, dtype)
def infer_scheme(tensor):
"""Infer column scheme from the given tensor data.
Parameters
---------
tensor : Tensor
The tensor data.
Returns
-------
Scheme
The column scheme.
"""
return Scheme(tuple(F.shape(tensor)[1:]), F.dtype(tensor))
class Column(object):
"""A column is a compact store of features of multiple nodes/edges.
Currently, we use one dense tensor to batch all the feature tensors
together (along the first dimension).
Parameters
----------
data : Tensor
The initial data of the column.
scheme : Scheme, optional
The scheme of the column. Will be inferred if not provided.
Attributes
----------
data : Tensor
The data of the column.
scheme : Scheme
The scheme of the column.
"""
def __init__(self, data, scheme=None):
self.data = data
self.scheme = scheme if scheme else infer_scheme(data)
def __len__(self):
"""The column length."""
return F.shape(self.data)[0]
@property
def shape(self):
"""Return the scheme shape (feature shape) of this column."""
return self.scheme.shape
def __getitem__(self, idx):
"""Return the feature data given the index.
Parameters
----------
idx : utils.Index
The index.
Returns
-------
Tensor
The feature data
"""
if idx.slice_data() is not None:
slc = idx.slice_data()
return F.narrow_row(self.data, slc.start, slc.stop)
else:
user_idx = idx.tousertensor(F.context(self.data))
return F.gather_row(self.data, user_idx)
def __setitem__(self, idx, feats):
"""Update the feature data given the index.
The update is performed out-placely so it can be used in autograd mode.
For inplace write, please use ``update``.
Parameters
----------
idx : utils.Index or slice
The index.
feats : Tensor
The new features.
"""
self.update(idx, feats, inplace=False)
def update(self, idx, feats, inplace):
"""Update the feature data given the index.
Parameters
----------
idx : utils.Index
The index.
feats : Tensor
The new features.
inplace : bool
If true, use inplace write.
"""
feat_scheme = infer_scheme(feats)
if feat_scheme != self.scheme:
raise DGLError("Cannot update column of scheme %s using feature of scheme %s."
% (feat_scheme, self.scheme))
if inplace:
idx = idx.tousertensor(F.context(self.data))
F.scatter_row_inplace(self.data, idx, feats)
elif idx.slice_data() is not None:
# for contiguous indices narrow+concat is usually faster than scatter row
slc = idx.slice_data()
parts = [feats]
if slc.start > 0:
parts.insert(0, F.narrow_row(self.data, 0, slc.start))
if slc.stop < len(self):
parts.append(F.narrow_row(self.data, slc.stop, len(self)))
self.data = F.cat(parts, dim=0)
else:
idx = idx.tousertensor(F.context(self.data))
self.data = F.scatter_row(self.data, idx, feats)
def extend(self, feats, feat_scheme=None):
"""Extend the feature data.
Parameters
----------
feats : Tensor
The new features.
feat_scheme : Scheme, optional
The scheme
"""
if feat_scheme is None:
feat_scheme = infer_scheme(feats)
if feat_scheme != self.scheme:
raise DGLError("Cannot update column of scheme %s using feature of scheme %s."
% (feat_scheme, self.scheme))
feats = F.copy_to(feats, F.context(self.data))
self.data = F.cat([self.data, feats], dim=0)
def clone(self):
"""Return a deepcopy of this column."""
return Column(F.clone(self.data), self.scheme)
@staticmethod
def create(data):
"""Create a new column using the given data."""
if isinstance(data, Column):
return Column(data.data, data.scheme)
else:
return Column(data)
def __repr__(self):
return repr(self.data)
class Frame(MutableMapping):
"""The columnar storage for node/edge features.
The frame is a dictionary from feature fields to feature columns.
All columns should have the same number of rows (i.e. the same first dimension).
Parameters
----------
data : dict-like, optional
The frame data in dictionary. If the provided data is another frame,
this frame will NOT share columns with the given frame. So any out-place
update on one will not reflect to the other. The inplace update will
be seen by both. This follows the semantic of python's container.
num_rows : int, optional [default=0]
The number of rows in this frame. If ``data`` is provided and is not empty,
``num_rows`` will be ignored and inferred from the given data.
"""
def __init__(self, data=None, num_rows=0):
if data is None:
self._columns = dict()
self._num_rows = num_rows
else:
# Note that we always create a new column for the given data.
# This avoids two frames accidentally sharing the same column.
self._columns = {k : Column.create(v) for k, v in data.items()}
if isinstance(data, (Frame, FrameRef)):
self._num_rows = data.num_rows
elif len(self._columns) != 0:
self._num_rows = len(next(iter(self._columns.values())))
else:
self._num_rows = num_rows
# sanity check
for name, col in self._columns.items():
if len(col) != self._num_rows:
raise DGLError('Expected all columns to have same # rows (%d), '
'got %d on %r.' % (self._num_rows, len(col), name))
# Initializer for empty values. Initializer is a callable.
# If is none, then a warning will be raised
# in the first call and zero initializer will be used later.
self._initializers = {} # per-column initializers
self._remote_init_builder = None
self._default_initializer = None
def _set_zero_default_initializer(self):
"""Set the default initializer to be zero initializer."""
self._default_initializer = zero_initializer
def get_initializer(self, column=None):
"""Get the initializer for empty values for the given column.
Parameters
----------
column : str
The column
Returns
-------
callable
The initializer
"""
return self._initializers.get(column, self._default_initializer)
def set_initializer(self, initializer, column=None):
"""Set the initializer for empty values, for a given column or all future
columns.
Initializer is a callable that returns a tensor given the shape and data type.
Parameters
----------
initializer : callable
The initializer.
column : str, optional
The column name
"""
if column is None:
self._default_initializer = initializer
else:
self._initializers[column] = initializer
def set_remote_init_builder(self, builder):
"""Set an initializer builder to create a remote initializer for a new column to a frame.
NOTE(minjie): This is a temporary solution. Will be replaced by KVStore in the future.
The builder is a callable that returns an initializer. The returned initializer
is also a callable that returns a tensor given a local tensor and tensor name.
Parameters
----------
builder : callable
The builder to construct a remote initializer.
"""
self._remote_init_builder = builder
def get_remote_initializer(self, name):
"""Get a remote initializer.
NOTE(minjie): This is a temporary solution. Will be replaced by KVStore in the future.
Parameters
----------
name : string
The column name.
"""
if self._remote_init_builder is None:
return None
if self.get_initializer(name) is None:
self._set_zero_default_initializer()
initializer = self.get_initializer(name)
return self._remote_init_builder(initializer, name)
@property
def schemes(self):
"""Return a dictionary of column name to column schemes."""
return {k : col.scheme for k, col in self._columns.items()}
@property
def num_columns(self):
"""Return the number of columns in this frame."""
return len(self._columns)
@property
def num_rows(self):
"""Return the number of rows in this frame."""
return self._num_rows
def __contains__(self, name):
"""Return true if the given column name exists."""
return name in self._columns
def __getitem__(self, name):
"""Return the column of the given name.
Parameters
----------
name : str
The column name.
Returns
-------
Column
The column.
"""
return self._columns[name]
def __setitem__(self, name, data):
"""Update the whole column.
Parameters
----------
name : str
The column name.
col : Column or data convertible to Column
The column data.
"""
self.update_column(name, data)
def __delitem__(self, name):
"""Delete the whole column.
Parameters
----------
name : str
The column name.
"""
del self._columns[name]
def add_column(self, name, scheme, ctx):
"""Add a new column to the frame.
The frame will be initialized by the initializer.
Parameters
----------
name : str
The column name.
scheme : Scheme
The column scheme.
ctx : DGLContext
The column context.
"""
if name in self:
dgl_warning('Column "%s" already exists. Ignore adding this column again.' % name)
return
# If the data is backed by a remote server, we need to move data
# to the remote server.
initializer = self.get_remote_initializer(name)
if initializer is not None:
init_data = initializer((self.num_rows,) + scheme.shape, scheme.dtype, ctx)
else:
if self.get_initializer(name) is None:
self._set_zero_default_initializer()
initializer = self.get_initializer(name)
init_data = initializer((self.num_rows,) + scheme.shape, scheme.dtype,
ctx, slice(0, self.num_rows))
self._columns[name] = Column(init_data, scheme)
def add_rows(self, num_rows):
"""Add blank rows to this frame.
For existing fields, the rows will be extended according to their
initializers.
Parameters
----------
num_rows : int
The number of new rows
"""
feat_placeholders = {}
for key, col in self._columns.items():
scheme = col.scheme
ctx = F.context(col.data)
if self.get_initializer(key) is None:
self._set_zero_default_initializer()
initializer = self.get_initializer(key)
new_data = initializer((num_rows,) + scheme.shape, scheme.dtype,
ctx, slice(self._num_rows, self._num_rows + num_rows))
feat_placeholders[key] = new_data
self._append(Frame(feat_placeholders))
self._num_rows += num_rows
def update_column(self, name, data):
"""Add or replace the column with the given name and data.
Parameters
----------
name : str
The column name.
data : Column or data convertible to Column
The column data.
"""
# If the data is backed by a remote server, we need to move data
# to the remote server.
initializer = self.get_remote_initializer(name)
if initializer is not None:
new_data = initializer(F.shape(data), F.dtype(data), F.context(data))
new_data[:] = data
data = new_data
col = Column.create(data)
if len(col) != self.num_rows:
raise DGLError('Expected data to have %d rows, got %d.' %
(self.num_rows, len(col)))
self._columns[name] = col
def _append(self, other):
assert self._remote_init_builder is None, \
"We don't support append if data in the frame is mapped from a remote server."
# NOTE: `other` can be empty.
if self.num_rows == 0:
# if no rows in current frame; append is equivalent to
# directly updating columns.
self._columns = {key: Column.create(data) for key, data in other.items()}
else:
# pad columns that are not provided in the other frame with initial values
for key, col in self.items():
if key in other:
continue
scheme = col.scheme
ctx = F.context(col.data)
if self.get_initializer(key) is None:
self._set_zero_default_initializer()
initializer = self.get_initializer(key)
new_data = initializer((other.num_rows,) + scheme.shape,
scheme.dtype, ctx,
slice(self._num_rows, self._num_rows + other.num_rows))
other[key] = new_data
# append other to self
for key, col in other.items():
if key not in self._columns:
# the column does not exist; init a new column
self.add_column(key, col.scheme, F.context(col.data))
self._columns[key].extend(col.data, col.scheme)
def append(self, other):
"""Append another frame's data into this frame.
If the current frame is empty, it will just use the columns of the
given frame. Otherwise, the given data should contain all the
column keys of this frame.
Parameters
----------
other : Frame or dict-like
The frame data to be appended.
"""
if not isinstance(other, Frame):
other = Frame(other)
self._append(other)
self._num_rows += other.num_rows
def clear(self):
"""Clear this frame. Remove all the columns."""
self._columns = {}
self._num_rows = 0
def __iter__(self):
"""Return an iterator of columns."""
return iter(self._columns)
def __len__(self):
"""Return the number of columns."""
return self.num_columns
def keys(self):
"""Return the keys."""
return self._columns.keys()
def values(self):
"""Return the values."""
return self._columns.values()
def clone(self):
"""Return a clone of this frame.
The clone frame does not share the underlying storage with this frame,
i.e., adding or removing columns will not be visible to each other. However,
they still share the tensor contents so any mutable operation on the column
tensor are visible to each other. Hence, the function does not allocate extra
tensor memory. Use :func:`~dgl.Frame.deepclone` for cloning
a frame that does not share any data.
Returns
-------
Frame
A cloned frame.
"""
newframe = Frame(self._columns, self._num_rows)
newframe._initializers = self._initializers
newframe._remote_init_builder = self._remote_init_builder
newframe._default_initializer = self._default_initializer
return newframe
def deepclone(self):
"""Return a deep clone of this frame.
The clone frame has an copy of this frame and any modification to the clone frame
is not visible to this frame. The function allocate new tensors and copy the contents
from this frame. Use :func:`~dgl.Frame.clone` for cloning a frame that does not
allocate extra tensor memory.
Returns
-------
Frame
A deep-cloned frame.
"""
newframe = Frame({k : col.clone() for k, col in self._columns.items()}, self._num_rows)
newframe._initializers = self._initializers
newframe._remote_init_builder = self._remote_init_builder
newframe._default_initializer = self._default_initializer
return newframe
class FrameRef(MutableMapping):
"""Reference object to a frame on a subset of rows.
Parameters
----------
frame : Frame, optional
The underlying frame. If not given, the reference will point to a
new empty frame.
index : utils.Index, optional
The rows that are referenced in the underlying frame. If not given,
the whole frame is referenced. The index should be distinct (no
duplication is allowed).
"""
def __init__(self, frame=None, index=None):
self._frame = frame if frame is not None else Frame()
# TODO(minjie): check no duplication
assert index is None or isinstance(index, utils.Index)
if index is None:
self._index = utils.toindex(slice(0, self._frame.num_rows))
else:
self._index = index
@property
def schemes(self):
"""Return the frame schemes.
Returns
-------
dict of str to Scheme
The frame schemes.
"""
return self._frame.schemes
@property
def num_columns(self):
"""Return the number of columns in the referred frame."""
return self._frame.num_columns
@property
def num_rows(self):
"""Return the number of rows referred."""
return len(self._index)
def set_initializer(self, initializer, column=None):
"""Set the initializer for empty values.
Initializer is a callable that returns a tensor given the shape and data type.
Parameters
----------
initializer : callable
The initializer.
column : str, optional
The column name
"""
self._frame.set_initializer(initializer, column=column)
def set_remote_init_builder(self, builder):
"""Set an initializer builder to create a remote initializer for a new column to a frame.
NOTE(minjie): This is a temporary solution. Will be replaced by KVStore in the future.
The builder is a callable that returns an initializer. The returned initializer
is also a callable that returns a tensor given a local tensor and tensor name.
Parameters
----------
builder : callable
The builder to construct a remote initializer.
"""
self._frame.set_remote_init_builder(builder)
def get_initializer(self, column=None):
"""Get the initializer for empty values for the given column.
Parameters
----------
column : str
The column
Returns
-------
callable
The initializer
"""
return self._frame.get_initializer(column)
def __contains__(self, name):
"""Return whether the column name exists."""
return name in self._frame
def __iter__(self):
"""Return the iterator of the columns."""
return iter(self._frame)
def __len__(self):
"""Return the number of columns."""
return self.num_columns
def keys(self):
"""Return the keys."""
return self._frame.keys()
def values(self):
"""Return the values."""
return self._frame.values()
def __getitem__(self, key):
"""Get data from the frame.
If the provided key is string, the corresponding column data will be returned.
If the provided key is an index or a slice, the corresponding rows will be selected.
The returned rows are saved in a lazy dictionary so only the real selection happens
when the explicit column name is provided.
Examples (using pytorch)
------------------------
>>> # create a frame of two columns and five rows
>>> f = Frame({'c1' : torch.zeros([5, 2]), 'c2' : torch.ones([5, 2])})
>>> fr = FrameRef(f)
>>> # select the row 1 and 2, the returned `rows` is a lazy dictionary.
>>> rows = fr[Index([1, 2])]
>>> rows['c1'] # only select rows for 'c1' column; 'c2' column is not sliced.
Parameters
----------
key : str or utils.Index
The key.
Returns
-------
Tensor or lazy dict or tensors
Depends on whether it is a column selection or row selection.
"""
if not isinstance(key, (str, utils.Index)):
raise DGLError('Argument "key" must be either str or utils.Index type.')
if isinstance(key, str):
return self.select_column(key)
elif key.is_slice(0, self.num_rows):
# shortcut for selecting all the rows
return self
else:
return self.select_rows(key)
def select_column(self, name):
"""Return the column of the given name.
If only part of the rows are referenced, the fetching the whole column will
also slice out the referenced rows.
Parameters
----------
name : str
The column name.
Returns
-------
Tensor
The column data.
"""
col = self._frame[name]
if self.is_span_whole_column():
return col.data
else:
return col[self._index]
def select_rows(self, query):
"""Return the rows given the query.
Parameters
----------
query : utils.Index or slice
The rows to be selected.
Returns
-------
utils.LazyDict
The lazy dictionary from str to the selected data.
"""
rows = self._getrows(query)
return utils.LazyDict(lambda key: self._frame[key][rows], keys=self.keys())
def __setitem__(self, key, val):
"""Update the data in the frame. The update is done out-of-place.
Parameters
----------
key : str or utils.Index
The key.
val : Tensor or dict of tensors
The value.
See Also
--------
update
"""
self.update_data(key, val, inplace=False)
def update_data(self, key, val, inplace):
"""Update the data in the frame.
If the provided key is string, the corresponding column data will be updated.
The provided value should be one tensor that have the same scheme and length
as the column.
If the provided key is an index, the corresponding rows will be updated. The
value provided should be a dictionary of string to the data of each column.
All updates are performed out-placely to be work with autograd. For inplace
update, use ``update_column`` or ``update_rows``.
Parameters
----------
key : str or utils.Index
The key.
val : Tensor or dict of tensors
The value.
inplace: bool
If True, update will be done in place
"""
if not isinstance(key, (str, utils.Index)):
raise DGLError('Argument "key" must be either str or utils.Index type.')
if isinstance(key, str):
self.update_column(key, val, inplace=inplace)
elif key.is_slice(0, self.num_rows):
# shortcut for updating all the rows
for colname, col in val.items():
self.update_column(colname, col, inplace=inplace)
else:
self.update_rows(key, val, inplace=inplace)
def update_column(self, name, data, inplace):
"""Update the column.
If this frameref spans the whole column of the underlying frame, this is
equivalent to update the column of the frame.
If this frameref only points to part of the rows, then update the column
here will correspond to update part of the column in the frame. Raise error
if the given column name does not exist.
Parameters
----------
name : str
The column name.
data : Tensor
The update data.
inplace : bool
True if the update is performed inplacely.
"""
if self.is_span_whole_column():
if self.num_columns == 0:
# the frame is empty
self._index = utils.toindex(slice(0, len(data)))
self._frame[name] = data
else:
if name not in self._frame:
ctx = F.context(data)
self._frame.add_column(name, infer_scheme(data), ctx)
fcol = self._frame[name]
fcol.update(self._index, data, inplace)
def add_rows(self, num_rows):
"""Add blank rows to the underlying frame.
For existing fields, the rows will be extended according to their
initializers.
Note: only available for FrameRef that spans the whole column. The row
span will extend to new rows. Other FrameRefs referencing the same
frame will not be affected.
Parameters
----------
num_rows : int
Number of rows to add
"""
if not self.is_span_whole_column():
raise RuntimeError('FrameRef not spanning whole column.')
self._frame.add_rows(num_rows)
if self._index.slice_data() is not None:
# the index is a slice
slc = self._index.slice_data()
self._index = utils.toindex(slice(slc.start, slc.stop + num_rows))
else:
selfidxdata = self._index.tousertensor()
newdata = F.arange(self.num_rows, self.num_rows + num_rows)
self._index = utils.toindex(F.cat([selfidxdata, newdata], dim=0))
def update_rows(self, query, data, inplace):
"""Update the rows.
If the provided data has new column, it will be added to the frame.
See Also
--------
``update_column``
Parameters
----------
query : utils.Index or slice
The rows to be updated.
data : dict-like
The row data.
inplace : bool
True if the update is performed inplace.
"""
rows = self._getrows(query)
for key, col in data.items():
if key not in self:
# add new column
tmpref = FrameRef(self._frame, rows)
tmpref.update_column(key, col, inplace)
else:
self._frame[key].update(rows, col, inplace)
def __delitem__(self, key):
"""Delete data in the frame.
If the provided key is a string, the corresponding column will be deleted.
If the provided key is an index object or a slice, the corresponding rows will
be deleted.
Please note that "deleted" rows are not really deleted, but simply removed
in the reference. As a result, if two FrameRefs point to the same Frame, deleting
from one ref will not reflect on the other. However, deleting columns is real.
Parameters
----------
key : str or utils.Index
The key.
"""
if not isinstance(key, (str, utils.Index)):
raise DGLError('Argument "key" must be either str or utils.Index type.')
if isinstance(key, str):
del self._frame[key]
else:
self.delete_rows(key)
def delete_rows(self, query):
"""Delete rows.
Please note that "deleted" rows are not really deleted, but simply removed
in the reference. As a result, if two FrameRefs point to the same Frame, deleting
from one ref will not reflect on the other. By contrast, deleting columns is real.
Parameters
----------
query : utils.Index
The rows to be deleted.
"""
query = query.tonumpy()
index = self._index.tonumpy()
self._index = utils.toindex(np.delete(index, query))
def append(self, other):
"""Append another frame into this one.
Parameters
----------
other : dict of str to tensor
The data to be appended.
"""
old_nrows = self._frame.num_rows
self._frame.append(other)
new_nrows = self._frame.num_rows
# update index
if (self._index.slice_data() is not None
and self._index.slice_data().stop == old_nrows):
# Self index is a slice and index.stop is equal to the size of the
# underlying frame. Can still use a slice for the new index.
oldstart = self._index.slice_data().start
self._index = utils.toindex(slice(oldstart, new_nrows))
else:
# convert it to user tensor and concat
selfidxdata = self._index.tousertensor()
newdata = F.arange(old_nrows, new_nrows)
self._index = utils.toindex(F.cat([selfidxdata, newdata], dim=0))
def clear(self):
"""Clear the frame."""
self._frame.clear()
self._index = utils.toindex(slice(0, 0))
def is_contiguous(self):
"""Return whether this refers to a contiguous range of rows."""
# NOTE: this check could have false negatives
return self._index.slice_data() is not None
def is_span_whole_column(self):
"""Return whether this refers to all the rows."""
return self.is_contiguous() and self.num_rows == self._frame.num_rows
def clone(self):
"""Return a new reference to a clone of the underlying frame.
Returns
-------
FrameRef
A cloned frame reference.
See Also
--------
dgl.Frame.clone
"""
return FrameRef(self._frame.clone(), self._index)
def deepclone(self):
"""Return a new reference to a deep clone of the underlying frame.
Returns
-------
FrameRef
A deep-cloned frame reference.
See Also
--------
dgl.Frame.deepclone
"""
return FrameRef(self._frame.deepclone(), self._index)
def _getrows(self, query):
"""Internal function to convert from the local row ids to the row ids of the frame.
Parameters
----------
query : utils.Index
The query index.
Returns
-------
utils.Index
The actual index to the underlying frame.
"""
return self._index.get_items(query)
def frame_like(other, num_rows=None):
"""Create an empty frame that has the same initializer as the given one.
Parameters
----------
other : Frame
The given frame.
num_rows : int
The number of rows of the new one. If None, use other.num_rows
(Default: None)
Returns
-------
Frame
The new frame.
"""
num_rows = other.num_rows if num_rows is None else num_rows
newf = Frame(num_rows=num_rows)
# set global initializr
if other.get_initializer() is None:
other._set_zero_default_initializer()
sync_frame_initializer(newf, other)
return newf
def sync_frame_initializer(new_frame, reference_frame):
"""Set the initializers of the new_frame to be the same as the reference_frame,
for both the default initializer and per-column initializers.
Parameters
----------
new_frame : Frame
The frame to set initializers
reference_frame : Frame
The frame to copy initializers
"""
new_frame._default_initializer = reference_frame._default_initializer
# set per-col initializer
# TODO(minjie): hack; cannot rely on keys as the _initializers
# now supports non-exist columns.
new_frame._initializers = reference_frame._initializers
...@@ -12,7 +12,7 @@ import dgl ...@@ -12,7 +12,7 @@ import dgl
from ..base import ALL, NID, EID, is_all, DGLError, dgl_warning from ..base import ALL, NID, EID, is_all, DGLError, dgl_warning
from .. import backend as F from .. import backend as F
from .. import init from .. import init
from ..frame import FrameRef, Frame, Scheme, sync_frame_initializer from .frame import FrameRef, Frame, Scheme, sync_frame_initializer
from .. import graph_index from .. import graph_index
from .runtime import ir, scheduler, Runtime, GraphAdapter from .runtime import ir, scheduler, Runtime, GraphAdapter
from .. import utils from .. import utils
......
...@@ -5,7 +5,7 @@ from .._ffi.object import register_object, ObjectBase ...@@ -5,7 +5,7 @@ from .._ffi.object import register_object, ObjectBase
from .._ffi.function import _init_api from .._ffi.function import _init_api
from ..base import ALL, is_all, DGLError, dgl_warning from ..base import ALL, is_all, DGLError, dgl_warning
from .. import backend as F from .. import backend as F
from ..frame import Frame, FrameRef from .frame import Frame, FrameRef
from .graph import DGLBaseGraph from .graph import DGLBaseGraph
from ..graph_index import transform_ids from ..graph_index import transform_ids
from .runtime import ir, scheduler, Runtime from .runtime import ir, scheduler, Runtime
......
...@@ -5,7 +5,7 @@ from __future__ import absolute_import ...@@ -5,7 +5,7 @@ from __future__ import absolute_import
from abc import abstractmethod from abc import abstractmethod
from .... import backend as F from .... import backend as F
from ....frame import FrameRef, Frame from ...frame import FrameRef, Frame
from .... import utils from .... import utils
from .program import get_current_prog from .program import get_current_prog
......
...@@ -5,7 +5,7 @@ from ... import utils ...@@ -5,7 +5,7 @@ from ... import utils
from ..._ffi.function import _init_api from ..._ffi.function import _init_api
from ...base import DGLError from ...base import DGLError
from ... import backend as F from ... import backend as F
from ...frame import frame_like, FrameRef from ..frame import frame_like, FrameRef
from ...function.base import BuiltinFunction from ...function.base import BuiltinFunction
from ..udf import EdgeBatch, NodeBatch from ..udf import EdgeBatch, NodeBatch
from ... import ndarray as nd from ... import ndarray as nd
......
...@@ -971,58 +971,6 @@ def pack_padded_tensor(input, lengths): ...@@ -971,58 +971,6 @@ def pack_padded_tensor(input, lengths):
""" """
pass pass
def unsorted_1d_segment_sum(input, seg_id, n_segs, dim):
"""Computes the sum along segments of a tensor.
Equivalent to tf.unsorted_segment_sum, but seg_id is required to be a
1D tensor.
Parameters
----------
input : Tensor
The input tensor
seg_id : 1D Tensor
The segment IDs whose values are between 0 and n_segs - 1. Should
have the same length as input.
n_segs : int
Number of distinct segments
dim : int
Dimension to sum on
Returns
-------
Tensor
The result
"""
pass
def unsorted_1d_segment_mean(input, seg_id, n_segs, dim):
"""Computes the mean along segments of a tensor.
Equivalent to tf.unsorted_segment_mean, but seg_id is required to be a
1D tensor.
Note that segments never appeared in seg_id will have results of 0.
Parameters
----------
input : Tensor
The input tensor
seg_id : 1D Tensor
The segment IDs whose values are between 0 and n_segs - 1. Should
have the same length as input.
n_segs : int
Number of distinct segments
dim : int
Dimension to average on
Returns
-------
Tensor
The result
"""
pass
def boolean_mask(input, mask): def boolean_mask(input, mask):
"""Selects elements in x according to the given mask from the first """Selects elements in x according to the given mask from the first
dimension. dimension.
...@@ -1089,6 +1037,26 @@ def clone(input): ...@@ -1089,6 +1037,26 @@ def clone(input):
""" """
pass pass
def clamp(data, min_val, max_val):
"""Clamp all elements in :attr:`input` into the range [min_val, max_val]
and return a resulting tensor.
Parameters
----------
data : Tensor
Input tensor
min_val : Scalar
Min value.
max_val : Scalar
Max value.
Returns
-------
Tensor
The result.
"""
pass
############################################################################### ###############################################################################
# Tensor functions used *only* on index tensor # Tensor functions used *only* on index tensor
# ---------------- # ----------------
......
...@@ -176,10 +176,17 @@ class GSpMM(mx.autograd.Function): ...@@ -176,10 +176,17 @@ class GSpMM(mx.autograd.Function):
def gspmm(gidx, op, reduce_op, lhs_data, rhs_data): def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
func = GSpMM(gidx, op, reduce_op) func = GSpMM(gidx, op, reduce_op)
ctx = to_backend_ctx(gidx.ctx) ctx = to_backend_ctx(gidx.ctx)
# XXX(minjie): There is a bug in MXNet's autograd system when one of the inputs
# does not require gradient. Although it still invokes the backward function,
# it does not set the gradient value to the correct buffer, resulting all the
# input gradients to be zero. Fix this by enforcing all the inputs to require
# gradients.
if lhs_data is None: if lhs_data is None:
lhs_data = nd.zeros((1,), ctx=ctx) lhs_data = nd.zeros((1,), ctx=ctx)
lhs_data.attach_grad()
if rhs_data is None: if rhs_data is None:
rhs_data = nd.zeros((1,), ctx=ctx) rhs_data = nd.zeros((1,), ctx=ctx)
rhs_data.attach_grad()
return func(lhs_data, rhs_data) return func(lhs_data, rhs_data)
......
...@@ -304,35 +304,6 @@ def pack_padded_tensor(input, lengths): ...@@ -304,35 +304,6 @@ def pack_padded_tensor(input, lengths):
index = nd.array(index, ctx=ctx) index = nd.array(index, ctx=ctx)
return gather_row(input.reshape(batch_size * max_len, -1), index) return gather_row(input.reshape(batch_size * max_len, -1), index)
def unsorted_1d_segment_sum(input, seg_id, n_segs, dim):
# TODO: support other dimensions
assert dim == 0, 'MXNet only supports segment sum on first dimension'
# Use SPMV to simulate segment sum
ctx = input.context
n_inputs = input.shape[0]
input_shape_suffix = input.shape[1:]
input = input.reshape(n_inputs, -1)
n_range = nd.arange(n_inputs, dtype='int64').as_in_context(input.context)
w_nnz = nd.ones(n_inputs).as_in_context(input.context)
w_nid = nd.stack(seg_id, n_range, axis=0)
w = nd.sparse.csr_matrix((w_nnz, (seg_id, n_range)), (n_segs, n_inputs))
w = w.as_in_context(input.context)
y = nd.dot(w, input)
y = nd.reshape(y, (n_segs,) + input_shape_suffix)
return y
def unsorted_1d_segment_mean(input, seg_id, n_segs, dim):
# TODO: support other dimensions
assert dim == 0, 'MXNet only supports segment mean on first dimension'
n_ones = nd.ones_like(seg_id).astype(input.dtype)
w = unsorted_1d_segment_sum(n_ones, seg_id, n_segs, 0)
w = nd.clip(w, a_min=1, a_max=np.inf)
y = unsorted_1d_segment_sum(input, seg_id, n_segs, dim)
y = y / w.reshape((-1,) + (1,) * (y.ndim - 1))
return y
def boolean_mask(input, mask): def boolean_mask(input, mask):
return mx.contrib.nd.boolean_mask(input, mask) return mx.contrib.nd.boolean_mask(input, mask)
...@@ -348,6 +319,9 @@ def logical_and(input1, input2): ...@@ -348,6 +319,9 @@ def logical_and(input1, input2):
def clone(input): def clone(input):
return input.copy() return input.copy()
def clamp(data, min_val, max_val):
return nd.clip(data, min_val, max_val)
def unique(input): def unique(input):
# TODO: fallback to numpy is unfortunate # TODO: fallback to numpy is unfortunate
tmp = input.asnumpy() tmp = input.asnumpy()
......
...@@ -106,7 +106,6 @@ class GSpMM(th.autograd.Function): ...@@ -106,7 +106,6 @@ class GSpMM(th.autograd.Function):
else: # max/min else: # max/min
dY = th.zeros((Y.shape[0],) + dZ.shape[1:], dY = th.zeros((Y.shape[0],) + dZ.shape[1:],
dtype=Y.dtype, device=Y.device) dtype=Y.dtype, device=Y.device)
print(X.shape, dZ.shape)
if op in ['mul', 'div']: if op in ['mul', 'div']:
grad = _expand(X, dZ.shape[1:]).gather( grad = _expand(X, dZ.shape[1:]).gather(
0, argX.long()) * dZ 0, argX.long()) * dZ
......
...@@ -245,19 +245,6 @@ def pack_padded_tensor(input, lengths): ...@@ -245,19 +245,6 @@ def pack_padded_tensor(input, lengths):
index = th.tensor(index).to(device) index = th.tensor(index).to(device)
return gather_row(input.view(batch_size * max_len, -1), index) return gather_row(input.view(batch_size * max_len, -1), index)
def unsorted_1d_segment_sum(input, seg_id, n_segs, dim):
y = th.zeros(n_segs, *input.shape[1:]).to(input)
seg_id = seg_id.view((-1,) + (1,) * (input.dim() - 1)).expand_as(input)
y = y.scatter_add_(dim, seg_id, input)
return y
def unsorted_1d_segment_mean(input, seg_id, n_segs, dim):
w = unsorted_1d_segment_sum(th.ones_like(seg_id), seg_id, n_segs, 0).to(input)
w = w.clamp(min=1) # remove 0 entries
y = unsorted_1d_segment_sum(input, seg_id, n_segs, dim)
y = y / w.view((-1,) + (1,) * (y.dim() - 1))
return y
def boolean_mask(input, mask): def boolean_mask(input, mask):
if 'bool' not in str(mask.dtype): if 'bool' not in str(mask.dtype):
mask = th.tensor(mask, dtype=th.bool) mask = th.tensor(mask, dtype=th.bool)
...@@ -275,6 +262,9 @@ def logical_and(input1, input2): ...@@ -275,6 +262,9 @@ def logical_and(input1, input2):
def clone(input): def clone(input):
return input.clone() return input.clone()
def clamp(data, min_val, max_val):
return th.clamp(data, min_val, max_val)
def unique(input): def unique(input):
if input.dtype == th.bool: if input.dtype == th.bool:
input = input.type(th.int8) input = input.type(th.int8)
......
...@@ -283,7 +283,11 @@ def narrow_row(x, start, stop): ...@@ -283,7 +283,11 @@ def narrow_row(x, start, stop):
def scatter_row(data, row_index, value): def scatter_row(data, row_index, value):
row_index = tf.expand_dims(row_index, 1) row_index = tf.expand_dims(row_index, 1)
return tf.tensor_scatter_nd_update(data, row_index, value) # XXX(minjie): Normally, the copy_to here is unnecessary. However, TF has this
# notorious legacy issue that int32 type data is always on CPU, which will
# crash the program since DGL requires feature data to be on the same device
# as graph structure.
return copy_to(tf.tensor_scatter_nd_update(data, row_index, value), data.device)
def index_add_inplace(data, row_idx, value): def index_add_inplace(data, row_idx, value):
...@@ -366,18 +370,6 @@ def pack_padded_tensor(input, lengths): ...@@ -366,18 +370,6 @@ def pack_padded_tensor(input, lengths):
return tf.concat(out_list, axis=0) return tf.concat(out_list, axis=0)
def unsorted_1d_segment_sum(input, seg_id, n_segs, dim):
assert dim == 0 # Why we need dim for 1d?
return tf.math.unsorted_segment_sum(input, seg_id, n_segs)
def unsorted_1d_segment_mean(input, seg_id, n_segs, dim):
assert dim == 0 # Why we need dim for 1d?
return tf.math.unsorted_segment_mean(input, seg_id, n_segs)
# TODO: TF has unsorted_segment_max, which can accelerate _max_on on batched graph
def boolean_mask(input, mask): def boolean_mask(input, mask):
return tf.boolean_mask(input, mask) return tf.boolean_mask(input, mask)
...@@ -396,6 +388,9 @@ def clone(input): ...@@ -396,6 +388,9 @@ def clone(input):
# TF tensor is always immutable so returning the input is safe. # TF tensor is always immutable so returning the input is safe.
return input return input
def clamp(data, min_val, max_val):
return tf.clip_by_value(data, min_val, max_val)
def unique(input): def unique(input):
return tf.unique(input).y return tf.unique(input).y
......
"""Implementation for core graph computation."""
# pylint: disable=not-callable
from .base import DGLError, is_all, NID, EID, ALL
from . import backend as F
from . import function as fn
from .frame import Frame
from .udf import NodeBatch, EdgeBatch
from . import ops
def is_builtin(func):
"""Return true if the function is a DGL builtin function."""
return isinstance(func, fn.BuiltinFunction)
def invoke_node_udf(graph, nid, ntype, func, *, ndata=None, orig_nid=None):
"""Invoke user-defined node function on the given nodes.
Parameters
----------
graph : DGLGraph
The input graph.
eid : Tensor
The IDs of the nodes to invoke UDF on.
ntype : str
Node type.
func : callable
The user-defined function.
ndata : dict[str, Tensor], optional
If provided, apply the UDF on this ndata instead of the ndata of the graph.
orig_nid : Tensor, optional
Original node IDs. Useful if the input graph is an extracted subgraph.
Returns
-------
dict[str, Tensor]
Results from running the UDF.
"""
ntid = graph.get_ntype_id(ntype)
if ndata is None:
if is_all(nid):
ndata = graph._node_frames[ntid]
nid = graph.nodes(ntype=ntype)
else:
ndata = graph._node_frames[ntid].subframe(nid)
nbatch = NodeBatch(graph, nid if orig_nid is None else orig_nid, ntype, ndata)
return func(nbatch)
def invoke_edge_udf(graph, eid, etype, func, *, orig_eid=None):
"""Invoke user-defined edge function on the given edges.
Parameters
----------
graph : DGLGraph
The input graph.
eid : Tensor
The IDs of the edges to invoke UDF on.
etype : (str, str, str)
Edge type.
func : callable
The user-defined function.
orig_eid : Tensor, optional
Original edge IDs. Useful if the input graph is an extracted subgraph.
Returns
-------
dict[str, Tensor]
Results from running the UDF.
"""
etid = graph.get_etype_id(etype)
stid, dtid = graph._graph.metagraph.find_edge(etid)
if is_all(eid):
u, v, eid = graph.edges(form='all')
edata = graph._edge_frames[etid]
else:
u, v = graph.find_edges(eid)
edata = graph._edge_frames[etid].subframe(eid)
srcdata = graph._node_frames[stid].subframe(u)
dstdata = graph._node_frames[dtid].subframe(v)
ebatch = EdgeBatch(graph, eid if orig_eid is None else orig_eid,
etype, srcdata, edata, dstdata)
return func(ebatch)
def invoke_udf_reduce(graph, func, msgdata, *, orig_nid=None):
"""Invoke user-defined reduce function on all the nodes in the graph.
It analyzes the graph, groups nodes by their degrees and applies the UDF on each
group -- a strategy called *degree-bucketing*.
Parameters
----------
graph : DGLGraph
The input graph.
func : callable
The user-defined function.
msgdata : dict[str, Tensor]
Message data.
orig_nid : Tensor, optional
Original node IDs. Useful if the input graph is an extracted subgraph.
Returns
-------
dict[str, Tensor]
Results from running the UDF.
"""
degs = graph.in_degrees()
nodes = graph.dstnodes()
if orig_nid is None:
orig_nid = nodes
ntype = graph.dsttypes[0]
ntid = graph.get_ntype_id_from_dst(ntype)
dstdata = graph._node_frames[ntid]
msgdata = Frame(msgdata)
# degree bucketing
unique_degs, bucketor = _bucketing(degs)
bkt_rsts = []
bkt_nodes = []
for deg, node_bkt, orig_nid_bkt in zip(unique_degs, bucketor(nodes), bucketor(orig_nid)):
if deg == 0:
# skip reduce function for zero-degree nodes
continue
bkt_nodes.append(node_bkt)
ndata_bkt = dstdata.subframe(node_bkt)
eid_bkt = graph.in_edges(node_bkt, form='eid')
assert len(eid_bkt) == deg * len(node_bkt)
msgdata_bkt = msgdata.subframe(eid_bkt)
# reshape all msg tensors to (num_nodes_bkt, degree, feat_size)
maildata = {}
for k, msg in msgdata_bkt.items():
newshape = (len(node_bkt), deg) + F.shape(msg)[1:]
maildata[k] = F.reshape(msg, newshape)
# invoke udf
nbatch = NodeBatch(graph, orig_nid_bkt, ntype, ndata_bkt, msgs=maildata)
bkt_rsts.append(func(nbatch))
# prepare a result frame
retf = Frame(num_rows=len(nodes))
retf._initializers = dstdata._initializers
retf._default_initializer = dstdata._default_initializer
# merge bucket results and write to the result frame
if len(bkt_rsts) != 0: # if all the nodes have zero degree, no need to merge results.
merged_rst = {}
for k in bkt_rsts[0].keys():
merged_rst[k] = F.cat([rst[k] for rst in bkt_rsts], dim=0)
merged_nodes = F.cat(bkt_nodes, dim=0)
retf.update_row(merged_nodes, merged_rst)
return retf
def _bucketing(val):
"""Internal function to create groups on the values.
Parameters
----------
val : Tensor
Value tensor.
Returns
-------
unique_val : Tensor
Unique values.
bucketor : callable[Tensor -> list[Tensor]]
A bucketing function that splits the given tensor data as the same
way of how the :attr:`val` tensor is grouped.
"""
sorted_val, idx = F.sort_1d(val)
unique_val = F.asnumpy(F.unique(sorted_val))
bkt_idx = []
for v in unique_val:
eqidx = F.nonzero_1d(F.equal(sorted_val, v))
bkt_idx.append(F.gather_row(idx, eqidx))
def bucketor(data):
bkts = [F.gather_row(data, idx) for idx in bkt_idx]
return bkts
return unique_val, bucketor
def invoke_gsddmm(graph, func):
"""Invoke g-SDDMM computation on the graph.
Parameters
----------
graph : DGLGraph
The input graph.
func : dgl.function.BaseMessageFunction
Built-in message function.
Returns
-------
dict[str, Tensor]
Results from the g-SDDMM computation.
"""
alldata = [graph.srcdata, graph.dstdata, graph.edata]
if isinstance(func, fn.BinaryMessageFunction):
x = alldata[func.lhs][func.lhs_field]
y = alldata[func.rhs][func.rhs_field]
op = getattr(ops, func.name)
z = op(graph, x, y)
else:
x = alldata[func.target][func.in_field]
op = getattr(ops, func.name)
z = op(graph, x)
return {func.out_field : z}
def invoke_gspmm(graph, mfunc, rfunc, *, srcdata=None, dstdata=None, edata=None):
"""Invoke g-SPMM computation on the graph.
Parameters
----------
graph : DGLGraph
The input graph.
mfunc : dgl.function.BaseMessageFunction
Built-in message function.
rfunc : dgl.function.BaseReduceFunction
Built-in reduce function.
srcdata : dict[str, Tensor], optional
Source node feature data. If not provided, it use ``graph.srcdata``.
dstdata : dict[str, Tensor], optional
Destination node feature data. If not provided, it use ``graph.dstdata``.
edata : dict[str, Tensor], optional
Edge feature data. If not provided, it use ``graph.edata``.
Returns
-------
dict[str, Tensor]
Results from the g-SPMM computation.
"""
# sanity check
if mfunc.out_field != rfunc.msg_field:
raise DGLError('Invalid message ({}) and reduce ({}) function pairs.'
' The output field of the message function must be equal to the'
' message field of the reduce function.'.format(mfunc, rfunc))
if edata is None:
edata = graph.edata
if srcdata is None:
srcdata = graph.srcdata
if dstdata is None:
dstdata = graph.dstdata
alldata = [srcdata, dstdata, edata]
if isinstance(mfunc, fn.BinaryMessageFunction):
x = alldata[mfunc.lhs][mfunc.lhs_field]
y = alldata[mfunc.rhs][mfunc.rhs_field]
op = getattr(ops, '{}_{}'.format(mfunc.name, rfunc.name))
z = op(graph, x, y)
else:
x = alldata[mfunc.target][mfunc.in_field]
op = getattr(ops, '{}_{}'.format(mfunc.name, rfunc.name))
z = op(graph, x)
return {rfunc.out_field : z}
def message_passing(g, mfunc, rfunc, afunc):
"""Invoke message passing computation on the whole graph.
Parameters
----------
g : DGLGraph
The input graph.
mfunc : callable or dgl.function.BuiltinFunction
Message function.
rfunc : callable or dgl.function.BuiltinFunction
Reduce function.
afunc : callable or dgl.function.BuiltinFunction
Apply function.
Returns
-------
dict[str, Tensor]
Results from the message passing computation.
"""
if g.number_of_edges() == 0:
# No message passing is triggered.
ndata = {}
elif (is_builtin(mfunc) and is_builtin(rfunc) and
getattr(ops, '{}_{}'.format(mfunc.name, rfunc.name), None) is not None):
# invoke fused message passing
ndata = invoke_gspmm(g, mfunc, rfunc)
else:
# invoke message passing in two separate steps
# message phase
if is_builtin(mfunc):
msgdata = invoke_gsddmm(g, mfunc)
else:
orig_eid = g.edata.get(EID, None)
msgdata = invoke_edge_udf(g, ALL, g.canonical_etypes[0], mfunc, orig_eid=orig_eid)
# reduce phase
if is_builtin(rfunc):
msg = rfunc.msg_field
ndata = invoke_gspmm(g, fn.copy_e(msg, msg), rfunc, edata=msgdata)
else:
orig_nid = g.dstdata.get(NID, None)
ndata = invoke_udf_reduce(g, rfunc, msgdata, orig_nid=orig_nid)
# apply phase
if afunc is not None:
for k, v in g.dstdata.items(): # include original node features
if k not in ndata:
ndata[k] = v
orig_nid = g.dstdata.get(NID, None)
ndata = invoke_node_udf(g, ALL, g.dsttypes[0], afunc, ndata=ndata, orig_nid=orig_nid)
return ndata
"""For HeteroGraph Serialization""" """For HeteroGraph Serialization"""
from __future__ import absolute_import from __future__ import absolute_import
from ..heterograph import DGLHeteroGraph from ..heterograph import DGLHeteroGraph
from ..frame import Frame, FrameRef from ..frame import Frame
from .._ffi.object import ObjectBase, register_object from .._ffi.object import ObjectBase, register_object
from .._ffi.function import _init_api from .._ffi.function import _init_api
from .. import backend as F from .. import backend as F
...@@ -51,10 +51,10 @@ class HeteroGraphData(ObjectBase): ...@@ -51,10 +51,10 @@ class HeteroGraphData(ObjectBase):
eframes = [] eframes = []
for ntid, ntensor in enumerate(ntensor_list): for ntid, ntensor in enumerate(ntensor_list):
ndict = {ntensor[i]: F.zerocopy_from_dgl_ndarray(ntensor[i+1]) for i in range(0, len(ntensor), 2)} ndict = {ntensor[i]: F.zerocopy_from_dgl_ndarray(ntensor[i+1]) for i in range(0, len(ntensor), 2)}
nframes.append(FrameRef(Frame(ndict, num_rows=gidx.number_of_nodes(ntid)))) nframes.append(Frame(ndict, num_rows=gidx.number_of_nodes(ntid)))
for etid, etensor in enumerate(etensor_list): for etid, etensor in enumerate(etensor_list):
edict = {etensor[i]: F.zerocopy_from_dgl_ndarray(etensor[i+1]) for i in range(0, len(etensor), 2)} edict = {etensor[i]: F.zerocopy_from_dgl_ndarray(etensor[i+1]) for i in range(0, len(etensor), 2)}
eframes.append(FrameRef(Frame(edict, num_rows=gidx.number_of_edges(etid)))) eframes.append(Frame(edict, num_rows=gidx.number_of_edges(etid)))
return DGLHeteroGraph(gidx, ntype_names, etype_names, nframes, eframes) return DGLHeteroGraph(gidx, ntype_names, etype_names, nframes, eframes)
...@@ -3,7 +3,7 @@ from collections import namedtuple ...@@ -3,7 +3,7 @@ from collections import namedtuple
from .rpc import Request, Response, send_requests_to_machine, recv_responses from .rpc import Request, Response, send_requests_to_machine, recv_responses
from ..sampling import sample_neighbors as local_sample_neighbors from ..sampling import sample_neighbors as local_sample_neighbors
from ..transform import in_subgraph as local_in_subgraph from ..subgraph import in_subgraph as local_in_subgraph
from .rpc import register_service from .rpc import register_service
from ..convert import graph from ..convert import graph
from ..base import NID, EID from ..base import NID, EID
......
...@@ -4,12 +4,9 @@ from __future__ import absolute_import ...@@ -4,12 +4,9 @@ from __future__ import absolute_import
from collections import namedtuple from collections import namedtuple
from collections.abc import MutableMapping from collections.abc import MutableMapping
import numpy as np
from . import backend as F from . import backend as F
from .base import DGLError, dgl_warning from .base import DGLError, dgl_warning
from .init import zero_initializer from .init import zero_initializer
from . import utils
class Scheme(namedtuple('Scheme', ['shape', 'dtype'])): class Scheme(namedtuple('Scheme', ['shape', 'dtype'])):
"""The column scheme. """The column scheme.
...@@ -53,108 +50,125 @@ def infer_scheme(tensor): ...@@ -53,108 +50,125 @@ def infer_scheme(tensor):
class Column(object): class Column(object):
"""A column is a compact store of features of multiple nodes/edges. """A column is a compact store of features of multiple nodes/edges.
Currently, we use one dense tensor to batch all the feature tensors It batches all the feature tensors together along the first dimension
together (along the first dimension). as one dense tensor.
The column can optionally have an index tensor I.
In this case, the i^th feature is stored in ``storage[index[i]]``.
The column class implements a Copy-On-Read semantics -- the index
select operation happens upon the first read of the feature data.
This is useful when one extracts a subset of the feature data
but wishes the actual index select happens on-demand.
Parameters Parameters
---------- ----------
data : Tensor storage : Tensor
The initial data of the column. The feature data storage.
scheme : Scheme, optional scheme : Scheme, optional
The scheme of the column. Will be inferred if not provided. The scheme of the column. Will be inferred if not provided.
index : Tensor, optional
The row index to the feature data storage. None means an
identity mapping.
Attributes Attributes
---------- ----------
storage : Tensor
The storage tensor. The storage tensor may not be the actual data
tensor of this column when the index tensor is not None.
This typically happens when the column is extracted from another
column using the `subcolumn` method.
data : Tensor data : Tensor
The data of the column. The actual data tensor of this column.
scheme : Scheme scheme : Scheme
The scheme of the column. The scheme of the column.
index : Tensor
Index tensor
""" """
def __init__(self, data, scheme=None): def __init__(self, storage, scheme=None, index=None):
self.data = data self.storage = storage
self.scheme = scheme if scheme else infer_scheme(data) self.scheme = scheme if scheme else infer_scheme(storage)
self.index = index
def __len__(self): def __len__(self):
"""The column length.""" """The number of features (number of rows) in this column."""
return F.shape(self.data)[0] if self.index is None:
return F.shape(self.storage)[0]
else:
return len(self.index)
@property @property
def shape(self): def shape(self):
"""Return the scheme shape (feature shape) of this column.""" """Return the scheme shape (feature shape) of this column."""
return self.scheme.shape return self.scheme.shape
def __getitem__(self, idx): @property
"""Return the feature data given the index. def data(self):
"""Return the feature data. Perform index selecting if needed."""
if self.index is not None:
self.storage = F.gather_row(self.storage, self.index)
self.index = None
return self.storage
@data.setter
def data(self, val):
"""Update the column data."""
self.index = None
self.storage = val
def __getitem__(self, rowids):
"""Return the feature data given the rowids.
The operation triggers index selection.
Parameters Parameters
---------- ----------
idx : utils.Index rowids : Tensor
The index. Row ID tensor.
Returns Returns
------- -------
Tensor Tensor
The feature data The feature data
""" """
if idx.slice_data() is not None: return F.gather_row(self.data, rwoids)
slc = idx.slice_data()
return F.narrow_row(self.data, slc.start, slc.stop)
else:
user_idx = idx.tousertensor(F.context(self.data))
return F.gather_row(self.data, user_idx)
def __setitem__(self, idx, feats): def __setitem__(self, rowids, feats):
"""Update the feature data given the index. """Update the feature data given the index.
The update is performed out-placely so it can be used in autograd mode. The update is performed out-placely so it can be used in autograd mode.
For inplace write, please use ``update``. The operation triggers index selection.
Parameters Parameters
---------- ----------
idx : utils.Index or slice rowids : Tensor
The index. Row IDs.
feats : Tensor feats : Tensor
The new features. New features.
""" """
self.update(idx, feats, inplace=False) self.update(idx, feats)
def update(self, idx, feats, inplace): def update(self, rowids, feats):
"""Update the feature data given the index. """Update the feature data given the index.
Parameters Parameters
---------- ----------
idx : utils.Index rowids : Tensor
The index. Row IDs.
feats : Tensor feats : Tensor
The new features. New features.
inplace : bool
If true, use inplace write.
""" """
feat_scheme = infer_scheme(feats) feat_scheme = infer_scheme(feats)
if feat_scheme != self.scheme: if feat_scheme != self.scheme:
raise DGLError("Cannot update column of scheme %s using feature of scheme %s." raise DGLError("Cannot update column of scheme %s using feature of scheme %s."
% (feat_scheme, self.scheme)) % (feat_scheme, self.scheme))
self.data = F.scatter_row(self.data, rowids, feats)
if inplace:
idx = idx.tousertensor(F.context(self.data))
F.scatter_row_inplace(self.data, idx, feats)
elif idx.slice_data() is not None:
# for contiguous indices narrow+concat is usually faster than scatter row
slc = idx.slice_data()
parts = [feats]
if slc.start > 0:
parts.insert(0, F.narrow_row(self.data, 0, slc.start))
if slc.stop < len(self):
parts.append(F.narrow_row(self.data, slc.stop, len(self)))
self.data = F.cat(parts, dim=0)
else:
idx = idx.tousertensor(F.context(self.data))
self.data = F.scatter_row(self.data, idx, feats)
def extend(self, feats, feat_scheme=None): def extend(self, feats, feat_scheme=None):
"""Extend the feature data. """Extend the feature data.
Parameters The operation triggers index selection.
Parameters
---------- ----------
feats : Tensor feats : Tensor
The new features. The new features.
...@@ -168,18 +182,47 @@ class Column(object): ...@@ -168,18 +182,47 @@ class Column(object):
raise DGLError("Cannot update column of scheme %s using feature of scheme %s." raise DGLError("Cannot update column of scheme %s using feature of scheme %s."
% (feat_scheme, self.scheme)) % (feat_scheme, self.scheme))
feats = F.copy_to(feats, F.context(self.data))
self.data = F.cat([self.data, feats], dim=0) self.data = F.cat([self.data, feats], dim=0)
def clone(self): def clone(self):
"""Return a deepcopy of this column.""" """Return a shallow copy of this column."""
return Column(self.storage, self.scheme, self.index)
def deepclone(self):
"""Return a deepcopy of this column.
The operation triggers index selection.
"""
return Column(F.clone(self.data), self.scheme) return Column(F.clone(self.data), self.scheme)
def subcolumn(self, rowids):
"""Return a subcolumn.
The resulting column will share the same storage as this column so this operation
is quite efficient. If the current column is also a sub-column (i.e., the
index tensor is not None), it slices the index tensor with the given
rowids as the index tensor of the resulting column.
Parameters
----------
rowids : Tensor
Row IDs.
Returns
-------
Column
Sub-column
"""
if self.index is None:
return Column(self.storage, self.scheme, rowids)
else:
return Column(self.storage, self.scheme, F.gather_row(self.index, rowids))
@staticmethod @staticmethod
def create(data): def create(data):
"""Create a new column using the given data.""" """Create a new column using the given data."""
if isinstance(data, Column): if isinstance(data, Column):
return Column(data.data, data.scheme) return data.clone()
else: else:
return Column(data) return Column(data)
...@@ -189,7 +232,7 @@ class Column(object): ...@@ -189,7 +232,7 @@ class Column(object):
class Frame(MutableMapping): class Frame(MutableMapping):
"""The columnar storage for node/edge features. """The columnar storage for node/edge features.
The frame is a dictionary from feature fields to feature columns. The frame is a dictionary from feature names to feature columns.
All columns should have the same number of rows (i.e. the same first dimension). All columns should have the same number of rows (i.e. the same first dimension).
Parameters Parameters
...@@ -197,36 +240,33 @@ class Frame(MutableMapping): ...@@ -197,36 +240,33 @@ class Frame(MutableMapping):
data : dict-like, optional data : dict-like, optional
The frame data in dictionary. If the provided data is another frame, The frame data in dictionary. If the provided data is another frame,
this frame will NOT share columns with the given frame. So any out-place this frame will NOT share columns with the given frame. So any out-place
update on one will not reflect to the other. The inplace update will update on one will not reflect to the other.
be seen by both. This follows the semantic of python's container. num_rows : int, optional
num_rows : int, optional [default=0]
The number of rows in this frame. If ``data`` is provided and is not empty, The number of rows in this frame. If ``data`` is provided and is not empty,
``num_rows`` will be ignored and inferred from the given data. ``num_rows`` will be ignored and inferred from the given data.
""" """
def __init__(self, data=None, num_rows=0): def __init__(self, data=None, num_rows=None):
if data is None: if data is None:
self._columns = dict() self._columns = dict()
self._num_rows = num_rows self._num_rows = 0 if num_rows is None else num_rows
else: else:
assert not isinstance(data, Frame) # sanity check for code refactor
# Note that we always create a new column for the given data. # Note that we always create a new column for the given data.
# This avoids two frames accidentally sharing the same column. # This avoids two frames accidentally sharing the same column.
self._columns = {k : Column.create(v) for k, v in data.items()} self._columns = {k : Column.create(v) for k, v in data.items()}
if isinstance(data, (Frame, FrameRef)): self._num_rows = num_rows
self._num_rows = data.num_rows # infer num_rows & sanity check
elif len(self._columns) != 0:
self._num_rows = len(next(iter(self._columns.values())))
else:
self._num_rows = num_rows
# sanity check
for name, col in self._columns.items(): for name, col in self._columns.items():
if len(col) != self._num_rows: if self._num_rows is None:
self._num_rows = len(col)
elif len(col) != self._num_rows:
raise DGLError('Expected all columns to have same # rows (%d), ' raise DGLError('Expected all columns to have same # rows (%d), '
'got %d on %r.' % (self._num_rows, len(col), name)) 'got %d on %r.' % (self._num_rows, len(col), name))
# Initializer for empty values. Initializer is a callable. # Initializer for empty values. Initializer is a callable.
# If is none, then a warning will be raised # If is none, then a warning will be raised
# in the first call and zero initializer will be used later. # in the first call and zero initializer will be used later.
self._initializers = {} # per-column initializers self._initializers = {} # per-column initializers
self._remote_init_builder = None
self._default_initializer = None self._default_initializer = None
def _set_zero_default_initializer(self): def _set_zero_default_initializer(self):
...@@ -266,39 +306,6 @@ class Frame(MutableMapping): ...@@ -266,39 +306,6 @@ class Frame(MutableMapping):
else: else:
self._initializers[column] = initializer self._initializers[column] = initializer
def set_remote_init_builder(self, builder):
"""Set an initializer builder to create a remote initializer for a new column to a frame.
NOTE(minjie): This is a temporary solution. Will be replaced by KVStore in the future.
The builder is a callable that returns an initializer. The returned initializer
is also a callable that returns a tensor given a local tensor and tensor name.
Parameters
----------
builder : callable
The builder to construct a remote initializer.
"""
self._remote_init_builder = builder
def get_remote_initializer(self, name):
"""Get a remote initializer.
NOTE(minjie): This is a temporary solution. Will be replaced by KVStore in the future.
Parameters
----------
name : string
The column name.
"""
if self._remote_init_builder is None:
return None
if self.get_initializer(name) is None:
self._set_zero_default_initializer()
initializer = self.get_initializer(name)
return self._remote_init_builder(initializer, name)
@property @property
def schemes(self): def schemes(self):
"""Return a dictionary of column name to column schemes.""" """Return a dictionary of column name to column schemes."""
...@@ -328,10 +335,10 @@ class Frame(MutableMapping): ...@@ -328,10 +335,10 @@ class Frame(MutableMapping):
Returns Returns
------- -------
Column Tensor
The column. Column data.
""" """
return self._columns[name] return self._columns[name].data
def __setitem__(self, name, data): def __setitem__(self, name, data):
"""Update the whole column. """Update the whole column.
...@@ -373,17 +380,11 @@ class Frame(MutableMapping): ...@@ -373,17 +380,11 @@ class Frame(MutableMapping):
dgl_warning('Column "%s" already exists. Ignore adding this column again.' % name) dgl_warning('Column "%s" already exists. Ignore adding this column again.' % name)
return return
# If the data is backed by a remote server, we need to move data if self.get_initializer(name) is None:
# to the remote server. self._set_zero_default_initializer()
initializer = self.get_remote_initializer(name) initializer = self.get_initializer(name)
if initializer is not None: init_data = initializer((self.num_rows,) + scheme.shape, scheme.dtype,
init_data = initializer((self.num_rows,) + scheme.shape, scheme.dtype, ctx) ctx, slice(0, self.num_rows))
else:
if self.get_initializer(name) is None:
self._set_zero_default_initializer()
initializer = self.get_initializer(name)
init_data = initializer((self.num_rows,) + scheme.shape, scheme.dtype,
ctx, slice(0, self.num_rows))
self._columns[name] = Column(init_data, scheme) self._columns[name] = Column(init_data, scheme)
def add_rows(self, num_rows): def add_rows(self, num_rows):
...@@ -420,22 +421,38 @@ class Frame(MutableMapping): ...@@ -420,22 +421,38 @@ class Frame(MutableMapping):
data : Column or data convertible to Column data : Column or data convertible to Column
The column data. The column data.
""" """
# If the data is backed by a remote server, we need to move data
# to the remote server.
initializer = self.get_remote_initializer(name)
if initializer is not None:
new_data = initializer(F.shape(data), F.dtype(data), F.context(data))
new_data[:] = data
data = new_data
col = Column.create(data) col = Column.create(data)
if len(col) != self.num_rows: if len(col) != self.num_rows:
raise DGLError('Expected data to have %d rows, got %d.' % raise DGLError('Expected data to have %d rows, got %d.' %
(self.num_rows, len(col))) (self.num_rows, len(col)))
self._columns[name] = col self._columns[name] = col
def update_row(self, rowids, data):
"""Update the feature data of the given rows.
If the data contains new keys (new columns) that do not exist in
this frame, add a new column.
The ``rowids`` shall not contain duplicates. Otherwise, the behavior
is undefined.
Parameters
----------
rowids : Tensor
Row Ids.
data : dict[str, Tensor]
Row data.
"""
for key, val in data.items():
if key not in self:
scheme = infer_scheme(val)
ctx = F.context(val)
self.add_column(key, scheme, ctx)
for key, val in data.items():
self._columns[key].update(rowids, val)
def _append(self, other): def _append(self, other):
assert self._remote_init_builder is None, \ """Append ``other`` frame to ``self`` frame."""
"We don't support append if data in the frame is mapped from a remote server."
# NOTE: `other` can be empty. # NOTE: `other` can be empty.
if self.num_rows == 0: if self.num_rows == 0:
# if no rows in current frame; append is equivalent to # if no rows in current frame; append is equivalent to
...@@ -443,7 +460,7 @@ class Frame(MutableMapping): ...@@ -443,7 +460,7 @@ class Frame(MutableMapping):
self._columns = {key: Column.create(data) for key, data in other.items()} self._columns = {key: Column.create(data) for key, data in other.items()}
else: else:
# pad columns that are not provided in the other frame with initial values # pad columns that are not provided in the other frame with initial values
for key, col in self.items(): for key, col in self._columns.items():
if key in other: if key in other:
continue continue
scheme = col.scheme scheme = col.scheme
...@@ -456,7 +473,7 @@ class Frame(MutableMapping): ...@@ -456,7 +473,7 @@ class Frame(MutableMapping):
slice(self._num_rows, self._num_rows + other.num_rows)) slice(self._num_rows, self._num_rows + other.num_rows))
other[key] = new_data other[key] = new_data
# append other to self # append other to self
for key, col in other.items(): for key, col in other._columns.items():
if key not in self._columns: if key not in self._columns:
# the column does not exist; init a new column # the column does not exist; init a new column
self.add_column(key, col.scheme, F.context(col.data)) self.add_column(key, col.scheme, F.context(col.data))
...@@ -517,7 +534,6 @@ class Frame(MutableMapping): ...@@ -517,7 +534,6 @@ class Frame(MutableMapping):
""" """
newframe = Frame(self._columns, self._num_rows) newframe = Frame(self._columns, self._num_rows)
newframe._initializers = self._initializers newframe._initializers = self._initializers
newframe._remote_init_builder = self._remote_init_builder
newframe._default_initializer = self._default_initializer newframe._default_initializer = self._default_initializer
return newframe return newframe
...@@ -534,485 +550,33 @@ class Frame(MutableMapping): ...@@ -534,485 +550,33 @@ class Frame(MutableMapping):
Frame Frame
A deep-cloned frame. A deep-cloned frame.
""" """
newframe = Frame({k : col.clone() for k, col in self._columns.items()}, self._num_rows) newframe = Frame({k : col.deepclone() for k, col in self._columns.items()},
self._num_rows)
newframe._initializers = self._initializers newframe._initializers = self._initializers
newframe._remote_init_builder = self._remote_init_builder
newframe._default_initializer = self._default_initializer newframe._default_initializer = self._default_initializer
return newframe return newframe
class FrameRef(MutableMapping): def subframe(self, rowids):
"""Reference object to a frame on a subset of rows. """Return a new frame whose columns are subcolumns of this frame.
Parameters The given row IDs should be within range [0, self.num_rows), and allow
---------- duplicate IDs.
frame : Frame, optional
The underlying frame. If not given, the reference will point to a
new empty frame.
index : utils.Index, optional
The rows that are referenced in the underlying frame. If not given,
the whole frame is referenced. The index should be distinct (no
duplication is allowed).
"""
def __init__(self, frame=None, index=None):
self._frame = frame if frame is not None else Frame()
# TODO(minjie): check no duplication
assert index is None or isinstance(index, utils.Index)
if index is None:
self._index = utils.toindex(slice(0, self._frame.num_rows))
else:
self._index = index
@property
def schemes(self):
"""Return the frame schemes.
Returns
-------
dict of str to Scheme
The frame schemes.
"""
return self._frame.schemes
@property
def num_columns(self):
"""Return the number of columns in the referred frame."""
return self._frame.num_columns
@property
def num_rows(self):
"""Return the number of rows referred."""
return len(self._index)
def set_initializer(self, initializer, column=None):
"""Set the initializer for empty values.
Initializer is a callable that returns a tensor given the shape and data type.
Parameters
----------
initializer : callable
The initializer.
column : str, optional
The column name
"""
self._frame.set_initializer(initializer, column=column)
def set_remote_init_builder(self, builder):
"""Set an initializer builder to create a remote initializer for a new column to a frame.
NOTE(minjie): This is a temporary solution. Will be replaced by KVStore in the future.
The builder is a callable that returns an initializer. The returned initializer
is also a callable that returns a tensor given a local tensor and tensor name.
Parameters Parameters
---------- ----------
builder : callable rowids : Tensor
The builder to construct a remote initializer. Row IDs
"""
self._frame.set_remote_init_builder(builder)
def get_initializer(self, column=None):
"""Get the initializer for empty values for the given column.
Parameters
----------
column : str
The column
Returns Returns
------- -------
callable Frame
The initializer A new subframe.
"""
return self._frame.get_initializer(column)
def __contains__(self, name):
"""Return whether the column name exists."""
return name in self._frame
def __iter__(self):
"""Return the iterator of the columns."""
return iter(self._frame)
def __len__(self):
"""Return the number of columns."""
return self.num_columns
def keys(self):
"""Return the keys."""
return self._frame.keys()
def values(self):
"""Return the values."""
return self._frame.values()
def __getitem__(self, key):
"""Get data from the frame.
If the provided key is string, the corresponding column data will be returned.
If the provided key is an index or a slice, the corresponding rows will be selected.
The returned rows are saved in a lazy dictionary so only the real selection happens
when the explicit column name is provided.
Examples (using pytorch)
------------------------
>>> # create a frame of two columns and five rows
>>> f = Frame({'c1' : torch.zeros([5, 2]), 'c2' : torch.ones([5, 2])})
>>> fr = FrameRef(f)
>>> # select the row 1 and 2, the returned `rows` is a lazy dictionary.
>>> rows = fr[Index([1, 2])]
>>> rows['c1'] # only select rows for 'c1' column; 'c2' column is not sliced.
Parameters
----------
key : str or utils.Index
The key.
Returns
-------
Tensor or lazy dict or tensors
Depends on whether it is a column selection or row selection.
"""
if not isinstance(key, (str, utils.Index)):
raise DGLError('Argument "key" must be either str or utils.Index type.')
if isinstance(key, str):
return self.select_column(key)
elif key.is_slice(0, self.num_rows):
# shortcut for selecting all the rows
return self
else:
return self.select_rows(key)
def select_column(self, name):
"""Return the column of the given name.
If only part of the rows are referenced, the fetching the whole column will
also slice out the referenced rows.
Parameters
----------
name : str
The column name.
Returns
-------
Tensor
The column data.
"""
col = self._frame[name]
if self.is_span_whole_column():
return col.data
else:
return col[self._index]
def select_rows(self, query):
"""Return the rows given the query.
Parameters
----------
query : utils.Index or slice
The rows to be selected.
Returns
-------
utils.LazyDict
The lazy dictionary from str to the selected data.
"""
rows = self._getrows(query)
return utils.LazyDict(lambda key: self._frame[key][rows], keys=self.keys())
def __setitem__(self, key, val):
"""Update the data in the frame. The update is done out-of-place.
Parameters
----------
key : str or utils.Index
The key.
val : Tensor or dict of tensors
The value.
See Also
--------
update
"""
self.update_data(key, val, inplace=False)
def update_data(self, key, val, inplace):
"""Update the data in the frame.
If the provided key is string, the corresponding column data will be updated.
The provided value should be one tensor that have the same scheme and length
as the column.
If the provided key is an index, the corresponding rows will be updated. The
value provided should be a dictionary of string to the data of each column.
All updates are performed out-placely to be work with autograd. For inplace
update, use ``update_column`` or ``update_rows``.
Parameters
----------
key : str or utils.Index
The key.
val : Tensor or dict of tensors
The value.
inplace: bool
If True, update will be done in place
"""
if not isinstance(key, (str, utils.Index)):
raise DGLError('Argument "key" must be either str or utils.Index type.')
if isinstance(key, str):
self.update_column(key, val, inplace=inplace)
elif key.is_slice(0, self.num_rows):
# shortcut for updating all the rows
for colname, col in val.items():
self.update_column(colname, col, inplace=inplace)
else:
self.update_rows(key, val, inplace=inplace)
def update_column(self, name, data, inplace):
"""Update the column.
If this frameref spans the whole column of the underlying frame, this is
equivalent to update the column of the frame.
If this frameref only points to part of the rows, then update the column
here will correspond to update part of the column in the frame. Raise error
if the given column name does not exist.
Parameters
----------
name : str
The column name.
data : Tensor
The update data.
inplace : bool
True if the update is performed inplacely.
"""
if self.is_span_whole_column():
if self.num_columns == 0:
# the frame is empty
self._index = utils.toindex(slice(0, len(data)))
self._frame[name] = data
else:
if name not in self._frame:
ctx = F.context(data)
self._frame.add_column(name, infer_scheme(data), ctx)
fcol = self._frame[name]
fcol.update(self._index, data, inplace)
def add_rows(self, num_rows):
"""Add blank rows to the underlying frame.
For existing fields, the rows will be extended according to their
initializers.
Note: only available for FrameRef that spans the whole column. The row
span will extend to new rows. Other FrameRefs referencing the same
frame will not be affected.
Parameters
----------
num_rows : int
Number of rows to add
"""
if not self.is_span_whole_column():
raise RuntimeError('FrameRef not spanning whole column.')
self._frame.add_rows(num_rows)
if self._index.slice_data() is not None:
# the index is a slice
slc = self._index.slice_data()
self._index = utils.toindex(slice(slc.start, slc.stop + num_rows))
else:
selfidxdata = self._index.tousertensor()
newdata = F.arange(self.num_rows, self.num_rows + num_rows)
self._index = utils.toindex(F.cat([selfidxdata, newdata], dim=0))
def update_rows(self, query, data, inplace):
"""Update the rows.
If the provided data has new column, it will be added to the frame.
See Also
--------
``update_column``
Parameters
----------
query : utils.Index or slice
The rows to be updated.
data : dict-like
The row data.
inplace : bool
True if the update is performed inplace.
"""
rows = self._getrows(query)
for key, col in data.items():
if key not in self:
# add new column
tmpref = FrameRef(self._frame, rows)
tmpref.update_column(key, col, inplace)
else:
self._frame[key].update(rows, col, inplace)
def __delitem__(self, key):
"""Delete data in the frame.
If the provided key is a string, the corresponding column will be deleted.
If the provided key is an index object or a slice, the corresponding rows will
be deleted.
Please note that "deleted" rows are not really deleted, but simply removed
in the reference. As a result, if two FrameRefs point to the same Frame, deleting
from one ref will not reflect on the other. However, deleting columns is real.
Parameters
----------
key : str or utils.Index
The key.
"""
if not isinstance(key, (str, utils.Index)):
raise DGLError('Argument "key" must be either str or utils.Index type.')
if isinstance(key, str):
del self._frame[key]
else:
self.delete_rows(key)
def delete_rows(self, query):
"""Delete rows.
Please note that "deleted" rows are not really deleted, but simply removed
in the reference. As a result, if two FrameRefs point to the same Frame, deleting
from one ref will not reflect on the other. By contrast, deleting columns is real.
Parameters
----------
query : utils.Index
The rows to be deleted.
"""
query = query.tonumpy()
index = self._index.tonumpy()
self._index = utils.toindex(np.delete(index, query))
def append(self, other):
"""Append another frame into this one.
Parameters
----------
other : dict of str to tensor
The data to be appended.
"""
old_nrows = self._frame.num_rows
self._frame.append(other)
new_nrows = self._frame.num_rows
# update index
if (self._index.slice_data() is not None
and self._index.slice_data().stop == old_nrows):
# Self index is a slice and index.stop is equal to the size of the
# underlying frame. Can still use a slice for the new index.
oldstart = self._index.slice_data().start
self._index = utils.toindex(slice(oldstart, new_nrows))
else:
# convert it to user tensor and concat
selfidxdata = self._index.tousertensor()
newdata = F.arange(old_nrows, new_nrows)
self._index = utils.toindex(F.cat([selfidxdata, newdata], dim=0))
def clear(self):
"""Clear the frame."""
self._frame.clear()
self._index = utils.toindex(slice(0, 0))
def is_contiguous(self):
"""Return whether this refers to a contiguous range of rows."""
# NOTE: this check could have false negatives
return self._index.slice_data() is not None
def is_span_whole_column(self):
"""Return whether this refers to all the rows."""
return self.is_contiguous() and self.num_rows == self._frame.num_rows
def clone(self):
"""Return a new reference to a clone of the underlying frame.
Returns
-------
FrameRef
A cloned frame reference.
See Also
--------
dgl.Frame.clone
"""
return FrameRef(self._frame.clone(), self._index)
def deepclone(self):
"""Return a new reference to a deep clone of the underlying frame.
Returns
-------
FrameRef
A deep-cloned frame reference.
See Also
--------
dgl.Frame.deepclone
"""
return FrameRef(self._frame.deepclone(), self._index)
def _getrows(self, query):
"""Internal function to convert from the local row ids to the row ids of the frame.
Parameters
----------
query : utils.Index
The query index.
Returns
-------
utils.Index
The actual index to the underlying frame.
""" """
return self._index.get_items(query) subcols = {k : col.subcolumn(rowids) for k, col in self._columns.items()}
subf = Frame(subcols, len(rowids))
def frame_like(other, num_rows=None): subf._initializers = self._initializers
"""Create an empty frame that has the same initializer as the given one. subf._default_initializer = self._default_initializer
return subf
Parameters
----------
other : Frame
The given frame.
num_rows : int
The number of rows of the new one. If None, use other.num_rows
(Default: None)
Returns def __repr__(self):
------- return repr(dict(self))
Frame
The new frame.
"""
num_rows = other.num_rows if num_rows is None else num_rows
newf = Frame(num_rows=num_rows)
# set global initializr
if other.get_initializer() is None:
other._set_zero_default_initializer()
sync_frame_initializer(newf, other)
return newf
def sync_frame_initializer(new_frame, reference_frame):
"""Set the initializers of the new_frame to be the same as the reference_frame,
for both the default initializer and per-column initializers.
Parameters
----------
new_frame : Frame
The frame to set initializers
reference_frame : Frame
The frame to copy initializers
"""
new_frame._default_initializer = reference_frame._default_initializer
# set per-col initializer
# TODO(minjie): hack; cannot rely on keys as the _initializers
# now supports non-exist columns.
new_frame._initializers = reference_frame._initializers
...@@ -15,9 +15,9 @@ class TargetCode(object): ...@@ -15,9 +15,9 @@ class TargetCode(object):
EDGE = 2 EDGE = 2
CODE2STR = { CODE2STR = {
0: "src", 0: "u",
1: "dst", 1: "v",
2: "edge", 2: "e",
} }
......
...@@ -9,7 +9,8 @@ from .._deprecate.runtime import ir ...@@ -9,7 +9,8 @@ from .._deprecate.runtime import ir
from .._deprecate.runtime.ir import var from .._deprecate.runtime.ir import var
__all__ = ["src_mul_edge", "copy_src", "copy_edge", "copy_u", "copy_e"] __all__ = ["src_mul_edge", "copy_src", "copy_edge", "copy_u", "copy_e",
"BinaryMessageFunction", "CopyMessageFunction"]
class MessageFunction(BuiltinFunction): class MessageFunction(BuiltinFunction):
......
...@@ -87,7 +87,7 @@ __all__ = [] ...@@ -87,7 +87,7 @@ __all__ = []
def _register_builtin_reduce_func(): def _register_builtin_reduce_func():
"""Register builtin reduce functions""" """Register builtin reduce functions"""
for reduce_op in ["max", "min", "sum", "mean", "prod"]: for reduce_op in ["max", "min", "sum", "mean"]:
builtin = _gen_reduce_builtin(reduce_op) builtin = _gen_reduce_builtin(reduce_op)
setattr(sys.modules[__name__], reduce_op, builtin) setattr(sys.modules[__name__], reduce_op, builtin)
__all__.append(reduce_op) __all__.append(reduce_op)
......
...@@ -8,15 +8,15 @@ import numbers ...@@ -8,15 +8,15 @@ import numbers
import networkx as nx import networkx as nx
import numpy as np import numpy as np
from ._ffi.function import _init_api
from .base import ALL, SLICE_FULL, NTYPE, NID, ETYPE, EID, is_all, DGLError, dgl_warning
from . import core
from . import graph_index from . import graph_index
from . import heterograph_index from . import heterograph_index
from . import utils from . import utils
from . import backend as F from . import backend as F
from ._deprecate.runtime import ir, scheduler, Runtime, GraphAdapter from .frame import Frame
from .frame import Frame, FrameRef, frame_like
from .view import HeteroNodeView, HeteroNodeDataView, HeteroEdgeView, HeteroEdgeDataView from .view import HeteroNodeView, HeteroNodeDataView, HeteroEdgeView, HeteroEdgeDataView
from .base import ALL, SLICE_FULL, NTYPE, NID, ETYPE, EID, is_all, DGLError, dgl_warning
from ._ffi.function import _init_api
__all__ = ['DGLHeteroGraph', 'combine_names'] __all__ = ['DGLHeteroGraph', 'combine_names']
...@@ -181,11 +181,11 @@ class DGLHeteroGraph(object): ...@@ -181,11 +181,11 @@ class DGLHeteroGraph(object):
and its SRC node types and DST node types are given as in the pair. and its SRC node types and DST node types are given as in the pair.
etypes : list of str etypes : list of str
Edge type list. ``etypes[i]`` stores the name of edge type i. Edge type list. ``etypes[i]`` stores the name of edge type i.
node_frames : list of FrameRef, optional node_frames : list[Frame], optional
Node feature storage. If None, empty frame is created. Node feature storage. If None, empty frame is created.
Otherwise, ``node_frames[i]`` stores the node features Otherwise, ``node_frames[i]`` stores the node features
of node type i. (default: None) of node type i. (default: None)
edge_frames : list of FrameRef, optional edge_frames : list[Frame], optional
Edge feature storage. If None, empty frame is created. Edge feature storage. If None, empty frame is created.
Otherwise, ``edge_frames[i]`` stores the edge features Otherwise, ``edge_frames[i]`` stores the edge features
of edge type i. (default: None) of edge type i. (default: None)
...@@ -271,14 +271,14 @@ class DGLHeteroGraph(object): ...@@ -271,14 +271,14 @@ class DGLHeteroGraph(object):
# node and edge frame # node and edge frame
if node_frames is None: if node_frames is None:
node_frames = [None] * len(self._ntypes) node_frames = [None] * len(self._ntypes)
node_frames = [FrameRef(Frame(num_rows=self._graph.number_of_nodes(i))) node_frames = [Frame(num_rows=self._graph.number_of_nodes(i))
if frame is None else frame if frame is None else frame
for i, frame in enumerate(node_frames)] for i, frame in enumerate(node_frames)]
self._node_frames = node_frames self._node_frames = node_frames
if edge_frames is None: if edge_frames is None:
edge_frames = [None] * len(self._etypes) edge_frames = [None] * len(self._etypes)
edge_frames = [FrameRef(Frame(num_rows=self._graph.number_of_edges(i))) edge_frames = [Frame(num_rows=self._graph.number_of_edges(i))
if frame is None else frame if frame is None else frame
for i, frame in enumerate(edge_frames)] for i, frame in enumerate(edge_frames)]
self._edge_frames = edge_frames self._edge_frames = edge_frames
...@@ -2419,408 +2419,6 @@ class DGLHeteroGraph(object): ...@@ -2419,408 +2419,6 @@ class DGLHeteroGraph(object):
else: else:
return deg return deg
def _create_hetero_subgraph(self, sgi, induced_nodes, induced_edges):
"""Internal function to create a subgraph."""
node_frames = []
for i, ind_nodes in enumerate(induced_nodes):
subframe = self._node_frames[i][utils.toindex(ind_nodes, self._idtype_str)]
node_frames.append(FrameRef(Frame(subframe, num_rows=len(ind_nodes))))
edge_frames = []
for i, ind_edges in enumerate(induced_edges):
subframe = self._edge_frames[i][utils.toindex(ind_edges, self._idtype_str)]
edge_frames.append(FrameRef(Frame(subframe, num_rows=len(ind_edges))))
hsg = DGLHeteroGraph(sgi.graph, self._ntypes, self._etypes, node_frames, edge_frames)
for ntype, induced_nid in zip(self.ntypes, induced_nodes):
ndata = hsg.nodes[ntype].data
orig_ndata = self.nodes[ntype].data
ndata[NID] = induced_nid
for key in orig_ndata:
ndata[key] = F.gather_row(orig_ndata[key], induced_nid)
for etype, induced_eid in zip(self.canonical_etypes, induced_edges):
edata = hsg.edges[etype].data
orig_edata = self.edges[etype].data
edata[EID] = induced_eid
for key in orig_edata:
edata[key] = F.gather_row(orig_edata[key], induced_eid)
return hsg
def subgraph(self, nodes):
"""Return the subgraph induced on given nodes.
The metagraph of the returned subgraph is the same as the parent graph.
Features are copied from the original graph.
Parameters
----------
nodes : list or dict[str->list or iterable]
A dictionary mapping node types to node ID array for constructing
subgraph. All nodes must exist in the graph.
If the graph only has one node type, one can just specify a list,
tensor, or any iterable of node IDs intead.
The node ID array can be either an interger tensor or a bool tensor.
When a bool tensor is used, it is automatically converted to
an interger tensor using the semantic of np.where(nodes_idx == True).
Note: When using bool tensor, only backend (torch, tensorflow, mxnet)
tensors are supported.
Returns
-------
G : DGLHeteroGraph
The subgraph.
The nodes and edges in the subgraph are relabeled using consecutive
integers from 0.
One can retrieve the mapping from subgraph node/edge ID to parent
node/edge ID via ``dgl.NID`` and ``dgl.EID`` node/edge features of the
subgraph.
Examples
--------
The following example uses PyTorch backend.
Instantiate a heterograph.
>>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
>>> follows_g = dgl.graph(([0, 1, 1], [1, 2, 2]), 'user', 'follows')
>>> g = dgl.hetero_from_relations([plays_g, follows_g])
>>> # Set node features
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
Get subgraphs.
>>> g.subgraph({'user': [4, 5]})
An error occurs as these nodes do not exist.
>>> sub_g = g.subgraph({'user': [1, 2]})
>>> print(sub_g)
Graph(num_nodes={'user': 2, 'game': 0},
num_edges={('user', 'plays', 'game'): 0, ('user', 'follows', 'user'): 2},
metagraph=[('user', 'game'), ('user', 'user')])
Get subgraphs using boolean mask tensor.
>>> sub_g = g.subgraph({'user': th.tensor([False, True, True])})
>>> print(sub_g)
Graph(num_nodes={'user': 2, 'game': 0},
num_edges={('user', 'plays', 'game'): 0, ('user', 'follows', 'user'): 2},
metagraph=[('user', 'game'), ('user', 'user')])
Get the original node/edge indices.
>>> sub_g['follows'].ndata[dgl.NID] # Get the node indices in the raw graph
tensor([1, 2])
>>> sub_g['follows'].edata[dgl.EID] # Get the edge indices in the raw graph
tensor([1, 2])
Get the copied node features.
>>> sub_g.nodes['user'].data['h']
tensor([[1.],
[2.]])
>>> sub_g.nodes['user'].data['h'] += 1
>>> g.nodes['user'].data['h'] # Features are not shared.
tensor([[0.],
[1.],
[2.]])
See Also
--------
edge_subgraph
"""
if self.is_block:
raise DGLError('Extracting subgraph from a block graph is not allowed.')
if not isinstance(nodes, Mapping):
assert len(self.ntypes) == 1, \
'need a dict of node type and IDs for graph with multiple node types'
nodes = {self.ntypes[0]: nodes}
def _process_nodes(ntype, v):
if F.is_tensor(v) and F.dtype(v) == F.bool:
return F.astype(F.nonzero_1d(F.copy_to(v, self.device)), self.idtype)
else:
return utils.prepare_tensor(self, v, 'nodes["{}"]'.format(ntype))
induced_nodes = [_process_nodes(ntype, nodes.get(ntype, [])) for ntype in self.ntypes]
sgi = self._graph.node_subgraph(induced_nodes)
induced_edges = sgi.induced_edges
return self._create_hetero_subgraph(sgi, induced_nodes, induced_edges)
def edge_subgraph(self, edges, preserve_nodes=False):
"""Return the subgraph induced on given edges.
The metagraph of the returned subgraph is the same as the parent graph.
Features are copied from the original graph.
Parameters
----------
edges : dict[str->list or iterable]
A dictionary mapping edge types to edge ID array for constructing
subgraph. All edges must exist in the subgraph.
The edge types are characterized by triplets of
``(src type, etype, dst type)``.
If the graph only has one edge type, one can just specify a list,
tensor, or any iterable of edge IDs intead.
The edge ID array can be either an interger tensor or a bool tensor.
When a bool tensor is used, it is automatically converted to
an interger tensor using the semantic of np.where(edges_idx == True).
Note: When using bool tensor, only backend (torch, tensorflow, mxnet)
tensors are supported.
preserve_nodes : bool
Whether to preserve all nodes or not. If false, all nodes
without edges will be removed. (Default: False)
Returns
-------
G : DGLHeteroGraph
The subgraph.
The nodes and edges are relabeled using consecutive integers from 0.
One can retrieve the mapping from subgraph node/edge ID to parent
node/edge ID via ``dgl.NID`` and ``dgl.EID`` node/edge features of the
subgraph.
Examples
--------
The following example uses PyTorch backend.
Instantiate a heterograph.
>>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
>>> follows_g = dgl.graph(([0, 1, 1], [1, 2, 2]), 'user', 'follows')
>>> g = dgl.hetero_from_relations([plays_g, follows_g])
>>> # Set edge features
>>> g.edges['follows'].data['h'] = torch.tensor([[0.], [1.], [2.]])
Get subgraphs.
>>> g.edge_subgraph({('user', 'follows', 'user'): [5, 6]})
An error occurs as these edges do not exist.
>>> sub_g = g.edge_subgraph({('user', 'follows', 'user'): [1, 2],
>>> ('user', 'plays', 'game'): [2]})
>>> print(sub_g)
Graph(num_nodes={'user': 2, 'game': 1},
num_edges={('user', 'plays', 'game'): 1, ('user', 'follows', 'user'): 2},
metagraph=[('user', 'game'), ('user', 'user')])
Get subgraphs using boolean mask tensor.
>>> sub_g = g.edge_subgraph({('user', 'follows', 'user'): th.tensor([False, True, True]),
>>> ('user', 'plays', 'game'): th.tensor([False, False, True, False])})
>>> sub_g
Graph(num_nodes={'user': 2, 'game': 1},
num_edges={('user', 'plays', 'game'): 1, ('user', 'follows', 'user'): 2},
metagraph=[('user', 'game'), ('user', 'user')])
Get the original node/edge indices.
>>> sub_g['follows'].ndata[dgl.NID] # Get the node indices in the raw graph
tensor([1, 2])
>>> sub_g['plays'].edata[dgl.EID] # Get the edge indices in the raw graph
tensor([2])
Get the copied node features.
>>> sub_g.edges['follows'].data['h']
tensor([[1.],
[2.]])
>>> sub_g.edges['follows'].data['h'] += 1
>>> g.edges['follows'].data['h'] # Features are not shared.
tensor([[0.],
[1.],
[2.]])
See Also
--------
subgraph
"""
if self.is_block:
raise DGLError('Extracting subgraph from a block graph is not allowed.')
if not isinstance(edges, Mapping):
assert len(self.canonical_etypes) == 1, \
'need a dict of edge type and IDs for graph with multiple edge types'
edges = {self.canonical_etypes[0]: edges}
def _process_edges(etype, e):
if F.is_tensor(e) and F.dtype(e) == F.bool:
return F.astype(F.nonzero_1d(F.copy_to(e, self.device)), self.idtype)
else:
return utils.prepare_tensor(self, e, 'edges["{}"]'.format(etype))
edges = {self.to_canonical_etype(etype): e for etype, e in edges.items()}
induced_edges = [
_process_edges(cetype, edges.get(cetype, []))
for cetype in self.canonical_etypes]
sgi = self._graph.edge_subgraph(induced_edges, preserve_nodes)
induced_nodes = sgi.induced_nodes
return self._create_hetero_subgraph(sgi, induced_nodes, induced_edges)
def node_type_subgraph(self, ntypes):
"""Return the subgraph induced on given node types.
The metagraph of the returned subgraph is the subgraph of the original
metagraph induced from the node types.
Features are shared with the original graph.
Parameters
----------
ntypes : list[str]
The node types
Returns
-------
G : DGLHeteroGraph
The subgraph.
Examples
--------
The following example uses PyTorch backend.
Instantiate a heterograph.
>>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
>>> follows_g = dgl.graph(([0, 1, 1], [1, 2, 2]), 'user', 'follows')
>>> g = dgl.hetero_from_relations([plays_g, follows_g])
>>> # Set node features
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
Get subgraphs.
>>> sub_g = g.node_type_subgraph(['user'])
>>> print(sub_g)
Graph(num_nodes=3, num_edges=3,
ndata_schemes={'h': Scheme(shape=(1,), dtype=torch.float32)}
edata_schemes={})
Get the shared node features.
>>> sub_g.nodes['user'].data['h']
tensor([[0.],
[1.],
[2.]])
>>> sub_g.nodes['user'].data['h'] += 1
>>> g.nodes['user'].data['h'] # Features are shared.
tensor([[1.],
[2.],
[3.]])
See Also
--------
edge_type_subgraph
"""
rel_graphs = []
meta_edges = []
induced_etypes = []
node_frames = [self._node_frames[self.get_ntype_id(ntype)] for ntype in ntypes]
edge_frames = []
num_nodes_per_type = [self.number_of_nodes(ntype) for ntype in ntypes]
ntypes_invmap = {ntype: i for i, ntype in enumerate(ntypes)}
srctype_id, dsttype_id, _ = self._graph.metagraph.edges('eid')
for i in range(len(self._etypes)):
srctype = self._ntypes[srctype_id[i]]
dsttype = self._ntypes[dsttype_id[i]]
if srctype in ntypes and dsttype in ntypes:
meta_edges.append((ntypes_invmap[srctype], ntypes_invmap[dsttype]))
rel_graphs.append(self._graph.get_relation_graph(i))
induced_etypes.append(self.etypes[i])
edge_frames.append(self._edge_frames[i])
metagraph = graph_index.from_edge_list(meta_edges, True)
# num_nodes_per_type doesn't need to be int32
hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, rel_graphs, utils.toindex(num_nodes_per_type, "int64"))
hg = DGLHeteroGraph(hgidx, ntypes, induced_etypes,
node_frames, edge_frames)
return hg
def edge_type_subgraph(self, etypes):
"""Return the subgraph induced on given edge types.
The metagraph of the returned subgraph is the subgraph of the original metagraph
induced from the edge types.
Features are shared with the original graph.
Parameters
----------
etypes : list[str or tuple]
The edge types
Returns
-------
G : DGLHeteroGraph
The subgraph.
Examples
--------
The following example uses PyTorch backend.
Instantiate a heterograph.
>>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
>>> follows_g = dgl.graph(([0, 1, 1], [1, 2, 2]), 'user', 'follows')
>>> g = dgl.hetero_from_relations([plays_g, follows_g])
>>> # Set edge features
>>> g.edges['follows'].data['h'] = torch.tensor([[0.], [1.], [2.]])
Get subgraphs.
>>> sub_g = g.edge_type_subgraph(['follows'])
>>> print(sub_g)
Graph(num_nodes=3, num_edges=3,
ndata_schemes={}
edata_schemes={'h': Scheme(shape=(1,), dtype=torch.float32)})
Get the shared edge features.
>>> sub_g.edges['follows'].data['h']
tensor([[0.],
[1.],
[2.]])
>>> sub_g.edges['follows'].data['h'] += 1
>>> g.edges['follows'].data['h'] # Features are shared.
tensor([[1.],
[2.],
[3.]])
See Also
--------
node_type_subgraph
"""
etype_ids = [self.get_etype_id(etype) for etype in etypes]
# meta graph is homograph, still using int64
meta_src, meta_dst, _ = self._graph.metagraph.find_edges(utils.toindex(etype_ids, "int64"))
rel_graphs = [self._graph.get_relation_graph(i) for i in etype_ids]
meta_src = meta_src.tonumpy()
meta_dst = meta_dst.tonumpy()
ntypes_invmap = {n: i for i, n in enumerate(set(meta_src) | set(meta_dst))}
mapped_meta_src = [ntypes_invmap[v] for v in meta_src]
mapped_meta_dst = [ntypes_invmap[v] for v in meta_dst]
node_frames = [self._node_frames[i] for i in ntypes_invmap]
edge_frames = [self._edge_frames[i] for i in etype_ids]
induced_ntypes = [self._ntypes[i] for i in ntypes_invmap]
induced_etypes = [self._etypes[i] for i in etype_ids] # get the "name" of edge type
num_nodes_per_induced_type = [self.number_of_nodes(ntype) for ntype in induced_ntypes]
metagraph = graph_index.from_edge_list((mapped_meta_src, mapped_meta_dst), True)
# num_nodes_per_type should be int64
hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, rel_graphs, utils.toindex(num_nodes_per_induced_type, "int64"))
hg = DGLHeteroGraph(hgidx, induced_ntypes, induced_etypes, node_frames, edge_frames)
return hg
def adjacency_matrix(self, transpose=None, ctx=F.cpu(), scipy_fmt=None, etype=None): def adjacency_matrix(self, transpose=None, ctx=F.cpu(), scipy_fmt=None, etype=None):
"""Return the adjacency matrix of edges of the given edge type. """Return the adjacency matrix of edges of the given edge type.
...@@ -3101,7 +2699,7 @@ class DGLHeteroGraph(object): ...@@ -3101,7 +2699,7 @@ class DGLHeteroGraph(object):
etid = self.get_etype_id(etype) etid = self.get_etype_id(etype)
self._edge_frames[etid].set_initializer(initializer, field) self._edge_frames[etid].set_initializer(initializer, field)
def _set_n_repr(self, ntid, u, data, inplace=False): def _set_n_repr(self, ntid, u, data):
"""Internal API to set node features. """Internal API to set node features.
`data` is a dictionary from the feature name to feature tensor. Each tensor `data` is a dictionary from the feature name to feature tensor. Each tensor
...@@ -3109,8 +2707,7 @@ class DGLHeteroGraph(object): ...@@ -3109,8 +2707,7 @@ class DGLHeteroGraph(object):
and (D1, D2, ...) be the shape of the node representation tensor. The and (D1, D2, ...) be the shape of the node representation tensor. The
length of the given node ids must match B (i.e, len(u) == B). length of the given node ids must match B (i.e, len(u) == B).
All update will be done out of place to work with autograd unless the All updates will be done out of place to work with autograd.
inplace flag is true.
Parameters Parameters
---------- ----------
...@@ -3120,9 +2717,6 @@ class DGLHeteroGraph(object): ...@@ -3120,9 +2717,6 @@ class DGLHeteroGraph(object):
The node(s). The node(s).
data : dict of tensor data : dict of tensor
Node representation. Node representation.
inplace : bool, optional
If True, update will be done in place, but autograd will break.
(Default: False)
""" """
if is_all(u): if is_all(u):
num_nodes = self._graph.number_of_nodes(ntid) num_nodes = self._graph.number_of_nodes(ntid)
...@@ -3140,11 +2734,9 @@ class DGLHeteroGraph(object): ...@@ -3140,11 +2734,9 @@ class DGLHeteroGraph(object):
' same device.'.format(key, F.context(val), self.device)) ' same device.'.format(key, F.context(val), self.device))
if is_all(u): if is_all(u):
for key, val in data.items(): self._node_frames[ntid].update(data)
self._node_frames[ntid][key] = val
else: else:
u = utils.toindex(u, self._idtype_str) self._node_frames[ntid].update_row(u, data)
self._node_frames[ntid].update_rows(u, data, inplace=inplace)
def _get_n_repr(self, ntid, u): def _get_n_repr(self, ntid, u):
"""Get node(s) representation of a single node type. """Get node(s) representation of a single node type.
...@@ -3167,8 +2759,7 @@ class DGLHeteroGraph(object): ...@@ -3167,8 +2759,7 @@ class DGLHeteroGraph(object):
return dict(self._node_frames[ntid]) return dict(self._node_frames[ntid])
else: else:
u = utils.prepare_tensor(self, u, 'u') u = utils.prepare_tensor(self, u, 'u')
u = utils.toindex(u, self._idtype_str) return self._node_frames[ntid].subframe(u)
return self._node_frames[ntid].select_rows(u)
def _pop_n_repr(self, ntid, key): def _pop_n_repr(self, ntid, key):
"""Internal API to get and remove the specified node feature. """Internal API to get and remove the specified node feature.
...@@ -3187,15 +2778,14 @@ class DGLHeteroGraph(object): ...@@ -3187,15 +2778,14 @@ class DGLHeteroGraph(object):
""" """
return self._node_frames[ntid].pop(key) return self._node_frames[ntid].pop(key)
def _set_e_repr(self, etid, edges, data, inplace=False): def _set_e_repr(self, etid, edges, data):
"""Internal API to set edge(s) features. """Internal API to set edge(s) features.
`data` is a dictionary from the feature name to feature tensor. Each tensor `data` is a dictionary from the feature name to feature tensor. Each tensor
is of shape (B, D1, D2, ...), where B is the number of edges to be updated, is of shape (B, D1, D2, ...), where B is the number of edges to be updated,
and (D1, D2, ...) be the shape of the edge representation tensor. and (D1, D2, ...) be the shape of the edge representation tensor.
All update will be done out of place to work with autograd unless the All update will be done out of place to work with autograd.
inplace flag is true.
Parameters Parameters
---------- ----------
...@@ -3211,27 +2801,17 @@ class DGLHeteroGraph(object): ...@@ -3211,27 +2801,17 @@ class DGLHeteroGraph(object):
The default value is all the edges. The default value is all the edges.
data : tensor or dict of tensor data : tensor or dict of tensor
Edge representation. Edge representation.
inplace : bool, optional
If True, update will be done in place, but autograd will break.
(Default: False)
""" """
# parse argument # parse argument
if is_all(edges): if not is_all(edges):
eid = ALL eid = utils.parse_edges_arg_to_eid(self, edges, etid, 'edges')
elif isinstance(edges, tuple):
u, v = edges
u = utils.prepare_tensor(self, u, 'edges[0]')
v = utils.prepare_tensor(self, v, 'edges[1]')
eid = self.edge_ids(u, v, etype=self.canonical_etypes[etid])
else:
eid = utils.prepare_tensor(self, edges, 'edges')
# sanity check # sanity check
if not utils.is_dict_like(data): if not utils.is_dict_like(data):
raise DGLError('Expect dictionary type for feature data.' raise DGLError('Expect dictionary type for feature data.'
' Got "%s" instead.' % type(data)) ' Got "%s" instead.' % type(data))
if is_all(eid): if is_all(edges):
num_edges = self._graph.number_of_edges(etid) num_edges = self._graph.number_of_edges(etid)
else: else:
num_edges = len(eid) num_edges = len(eid)
...@@ -3246,14 +2826,10 @@ class DGLHeteroGraph(object): ...@@ -3246,14 +2826,10 @@ class DGLHeteroGraph(object):
' same device.'.format(key, F.context(val), self.device)) ' same device.'.format(key, F.context(val), self.device))
# set # set
if is_all(eid): if is_all(edges):
# update column self._edge_frames[etid].update(data)
for key, val in data.items():
self._edge_frames[etid][key] = val
else: else:
# update row self._edge_frames[etid].update_row(eid, data)
eid = utils.toindex(eid, self._idtype_str)
self._edge_frames[etid].update_rows(eid, data, inplace=inplace)
def _get_e_repr(self, etid, edges): def _get_e_repr(self, etid, edges):
"""Internal API to get edge features. """Internal API to get edge features.
...@@ -3273,20 +2849,10 @@ class DGLHeteroGraph(object): ...@@ -3273,20 +2849,10 @@ class DGLHeteroGraph(object):
""" """
# parse argument # parse argument
if is_all(edges): if is_all(edges):
eid = ALL
elif isinstance(edges, tuple):
u, v = edges
u = utils.prepare_tensor(self, u, 'edges[0]')
v = utils.prepare_tensor(self, v, 'edges[1]')
eid = self.edge_ids(u, v, etype=self.canonical_etypes[etid])
else:
eid = utils.prepare_tensor(self, edges, 'edges')
if is_all(eid):
return dict(self._edge_frames[etid]) return dict(self._edge_frames[etid])
else: else:
eid = utils.toindex(eid, self._idtype_str) eid = utils.parse_edges_arg_to_eid(self, edges, etid, 'edges')
return self._edge_frames[etid].select_rows(eid) return self._edge_frames[etid].subframe(eid)
def _pop_e_repr(self, etid, key): def _pop_e_repr(self, etid, key):
"""Get and remove the specified edge repr of a single edge type. """Get and remove the specified edge repr of a single edge type.
...@@ -3305,65 +2871,10 @@ class DGLHeteroGraph(object): ...@@ -3305,65 +2871,10 @@ class DGLHeteroGraph(object):
""" """
self._edge_frames[etid].pop(key) self._edge_frames[etid].pop(key)
#################################################################
# DEPRECATED: from the old DGLGraph
#################################################################
def from_networkx(self, nx_graph, node_attrs=None, edge_attrs=None):
"""DEPRECATED: please use
``dgl.from_networkx(nx_graph, node_attrs, edge_attrs)``
which will return a new graph created from the networkx graph.
"""
raise DGLError('DGLGraph.from_networkx is deprecated. Please call the following\n\n'
'\t dgl.from_networkx(nx_graph, node_attrs, edge_attrs)\n\n'
', which creates a new DGLGraph from the networkx graph.')
def from_scipy_sparse_matrix(self, spmat, multigraph=None):
"""DEPRECATED: please use
``dgl.from_scipy(spmat)``
which will return a new graph created from the scipy matrix.
"""
raise DGLError('DGLGraph.from_scipy_sparse_matrix is deprecated. '
'Please call the following\n\n'
'\t dgl.from_scipy(spmat)\n\n'
', which creates a new DGLGraph from the scipy matrix.')
################################################################# #################################################################
# Message passing # Message passing
################################################################# #################################################################
def register_apply_node_func(self, func):
"""Deprecated: please directly call :func:`apply_nodes` with ``func``
as argument.
"""
raise DGLError('DGLGraph.register_apply_node_func is deprecated.'
' Please directly call apply_nodes with func as the argument.')
def register_apply_edge_func(self, func):
"""Deprecated: please directly call :func:`apply_edges` with ``func``
as argument.
"""
raise DGLError('DGLGraph.register_apply_edge_func is deprecated.'
' Please directly call apply_edges with func as the argument.')
def register_message_func(self, func):
"""Deprecated: please directly call :func:`update_all` with ``func``
as argument.
"""
raise DGLError('DGLGraph.register_message_func is deprecated.'
' Please directly call update_all with func as the argument.')
def register_reduce_func(self, func):
"""Deprecated: please directly call :func:`update_all` with ``func``
as argument.
"""
raise DGLError('DGLGraph.register_reduce_func is deprecated.'
' Please directly call update_all with func as the argument.')
def apply_nodes(self, func, v=ALL, ntype=None, inplace=False): def apply_nodes(self, func, v=ALL, ntype=None, inplace=False):
"""Apply the function on the nodes with the same type to update their """Apply the function on the nodes with the same type to update their
features. features.
...@@ -3381,7 +2892,7 @@ class DGLHeteroGraph(object): ...@@ -3381,7 +2892,7 @@ class DGLHeteroGraph(object):
The node type. Can be omitted if there is only one node type The node type. Can be omitted if there is only one node type
in the graph. (Default: None) in the graph. (Default: None)
inplace : bool, optional inplace : bool, optional
If True, update will be done in place, but autograd will break. **DEPRECATED**. If True, update will be done in place, but autograd will break.
(Default: False) (Default: False)
Examples Examples
...@@ -3398,16 +2909,16 @@ class DGLHeteroGraph(object): ...@@ -3398,16 +2909,16 @@ class DGLHeteroGraph(object):
-------- --------
apply_edges apply_edges
""" """
if inplace:
raise DGLError('The `inplace` option is removed in v0.5.')
ntid = self.get_ntype_id(ntype) ntid = self.get_ntype_id(ntype)
ntype = self.ntypes[ntid]
if is_all(v): if is_all(v):
v = F.arange(0, self.number_of_nodes(ntype), self.idtype) v = self.nodes(ntype)
else: else:
v = utils.prepare_tensor(self, v, 'v') v = utils.prepare_tensor(self, v, 'v')
with ir.prog() as prog: ndata = core.invoke_node_udf(self, v, ntype, func, orig_nid=v)
v_ntype = utils.toindex(v, self._idtype_str) self._set_n_repr(ntid, v, ndata)
scheduler.schedule_apply_nodes(v_ntype, func, self._node_frames[ntid],
inplace=inplace, ntype=self._ntypes[ntid])
Runtime.run(prog)
def apply_edges(self, func, edges=ALL, etype=None, inplace=False): def apply_edges(self, func, edges=ALL, etype=None, inplace=False):
"""Apply the function on the edges with the same type to update their """Apply the function on the edges with the same type to update their
...@@ -3417,7 +2928,7 @@ class DGLHeteroGraph(object): ...@@ -3417,7 +2928,7 @@ class DGLHeteroGraph(object):
Parameters Parameters
---------- ----------
func : callable or None func : callable
Apply function on the edge. The function should be Apply function on the edge. The function should be
an :mod:`Edge UDF <dgl.udf>`. an :mod:`Edge UDF <dgl.udf>`.
edges : optional edges : optional
...@@ -3427,8 +2938,7 @@ class DGLHeteroGraph(object): ...@@ -3427,8 +2938,7 @@ class DGLHeteroGraph(object):
The edge type. Can be omitted if there is only one edge type The edge type. Can be omitted if there is only one edge type
in the graph. (Default: None) in the graph. (Default: None)
inplace: bool, optional inplace: bool, optional
If True, update will be done in place, but autograd will break. **DEPRECATED**. Must be False.
(Default: False)
Examples Examples
-------- --------
...@@ -3444,149 +2954,39 @@ class DGLHeteroGraph(object): ...@@ -3444,149 +2954,39 @@ class DGLHeteroGraph(object):
See Also See Also
-------- --------
apply_nodes apply_nodes
group_apply_edges
""" """
if inplace:
raise DGLError('The `inplace` option is removed in v0.5.')
etid = self.get_etype_id(etype) etid = self.get_etype_id(etype)
stid, dtid = self._graph.metagraph.find_edge(etid) etype = self.canonical_etypes[etid]
g = self if etype is None else self[etype]
if is_all(edges): if is_all(edges):
u, v, _ = self.edges(etype=etype, form='all') eid = ALL
# TODO(minjie): temporary hack
eid = utils.toindex(slice(0, self.number_of_edges(etype)), self._idtype_str)
elif isinstance(edges, tuple):
u, v = edges
u = utils.prepare_tensor(self, u, 'edges[0]')
v = utils.prepare_tensor(self, v, 'edges[1]')
eid = self.edge_ids(u, v, etype=etype)
else: else:
eid = utils.prepare_tensor(self, edges, 'edges') eid = utils.parse_edges_arg_to_eid(self, edges, etid, 'edges')
u, v = self.find_edges(eid, etype=etype) if core.is_builtin(func):
if not is_all(eid):
with ir.prog() as prog: g = g.edge_subgraph(eid, preserve_nodes=True)
u = utils.toindex(u, self._idtype_str) edata = core.invoke_gsddmm(g, func)
v = utils.toindex(v, self._idtype_str) else:
eid = utils.toindex(eid, self._idtype_str) edata = core.invoke_edge_udf(g, eid, etype, func)
scheduler.schedule_apply_edges( self._set_e_repr(etid, eid, edata)
AdaptedHeteroGraph(self, stid, dtid, etid),
u, v, eid, func, inplace=inplace)
Runtime.run(prog)
def group_apply_edges(self, group_by, func, edges=ALL, etype=None, inplace=False): def send_and_recv(self,
"""Group the edges by nodes and apply the function of the grouped edges,
edges to update their features. The edges are of the same edge type message_func,
(hence having the same source and destination node type). reduce_func,
apply_node_func=None,
etype=None,
inplace=False):
"""Send messages along edges of the specified type, and let destinations
receive them.
Parameters Optionally, apply a function to update the node features after "receive".
----------
group_by : str
Specify how to group edges. Expected to be either ``'src'`` or ``'dst'``
func : callable
Apply function on the edge. The function should be an
:mod:`Edge UDF <dgl.udf>`. The input of `Edge UDF` should be
(bucket_size, degrees, *feature_shape), and return the dict
with values of the same shapes.
edges : optional
Edges on which to group and apply ``func``. See :func:`send` for valid
edge specification. Default is all the edges.
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph. (Default: None)
inplace: bool, optional
If True, update will be done in place, but autograd will break.
(Default: False)
Examples This is a convenient combination for performing
-------- :mod:`send <dgl.DGLHeteroGraph.send>` along the ``edges`` and
>>> g = dgl.graph(([0, 0, 1], [1, 2, 2]), 'user', 'follows') :mod:`recv <dgl.DGLHeteroGraph.recv>` for the destinations of the ``edges``.
>>> g.edata['feat'] = torch.randn((g.number_of_edges(), 1))
>>> def softmax_feat(edges):
>>> return {'norm_feat': th.softmax(edges.data['feat'], dim=1)}
>>> g.group_apply_edges(group_by='src', func=softmax_feat)
>>> g.edata['norm_feat']
tensor([[0.3796],
[0.6204],
[1.0000]])
See Also
--------
apply_edges
"""
if group_by not in ('src', 'dst'):
raise DGLError("Group_by should be either src or dst")
etid = self.get_etype_id(etype)
stid, dtid = self._graph.metagraph.find_edge(etid)
if is_all(edges):
u, v, eid = self.edges(etype=etype, form='all')
elif isinstance(edges, tuple):
u, v = edges
u = utils.prepare_tensor(self, u, 'edges[0]')
v = utils.prepare_tensor(self, v, 'edges[1]')
eid = self.edge_ids(u, v, etype=etype)
else:
eid = utils.prepare_tensor(self, edges, 'edges')
u, v = self.find_edges(eid, etype=etype)
with ir.prog() as prog:
u = utils.toindex(u, self._idtype_str)
v = utils.toindex(v, self._idtype_str)
eid = utils.toindex(eid, self._idtype_str)
scheduler.schedule_group_apply_edge(
AdaptedHeteroGraph(self, stid, dtid, etid),
u, v, eid,
func, group_by,
inplace=inplace)
Runtime.run(prog)
def send(self, edges, message_func, etype=None):
"""Send messages along the given edges with the same edge type.
DEPRECATE: please use send_and_recv, update_all.
"""
raise DGLError('DGLGraph.send is deprecated. As a replacement, use DGLGraph.apply_edges\n'
' API to compute messages as edge data. Then use DGLGraph.send_and_recv\n'
' and set the message function as dgl.function.copy_e to conduct message\n'
' aggregation.')
def recv(self,
v,
reduce_func,
apply_node_func=None,
etype=None,
inplace=False):
r"""Receive and reduce incoming messages and update the features of node(s) :math:`v`.
DEPRECATE: please use send_and_recv, update_all.
"""
raise DGLError('DGLGraph.recv is deprecated. As a replacement, use DGLGraph.apply_edges\n'
' API to compute messages as edge data. Then use DGLGraph.send_and_recv\n'
' and set the message function as dgl.function.copy_e to conduct message\n'
' aggregation.')
def multi_recv(self, v, reducer_dict, cross_reducer, apply_node_func=None, inplace=False):
r"""Receive messages from multiple edge types and perform aggregation.
DEPRECATE: please use multi_send_and_recv, multi_update_all.
"""
raise DGLError('DGLGraph.multi_recv is deprecated. As a replacement,\n'
' use DGLGraph.apply_edges API to compute messages as edge data.\n'
' Then use DGLGraph.multi_send_and_recv and set the message function\n'
' as dgl.function.copy_e to conduct message aggregation.')
def send_and_recv(self,
edges,
message_func,
reduce_func,
apply_node_func=None,
etype=None,
inplace=False):
"""Send messages along edges of the specified type, and let destinations
receive them.
Optionally, apply a function to update the node features after "receive".
This is a convenient combination for performing
:mod:`send <dgl.DGLHeteroGraph.send>` along the ``edges`` and
:mod:`recv <dgl.DGLHeteroGraph.recv>` for the destinations of the ``edges``.
**Only works if the graph has one edge type.** For multiple types, use **Only works if the graph has one edge type.** For multiple types, use
...@@ -3612,8 +3012,7 @@ class DGLHeteroGraph(object): ...@@ -3612,8 +3012,7 @@ class DGLHeteroGraph(object):
The edge type. Can be omitted if there is only one edge type The edge type. Can be omitted if there is only one edge type
in the graph. (Default: None) in the graph. (Default: None)
inplace: bool, optional inplace: bool, optional
If True, update will be done in place, but autograd will break. **DEPRECATED**. Must be False.
(Default: False)
Examples Examples
-------- --------
...@@ -3633,152 +3032,23 @@ class DGLHeteroGraph(object): ...@@ -3633,152 +3032,23 @@ class DGLHeteroGraph(object):
[0.], [0.],
[1.]]) [1.]])
""" """
if inplace:
raise DGLError('The `inplace` option is removed in v0.5.')
# edge type
etid = self.get_etype_id(etype) etid = self.get_etype_id(etype)
stid, dtid = self._graph.metagraph.find_edge(etid) _, dtid = self._graph.metagraph.find_edge(etid)
etype = self.canonical_etypes[etid]
if isinstance(edges, tuple): # edge IDs
u, v = edges eid = utils.parse_edges_arg_to_eid(self, edges, etid, 'edges')
u = utils.prepare_tensor(self, u, 'edges[0]') if len(eid) == 0:
v = utils.prepare_tensor(self, v, 'edges[1]') # no computation
eid = self.edge_ids(u, v, etype=etype)
else:
eid = utils.prepare_tensor(self, edges, 'edges')
u, v = self.find_edges(eid, etype=etype)
if len(u) == 0:
# no edges to be triggered
return return
u, v = self.find_edges(eid, etype=etype)
with ir.prog() as prog: # call message passing onsubgraph
u = utils.toindex(u, self._idtype_str) ndata = core.message_passing(_create_compute_graph(self, u, v, eid),
v = utils.toindex(v, self._idtype_str) message_func, reduce_func, apply_node_func)
eid = utils.toindex(eid, self._idtype_str) dstnodes = F.unique(v)
scheduler.schedule_snr(AdaptedHeteroGraph(self, stid, dtid, etid), self._set_n_repr(dtid, dstnodes, ndata)
(u, v, eid),
message_func, reduce_func, apply_node_func,
inplace=inplace)
Runtime.run(prog)
def multi_send_and_recv(self, etype_dict, cross_reducer, apply_node_func=None, inplace=False):
r"""Send and receive messages along multiple edge types and perform aggregation.
Optionally, apply a function to update the node features after "receive".
This is a convenient combination for performing multiple
:mod:`send <dgl.DGLHeteroGraph.send>` along edges of different types and
:mod:`multi_recv <dgl.DGLHeteroGraph.multi_recv>` for the destinations of all edges.
Parameters
----------
etype_dict : dict
Mapping an edge type (str or tuple of str) to the type specific
configuration (4-tuples). Each 4-tuple represents
(edges, msg_func, reduce_func, apply_node_func):
* edges: See send() for valid edge specification.
Edges on which to pass messages.
* msg_func: callable
Message function on the edges. The function should be
an :mod:`Edge UDF <dgl.udf>`.
* reduce_func: callable
Reduce function on the node. The function should be
a :mod:`Node UDF <dgl.udf>`.
* apply_node_func : callable, optional
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`. (Default: None)
cross_reducer : str
Cross type reducer. One of ``"sum"``, ``"min"``, ``"max"``, ``"mean"``, ``"stack"``.
apply_node_func : callable
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`. (Default: None)
inplace: bool, optional
If True, update will be done in place, but autograd will break.
(Default: False)
Examples
--------
>>> import dgl
>>> import dgl.function as fn
>>> import torch
Instantiate a heterograph.
>>> g1 = dgl.graph(([0], [1]), 'user', 'follows')
>>> g2 = dgl.bipartite(([0], [1]), 'game', 'attracts', 'user')
>>> g = dgl.hetero_from_relations([g1, g2])
Trigger send and recv separately.
>>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])
>>> g.nodes['game'].data['h'] = torch.tensor([[1.]])
>>> g.send(g['follows'].edges(), fn.copy_src('h', 'm'), etype='follows')
>>> g.send(g['attracts'].edges(), fn.copy_src('h', 'm'), etype='attracts')
>>> g.multi_recv(g.nodes('user'),
>>> {'follows': fn.sum('m', 'h'), 'attracts': fn.sum('m', 'h')}, "sum")
>>> g.nodes['user'].data['h']
tensor([[0.],
[2.]])
Trigger “send” and “receive” in one call.
>>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])
>>> g.nodes['game'].data['h'] = torch.tensor([[1.]])
>>> g.multi_send_and_recv(
>>> {'follows': (g['follows'].edges(), fn.copy_src('h', 'm'), fn.sum('m', 'h')),
>>> 'attracts': (g['attracts'].edges(), fn.copy_src('h', 'm'), fn.sum('m', 'h'))},
>>> "sum")
>>> g.nodes['user'].data['h']
tensor([[0.],
[2.]])
"""
# infer receive node type
ntype = infer_ntype_from_dict(self, etype_dict)
dtid = self.get_ntype_id_from_dst(ntype)
# TODO(minjie): currently loop over each edge type and reuse the old schedule.
# Should replace it with fused kernel.
all_out = []
all_vs = []
merge_order = []
with ir.prog() as prog:
for etype, args in etype_dict.items():
etid = self.get_etype_id(etype)
stid, _ = self._graph.metagraph.find_edge(etid)
outframe = FrameRef(frame_like(self._node_frames[dtid]._frame))
args = pad_tuple(args, 4)
if args is None:
raise DGLError('Invalid per-type arguments. Should be '
'(edges, msg_func, reduce_func, [apply_node_func])')
edges, mfunc, rfunc, afunc = args
if isinstance(edges, tuple):
u, v = edges
u = utils.prepare_tensor(self, u, 'edges[0]')
v = utils.prepare_tensor(self, v, 'edges[1]')
eid = self.edge_ids(u, v, etype=etype)
else:
eid = utils.prepare_tensor(self, edges, 'edges')
u, v = self.find_edges(eid, etype=etype)
all_vs.append(v)
if len(u) == 0:
# no edges to be triggered
continue
u = utils.toindex(u, self._idtype_str)
v = utils.toindex(v, self._idtype_str)
eid = utils.toindex(eid, self._idtype_str)
scheduler.schedule_snr(AdaptedHeteroGraph(self, stid, dtid, etid),
(u, v, eid),
mfunc, rfunc, afunc,
inplace=inplace, outframe=outframe)
all_out.append(outframe)
merge_order.append(etid) # use edge type id as merge order hint
Runtime.run(prog)
# merge by cross_reducer
self._node_frames[dtid].update(merge_frames(all_out, cross_reducer, merge_order))
# apply
if apply_node_func is not None:
dstnodes = F.unique(F.cat(all_vs, 0))
self.apply_nodes(apply_node_func, dstnodes, ntype, inplace)
def pull(self, def pull(self,
v, v,
...@@ -3825,8 +3095,7 @@ class DGLHeteroGraph(object): ...@@ -3825,8 +3095,7 @@ class DGLHeteroGraph(object):
The edge type. Can be omitted if there is only one edge type The edge type. Can be omitted if there is only one edge type
in the graph. (Default: None) in the graph. (Default: None)
inplace: bool, optional inplace: bool, optional
If True, update will be done in place, but autograd will break. **DEPRECATED**. Must be False.
(Default: False)
Examples Examples
-------- --------
...@@ -3850,113 +3119,21 @@ class DGLHeteroGraph(object): ...@@ -3850,113 +3119,21 @@ class DGLHeteroGraph(object):
[1.], [1.],
[1.]]) [1.]])
""" """
# only one type of edges if inplace:
etid = self.get_etype_id(etype) raise DGLError('The `inplace` option is removed in v0.5.')
stid, dtid = self._graph.metagraph.find_edge(etid)
v = utils.prepare_tensor(self, v, 'v') v = utils.prepare_tensor(self, v, 'v')
if len(v) == 0: if len(v) == 0:
# no computation
return return
with ir.prog() as prog: etid = self.get_etype_id(etype)
v = utils.toindex(v, self._idtype_str) _, dtid = self._graph.metagraph.find_edge(etid)
scheduler.schedule_pull(AdaptedHeteroGraph(self, stid, dtid, etid), etype = self.canonical_etypes[etid]
v, g = self if etype is None else self[etype]
message_func, reduce_func, apply_node_func, # call message passing on subgraph
inplace=inplace) src, dst, eid = g.in_edges(v, form='all')
Runtime.run(prog) ndata = core.message_passing(_create_compute_graph(self, src, dst, eid, v),
message_func, reduce_func, apply_node_func)
def multi_pull(self, v, etype_dict, cross_reducer, apply_node_func=None, inplace=False): self._set_n_repr(dtid, v, ndata)
r"""Pull and receive messages of the given nodes along multiple edge types
and perform aggregation.
This is equivalent to :mod:`multi_send_and_recv <dgl.DGLHeteroGraph.multi_send_and_recv>`
on the incoming edges of ``v`` with the specified types.
Parameters
----------
v : int, container or tensor
The node(s) to be updated.
etype_dict : dict
Mapping an edge type (str or tuple of str) to the type specific
configuration (3-tuples). Each 3-tuple represents
(msg_func, reduce_func, apply_node_func):
* msg_func: callable
Message function on the edges. The function should be
an :mod:`Edge UDF <dgl.udf>`.
* reduce_func: callable
Reduce function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
* apply_node_func : callable, optional
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`. (Default: None)
cross_reducer : str
Cross type reducer. One of ``"sum"``, ``"min"``, ``"max"``, ``"mean"``, ``"stack"``.
apply_node_func : callable
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`. (Default: None)
inplace: bool, optional
If True, update will be done in place, but autograd will break.
(Default: False)
Examples
--------
>>> import dgl
>>> import dgl.function as fn
>>> import torch
Instantiate a heterograph.
>>> g1 = dgl.graph(([1, 1], [1, 0]), 'user', 'follows')
>>> g2 = dgl.bipartite(([0], [1]), 'game', 'attracts', 'user')
>>> g = dgl.hetero_from_relations([g1, g2])
Pull.
>>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])
>>> g.nodes['game'].data['h'] = torch.tensor([[1.]])
>>> g.multi_pull(1,
>>> {'follows': (fn.copy_src('h', 'm'), fn.sum('m', 'h')),
>>> 'attracts': (fn.copy_src('h', 'm'), fn.sum('m', 'h'))},
>>> "sum")
>>> g.nodes['user'].data['h']
tensor([[0.],
[3.]])
"""
v = utils.prepare_tensor(self, v, 'v')
if len(v) == 0:
return
# infer receive node type
ntype = infer_ntype_from_dict(self, etype_dict)
dtid = self.get_ntype_id_from_dst(ntype)
# TODO(minjie): currently loop over each edge type and reuse the old schedule.
# Should replace it with fused kernel.
all_out = []
merge_order = []
with ir.prog() as prog:
for etype, args in etype_dict.items():
etid = self.get_etype_id(etype)
stid, _ = self._graph.metagraph.find_edge(etid)
outframe = FrameRef(frame_like(self._node_frames[dtid]._frame))
args = pad_tuple(args, 3)
if args is None:
raise DGLError('Invalid per-type arguments. Should be '
'(msg_func, reduce_func, [apply_node_func])')
mfunc, rfunc, afunc = args
v = utils.toindex(v, self._idtype_str)
scheduler.schedule_pull(AdaptedHeteroGraph(self, stid, dtid, etid),
v,
mfunc, rfunc, afunc,
inplace=inplace, outframe=outframe)
all_out.append(outframe)
merge_order.append(etid) # use edge type id as merge order hint
Runtime.run(prog)
# merge by cross_reducer
self._node_frames[dtid].update(merge_frames(all_out, cross_reducer, merge_order))
# apply
if apply_node_func is not None:
self.apply_nodes(apply_node_func, v, ntype, inplace)
def push(self, def push(self,
u, u,
...@@ -3994,8 +3171,7 @@ class DGLHeteroGraph(object): ...@@ -3994,8 +3171,7 @@ class DGLHeteroGraph(object):
The edge type. Can be omitted if there is only one edge type The edge type. Can be omitted if there is only one edge type
in the graph. (Default: None) in the graph. (Default: None)
inplace: bool, optional inplace: bool, optional
If True, update will be done in place, but autograd will break. **DEPRECATED**. Must be False.
(Default: False)
Examples Examples
-------- --------
...@@ -4017,20 +3193,10 @@ class DGLHeteroGraph(object): ...@@ -4017,20 +3193,10 @@ class DGLHeteroGraph(object):
[0.], [0.],
[0.]]) [0.]])
""" """
# only one type of edges if inplace:
etid = self.get_etype_id(etype) raise DGLError('The `inplace` option is removed in v0.5.')
stid, dtid = self._graph.metagraph.find_edge(etid) edges = self.out_edges(u, form='eid', etype=etype)
self.send_and_recv(edges, message_func, reduce_func, apply_node_func, etype=etype)
u = utils.prepare_tensor(self, u, 'u')
if len(u) == 0:
return
with ir.prog() as prog:
u = utils.toindex(u, self._idtype_str)
scheduler.schedule_push(AdaptedHeteroGraph(self, stid, dtid, etid),
u,
message_func, reduce_func, apply_node_func,
inplace=inplace)
Runtime.run(prog)
def update_all(self, def update_all(self,
message_func, message_func,
...@@ -4085,15 +3251,16 @@ class DGLHeteroGraph(object): ...@@ -4085,15 +3251,16 @@ class DGLHeteroGraph(object):
[0.], [0.],
[3.]]) [3.]])
""" """
# only one type of edges
etid = self.get_etype_id(etype) etid = self.get_etype_id(etype)
stid, dtid = self._graph.metagraph.find_edge(etid) etype = self.canonical_etypes[etid]
_, dtid = self._graph.metagraph.find_edge(etid)
g = self if etype is None else self[etype]
ndata = core.message_passing(g, message_func, reduce_func, apply_node_func)
self._set_n_repr(dtid, ALL, ndata)
with ir.prog() as prog: #################################################################
scheduler.schedule_update_all(AdaptedHeteroGraph(self, stid, dtid, etid), # Message passing on heterograph
message_func, reduce_func, #################################################################
apply_node_func)
Runtime.run(prog)
def multi_update_all(self, etype_dict, cross_reducer, apply_node_func=None): def multi_update_all(self, etype_dict, cross_reducer, apply_node_func=None):
r"""Send and receive messages along all edges. r"""Send and receive messages along all edges.
...@@ -4124,11 +3291,7 @@ class DGLHeteroGraph(object): ...@@ -4124,11 +3291,7 @@ class DGLHeteroGraph(object):
Apply function on the nodes. The function should be Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`. (Default: None) a :mod:`Node UDF <dgl.udf>`. (Default: None)
inplace: bool, optional inplace: bool, optional
If True, update will be done in place, but autograd will break. **DEPRECATED**. Must be False.
(Default: False)
etype_dict : dict of callable
``update_all`` arguments per edge type.
Examples Examples
-------- --------
...@@ -4154,33 +3317,29 @@ class DGLHeteroGraph(object): ...@@ -4154,33 +3317,29 @@ class DGLHeteroGraph(object):
tensor([[0.], tensor([[0.],
[4.]]) [4.]])
""" """
# TODO(minjie): currently loop over each edge type and reuse the old schedule.
# Should replace it with fused kernel.
all_out = defaultdict(list) all_out = defaultdict(list)
merge_order = defaultdict(list) merge_order = defaultdict(list)
with ir.prog() as prog: for etype, args in etype_dict.items():
for etype, args in etype_dict.items(): etid = self.get_etype_id(etype)
etid = self.get_etype_id(etype) _, dtid = self._graph.metagraph.find_edge(etid)
stid, dtid = self._graph.metagraph.find_edge(etid) args = pad_tuple(args, 3)
outframe = FrameRef(frame_like(self._node_frames[dtid]._frame)) if args is None:
args = pad_tuple(args, 3) raise DGLError('Invalid arguments for edge type "{}". Should be '
if args is None: '(msg_func, reduce_func, [apply_node_func])'.format(etype))
raise DGLError('Invalid per-type arguments. Should be ' mfunc, rfunc, afunc = args
'(msg_func, reduce_func, [apply_node_func])') all_out[dtid].append(core.message_passing(self[etype], mfunc, rfunc, afunc))
mfunc, rfunc, afunc = args merge_order[dtid].append(etid) # use edge type id as merge order hint
scheduler.schedule_update_all(AdaptedHeteroGraph(self, stid, dtid, etid),
mfunc, rfunc, afunc,
outframe=outframe)
all_out[dtid].append(outframe)
merge_order[dtid].append(etid) # use edge type id as merge order hint
Runtime.run(prog)
for dtid, frames in all_out.items(): for dtid, frames in all_out.items():
# merge by cross_reducer # merge by cross_reducer
self._node_frames[dtid].update( self._node_frames[dtid].update(
merge_frames(frames, cross_reducer, merge_order[dtid])) reduce_dict_data(frames, cross_reducer, merge_order[dtid]))
# apply # apply
if apply_node_func is not None: if apply_node_func is not None:
self.apply_nodes(apply_node_func, ALL, self.ntypes[dtid], inplace=False) self.apply_nodes(apply_node_func, ALL, self.ntypes[dtid])
#################################################################
# Message propagation
#################################################################
def prop_nodes(self, def prop_nodes(self,
nodes_generator, nodes_generator,
...@@ -4393,12 +3552,6 @@ class DGLHeteroGraph(object): ...@@ -4393,12 +3552,6 @@ class DGLHeteroGraph(object):
e = utils.prepare_tensor(self, edges, 'edges') e = utils.prepare_tensor(self, edges, 'edges')
return F.boolean_mask(e, F.gather_row(mask, e)) return F.boolean_mask(e, F.gather_row(mask, e))
def readonly(self, readonly_state=True):
"""Deprecated: DGLGraph will always be mutable."""
dgl_warning('DGLGraph.is_readonly is deprecated in v0.5.\n'
'DGLGraph now always supports mutable operations like add_nodes'
' and add_edges.')
@property @property
def device(self): def device(self):
"""Get the device context of this graph. """Get the device context of this graph.
...@@ -4462,13 +3615,13 @@ class DGLHeteroGraph(object): ...@@ -4462,13 +3615,13 @@ class DGLHeteroGraph(object):
new_nframes = [] new_nframes = []
for nframe in self._node_frames: for nframe in self._node_frames:
new_feats = {k : F.copy_to(feat, device, **kwargs) for k, feat in nframe.items()} new_feats = {k : F.copy_to(feat, device, **kwargs) for k, feat in nframe.items()}
new_nframes.append(FrameRef(Frame(new_feats, num_rows=nframe.num_rows))) new_nframes.append(Frame(new_feats, num_rows=nframe.num_rows))
ret._node_frames = new_nframes ret._node_frames = new_nframes
new_eframes = [] new_eframes = []
for eframe in self._edge_frames: for eframe in self._edge_frames:
new_feats = {k : F.copy_to(feat, device, **kwargs) for k, feat in eframe.items()} new_feats = {k : F.copy_to(feat, device, **kwargs) for k, feat in eframe.items()}
new_eframes.append(FrameRef(Frame(new_feats, num_rows=eframe.num_rows))) new_eframes.append(Frame(new_feats, num_rows=eframe.num_rows))
ret._edge_frames = new_eframes ret._edge_frames = new_eframes
# 2. Copy misc info # 2. Copy misc info
...@@ -4871,6 +4024,112 @@ class DGLHeteroGraph(object): ...@@ -4871,6 +4024,112 @@ class DGLHeteroGraph(object):
""" """
return self.astype(F.int32) return self.astype(F.int32)
#################################################################
# DEPRECATED: from the old DGLGraph
#################################################################
def from_networkx(self, nx_graph, node_attrs=None, edge_attrs=None):
"""DEPRECATED: please use
``dgl.from_networkx(nx_graph, node_attrs, edge_attrs)``
which will return a new graph created from the networkx graph.
"""
raise DGLError('DGLGraph.from_networkx is deprecated. Please call the following\n\n'
'\t dgl.from_networkx(nx_graph, node_attrs, edge_attrs)\n\n'
', which creates a new DGLGraph from the networkx graph.')
def from_scipy_sparse_matrix(self, spmat, multigraph=None):
"""DEPRECATED: please use
``dgl.from_scipy(spmat)``
which will return a new graph created from the scipy matrix.
"""
raise DGLError('DGLGraph.from_scipy_sparse_matrix is deprecated. '
'Please call the following\n\n'
'\t dgl.from_scipy(spmat)\n\n'
', which creates a new DGLGraph from the scipy matrix.')
def register_apply_node_func(self, func):
"""Deprecated: please directly call :func:`apply_nodes` with ``func``
as argument.
"""
raise DGLError('DGLGraph.register_apply_node_func is deprecated.'
' Please directly call apply_nodes with func as the argument.')
def register_apply_edge_func(self, func):
"""Deprecated: please directly call :func:`apply_edges` with ``func``
as argument.
"""
raise DGLError('DGLGraph.register_apply_edge_func is deprecated.'
' Please directly call apply_edges with func as the argument.')
def register_message_func(self, func):
"""Deprecated: please directly call :func:`update_all` with ``func``
as argument.
"""
raise DGLError('DGLGraph.register_message_func is deprecated.'
' Please directly call update_all with func as the argument.')
def register_reduce_func(self, func):
"""Deprecated: please directly call :func:`update_all` with ``func``
as argument.
"""
raise DGLError('DGLGraph.register_reduce_func is deprecated.'
' Please directly call update_all with func as the argument.')
def group_apply_edges(self, group_by, func, edges=ALL, etype=None, inplace=False):
"""**DEPRECATED**: The API is removed in 0.5."""
raise DGLError('DGLGraph.group_apply_edges is removed in 0.5.')
def send(self, edges, message_func, etype=None):
"""Send messages along the given edges with the same edge type.
DEPRECATE: please use send_and_recv, update_all.
"""
raise DGLError('DGLGraph.send is deprecated. As a replacement, use DGLGraph.apply_edges\n'
' API to compute messages as edge data. Then use DGLGraph.send_and_recv\n'
' and set the message function as dgl.function.copy_e to conduct message\n'
' aggregation.')
def recv(self, v, reduce_func, apply_node_func=None, etype=None, inplace=False):
r"""Receive and reduce incoming messages and update the features of node(s) :math:`v`.
DEPRECATE: please use send_and_recv, update_all.
"""
raise DGLError('DGLGraph.recv is deprecated. As a replacement, use DGLGraph.apply_edges\n'
' API to compute messages as edge data. Then use DGLGraph.send_and_recv\n'
' and set the message function as dgl.function.copy_e to conduct message\n'
' aggregation.')
def multi_recv(self, v, reducer_dict, cross_reducer, apply_node_func=None, inplace=False):
r"""Receive messages from multiple edge types and perform aggregation.
DEPRECATE: please use multi_send_and_recv, multi_update_all.
"""
raise DGLError('DGLGraph.multi_recv is deprecated. As a replacement,\n'
' use DGLGraph.apply_edges API to compute messages as edge data.\n'
' Then use DGLGraph.multi_send_and_recv and set the message function\n'
' as dgl.function.copy_e to conduct message aggregation.')
def multi_send_and_recv(self, etype_dict, cross_reducer, apply_node_func=None, inplace=False):
r"""**DEPRECATED**: The API is removed in v0.5."""
raise DGLError('DGLGraph.multi_pull is removed in v0.5. As a replacement,\n'
' use DGLGraph.edge_subgraph to extract the subgraph first \n'
' and then call DGLGraph.multi_update_all.')
def multi_pull(self, v, etype_dict, cross_reducer, apply_node_func=None, inplace=False):
r"""**DEPRECATED**: The API is removed in v0.5."""
raise DGLError('DGLGraph.multi_pull is removed in v0.5. As a replacement,\n'
' use DGLGraph.edge_subgraph to extract the subgraph first \n'
' and then call DGLGraph.multi_update_all.')
def readonly(self, readonly_state=True):
"""Deprecated: DGLGraph will always be mutable."""
dgl_warning('DGLGraph.readonly is deprecated in v0.5.\n'
'DGLGraph now always supports mutable operations like add_nodes'
' and add_edges.')
############################################################ ############################################################
# Internal APIs # Internal APIs
...@@ -4956,35 +4215,6 @@ def find_src_dst_ntypes(ntypes, metagraph): ...@@ -4956,35 +4215,6 @@ def find_src_dst_ntypes(ntypes, metagraph):
dsttypes = {ntypes[tid] : tid for tid in dst} dsttypes = {ntypes[tid] : tid for tid in dst}
return srctypes, dsttypes return srctypes, dsttypes
def infer_ntype_from_dict(graph, etype_dict):
"""Infer node type from dictionary of edge type to values.
All the edge types in the dict must share the same destination node type
and the node type will be returned. Otherwise, throw error.
Parameters
----------
graph : DGLHeteroGraph
Graph
etype_dict : dict
Dictionary whose key is edge type
Returns
-------
str
Node type
"""
ntype = None
for ety in etype_dict:
_, _, dty = graph.to_canonical_etype(ety)
if ntype is None:
ntype = dty
if ntype != dty:
raise DGLError("Cannot infer destination node type from the dictionary. "
"A valid specification must make sure that all the edge "
"type keys share the same destination node type.")
return ntype
def pad_tuple(tup, length, pad_val=None): def pad_tuple(tup, length, pad_val=None):
"""Pad the given tuple to the given length. """Pad the given tuple to the given length.
...@@ -5000,13 +4230,13 @@ def pad_tuple(tup, length, pad_val=None): ...@@ -5000,13 +4230,13 @@ def pad_tuple(tup, length, pad_val=None):
else: else:
return tup + (pad_val,) * (length - len(tup)) return tup + (pad_val,) * (length - len(tup))
def merge_frames(frames, reducer, order=None): def reduce_dict_data(frames, reducer, order=None):
"""Merge input frames into one. Resolve conflict fields using reducer. """Merge tensor dictionaries into one. Resolve conflict fields using reducer.
Parameters Parameters
---------- ----------
frames : list[FrameRef] frames : list[dict[str, Tensor]]
Input frames Input tensor dictionaries
reducer : str reducer : str
One of "sum", "max", "min", "mean", "stack" One of "sum", "max", "min", "mean", "stack"
order : list[Int], optional order : list[Int], optional
...@@ -5018,7 +4248,7 @@ def merge_frames(frames, reducer, order=None): ...@@ -5018,7 +4248,7 @@ def merge_frames(frames, reducer, order=None):
Returns Returns
------- -------
FrameRef dict[str, Tensor]
Merged frame Merged frame
""" """
if len(frames) == 1 and reducer != 'stack': if len(frames) == 1 and reducer != 'stack':
...@@ -5040,10 +4270,10 @@ def merge_frames(frames, reducer, order=None): ...@@ -5040,10 +4270,10 @@ def merge_frames(frames, reducer, order=None):
'"sum", "max", "min", "mean" or "stack".') '"sum", "max", "min", "mean" or "stack".')
def merger(flist): def merger(flist):
return redfn(F.stack(flist, 0), 0) if len(flist) > 1 else flist[0] return redfn(F.stack(flist, 0), 0) if len(flist) > 1 else flist[0]
ret = FrameRef(frame_like(frames[0]._frame))
keys = set() keys = set()
for frm in frames: for frm in frames:
keys.update(frm.keys()) keys.update(frm.keys())
ret = {}
for k in keys: for k in keys:
flist = [] flist = []
for frm in frames: for frm in frames:
...@@ -5059,14 +4289,14 @@ def combine_frames(frames, ids): ...@@ -5059,14 +4289,14 @@ def combine_frames(frames, ids):
Parameters Parameters
---------- ----------
frames : List[FrameRef] frames : List[Frame]
List of frames List of frames
ids : List[int] ids : List[int]
List of frame IDs List of frame IDs
Returns Returns
------- -------
FrameRef Frame
The resulting frame The resulting frame
""" """
# find common columns and check if their schemes match # find common columns and check if their schemes match
...@@ -5087,7 +4317,7 @@ def combine_frames(frames, ids): ...@@ -5087,7 +4317,7 @@ def combine_frames(frames, ids):
# concatenate the columns # concatenate the columns
to_cat = lambda key: [frames[i][key] for i in ids if frames[i].num_rows > 0] to_cat = lambda key: [frames[i][key] for i in ids if frames[i].num_rows > 0]
cols = {key: F.cat(to_cat(key), dim=0) for key in schemes} cols = {key: F.cat(to_cat(key), dim=0) for key in schemes}
return FrameRef(Frame(cols)) return Frame(cols)
def combine_names(names, ids=None): def combine_names(names, ids=None):
"""Combine the selected names into one new name. """Combine the selected names into one new name.
...@@ -5109,118 +4339,6 @@ def combine_names(names, ids=None): ...@@ -5109,118 +4339,6 @@ def combine_names(names, ids=None):
selected = sorted([names[i] for i in ids]) selected = sorted([names[i] for i in ids])
return '+'.join(selected) return '+'.join(selected)
class AdaptedHeteroGraph(GraphAdapter):
"""Adapt DGLGraph to interface required by scheduler.
Parameters
----------
graph : DGLHeteroGraph
Graph
stid : int
Source node type id
dtid : int
Destination node type id
etid : int
Edge type id
"""
def __init__(self, graph, stid, dtid, etid):
self.graph = graph
self.stid = stid
self.dtid = dtid
self.etid = etid
@property
def gidx(self):
return self.graph._graph
def num_src(self):
"""Number of source nodes."""
return self.graph._graph.number_of_nodes(self.stid)
def num_dst(self):
"""Number of destination nodes."""
return self.graph._graph.number_of_nodes(self.dtid)
def num_edges(self):
"""Number of edges."""
return self.graph._graph.number_of_edges(self.etid)
@property
def srcframe(self):
"""Frame to store source node features."""
return self.graph._node_frames[self.stid]
@property
def dstframe(self):
"""Frame to store source node features."""
return self.graph._node_frames[self.dtid]
@property
def edgeframe(self):
"""Frame to store edge features."""
return self.graph._edge_frames[self.etid]
@property
def msgframe(self):
"""Frame to store messages."""
return self.graph._msg_frames[self.etid]
@property
def msgindicator(self):
"""Message indicator tensor."""
return self.graph._get_msg_index(self.etid)
@msgindicator.setter
def msgindicator(self, val):
"""Set new message indicator tensor."""
self.graph._set_msg_index(self.etid, val)
def in_edges(self, nodes):
nodes = nodes.tousertensor(self.graph.device)
src, dst, eid = self.graph._graph.in_edges(self.etid, nodes)
return (utils.toindex(src, self.graph._graph.dtype),
utils.toindex(dst, self.graph._graph.dtype),
utils.toindex(eid, self.graph._graph.dtype))
def out_edges(self, nodes):
nodes = nodes.tousertensor(self.graph.device)
src, dst, eid = self.graph._graph.out_edges(self.etid, nodes)
return (utils.toindex(src, self.graph._graph.dtype),
utils.toindex(dst, self.graph._graph.dtype),
utils.toindex(eid, self.graph._graph.dtype))
def edges(self, form):
src, dst, eid = self.graph._graph.edges(self.etid, form)
return (utils.toindex(src, self.graph._graph.dtype),
utils.toindex(dst, self.graph._graph.dtype),
utils.toindex(eid, self.graph._graph.dtype))
def get_immutable_gidx(self, ctx):
return self.graph._graph.get_unitgraph(self.etid, ctx)
def bits_needed(self):
return self.graph._graph.bits_needed(self.etid)
@property
def canonical_etype(self):
"""Canonical edge type."""
return self.graph.canonical_etypes[self.etid]
def check_same_dtype(graph_dtype, tensor):
"""check whether tensor's dtype is consistent with graph's dtype"""
if F.is_tensor(tensor):
if graph_dtype != F.reverse_data_type_dict[F.dtype(tensor)]:
raise utils.InconsistentDtypeException(
"Expect the input tensor to be the same as the graph index dtype({}), but got {}"
.format(graph_dtype, F.reverse_data_type_dict[F.dtype(tensor)]))
def check_idtype_dict(graph_dtype, tensor_dict):
"""check whether the dtypes of tensors in dict are consistent with graph's dtype"""
for _, v in tensor_dict.items():
check_same_dtype(graph_dtype, v)
class DGLBlock(DGLHeteroGraph): class DGLBlock(DGLHeteroGraph):
"""Subclass that signifies the graph is a block created from """Subclass that signifies the graph is a block created from
:func:`dgl.to_block`. :func:`dgl.to_block`.
...@@ -5251,4 +4369,69 @@ class DGLBlock(DGLHeteroGraph): ...@@ -5251,4 +4369,69 @@ class DGLBlock(DGLHeteroGraph):
return ret.format( return ret.format(
srcnode=nsrcnode_dict, dstnode=ndstnode_dict, edge=nedge_dict, meta=meta) srcnode=nsrcnode_dict, dstnode=ndstnode_dict, edge=nedge_dict, meta=meta)
def _create_compute_graph(graph, u, v, eid, recv_nodes=None):
"""Create a computation graph from the given edges.
The compute graph is a uni-directional bipartite graph with only
one edge type. Similar to subgraph extraction, it stores the original node IDs
in the srcdata[NID] and dstdata[NID] and extracts features accordingly.
Edges are not relabeled.
This function is typically used during message passing to generate
a graph that contains only the active set of edges.
Parameters
----------
graph : DGLGraph
The input graph.
u : Tensor
Src nodes.
v : Tensor
Dst nodes.
eid : Tensor
Edge IDs.
recv_nodes : Tensor
Nodes that receive messages. If None, it is equal to unique(v).
Otherwise, it must be a superset of v and can contain nodes
that have no incoming edges.
Returns
-------
DGLGraph
A computation graph.
"""
if len(u) == 0:
# The computation graph has no edge and will not trigger message
# passing. However, because of the apply node phase, we still construct
# an empty graph to continue.
unique_src = new_u = new_v = u
assert recv_nodes is not None
unique_dst, _ = utils.relabel(recv_nodes)
else:
# relabel u and v to starting from 0
unique_src, src_map = utils.relabel(u)
if recv_nodes is None:
unique_dst, dst_map = utils.relabel(v)
else:
unique_dst, dst_map = utils.relabel(recv_nodes)
new_u = F.gather_row(src_map, u)
new_v = F.gather_row(dst_map, v)
srctype, etype, dsttype = graph.canonical_etypes[0]
# create graph
hgidx = heterograph_index.create_unitgraph_from_coo(
2, len(unique_src), len(unique_dst), new_u, new_v, ['coo', 'csr', 'csc'])
# create frame
srcframe = graph._node_frames[graph.get_ntype_id(srctype)].subframe(unique_src)
srcframe[NID] = unique_src
dstframe = graph._node_frames[graph.get_ntype_id(dsttype)].subframe(unique_dst)
dstframe[NID] = unique_dst
eframe = graph._edge_frames[0].subframe(eid)
eframe[EID] = eid
return DGLHeteroGraph(hgidx, ([srctype], [dsttype]), [etype],
node_frames=[srcframe, dstframe],
edge_frames=[eframe])
_init_api("dgl.heterograph") _init_api("dgl.heterograph")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment