Unverified Commit b9631912 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Frame] Refactor frame. (#85)

* refactor frame codes

* fix unit test

* fix gcn example

* minor doc/message changes

* raise errors for non-exist columns in FrameRef; sanity check when append

* fix unittest; change error msg

* Add warning for none initializer

* fix unittest

* use warnings package
parent 66261aee
...@@ -16,10 +16,10 @@ from dgl import DGLGraph ...@@ -16,10 +16,10 @@ from dgl import DGLGraph
from dgl.data import register_data_args, load_data from dgl.data import register_data_args, load_data
def gcn_msg(src, edge): def gcn_msg(src, edge):
return src return {'m' : src['h']}
def gcn_reduce(node, msgs): def gcn_reduce(node, msgs):
return torch.sum(msgs, 1) return {'h' : torch.sum(msgs['m'], 1)}
class NodeApplyModule(nn.Module): class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None): def __init__(self, in_feats, out_feats, activation=None):
...@@ -28,10 +28,10 @@ class NodeApplyModule(nn.Module): ...@@ -28,10 +28,10 @@ class NodeApplyModule(nn.Module):
self.activation = activation self.activation = activation
def forward(self, node): def forward(self, node):
h = self.linear(node) h = self.linear(node['h'])
if self.activation: if self.activation:
h = self.activation(h) h = self.activation(h)
return h return {'h' : h}
class GCN(nn.Module): class GCN(nn.Module):
def __init__(self, def __init__(self,
...@@ -54,14 +54,14 @@ class GCN(nn.Module): ...@@ -54,14 +54,14 @@ class GCN(nn.Module):
self.layers.append(NodeApplyModule(n_hidden, n_classes)) self.layers.append(NodeApplyModule(n_hidden, n_classes))
def forward(self, features): def forward(self, features):
self.g.set_n_repr(features) self.g.set_n_repr({'h' : features})
for layer in self.layers: for layer in self.layers:
# apply dropout # apply dropout
if self.dropout: if self.dropout:
val = F.dropout(self.g.get_n_repr(), p=self.dropout) g.apply_nodes(apply_node_func=
self.g.set_n_repr(val) lambda node: F.dropout(node['h'], p=self.dropout))
self.g.update_all(gcn_msg, gcn_reduce, layer) self.g.update_all(gcn_msg, gcn_reduce, layer)
return self.g.pop_n_repr() return self.g.pop_n_repr('h')
def main(args): def main(args):
# load and preprocess dataset # load and preprocess dataset
......
...@@ -23,10 +23,10 @@ class NodeApplyModule(nn.Module): ...@@ -23,10 +23,10 @@ class NodeApplyModule(nn.Module):
self.activation = activation self.activation = activation
def forward(self, node): def forward(self, node):
h = self.linear(node) h = self.linear(node['h'])
if self.activation: if self.activation:
h = self.activation(h) h = self.activation(h)
return h return {'h' : h}
class GCN(nn.Module): class GCN(nn.Module):
def __init__(self, def __init__(self,
...@@ -49,14 +49,16 @@ class GCN(nn.Module): ...@@ -49,14 +49,16 @@ class GCN(nn.Module):
self.layers.append(NodeApplyModule(n_hidden, n_classes)) self.layers.append(NodeApplyModule(n_hidden, n_classes))
def forward(self, features): def forward(self, features):
self.g.set_n_repr(features) self.g.set_n_repr({'h' : features})
for layer in self.layers: for layer in self.layers:
# apply dropout # apply dropout
if self.dropout: if self.dropout:
val = F.dropout(self.g.get_n_repr(), p=self.dropout) g.apply_nodes(apply_node_func=
self.g.set_n_repr(val) lambda node: F.dropout(node['h'], p=self.dropout))
self.g.update_all(fn.copy_src(), fn.sum(), layer) self.g.update_all(fn.copy_src(src='h', out='m'),
return self.g.pop_n_repr() fn.sum(msgs='m', out='h'),
layer)
return self.g.pop_n_repr('h')
def main(args): def main(args):
# load and preprocess dataset # load and preprocess dataset
......
...@@ -93,23 +93,24 @@ def get_context(arr): ...@@ -93,23 +93,24 @@ def get_context(arr):
return TVMContext( return TVMContext(
TVMContext.STR2MASK[arr.device.type], arr.device.index) TVMContext.STR2MASK[arr.device.type], arr.device.index)
def _typestr(arr_dtype): def get_tvmtype(arr):
arr_dtype = arr.dtype
if arr_dtype in (th.float16, th.half): if arr_dtype in (th.float16, th.half):
return 'float16' return TVMType('float16')
elif arr_dtype in (th.float32, th.float): elif arr_dtype in (th.float32, th.float):
return 'float32' return TVMType('float32')
elif arr_dtype in (th.float64, th.double): elif arr_dtype in (th.float64, th.double):
return 'float64' return TVMType('float64')
elif arr_dtype in (th.int16, th.short): elif arr_dtype in (th.int16, th.short):
return 'int16' return TVMType('int16')
elif arr_dtype in (th.int32, th.int): elif arr_dtype in (th.int32, th.int):
return 'int32' return TVMType('int32')
elif arr_dtype in (th.int64, th.long): elif arr_dtype in (th.int64, th.long):
return 'int64' return TVMType('int64')
elif arr_dtype == th.int8: elif arr_dtype == th.int8:
return 'int8' return TVMType('int8')
elif arr_dtype == th.uint8: elif arr_dtype == th.uint8:
return 'uint8' return TVMType('uint8')
else: else:
raise RuntimeError('Unsupported data type:', arr_dtype) raise RuntimeError('Unsupported data type:', arr_dtype)
...@@ -130,20 +131,6 @@ def zerocopy_from_numpy(np_data): ...@@ -130,20 +131,6 @@ def zerocopy_from_numpy(np_data):
"""Return a tensor that shares the numpy data.""" """Return a tensor that shares the numpy data."""
return th.from_numpy(np_data) return th.from_numpy(np_data)
'''
data = arr_data
assert data.is_contiguous()
arr = TVMArray()
shape = c_array(tvm_shape_index_t, tuple(data.shape))
arr.data = ctypes.cast(data.data_ptr(), ctypes.c_void_p)
arr.shape = shape
arr.strides = None
arr.dtype = TVMType(_typestr(data.dtype))
arr.ndim = len(shape)
arr.ctx = get_context(data)
return arr
'''
def nonzero_1d(arr): def nonzero_1d(arr):
"""Return a 1D tensor with nonzero element indices in a 1D vector""" """Return a 1D tensor with nonzero element indices in a 1D vector"""
assert arr.dim() == 1 assert arr.dim() == 1
......
"""Module for base types and utilities.""" """Module for base types and utilities."""
from __future__ import absolute_import
import warnings
from ._ffi.base import DGLError
# A special argument for selecting all nodes/edges. # A special argument for selecting all nodes/edges.
ALL = "__ALL__" ALL = "__ALL__"
...@@ -8,3 +13,5 @@ def is_all(arg): ...@@ -8,3 +13,5 @@ def is_all(arg):
__MSG__ = "__MSG__" __MSG__ = "__MSG__"
__REPR__ = "__REPR__" __REPR__ = "__REPR__"
dgl_warning = warnings.warn
"""Columnar storage for graph attributes.""" """Columnar storage for DGLGraph."""
from __future__ import absolute_import from __future__ import absolute_import
from collections import MutableMapping from collections import MutableMapping
...@@ -6,178 +6,598 @@ import numpy as np ...@@ -6,178 +6,598 @@ import numpy as np
from . import backend as F from . import backend as F
from .backend import Tensor from .backend import Tensor
from .base import DGLError, dgl_warning
from . import utils from . import utils
class Scheme(object):
"""The column scheme.
Parameters
----------
shape : tuple of int
The feature shape.
dtype : TVMType
The feature data type.
"""
def __init__(self, shape, dtype):
self.shape = shape
self.dtype = dtype
def __repr__(self):
return '{shape=%s, dtype=%s}' % (repr(self.shape), repr(self.dtype))
def __eq__(self, other):
return self.shape == other.shape and self.dtype == other.dtype
def __ne__(self, other):
return not self.__eq__(other)
@staticmethod
def infer_scheme(tensor):
"""Infer the scheme of the given tensor."""
return Scheme(tuple(F.shape(tensor)[1:]), F.get_tvmtype(tensor))
class Column(object):
"""A column is a compact store of features of multiple nodes/edges.
Currently, we use one dense tensor to batch all the feature tensors
together (along the first dimension).
Parameters
----------
data : Tensor
The initial data of the column.
scheme : Scheme, optional
The scheme of the column. Will be inferred if not provided.
"""
def __init__(self, data, scheme=None):
self.data = data
self.scheme = scheme if scheme else Scheme.infer_scheme(data)
def __len__(self):
"""The column length."""
return F.shape(self.data)[0]
def __getitem__(self, idx):
"""Return the feature data given the index.
Parameters
----------
idx : utils.Index
The index.
Returns
-------
Tensor
The feature data
"""
user_idx = idx.tousertensor(F.get_context(self.data))
return F.gather_row(self.data, user_idx)
def __setitem__(self, idx, feats):
"""Update the feature data given the index.
The update is performed out-placely so it can be used in autograd mode.
For inplace write, please use ``update``.
Parameters
----------
idx : utils.Index
The index.
feats : Tensor
The new features.
"""
self.update(idx, feats, inplace=False)
def update(self, idx, feats, inplace):
"""Update the feature data given the index.
Parameters
----------
idx : utils.Index
The index.
feats : Tensor
The new features.
inplace : bool
If true, use inplace write.
"""
feat_scheme = Scheme.infer_scheme(feats)
if feat_scheme != self.scheme:
raise DGLError("Cannot update column of scheme %s using feature of scheme %s."
% (feat_scheme, self.scheme))
user_idx = idx.tousertensor(F.get_context(self.data))
if inplace:
# TODO(minjie): do not use [] operator directly
self.data[user_idx] = feats
else:
self.data = F.scatter_row(self.data, user_idx, feats)
@staticmethod
def create(data):
"""Create a new column using the given data."""
if isinstance(data, Column):
return Column(data.data)
else:
return Column(data)
class Frame(MutableMapping): class Frame(MutableMapping):
"""The columnar storage for node/edge features.
The frame is a dictionary from feature fields to feature columns.
All columns should have the same number of rows (i.e. the same first dimension).
Parameters
----------
data : dict-like, optional
The frame data in dictionary. If the provided data is another frame,
this frame will NOT share columns with the given frame. So any out-place
update on one will not reflect to the other. The inplace update will
be seen by both. This follows the semantic of python's container.
"""
def __init__(self, data=None): def __init__(self, data=None):
if data is None: if data is None:
self._columns = dict() self._columns = dict()
self._num_rows = 0 self._num_rows = 0
else: else:
self._columns = dict(data) # Note that we always create a new column for the given data.
self._num_rows = F.shape(list(data.values())[0])[0] # This avoids two frames accidentally sharing the same column.
for k, v in data.items(): self._columns = {k : Column.create(v) for k, v in data.items()}
assert F.shape(v)[0] == self._num_rows if len(self._columns) != 0:
self._num_rows = len(next(iter(self._columns.values())))
else:
self._num_rows = 0
# sanity check
for name, col in self._columns.items():
if len(col) != self._num_rows:
raise DGLError('Expected all columns to have same # rows (%d), '
'got %d on %r.' % (self._num_rows, len(col), name))
# Initializer for empty values. Initializer is a callable.
# If is none, then a warning will be raised
# in the first call and zero initializer will be used later.
self._initializer = None
def set_initializer(self, initializer):
"""Set the initializer for empty values.
Initializer is a callable that returns a tensor given the shape and data type.
Parameters
----------
initializer : callable
The initializer.
"""
self._initializer = initializer
@property
def initializer(self):
"""Return the initializer of this frame."""
return self._initializer
@property @property
def schemes(self): def schemes(self):
return set(self._columns.keys()) """Return a dictionary of column name to column schemes."""
return {k : col.scheme for k, col in self._columns.items()}
@property @property
def num_columns(self): def num_columns(self):
"""Return the number of columns in this frame."""
return len(self._columns) return len(self._columns)
@property @property
def num_rows(self): def num_rows(self):
"""Return the number of rows in this frame."""
return self._num_rows return self._num_rows
def __contains__(self, key): def __contains__(self, name):
return key in self._columns """Return true if the given column name exists."""
return name in self._columns
def __getitem__(self, key):
# get column def __getitem__(self, name):
return self._columns[key] """Return the column of the given name.
def __setitem__(self, key, val): Parameters
# set column ----------
self.add_column(key, val) name : str
The column name.
def __delitem__(self, key):
# delete column Returns
del self._columns[key] -------
Column
The column.
"""
return self._columns[name]
def __setitem__(self, name, data):
"""Update the whole column.
Parameters
----------
name : str
The column name.
col : Column or data convertible to Column
The column data.
"""
self.update_column(name, data)
def __delitem__(self, name):
"""Delete the whole column.
Parameters
----------
name : str
The column name.
"""
del self._columns[name]
if len(self._columns) == 0: if len(self._columns) == 0:
self._num_rows = 0 self._num_rows = 0
def add_column(self, name, col): def add_column(self, name, scheme, ctx):
"""Add a new column to the frame.
The frame will be initialized by the initializer.
Parameters
----------
name : str
The column name.
scheme : Scheme
The column scheme.
ctx : TVMContext
The column context.
"""
if name in self:
dgl_warning('Column "%s" already exists. Ignore adding this column again.' % name)
return
if self.num_rows == 0:
raise DGLError('Cannot add column "%s" using column schemes because'
' number of rows is unknown. Make sure there is at least'
' one column in the frame so number of rows can be inferred.')
if self.initializer is None:
dgl_warning('Initializer is not set. Use zero initializer instead.'
' To suppress this warning, use `set_initializer` to'
' explicitly specify which initializer to use.')
# TODO(minjie): handle data type
self.set_initializer(lambda shape, dtype : F.zeros(shape))
# TODO(minjie): directly init data on the targer device.
init_data = self.initializer((self.num_rows,) + scheme.shape, scheme.dtype)
init_data = F.to_context(init_data, ctx)
self._columns[name] = Column(init_data, scheme)
def update_column(self, name, data):
"""Add or replace the column with the given name and data.
Parameters
----------
name : str
The column name.
data : Column or data convertible to Column
The column data.
"""
col = Column.create(data)
if self.num_columns == 0: if self.num_columns == 0:
self._num_rows = F.shape(col)[0] self._num_rows = len(col)
else: elif len(col) != self._num_rows:
assert F.shape(col)[0] == self._num_rows raise DGLError('Expected data to have %d rows, got %d.' %
(self._num_rows, len(col)))
self._columns[name] = col self._columns[name] = col
def append(self, other): def append(self, other):
"""Append another frame's data into this frame.
If the current frame is empty, it will just use the columns of the
given frame. Otherwise, the given data should contain all the
column keys of this frame.
Parameters
----------
other : Frame or dict-like
The frame data to be appended.
"""
if not isinstance(other, Frame):
other = Frame(other)
if len(self._columns) == 0: if len(self._columns) == 0:
for key, col in other.items(): for key, col in other.items():
self._columns[key] = col self._columns[key] = col
self._num_rows = other.num_rows
else: else:
for key, col in other.items(): for key, col in other.items():
self._columns[key] = F.pack([self[key], col]) sch = self._columns[key].scheme
# TODO(minjie): sanity check for num_rows other_sch = col.scheme
if len(self._columns) != 0: if sch != other_sch:
self._num_rows = F.shape(list(self._columns.values())[0])[0] raise DGLError("Cannot append column of scheme %s to column of scheme %s."
% (other_scheme, sch))
self._columns[key].data = F.pack(
[self._columns[key].data, col.data])
self._num_rows += other.num_rows
def clear(self): def clear(self):
"""Clear this frame. Remove all the columns."""
self._columns = {} self._columns = {}
self._num_rows = 0 self._num_rows = 0
def __iter__(self): def __iter__(self):
"""Return an iterator of columns."""
return iter(self._columns) return iter(self._columns)
def __len__(self): def __len__(self):
"""Return the number of columns."""
return self.num_columns return self.num_columns
def keys(self):
"""Return the keys."""
return self._columns.keys()
class FrameRef(MutableMapping): class FrameRef(MutableMapping):
"""Frame reference """Reference object to a frame on a subset of rows.
Parameters Parameters
---------- ----------
frame : dgl.frame.Frame frame : Frame, optional
The underlying frame. The underlying frame. If not given, the reference will point to a
index : iterable of int new empty frame.
The rows that are referenced in the underlying frame. index : iterable of int, optional
The rows that are referenced in the underlying frame. If not given,
the whole frame is referenced. The index should be distinct (no
duplication is allowed).
""" """
def __init__(self, frame=None, index=None): def __init__(self, frame=None, index=None):
self._frame = frame if frame is not None else Frame() self._frame = frame if frame is not None else Frame()
if index is None: if index is None:
self._index_data = slice(0, self._frame.num_rows) self._index_data = slice(0, self._frame.num_rows)
else: else:
# check no duplication # TODO(minjie): check no duplication
assert len(index) == len(np.unique(index))
self._index_data = index self._index_data = index
self._index = None self._index = None
@property @property
def schemes(self): def schemes(self):
"""Return the frame schemes.
Returns
-------
dict of str to Scheme
The frame schemes.
"""
return self._frame.schemes return self._frame.schemes
@property @property
def num_columns(self): def num_columns(self):
"""Return the number of columns in the referred frame."""
return self._frame.num_columns return self._frame.num_columns
@property @property
def num_rows(self): def num_rows(self):
"""Return the number of rows referred."""
if isinstance(self._index_data, slice): if isinstance(self._index_data, slice):
# NOTE: we are assuming that the index is a slice ONLY IF
# index=None during construction.
# As such, start is always 0, and step is always 1.
return self._index_data.stop return self._index_data.stop
else: else:
return len(self._index_data) return len(self._index_data)
def __contains__(self, key): def set_initializer(self, initializer):
return key in self._frame """Set the initializer for empty values.
Initializer is a callable that returns a tensor given the shape and data type.
Parameters
----------
initializer : callable
The initializer.
"""
self._frame.set_initializer(initializer)
def index(self):
"""Return the index object.
Returns
-------
utils.Index
The index.
"""
if self._index is None:
if self.is_contiguous():
self._index = utils.toindex(
F.arange(self._index_data.stop, dtype=F.int64))
else:
self._index = utils.toindex(self._index_data)
return self._index
def __contains__(self, name):
"""Return whether the column name exists."""
return name in self._frame
def __iter__(self):
"""Return the iterator of the columns."""
return iter(self._frame)
def __len__(self):
"""Return the number of columns."""
return self.num_columns
def keys(self):
"""Return the keys."""
return self._frame.keys()
def __getitem__(self, key): def __getitem__(self, key):
"""Get data from the frame.
If the provided key is string, the corresponding column data will be returned.
If the provided key is an index, the corresponding rows will be selected. The
returned rows are saved in a lazy dictionary so only the real selection happens
when the explicit column name is provided.
Examples (using pytorch)
------------------------
>>> # create a frame of two columns and five rows
>>> f = Frame({'c1' : torch.zeros([5, 2]), 'c2' : torch.ones([5, 2])})
>>> fr = FrameRef(f)
>>> # select the row 1 and 2, the returned `rows` is a lazy dictionary.
>>> rows = fr[Index([1, 2])]
>>> rows['c1'] # only select rows for 'c1' column; 'c2' column is not sliced.
Parameters
----------
key : str or utils.Index
The key.
Returns
-------
Tensor or lazy dict or tensors
Depends on whether it is a column selection or row selection.
"""
if isinstance(key, str): if isinstance(key, str):
return self.get_column(key) return self.select_column(key)
else: else:
return self.select_rows(key) return self.select_rows(key)
def select_rows(self, query): def select_column(self, name):
rowids = self._getrowid(query) """Return the column of the given name.
def _lazy_select(key):
idx = rowids.tousertensor(F.get_context(self._frame[key])) If only part of the rows are referenced, the fetching the whole column will
return F.gather_row(self._frame[key], idx) also slice out the referenced rows.
return utils.LazyDict(_lazy_select, keys=self.schemes)
Parameters
----------
name : str
The column name.
def get_column(self, name): Returns
-------
Tensor
The column data.
"""
col = self._frame[name] col = self._frame[name]
if self.is_span_whole_column(): if self.is_span_whole_column():
return col return col.data
else: else:
idx = self.index().tousertensor(F.get_context(col)) return col[self.index()]
return F.gather_row(col, idx)
def select_rows(self, query):
"""Return the rows given the query.
Parameters
----------
query : utils.Index
The rows to be selected.
Returns
-------
utils.LazyDict
The lazy dictionary from str to the selected data.
"""
rowids = self._getrowid(query)
return utils.LazyDict(lambda key: self._frame[key][rowids], keys=self.keys())
def __setitem__(self, key, val): def __setitem__(self, key, val):
"""Update the data in the frame.
If the provided key is string, the corresponding column data will be updated.
The provided value should be one tensor that have the same scheme and length
as the column.
If the provided key is an index, the corresponding rows will be updated. The
value provided should be a dictionary of string to the data of each column.
All updates are performed out-placely to be work with autograd. For inplace
update, use ``update_column`` or ``update_rows``.
Parameters
----------
key : str or utils.Index
The key.
val : Tensor or dict of tensors
The value.
"""
if isinstance(key, str): if isinstance(key, str):
self.add_column(key, val) self.update_column(key, val, inplace=False)
else: else:
self.update_rows(key, val) self.update_rows(key, val, inplace=False)
def add_column(self, name, col, inplace=False): def update_column(self, name, data, inplace):
shp = F.shape(col) """Update the column.
If this frameref spans the whole column of the underlying frame, this is
equivalent to update the column of the frame.
If this frameref only points to part of the rows, then update the column
here will correspond to update part of the column in the frame. Raise error
if the given column name does not exist.
Parameters
----------
name : str
The column name.
data : Tensor
The update data.
inplace : bool
True if the update is performed inplacely.
"""
if self.is_span_whole_column(): if self.is_span_whole_column():
col = Column.create(data)
if self.num_columns == 0: if self.num_columns == 0:
self._index_data = slice(0, shp[0]) # the frame is empty
self._index_data = slice(0, len(col))
self._clear_cache() self._clear_cache()
assert shp[0] == self.num_rows
self._frame[name] = col self._frame[name] = col
else: else:
colctx = F.get_context(col) if name not in self._frame:
if name in self._frame: feat_shape = F.shape(data)[1:]
fcol = self._frame[name] feat_dtype = F.get_tvmtype(data)
else: ctx = F.get_context(data)
fcol = F.zeros((self._frame.num_rows,) + shp[1:]) self._frame.add_column(name, Scheme(feat_shape, feat_dtype), ctx)
fcol = F.to_context(fcol, colctx) #raise DGLError('Cannot update column. Column "%s" does not exist.'
idx = self.index().tousertensor(colctx) # ' Did you forget to init the column using `set_n_repr`'
if inplace: # ' or `set_e_repr`?' % name)
self._frame[name] = fcol fcol = self._frame[name]
self._frame[name][idx] = col fcol.update(self.index(), data, inplace)
else:
newfcol = F.scatter_row(fcol, idx, col) def update_rows(self, query, data, inplace):
self._frame[name] = newfcol """Update the rows.
def update_rows(self, query, other, inplace=False): If the provided data has new column, it will be added to the frame.
See Also
--------
``update_column``
Parameters
----------
query : utils.Index
The rows to be updated.
data : dict-like
The row data.
inplace : bool
True if the update is performed inplacely.
"""
rowids = self._getrowid(query) rowids = self._getrowid(query)
for key, col in other.items(): for key, col in data.items():
if key not in self: if key not in self:
# add new column # add new column
tmpref = FrameRef(self._frame, rowids) tmpref = FrameRef(self._frame, rowids)
tmpref.add_column(key, col, inplace) tmpref.update_column(key, col, inplace)
idx = rowids.tousertensor(F.get_context(self._frame[key])) #raise DGLError('Cannot update rows. Column "%s" does not exist.'
if inplace: # ' Did you forget to init the column using `set_n_repr`'
self._frame[key][idx] = col # ' or `set_e_repr`?' % key)
else: else:
self._frame[key] = F.scatter_row(self._frame[key], idx, col) self._frame[key].update(rowids, col, inplace)
def __delitem__(self, key): def __delitem__(self, key):
"""Delete data in the frame.
If the provided key is a string, the corresponding column will be deleted.
If the provided key is an index object, the corresponding rows will be deleted.
Please note that "deleted" rows are not really deleted, but simply removed
in the reference. As a result, if two FrameRefs point to the same Frame, deleting
from one ref will not relect on the other. By contrast, deleting columns is real.
Parameters
----------
key : str or utils.Index
The key.
"""
if isinstance(key, str): if isinstance(key, str):
del self._frame[key] del self._frame[key]
if len(self._frame) == 0: if len(self._frame) == 0:
...@@ -186,7 +606,18 @@ class FrameRef(MutableMapping): ...@@ -186,7 +606,18 @@ class FrameRef(MutableMapping):
self.delete_rows(key) self.delete_rows(key)
def delete_rows(self, query): def delete_rows(self, query):
query = F.asnumpy(query) """Delete rows.
Please note that "deleted" rows are not really deleted, but simply removed
in the reference. As a result, if two FrameRefs point to the same Frame, deleting
from one ref will not relect on the other. By contrast, deleting columns is real.
Parameters
----------
query : utils.Index
The rows to be deleted.
"""
query = query.tolist()
if isinstance(self._index_data, slice): if isinstance(self._index_data, slice):
self._index_data = list(range(self._index_data.start, self._index_data.stop)) self._index_data = list(range(self._index_data.start, self._index_data.stop))
arr = np.array(self._index_data, dtype=np.int32) arr = np.array(self._index_data, dtype=np.int32)
...@@ -194,6 +625,13 @@ class FrameRef(MutableMapping): ...@@ -194,6 +625,13 @@ class FrameRef(MutableMapping):
self._clear_cache() self._clear_cache()
def append(self, other): def append(self, other):
"""Append another frame into this one.
Parameters
----------
other : dict of str to tensor
The data to be appended.
"""
span_whole = self.is_span_whole_column() span_whole = self.is_span_whole_column()
contiguous = self.is_contiguous() contiguous = self.is_contiguous()
old_nrows = self._frame.num_rows old_nrows = self._frame.num_rows
...@@ -208,24 +646,23 @@ class FrameRef(MutableMapping): ...@@ -208,24 +646,23 @@ class FrameRef(MutableMapping):
self._clear_cache() self._clear_cache()
def clear(self): def clear(self):
"""Clear the frame."""
self._frame.clear() self._frame.clear()
self._index_data = slice(0, 0) self._index_data = slice(0, 0)
self._clear_cache() self._clear_cache()
def __iter__(self):
return iter(self._frame)
def __len__(self):
return self.num_columns
def is_contiguous(self): def is_contiguous(self):
# NOTE: this check could have false negative """Return whether this refers to a contiguous range of rows."""
# NOTE: this check could have false negatives and false positives
# (step other than 1)
return isinstance(self._index_data, slice) return isinstance(self._index_data, slice)
def is_span_whole_column(self): def is_span_whole_column(self):
"""Return whether this refers to all the rows."""
return self.is_contiguous() and self.num_rows == self._frame.num_rows return self.is_contiguous() and self.num_rows == self._frame.num_rows
def _getrowid(self, query): def _getrowid(self, query):
"""Internal function to convert from the local row ids to the row ids of the frame."""
if self.is_contiguous(): if self.is_contiguous():
# shortcut for identical mapping # shortcut for identical mapping
return query return query
...@@ -233,16 +670,8 @@ class FrameRef(MutableMapping): ...@@ -233,16 +670,8 @@ class FrameRef(MutableMapping):
idxtensor = self.index().tousertensor() idxtensor = self.index().tousertensor()
return utils.toindex(F.gather_row(idxtensor, query.tousertensor())) return utils.toindex(F.gather_row(idxtensor, query.tousertensor()))
def index(self):
if self._index is None:
if self.is_contiguous():
self._index = utils.toindex(
F.arange(self._index_data.stop, dtype=F.int64))
else:
self._index = utils.toindex(self._index_data)
return self._index
def _clear_cache(self): def _clear_cache(self):
"""Internal function to clear the cached object."""
self._index_tensor = None self._index_tensor = None
def merge_frames(frames, indices, max_index, reduce_func): def merge_frames(frames, indices, max_index, reduce_func):
...@@ -267,6 +696,8 @@ def merge_frames(frames, indices, max_index, reduce_func): ...@@ -267,6 +696,8 @@ def merge_frames(frames, indices, max_index, reduce_func):
merged : FrameRef merged : FrameRef
The merged frame. The merged frame.
""" """
# TODO(minjie)
assert False, 'Buggy code, disabled for now.'
assert reduce_func == 'sum' assert reduce_func == 'sum'
assert len(frames) > 0 assert len(frames) > 0
schemes = frames[0].schemes schemes = frames[0].schemes
......
...@@ -504,25 +504,49 @@ class DGLGraph(object): ...@@ -504,25 +504,49 @@ class DGLGraph(object):
self._msg_graph.add_nodes(self._graph.number_of_nodes()) self._msg_graph.add_nodes(self._graph.number_of_nodes())
def node_attr_schemes(self): def node_attr_schemes(self):
"""Return the node attribute schemes. """Return the node feature schemes.
Returns Returns
------- -------
iterable dict of str to schemes
The set of attribute names The schemes of node feature columns.
""" """
return self._node_frame.schemes return self._node_frame.schemes
def edge_attr_schemes(self): def edge_attr_schemes(self):
"""Return the edge attribute schemes. """Return the edge feature schemes.
Returns Returns
------- -------
iterable dict of str to schemes
The set of attribute names The schemes of edge feature columns.
""" """
return self._edge_frame.schemes return self._edge_frame.schemes
def set_n_initializer(self, initializer):
"""Set the initializer for empty node features.
Initializer is a callable that returns a tensor given the shape and data type.
Parameters
----------
initializer : callable
The initializer.
"""
self._node_frame.set_initializer(initializer)
def set_e_initializer(self, initializer):
"""Set the initializer for empty edge features.
Initializer is a callable that returns a tensor given the shape and data type.
Parameters
----------
initializer : callable
The initializer.
"""
self._edge_frame.set_initializer(initializer)
def set_n_repr(self, hu, u=ALL, inplace=False): def set_n_repr(self, hu, u=ALL, inplace=False):
"""Set node(s) representation. """Set node(s) representation.
...@@ -534,12 +558,17 @@ class DGLGraph(object): ...@@ -534,12 +558,17 @@ class DGLGraph(object):
Dictionary type is also supported for `hu`. In this case, each item Dictionary type is also supported for `hu`. In this case, each item
will be treated as separate attribute of the nodes. will be treated as separate attribute of the nodes.
All update will be done out-placely to work with autograd unless the inplace
flag is true.
Parameters Parameters
---------- ----------
hu : tensor or dict of tensor hu : tensor or dict of tensor
Node representation. Node representation.
u : node, container or tensor u : node, container or tensor
The node(s). The node(s).
inplace : bool
True if the update is done inplacely
""" """
# sanity check # sanity check
if is_all(u): if is_all(u):
...@@ -607,7 +636,7 @@ class DGLGraph(object): ...@@ -607,7 +636,7 @@ class DGLGraph(object):
""" """
return self._node_frame.pop(key) return self._node_frame.pop(key)
def set_e_repr(self, h_uv, u=ALL, v=ALL): def set_e_repr(self, h_uv, u=ALL, v=ALL, inplace=False):
"""Set edge(s) representation. """Set edge(s) representation.
To set multiple edge representations at once, pass `u` and `v` with tensors or To set multiple edge representations at once, pass `u` and `v` with tensors or
...@@ -618,6 +647,9 @@ class DGLGraph(object): ...@@ -618,6 +647,9 @@ class DGLGraph(object):
Dictionary type is also supported for `h_uv`. In this case, each item Dictionary type is also supported for `h_uv`. In this case, each item
will be treated as separate attribute of the edges. will be treated as separate attribute of the edges.
All update will be done out-placely to work with autograd unless the inplace
flag is true.
Parameters Parameters
---------- ----------
h_uv : tensor or dict of tensor h_uv : tensor or dict of tensor
...@@ -626,28 +658,35 @@ class DGLGraph(object): ...@@ -626,28 +658,35 @@ class DGLGraph(object):
The source node(s). The source node(s).
v : node, container or tensor v : node, container or tensor
The destination node(s). The destination node(s).
inplace : bool
True if the update is done inplacely
""" """
# sanity check # sanity check
u_is_all = is_all(u) u_is_all = is_all(u)
v_is_all = is_all(v) v_is_all = is_all(v)
assert u_is_all == v_is_all assert u_is_all == v_is_all
if u_is_all: if u_is_all:
self.set_e_repr_by_id(h_uv, eid=ALL) self.set_e_repr_by_id(h_uv, eid=ALL, inplace=inplace)
else: else:
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
_, _, eid = self._graph.edge_ids(u, v) _, _, eid = self._graph.edge_ids(u, v)
self.set_e_repr_by_id(h_uv, eid=eid) self.set_e_repr_by_id(h_uv, eid=eid, inplace=inplace)
def set_e_repr_by_id(self, h_uv, eid=ALL): def set_e_repr_by_id(self, h_uv, eid=ALL, inplace=False):
"""Set edge(s) representation by edge id. """Set edge(s) representation by edge id.
All update will be done out-placely to work with autograd unless the inplace
flag is true.
Parameters Parameters
---------- ----------
h_uv : tensor or dict of tensor h_uv : tensor or dict of tensor
Edge representation. Edge representation.
eid : int, container or tensor eid : int, container or tensor
The edge id(s). The edge id(s).
inplace : bool
True if the update is done inplacely
""" """
# sanity check # sanity check
if is_all(eid): if is_all(eid):
...@@ -662,16 +701,18 @@ class DGLGraph(object): ...@@ -662,16 +701,18 @@ class DGLGraph(object):
assert F.shape(h_uv)[0] == num_edges assert F.shape(h_uv)[0] == num_edges
# set # set
if is_all(eid): if is_all(eid):
# update column
if utils.is_dict_like(h_uv): if utils.is_dict_like(h_uv):
for key, val in h_uv.items(): for key, val in h_uv.items():
self._edge_frame[key] = val self._edge_frame[key] = val
else: else:
self._edge_frame[__REPR__] = h_uv self._edge_frame[__REPR__] = h_uv
else: else:
# update row
if utils.is_dict_like(h_uv): if utils.is_dict_like(h_uv):
self._edge_frame[eid] = h_uv self._edge_frame.update_rows(eid, h_uv, inplace=inplace)
else: else:
self._edge_frame[eid] = {__REPR__ : h_uv} self._edge_frame.update_rows(eid, {__REPR__ : h_uv}, inplace=inplace)
def get_e_repr(self, u=ALL, v=ALL): def get_e_repr(self, u=ALL, v=ALL):
"""Get node(s) representation. """Get node(s) representation.
...@@ -793,12 +834,12 @@ class DGLGraph(object): ...@@ -793,12 +834,12 @@ class DGLGraph(object):
""" """
self._apply_edge_func = apply_edge_func self._apply_edge_func = apply_edge_func
def apply_nodes(self, v, apply_node_func="default"): def apply_nodes(self, v=ALL, apply_node_func="default"):
"""Apply the function on node representations. """Apply the function on node representations.
Parameters Parameters
---------- ----------
v : int, iterable of int, tensor v : int, iterable of int, tensor, optional
The node id(s). The node id(s).
apply_node_func : callable apply_node_func : callable
The apply node function. The apply node function.
...@@ -952,8 +993,8 @@ class DGLGraph(object): ...@@ -952,8 +993,8 @@ class DGLGraph(object):
self._msg_frame.update_rows( self._msg_frame.update_rows(
msg_target_rows, msg_target_rows,
{k: F.gather_row(msgs[k], msg_update_rows.tousertensor()) {k: F.gather_row(msgs[k], msg_update_rows.tousertensor())
for k in msgs} for k in msgs},
) inplace=False)
if len(msg_append_rows) > 0: if len(msg_append_rows) > 0:
new_u, new_v = zip(*new_uv) new_u, new_v = zip(*new_uv)
new_u = utils.toindex(new_u) new_u = utils.toindex(new_u)
...@@ -961,14 +1002,13 @@ class DGLGraph(object): ...@@ -961,14 +1002,13 @@ class DGLGraph(object):
self._msg_graph.add_edges(new_u, new_v) self._msg_graph.add_edges(new_u, new_v)
self._msg_frame.append( self._msg_frame.append(
{k: F.gather_row(msgs[k], msg_append_rows.tousertensor()) {k: F.gather_row(msgs[k], msg_append_rows.tousertensor())
for k in msgs} for k in msgs})
)
else: else:
if len(msg_target_rows) > 0: if len(msg_target_rows) > 0:
self._msg_frame.update_rows( self._msg_frame.update_rows(
msg_target_rows, msg_target_rows,
{__MSG__: F.gather_row(msgs, msg_update_rows.tousertensor())} {__MSG__: F.gather_row(msgs, msg_update_rows.tousertensor())},
) inplace=False)
if len(msg_append_rows) > 0: if len(msg_append_rows) > 0:
new_u, new_v = zip(*new_uv) new_u, new_v = zip(*new_uv)
new_u = utils.toindex(new_u) new_u = utils.toindex(new_u)
......
...@@ -20,22 +20,26 @@ def reduce_func(node, msgs): ...@@ -20,22 +20,26 @@ def reduce_func(node, msgs):
reduce_msg_shapes.add(tuple(msgs.shape)) reduce_msg_shapes.add(tuple(msgs.shape))
assert len(msgs.shape) == 3 assert len(msgs.shape) == 3
assert msgs.shape[2] == D assert msgs.shape[2] == D
return {'m' : th.sum(msgs, 1)} return {'accum' : th.sum(msgs, 1)}
def apply_node_func(node): def apply_node_func(node):
return {'h' : node['h'] + node['m']} return {'h' : node['h'] + node['accum']}
def generate_graph(grad=False): def generate_graph(grad=False):
g = DGLGraph() g = DGLGraph()
g.add_nodes(10) # 10 nodes. g.add_nodes(10) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink # create a graph where 0 is the source and 9 is the sink
# 17 edges
for i in range(1, 9): for i in range(1, 9):
g.add_edge(0, i) g.add_edge(0, i)
g.add_edge(i, 9) g.add_edge(i, 9)
# add a back flow from 9 to 0 # add a back flow from 9 to 0
g.add_edge(9, 0) g.add_edge(9, 0)
ncol = Variable(th.randn(10, D), requires_grad=grad) ncol = Variable(th.randn(10, D), requires_grad=grad)
accumcol = Variable(th.randn(10, D), requires_grad=grad)
ecol = Variable(th.randn(17, D), requires_grad=grad)
g.set_n_repr({'h' : ncol}) g.set_n_repr({'h' : ncol})
g.set_n_initializer(lambda shape, dtype : th.zeros(shape))
return g return g
def test_batch_setter_getter(): def test_batch_setter_getter():
...@@ -46,8 +50,9 @@ def test_batch_setter_getter(): ...@@ -46,8 +50,9 @@ def test_batch_setter_getter():
g.set_n_repr({'h' : th.zeros((10, D))}) g.set_n_repr({'h' : th.zeros((10, D))})
assert _pfc(g.get_n_repr()['h']) == [0.] * 10 assert _pfc(g.get_n_repr()['h']) == [0.] * 10
# pop nodes # pop nodes
old_len = len(g.get_n_repr())
assert _pfc(g.pop_n_repr('h')) == [0.] * 10 assert _pfc(g.pop_n_repr('h')) == [0.] * 10
assert len(g.get_n_repr()) == 0 assert len(g.get_n_repr()) == old_len - 1
g.set_n_repr({'h' : th.zeros((10, D))}) g.set_n_repr({'h' : th.zeros((10, D))})
# set partial nodes # set partial nodes
u = th.tensor([1, 3, 5]) u = th.tensor([1, 3, 5])
...@@ -81,8 +86,9 @@ def test_batch_setter_getter(): ...@@ -81,8 +86,9 @@ def test_batch_setter_getter():
g.set_e_repr({'l' : th.zeros((17, D))}) g.set_e_repr({'l' : th.zeros((17, D))})
assert _pfc(g.get_e_repr()['l']) == [0.] * 17 assert _pfc(g.get_e_repr()['l']) == [0.] * 17
# pop edges # pop edges
old_len = len(g.get_e_repr())
assert _pfc(g.pop_e_repr('l')) == [0.] * 17 assert _pfc(g.pop_e_repr('l')) == [0.] * 17
assert len(g.get_e_repr()) == 0 assert len(g.get_e_repr()) == old_len - 1
g.set_e_repr({'l' : th.zeros((17, D))}) g.set_e_repr({'l' : th.zeros((17, D))})
# set partial edges (many-many) # set partial edges (many-many)
u = th.tensor([0, 0, 2, 5, 9]) u = th.tensor([0, 0, 2, 5, 9])
......
...@@ -30,8 +30,10 @@ def generate_graph(grad=False): ...@@ -30,8 +30,10 @@ def generate_graph(grad=False):
g.add_edge(i, 9) g.add_edge(i, 9)
# add a back flow from 9 to 0 # add a back flow from 9 to 0
g.add_edge(9, 0) g.add_edge(9, 0)
col = Variable(th.randn(10, D), requires_grad=grad) ncol = Variable(th.randn(10, D), requires_grad=grad)
g.set_n_repr(col) ecol = Variable(th.randn(17, D), requires_grad=grad)
g.set_n_repr(ncol)
g.set_e_repr(ecol)
return g return g
def test_batch_setter_getter(): def test_batch_setter_getter():
......
...@@ -2,14 +2,11 @@ import torch as th ...@@ -2,14 +2,11 @@ import torch as th
from torch.autograd import Variable from torch.autograd import Variable
import numpy as np import numpy as np
from dgl.frame import Frame, FrameRef from dgl.frame import Frame, FrameRef
from dgl.utils import Index from dgl.utils import Index, toindex
N = 10 N = 10
D = 5 D = 5
def check_eq(a, b):
return a.shape == b.shape and np.allclose(a.numpy(), b.numpy())
def check_fail(fn): def check_fail(fn):
try: try:
fn() fn()
...@@ -27,12 +24,13 @@ def test_create(): ...@@ -27,12 +24,13 @@ def test_create():
data = create_test_data() data = create_test_data()
f1 = Frame() f1 = Frame()
for k, v in data.items(): for k, v in data.items():
f1.add_column(k, v) f1.update_column(k, v)
assert f1.schemes == set(data.keys()) print(f1.schemes)
assert f1.keys() == set(data.keys())
assert f1.num_columns == 3 assert f1.num_columns == 3
assert f1.num_rows == N assert f1.num_rows == N
f2 = Frame(data) f2 = Frame(data)
assert f2.schemes == set(data.keys()) assert f2.keys() == set(data.keys())
assert f2.num_columns == 3 assert f2.num_columns == 3
assert f2.num_rows == N assert f2.num_rows == N
f1.clear() f1.clear()
...@@ -45,9 +43,9 @@ def test_column1(): ...@@ -45,9 +43,9 @@ def test_column1():
f = Frame(data) f = Frame(data)
assert f.num_rows == N assert f.num_rows == N
assert len(f) == 3 assert len(f) == 3
assert check_eq(f['a1'], data['a1']) assert th.allclose(f['a1'].data, data['a1'].data)
f['a1'] = data['a2'] f['a1'] = data['a2']
assert check_eq(f['a2'], data['a2']) assert th.allclose(f['a2'].data, data['a2'].data)
# add a different length column should fail # add a different length column should fail
def failed_add_col(): def failed_add_col():
f['a4'] = th.zeros([N+1, D]) f['a4'] = th.zeros([N+1, D])
...@@ -70,16 +68,15 @@ def test_column2(): ...@@ -70,16 +68,15 @@ def test_column2():
f = FrameRef(data, [3, 4, 5, 6, 7]) f = FrameRef(data, [3, 4, 5, 6, 7])
assert f.num_rows == 5 assert f.num_rows == 5
assert len(f) == 3 assert len(f) == 3
assert check_eq(f['a1'], data['a1'][3:8]) assert th.allclose(f['a1'], data['a1'].data[3:8])
# set column should reflect on the referenced data # set column should reflect on the referenced data
f['a1'] = th.zeros([5, D]) f['a1'] = th.zeros([5, D])
assert check_eq(data['a1'][3:8], th.zeros([5, D])) assert th.allclose(data['a1'].data[3:8], th.zeros([5, D]))
# add new column should be padded with zero # add new partial column should fail with error initializer
f['a4'] = th.ones([5, D]) f.set_initializer(lambda shape, dtype : assert_(False))
assert len(data) == 4 def failed_add_col():
assert check_eq(data['a4'][0:3], th.zeros([3, D])) f['a4'] = th.ones([5, D])
assert check_eq(data['a4'][3:8], th.ones([5, D])) assert check_fail(failed_add_col)
assert check_eq(data['a4'][8:10], th.zeros([2, D]))
def test_append1(): def test_append1():
# test append API on Frame # test append API on Frame
...@@ -91,9 +88,14 @@ def test_append1(): ...@@ -91,9 +88,14 @@ def test_append1():
f1.append(f2) f1.append(f2)
assert f1.num_rows == 2 * N assert f1.num_rows == 2 * N
c1 = f1['a1'] c1 = f1['a1']
assert c1.shape == (2 * N, D) assert c1.data.shape == (2 * N, D)
truth = th.cat([data['a1'], data['a1']]) truth = th.cat([data['a1'], data['a1']])
assert check_eq(truth, c1) assert th.allclose(truth, c1.data)
# append dict of different length columns should fail
f3 = {'a1' : th.zeros((3, D)), 'a2' : th.zeros((3, D)), 'a3' : th.zeros((2, D))}
def failed_append():
f1.append(f3)
assert check_fail(failed_append)
def test_append2(): def test_append2():
# test append on FrameRef # test append on FrameRef
...@@ -113,7 +115,7 @@ def test_append2(): ...@@ -113,7 +115,7 @@ def test_append2():
assert not f.is_span_whole_column() assert not f.is_span_whole_column()
assert f.num_rows == 3 * N assert f.num_rows == 3 * N
new_idx = list(range(N)) + list(range(2*N, 4*N)) new_idx = list(range(N)) + list(range(2*N, 4*N))
assert check_eq(f.index().tousertensor(), th.tensor(new_idx)) assert th.all(f.index().tousertensor() == th.tensor(new_idx, dtype=th.int64))
assert data.num_rows == 4 * N assert data.num_rows == 4 * N
def test_row1(): def test_row1():
...@@ -127,13 +129,13 @@ def test_row1(): ...@@ -127,13 +129,13 @@ def test_row1():
rows = f[rowid] rows = f[rowid]
for k, v in rows.items(): for k, v in rows.items():
assert v.shape == (len(rowid), D) assert v.shape == (len(rowid), D)
assert check_eq(v, data[k][rowid]) assert th.allclose(v, data[k][rowid])
# test duplicate keys # test duplicate keys
rowid = Index(th.tensor([8, 2, 2, 1])) rowid = Index(th.tensor([8, 2, 2, 1]))
rows = f[rowid] rows = f[rowid]
for k, v in rows.items(): for k, v in rows.items():
assert v.shape == (len(rowid), D) assert v.shape == (len(rowid), D)
assert check_eq(v, data[k][rowid]) assert th.allclose(v, data[k][rowid])
# setter # setter
rowid = Index(th.tensor([0, 2, 4])) rowid = Index(th.tensor([0, 2, 4]))
...@@ -143,12 +145,14 @@ def test_row1(): ...@@ -143,12 +145,14 @@ def test_row1():
} }
f[rowid] = vals f[rowid] = vals
for k, v in f[rowid].items(): for k, v in f[rowid].items():
assert check_eq(v, th.zeros((len(rowid), D))) assert th.allclose(v, th.zeros((len(rowid), D)))
# setting rows with new column should automatically add a new column # setting rows with new column should raise error with error initializer
vals['a4'] = th.ones((len(rowid), D)) f.set_initializer(lambda shape, dtype : assert_(False))
f[rowid] = vals def failed_update_rows():
assert len(f) == 4 vals['a4'] = th.ones((len(rowid), D))
f[rowid] = vals
assert check_fail(failed_update_rows)
def test_row2(): def test_row2():
# test row getter/setter autograd compatibility # test row getter/setter autograd compatibility
...@@ -161,13 +165,13 @@ def test_row2(): ...@@ -161,13 +165,13 @@ def test_row2():
rowid = Index(th.tensor([0, 2])) rowid = Index(th.tensor([0, 2]))
rows = f[rowid] rows = f[rowid]
rows['a1'].backward(th.ones((len(rowid), D))) rows['a1'].backward(th.ones((len(rowid), D)))
assert check_eq(c1.grad[:,0], th.tensor([1., 0., 1., 0., 0., 0., 0., 0., 0., 0.])) assert th.allclose(c1.grad[:,0], th.tensor([1., 0., 1., 0., 0., 0., 0., 0., 0., 0.]))
c1.grad.data.zero_() c1.grad.data.zero_()
# test duplicate keys # test duplicate keys
rowid = Index(th.tensor([8, 2, 2, 1])) rowid = Index(th.tensor([8, 2, 2, 1]))
rows = f[rowid] rows = f[rowid]
rows['a1'].backward(th.ones((len(rowid), D))) rows['a1'].backward(th.ones((len(rowid), D)))
assert check_eq(c1.grad[:,0], th.tensor([0., 1., 2., 0., 0., 0., 0., 0., 1., 0.])) assert th.allclose(c1.grad[:,0], th.tensor([0., 1., 2., 0., 0., 0., 0., 0., 1., 0.]))
c1.grad.data.zero_() c1.grad.data.zero_()
# setter # setter
...@@ -180,8 +184,8 @@ def test_row2(): ...@@ -180,8 +184,8 @@ def test_row2():
f[rowid] = vals f[rowid] = vals
c11 = f['a1'] c11 = f['a1']
c11.backward(th.ones((N, D))) c11.backward(th.ones((N, D)))
assert check_eq(c1.grad[:,0], th.tensor([0., 1., 0., 1., 0., 1., 1., 1., 1., 1.])) assert th.allclose(c1.grad[:,0], th.tensor([0., 1., 0., 1., 0., 1., 1., 1., 1., 1.]))
assert check_eq(vals['a1'].grad, th.ones((len(rowid), D))) assert th.allclose(vals['a1'].grad, th.ones((len(rowid), D)))
assert vals['a2'].grad is None assert vals['a2'].grad is None
def test_row3(): def test_row3():
...@@ -201,8 +205,9 @@ def test_row3(): ...@@ -201,8 +205,9 @@ def test_row3():
newidx = list(range(N)) newidx = list(range(N))
newidx.pop(2) newidx.pop(2)
newidx.pop(2) newidx.pop(2)
newidx = toindex(newidx)
for k, v in f.items(): for k, v in f.items():
assert check_eq(v, data[k][th.tensor(newidx)]) assert th.allclose(v, data[k][newidx])
def test_sharing(): def test_sharing():
data = Frame(create_test_data()) data = Frame(create_test_data())
...@@ -210,10 +215,10 @@ def test_sharing(): ...@@ -210,10 +215,10 @@ def test_sharing():
f2 = FrameRef(data, index=[2, 3, 4, 5, 6]) f2 = FrameRef(data, index=[2, 3, 4, 5, 6])
# test read # test read
for k, v in f1.items(): for k, v in f1.items():
assert check_eq(data[k][0:4], v) assert th.allclose(data[k].data[0:4], v)
for k, v in f2.items(): for k, v in f2.items():
assert check_eq(data[k][2:7], v) assert th.allclose(data[k].data[2:7], v)
f2_a1 = f2['a1'] f2_a1 = f2['a1'].data
# test write # test write
# update own ref should not been seen by the other. # update own ref should not been seen by the other.
f1[Index(th.tensor([0, 1]))] = { f1[Index(th.tensor([0, 1]))] = {
...@@ -221,7 +226,7 @@ def test_sharing(): ...@@ -221,7 +226,7 @@ def test_sharing():
'a2' : th.zeros([2, D]), 'a2' : th.zeros([2, D]),
'a3' : th.zeros([2, D]), 'a3' : th.zeros([2, D]),
} }
assert check_eq(f2['a1'], f2_a1) assert th.allclose(f2['a1'], f2_a1)
# update shared space should been seen by the other. # update shared space should been seen by the other.
f1[Index(th.tensor([2, 3]))] = { f1[Index(th.tensor([2, 3]))] = {
'a1' : th.ones([2, D]), 'a1' : th.ones([2, D]),
...@@ -229,7 +234,7 @@ def test_sharing(): ...@@ -229,7 +234,7 @@ def test_sharing():
'a3' : th.ones([2, D]), 'a3' : th.ones([2, D]),
} }
f2_a1[0:2] = th.ones([2, D]) f2_a1[0:2] = th.ones([2, D])
assert check_eq(f2['a1'], f2_a1) assert th.allclose(f2['a1'], f2_a1)
if __name__ == '__main__': if __name__ == '__main__':
test_create() test_create()
......
...@@ -123,6 +123,7 @@ def test_update_all_multi_fn(): ...@@ -123,6 +123,7 @@ def test_update_all_multi_fn():
return {'v2': th.sum(msgs['m2'], 1)} return {'v2': th.sum(msgs['m2'], 1)}
g = generate_graph() g = generate_graph()
g.set_n_repr({'v1' : th.zeros((10,)), 'v2' : th.zeros((10,))})
fld = 'f2' fld = 'f2'
# update all, mix of builtin and UDF # update all, mix of builtin and UDF
g.update_all([fn.copy_src(src=fld, out='m1'), message_func], g.update_all([fn.copy_src(src=fld, out='m1'), message_func],
...@@ -173,6 +174,8 @@ def test_send_and_recv_multi_fn(): ...@@ -173,6 +174,8 @@ def test_send_and_recv_multi_fn():
return {'v2' : th.sum(msgs['m2'], 1)} return {'v2' : th.sum(msgs['m2'], 1)}
g = generate_graph() g = generate_graph()
g.set_n_repr({'v1' : th.zeros((10, D)), 'v2' : th.zeros((10, D)),
'v3' : th.zeros((10, D))})
fld = 'f2' fld = 'f2'
# send and recv, mix of builtin and UDF # send and recv, mix of builtin and UDF
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment