Commit 2ecd2b23 authored by Gan Quan's avatar Gan Quan Committed by Minjie Wang
Browse files

[Frame] Support slice type indexing; optimize dgl.batch (#110)

* cherry picking optimization from jtnn

* unbatch by slicing frames

* reduce pack

* oops

* support frame read/write with slices

* reverting to unbatch by splitting; slicing is unfriendly to backward

* replacing lru cache with static object factory

* replacing Scheme object with namedtuple

* remove comment

* forgot the find edges interface

* subclassing namedtuple
parent 9827e481
...@@ -27,7 +27,6 @@ except IMPORT_EXCEPT: ...@@ -27,7 +27,6 @@ except IMPORT_EXCEPT:
from ._ctypes.ndarray import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack from ._ctypes.ndarray import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack
from ._ctypes.ndarray import NDArrayBase as _NDArrayBase from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
def context(dev_type, dev_id=0): def context(dev_type, dev_id=0):
"""Construct a TVM context with given device type and id. """Construct a TVM context with given device type and id.
......
...@@ -44,33 +44,44 @@ class TVMType(ctypes.Structure): ...@@ -44,33 +44,44 @@ class TVMType(ctypes.Structure):
2 : 'float', 2 : 'float',
4 : 'handle' 4 : 'handle'
} }
def __init__(self, type_str): _cache = {}
super(TVMType, self).__init__()
def __new__(cls, type_str):
if type_str in cls._cache:
return cls._cache[type_str]
inst = super(TVMType, cls).__new__(TVMType)
if isinstance(type_str, np.dtype): if isinstance(type_str, np.dtype):
type_str = str(type_str) type_str = str(type_str)
arr = type_str.split("x") arr = type_str.split("x")
head = arr[0] head = arr[0]
self.lanes = int(arr[1]) if len(arr) > 1 else 1 inst.lanes = int(arr[1]) if len(arr) > 1 else 1
bits = 32 bits = 32
if head.startswith("int"): if head.startswith("int"):
self.type_code = 0 inst.type_code = 0
head = head[3:] head = head[3:]
elif head.startswith("uint"): elif head.startswith("uint"):
self.type_code = 1 inst.type_code = 1
head = head[4:] head = head[4:]
elif head.startswith("float"): elif head.startswith("float"):
self.type_code = 2 inst.type_code = 2
head = head[5:] head = head[5:]
elif head.startswith("handle"): elif head.startswith("handle"):
self.type_code = 4 inst.type_code = 4
bits = 64 bits = 64
head = "" head = ""
else: else:
raise ValueError("Donot know how to handle type %s" % type_str) raise ValueError("Donot know how to handle type %s" % type_str)
bits = int(head) if head else bits bits = int(head) if head else bits
self.bits = bits inst.bits = bits
cls._cache[type_str] = inst
return inst
def __init__(self, type_str):
pass
def __repr__(self): def __repr__(self):
x = "%s%d" % (TVMType.CODE2STR[self.type_code], self.bits) x = "%s%d" % (TVMType.CODE2STR[self.type_code], self.bits)
...@@ -124,10 +135,22 @@ class TVMContext(ctypes.Structure): ...@@ -124,10 +135,22 @@ class TVMContext(ctypes.Structure):
'opengl': 11, 'opengl': 11,
'ext_dev': 12, 'ext_dev': 12,
} }
_cache = {}
def __new__(cls, device_type, device_id):
if (device_type, device_id) in cls._cache:
return cls._cache[(device_type, device_id)]
inst = super(TVMContext, cls).__new__(TVMContext)
inst.device_type = device_type
inst.device_id = device_id
cls._cache[(device_type, device_id)] = inst
return inst
def __init__(self, device_type, device_id): def __init__(self, device_type, device_id):
super(TVMContext, self).__init__() pass
self.device_type = device_type
self.device_id = device_id
@property @property
def exist(self): def exist(self):
......
...@@ -86,12 +86,25 @@ def to_context(arr, ctx): ...@@ -86,12 +86,25 @@ def to_context(arr, ctx):
else: else:
raise RuntimeError('Invalid context', ctx) raise RuntimeError('Invalid context', ctx)
def get_context(arr): def _get_context(type, index):
if arr.device.type == 'cpu': if type == 'cpu':
return TVMContext(TVMContext.STR2MASK['cpu'], 0) return TVMContext(TVMContext.STR2MASK['cpu'], 0)
else: else:
return TVMContext( return TVMContext(TVMContext.STR2MASK[type], index)
TVMContext.STR2MASK[arr.device.type], arr.device.index)
def get_context(arr):
return _get_context(arr.device.type, arr.device.index)
_tvmtypes = {
th.float16: TVMType('float16'),
th.float32: TVMType('float32'),
th.float64: TVMType('float64'),
th.int8: TVMType('int8'),
th.uint8: TVMType('uint8'),
th.int16: TVMType('int16'),
th.int32: TVMType('int32'),
th.int64: TVMType('int64'),
}
def convert_to(src, dst): def convert_to(src, dst):
''' '''
...@@ -101,24 +114,9 @@ def convert_to(src, dst): ...@@ -101,24 +114,9 @@ def convert_to(src, dst):
def get_tvmtype(arr): def get_tvmtype(arr):
arr_dtype = arr.dtype arr_dtype = arr.dtype
if arr_dtype in (th.float16, th.half): if arr_dtype not in _tvmtypes:
return TVMType('float16')
elif arr_dtype in (th.float32, th.float):
return TVMType('float32')
elif arr_dtype in (th.float64, th.double):
return TVMType('float64')
elif arr_dtype in (th.int16, th.short):
return TVMType('int16')
elif arr_dtype in (th.int32, th.int):
return TVMType('int32')
elif arr_dtype in (th.int64, th.long):
return TVMType('int64')
elif arr_dtype == th.int8:
return TVMType('int8')
elif arr_dtype == th.uint8:
return TVMType('uint8')
else:
raise RuntimeError('Unsupported data type:', arr_dtype) raise RuntimeError('Unsupported data type:', arr_dtype)
return _tvmtypes[arr_dtype]
def zerocopy_to_dlpack(arr): def zerocopy_to_dlpack(arr):
"""Return a dlpack compatible array using zero copy.""" """Return a dlpack compatible array using zero copy."""
...@@ -130,7 +128,6 @@ def zerocopy_from_dlpack(dlpack_arr): ...@@ -130,7 +128,6 @@ def zerocopy_from_dlpack(dlpack_arr):
def zerocopy_to_numpy(arr): def zerocopy_to_numpy(arr):
"""Return a numpy array that shares the data.""" """Return a numpy array that shares the data."""
# TODO(minjie): zero copy
return arr.numpy() return arr.numpy()
def zerocopy_from_numpy(np_data): def zerocopy_from_numpy(np_data):
......
...@@ -4,7 +4,7 @@ from __future__ import absolute_import ...@@ -4,7 +4,7 @@ from __future__ import absolute_import
import numpy as np import numpy as np
from .base import ALL, is_all from .base import ALL, is_all
from .frame import FrameRef from .frame import FrameRef, Frame
from .graph import DGLGraph from .graph import DGLGraph
from . import graph_index as gi from . import graph_index as gi
from . import backend as F from . import backend as F
...@@ -31,14 +31,14 @@ class BatchedDGLGraph(DGLGraph): ...@@ -31,14 +31,14 @@ class BatchedDGLGraph(DGLGraph):
batched_index = gi.disjoint_union([g._graph for g in graph_list]) batched_index = gi.disjoint_union([g._graph for g in graph_list])
# create batched node and edge frames # create batched node and edge frames
# NOTE: following code will materialize the columns of the input graphs. # NOTE: following code will materialize the columns of the input graphs.
batched_node_frame = FrameRef() cols = {key: F.pack([gr._node_frame[key] for gr in graph_list])
for gr in graph_list: for key in node_attrs}
cols = {key : gr._node_frame[key] for key in node_attrs} batched_node_frame = FrameRef(Frame(cols))
batched_node_frame.append(cols)
batched_edge_frame = FrameRef() cols = {key: F.pack([gr._edge_frame[key] for gr in graph_list])
for gr in graph_list: for key in edge_attrs}
cols = {key : gr._edge_frame[key] for key in edge_attrs} batched_edge_frame = FrameRef(Frame(cols))
batched_edge_frame.append(cols)
super(BatchedDGLGraph, self).__init__( super(BatchedDGLGraph, self).__init__(
graph_data=batched_index, graph_data=batched_index,
node_frame=batched_node_frame, node_frame=batched_node_frame,
...@@ -154,6 +154,14 @@ def unbatch(graph): ...@@ -154,6 +154,14 @@ def unbatch(graph):
---------- ----------
graph : BatchedDGLGraph graph : BatchedDGLGraph
The batched graph. The batched graph.
Notes
-----
Unbatching will partition each field tensor of the batched graph into
smaller partitions. This is usually wasteful.
For simpler tasks such as node/edge state aggregation by example,
try to use BatchedDGLGraph.readout().
""" """
assert isinstance(graph, BatchedDGLGraph) assert isinstance(graph, BatchedDGLGraph)
bsize = graph.batch_size bsize = graph.batch_size
......
"""Columnar storage for DGLGraph.""" """Columnar storage for DGLGraph."""
from __future__ import absolute_import from __future__ import absolute_import
from collections import MutableMapping from collections import MutableMapping, namedtuple
import numpy as np import numpy as np
from . import backend as F from . import backend as F
...@@ -9,7 +9,8 @@ from .backend import Tensor ...@@ -9,7 +9,8 @@ from .backend import Tensor
from .base import DGLError, dgl_warning from .base import DGLError, dgl_warning
from . import utils from . import utils
class Scheme(object):
class Scheme(namedtuple('Scheme', ['shape', 'dtype'])):
"""The column scheme. """The column scheme.
Parameters Parameters
...@@ -19,22 +20,9 @@ class Scheme(object): ...@@ -19,22 +20,9 @@ class Scheme(object):
dtype : TVMType dtype : TVMType
The feature data type. The feature data type.
""" """
def __init__(self, shape, dtype): pass
self.shape = shape
self.dtype = dtype
def __repr__(self):
return '{shape=%s, dtype=%s}' % (repr(self.shape), repr(self.dtype))
def __eq__(self, other):
return self.shape == other.shape and self.dtype == other.dtype
def __ne__(self, other):
return not self.__eq__(other)
@staticmethod def infer_scheme(tensor):
def infer_scheme(tensor):
"""Infer the scheme of the given tensor."""
return Scheme(tuple(F.shape(tensor)[1:]), F.get_tvmtype(tensor)) return Scheme(tuple(F.shape(tensor)[1:]), F.get_tvmtype(tensor))
class Column(object): class Column(object):
...@@ -52,7 +40,7 @@ class Column(object): ...@@ -52,7 +40,7 @@ class Column(object):
""" """
def __init__(self, data, scheme=None): def __init__(self, data, scheme=None):
self.data = data self.data = data
self.scheme = scheme if scheme else Scheme.infer_scheme(data) self.scheme = scheme if scheme else infer_scheme(data)
def __len__(self): def __len__(self):
"""The column length.""" """The column length."""
...@@ -67,7 +55,7 @@ class Column(object): ...@@ -67,7 +55,7 @@ class Column(object):
Parameters Parameters
---------- ----------
idx : utils.Index idx : slice or utils.Index
The index. The index.
Returns Returns
...@@ -75,6 +63,9 @@ class Column(object): ...@@ -75,6 +63,9 @@ class Column(object):
Tensor Tensor
The feature data The feature data
""" """
if isinstance(idx, slice):
return self.data[idx]
else:
user_idx = idx.tousertensor(F.get_context(self.data)) user_idx = idx.tousertensor(F.get_context(self.data))
return F.gather_row(self.data, user_idx) return F.gather_row(self.data, user_idx)
...@@ -86,7 +77,7 @@ class Column(object): ...@@ -86,7 +77,7 @@ class Column(object):
Parameters Parameters
---------- ----------
idx : utils.Index idx : utils.Index or slice
The index. The index.
feats : Tensor feats : Tensor
The new features. The new features.
...@@ -98,23 +89,34 @@ class Column(object): ...@@ -98,23 +89,34 @@ class Column(object):
Parameters Parameters
---------- ----------
idx : utils.Index idx : utils.Index or slice
The index. The index.
feats : Tensor feats : Tensor
The new features. The new features.
inplace : bool inplace : bool
If true, use inplace write. If true, use inplace write.
""" """
feat_scheme = Scheme.infer_scheme(feats) feat_scheme = infer_scheme(feats)
if feat_scheme != self.scheme: if feat_scheme != self.scheme:
raise DGLError("Cannot update column of scheme %s using feature of scheme %s." raise DGLError("Cannot update column of scheme %s using feature of scheme %s."
% (feat_scheme, self.scheme)) % (feat_scheme, self.scheme))
user_idx = idx.tousertensor(F.get_context(self.data))
if isinstance(idx, utils.Index):
idx = idx.tousertensor(F.get_context(self.data))
if inplace: if inplace:
# TODO(minjie): do not use [] operator directly # TODO(minjie): do not use [] operator directly
self.data[user_idx] = feats self.data[idx] = feats
else: else:
self.data = F.scatter_row(self.data, user_idx, feats) if isinstance(idx, slice):
# for contiguous indices pack is usually faster than scatter row
self.data = F.pack([
self.data[:idx.start],
feats,
self.data[idx.stop:],
])
else:
self.data = F.scatter_row(self.data, idx, feats)
def extend(self, feats, feat_scheme=None): def extend(self, feats, feat_scheme=None):
"""Extend the feature data. """Extend the feature data.
...@@ -353,19 +355,23 @@ class FrameRef(MutableMapping): ...@@ -353,19 +355,23 @@ class FrameRef(MutableMapping):
frame : Frame, optional frame : Frame, optional
The underlying frame. If not given, the reference will point to a The underlying frame. If not given, the reference will point to a
new empty frame. new empty frame.
index : iterable of int, optional index : iterable, slice, or int, optional
The rows that are referenced in the underlying frame. If not given, The rows that are referenced in the underlying frame. If not given,
the whole frame is referenced. The index should be distinct (no the whole frame is referenced. The index should be distinct (no
duplication is allowed). duplication is allowed).
Note that if a slice is given, the step must be None.
""" """
def __init__(self, frame=None, index=None): def __init__(self, frame=None, index=None):
self._frame = frame if frame is not None else Frame() self._frame = frame if frame is not None else Frame()
if index is None: if index is None:
# _index_data can be either a slice or an iterable
self._index_data = slice(0, self._frame.num_rows) self._index_data = slice(0, self._frame.num_rows)
else: else:
# TODO(minjie): check no duplication # TODO(minjie): check no duplication
self._index_data = index self._index_data = index
self._index = None self._index = None
self._index_or_slice = None
@property @property
def schemes(self): def schemes(self):
...@@ -387,10 +393,8 @@ class FrameRef(MutableMapping): ...@@ -387,10 +393,8 @@ class FrameRef(MutableMapping):
def num_rows(self): def num_rows(self):
"""Return the number of rows referred.""" """Return the number of rows referred."""
if isinstance(self._index_data, slice): if isinstance(self._index_data, slice):
# NOTE: we are assuming that the index is a slice ONLY IF # NOTE: we always assume that slice.step is None
# index=None during construction. return self._index_data.stop - self._index_data.start
# As such, start is always 0, and step is always 1.
return self._index_data.stop
else: else:
return len(self._index_data) return len(self._index_data)
...@@ -417,11 +421,29 @@ class FrameRef(MutableMapping): ...@@ -417,11 +421,29 @@ class FrameRef(MutableMapping):
if self._index is None: if self._index is None:
if self.is_contiguous(): if self.is_contiguous():
self._index = utils.toindex( self._index = utils.toindex(
F.arange(self._index_data.stop, dtype=F.int64)) F.arange(
self._index_data.start,
self._index_data.stop,
dtype=F.int64))
else: else:
self._index = utils.toindex(self._index_data) self._index = utils.toindex(self._index_data)
return self._index return self._index
def index_or_slice(self):
"""Returns the index object or the slice
Returns
-------
utils.Index or slice
The index or slice
"""
if self._index_or_slice is None:
if self.is_contiguous():
self._index_or_slice = self._index_data
else:
self._index_or_slice = utils.toindex(self._index_data)
return self._index_or_slice
def __contains__(self, name): def __contains__(self, name):
"""Return whether the column name exists.""" """Return whether the column name exists."""
return name in self._frame return name in self._frame
...@@ -442,8 +464,8 @@ class FrameRef(MutableMapping): ...@@ -442,8 +464,8 @@ class FrameRef(MutableMapping):
"""Get data from the frame. """Get data from the frame.
If the provided key is string, the corresponding column data will be returned. If the provided key is string, the corresponding column data will be returned.
If the provided key is an index, the corresponding rows will be selected. The If the provided key is an index or a slice, the corresponding rows will be selected.
returned rows are saved in a lazy dictionary so only the real selection happens The returned rows are saved in a lazy dictionary so only the real selection happens
when the explicit column name is provided. when the explicit column name is provided.
Examples (using pytorch) Examples (using pytorch)
...@@ -457,7 +479,7 @@ class FrameRef(MutableMapping): ...@@ -457,7 +479,7 @@ class FrameRef(MutableMapping):
Parameters Parameters
---------- ----------
key : str or utils.Index key : str or utils.Index or slice
The key. The key.
Returns Returns
...@@ -490,14 +512,14 @@ class FrameRef(MutableMapping): ...@@ -490,14 +512,14 @@ class FrameRef(MutableMapping):
if self.is_span_whole_column(): if self.is_span_whole_column():
return col.data return col.data
else: else:
return col[self.index()] return col[self.index_or_slice()]
def select_rows(self, query): def select_rows(self, query):
"""Return the rows given the query. """Return the rows given the query.
Parameters Parameters
---------- ----------
query : utils.Index query : utils.Index or slice
The rows to be selected. The rows to be selected.
Returns Returns
...@@ -505,8 +527,8 @@ class FrameRef(MutableMapping): ...@@ -505,8 +527,8 @@ class FrameRef(MutableMapping):
utils.LazyDict utils.LazyDict
The lazy dictionary from str to the selected data. The lazy dictionary from str to the selected data.
""" """
rowids = self._getrowid(query) rows = self._getrows(query)
return utils.LazyDict(lambda key: self._frame[key][rowids], keys=self.keys()) return utils.LazyDict(lambda key: self._frame[key][rows], keys=self.keys())
def __setitem__(self, key, val): def __setitem__(self, key, val):
"""Update the data in the frame. """Update the data in the frame.
...@@ -561,15 +583,10 @@ class FrameRef(MutableMapping): ...@@ -561,15 +583,10 @@ class FrameRef(MutableMapping):
self._frame[name] = col self._frame[name] = col
else: else:
if name not in self._frame: if name not in self._frame:
feat_shape = F.shape(data)[1:]
feat_dtype = F.get_tvmtype(data)
ctx = F.get_context(data) ctx = F.get_context(data)
self._frame.add_column(name, Scheme(feat_shape, feat_dtype), ctx) self._frame.add_column(name, infer_scheme(data), ctx)
#raise DGLError('Cannot update column. Column "%s" does not exist.'
# ' Did you forget to init the column using `set_n_repr`'
# ' or `set_e_repr`?' % name)
fcol = self._frame[name] fcol = self._frame[name]
fcol.update(self.index(), data, inplace) fcol.update(self.index_or_slice(), data, inplace)
def add_rows(self, num_rows): def add_rows(self, num_rows):
"""Add blank rows. """Add blank rows.
...@@ -606,30 +623,28 @@ class FrameRef(MutableMapping): ...@@ -606,30 +623,28 @@ class FrameRef(MutableMapping):
Parameters Parameters
---------- ----------
query : utils.Index query : utils.Index or slice
The rows to be updated. The rows to be updated.
data : dict-like data : dict-like
The row data. The row data.
inplace : bool inplace : bool
True if the update is performed inplacely. True if the update is performed inplacely.
""" """
rowids = self._getrowid(query) rows = self._getrows(query)
for key, col in data.items(): for key, col in data.items():
if key not in self: if key not in self:
# add new column # add new column
tmpref = FrameRef(self._frame, rowids) tmpref = FrameRef(self._frame, rows)
tmpref.update_column(key, col, inplace) tmpref.update_column(key, col, inplace)
#raise DGLError('Cannot update rows. Column "%s" does not exist.'
# ' Did you forget to init the column using `set_n_repr`'
# ' or `set_e_repr`?' % key)
else: else:
self._frame[key].update(rowids, col, inplace) self._frame[key].update(rows, col, inplace)
def __delitem__(self, key): def __delitem__(self, key):
"""Delete data in the frame. """Delete data in the frame.
If the provided key is a string, the corresponding column will be deleted. If the provided key is a string, the corresponding column will be deleted.
If the provided key is an index object, the corresponding rows will be deleted. If the provided key is an index object or a slice, the corresponding rows will
be deleted.
Please note that "deleted" rows are not really deleted, but simply removed Please note that "deleted" rows are not really deleted, but simply removed
in the reference. As a result, if two FrameRefs point to the same Frame, deleting in the reference. As a result, if two FrameRefs point to the same Frame, deleting
...@@ -656,14 +671,17 @@ class FrameRef(MutableMapping): ...@@ -656,14 +671,17 @@ class FrameRef(MutableMapping):
Parameters Parameters
---------- ----------
query : utils.Index query : utils.Index or slice
The rows to be deleted. The rows to be deleted.
""" """
if isinstance(query, slice):
query = range(query.start, query.stop)
else:
query = query.tolist() query = query.tolist()
if isinstance(self._index_data, slice): if isinstance(self._index_data, slice):
self._index_data = list(range(self._index_data.start, self._index_data.stop)) self._index_data = range(self._index_data.start, self._index_data.stop)
arr = np.array(self._index_data, dtype=np.int32) self._index_data = list(np.delete(self._index_data, query))
self._index_data = list(np.delete(arr, query))
self._clear_cache() self._clear_cache()
def append(self, other): def append(self, other):
...@@ -682,8 +700,11 @@ class FrameRef(MutableMapping): ...@@ -682,8 +700,11 @@ class FrameRef(MutableMapping):
if span_whole: if span_whole:
self._index_data = slice(0, self._frame.num_rows) self._index_data = slice(0, self._frame.num_rows)
elif contiguous: elif contiguous:
if self._index_data.stop == old_nrows:
new_idx = slice(self._index_data.start, self._frame.num_rows)
else:
new_idx = list(range(self._index_data.start, self._index_data.stop)) new_idx = list(range(self._index_data.start, self._index_data.stop))
new_idx += list(range(old_nrows, self._frame.num_rows)) new_idx.extend(range(old_nrows, self._frame.num_rows))
self._index_data = new_idx self._index_data = new_idx
self._clear_cache() self._clear_cache()
...@@ -695,26 +716,35 @@ class FrameRef(MutableMapping): ...@@ -695,26 +716,35 @@ class FrameRef(MutableMapping):
def is_contiguous(self): def is_contiguous(self):
"""Return whether this refers to a contiguous range of rows.""" """Return whether this refers to a contiguous range of rows."""
# NOTE: this check could have false negatives and false positives # NOTE: this check could have false negatives
# (step other than 1) # NOTE: we always assume that slice.step is None
return isinstance(self._index_data, slice) return isinstance(self._index_data, slice)
def is_span_whole_column(self): def is_span_whole_column(self):
"""Return whether this refers to all the rows.""" """Return whether this refers to all the rows."""
return self.is_contiguous() and self.num_rows == self._frame.num_rows return self.is_contiguous() and self.num_rows == self._frame.num_rows
def _getrowid(self, query): def _getrows(self, query):
"""Internal function to convert from the local row ids to the row ids of the frame.""" """Internal function to convert from the local row ids to the row ids of the frame."""
if self.is_contiguous(): if self.is_contiguous():
start = self._index_data.start
if start == 0:
# shortcut for identical mapping # shortcut for identical mapping
return query return query
elif isinstance(query, slice):
return slice(query.start + start, query.stop + start)
else:
query = query.tousertensor()
return utils.toindex(query + start)
else: else:
idxtensor = self.index().tousertensor() idxtensor = self.index().tousertensor()
return utils.toindex(F.gather_row(idxtensor, query.tousertensor())) query = query.tousertensor()
return utils.toindex(F.gather_row(idxtensor, query))
def _clear_cache(self): def _clear_cache(self):
"""Internal function to clear the cached object.""" """Internal function to clear the cached object."""
self._index_tensor = None self._index = None
self._index_or_slice = None
def merge_frames(frames, indices, max_index, reduce_func): def merge_frames(frames, indices, max_index, reduce_func):
"""Merge a list of frames. """Merge a list of frames.
......
...@@ -323,6 +323,23 @@ class DGLGraph(object): ...@@ -323,6 +323,23 @@ class DGLGraph(object):
else: else:
return eid.tousertensor() return eid.tousertensor()
def find_edges(self, eid):
"""Given the edge ids, return their source and destination node ids.
Parameters
----------
eid : list, tensor
The edge ids.
Returns
-------
tensor, tensor
The source and destination node IDs.
"""
eid = utils.toindex(u)
src, dst, _ = self._graph.find_edges(eid)
return src.tousertensor(), dst.tousertensor()
def in_edges(self, v): def in_edges(self, v):
"""Return the in edges of the node(s). """Return the in edges of the node(s).
......
...@@ -36,7 +36,6 @@ def cpu(dev_id=0): ...@@ -36,7 +36,6 @@ def cpu(dev_id=0):
""" """
return TVMContext(1, dev_id) return TVMContext(1, dev_id)
def gpu(dev_id=0): def gpu(dev_id=0):
"""Construct a CPU device """Construct a CPU device
......
...@@ -236,6 +236,38 @@ def test_sharing(): ...@@ -236,6 +236,38 @@ def test_sharing():
f2_a1[0:2] = th.ones([2, D]) f2_a1[0:2] = th.ones([2, D])
assert th.allclose(f2['a1'], f2_a1) assert th.allclose(f2['a1'], f2_a1)
def test_slicing():
data = Frame(create_test_data(grad=True))
f1 = FrameRef(data, index=slice(1, 5))
f2 = FrameRef(data, index=slice(3, 8))
# test read
for k, v in f1.items():
assert th.allclose(data[k].data[1:5], v)
f2_a1 = f2['a1'].data
# test write
f1[Index(th.tensor([0, 1]))] = {
'a1': th.zeros([2, D]),
'a2': th.zeros([2, D]),
'a3': th.zeros([2, D]),
}
assert th.allclose(f2['a1'], f2_a1)
f1[Index(th.tensor([2, 3]))] = {
'a1': th.ones([2, D]),
'a2': th.ones([2, D]),
'a3': th.ones([2, D]),
}
f2_a1[0:2] = 1
assert th.allclose(f2['a1'], f2_a1)
f1[2:4] = {
'a1': th.zeros([2, D]),
'a2': th.zeros([2, D]),
'a3': th.zeros([2, D]),
}
f2_a1[0:2] = 0
assert th.allclose(f2['a1'], f2_a1)
if __name__ == '__main__': if __name__ == '__main__':
test_create() test_create()
test_column1() test_column1()
...@@ -246,3 +278,4 @@ if __name__ == '__main__': ...@@ -246,3 +278,4 @@ if __name__ == '__main__':
test_row2() test_row2()
test_row3() test_row3()
test_sharing() test_sharing()
test_slicing()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment