Commit 2ecd2b23 authored by Gan Quan's avatar Gan Quan Committed by Minjie Wang
Browse files

[Frame] Support slice type indexing; optimize dgl.batch (#110)

* cherry picking optimization from jtnn

* unbatch by slicing frames

* reduce pack

* oops

* support frame read/write with slices

* reverting to unbatch by splitting; slicing is unfriendly to backward

* replacing lru cache with static object factory

* replacing Scheme object with namedtuple

* remove comment

* forgot the find edges interface

* subclassing namedtuple
parent 9827e481
......@@ -27,7 +27,6 @@ except IMPORT_EXCEPT:
from ._ctypes.ndarray import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack
from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
def context(dev_type, dev_id=0):
"""Construct a TVM context with given device type and id.
......
......@@ -44,33 +44,44 @@ class TVMType(ctypes.Structure):
2 : 'float',
4 : 'handle'
}
def __init__(self, type_str):
super(TVMType, self).__init__()
_cache = {}
def __new__(cls, type_str):
if type_str in cls._cache:
return cls._cache[type_str]
inst = super(TVMType, cls).__new__(TVMType)
if isinstance(type_str, np.dtype):
type_str = str(type_str)
arr = type_str.split("x")
head = arr[0]
self.lanes = int(arr[1]) if len(arr) > 1 else 1
inst.lanes = int(arr[1]) if len(arr) > 1 else 1
bits = 32
if head.startswith("int"):
self.type_code = 0
inst.type_code = 0
head = head[3:]
elif head.startswith("uint"):
self.type_code = 1
inst.type_code = 1
head = head[4:]
elif head.startswith("float"):
self.type_code = 2
inst.type_code = 2
head = head[5:]
elif head.startswith("handle"):
self.type_code = 4
inst.type_code = 4
bits = 64
head = ""
else:
raise ValueError("Donot know how to handle type %s" % type_str)
bits = int(head) if head else bits
self.bits = bits
inst.bits = bits
cls._cache[type_str] = inst
return inst
def __init__(self, type_str):
pass
def __repr__(self):
x = "%s%d" % (TVMType.CODE2STR[self.type_code], self.bits)
......@@ -124,10 +135,22 @@ class TVMContext(ctypes.Structure):
'opengl': 11,
'ext_dev': 12,
}
_cache = {}
def __new__(cls, device_type, device_id):
if (device_type, device_id) in cls._cache:
return cls._cache[(device_type, device_id)]
inst = super(TVMContext, cls).__new__(TVMContext)
inst.device_type = device_type
inst.device_id = device_id
cls._cache[(device_type, device_id)] = inst
return inst
def __init__(self, device_type, device_id):
super(TVMContext, self).__init__()
self.device_type = device_type
self.device_id = device_id
pass
@property
def exist(self):
......
......@@ -86,12 +86,25 @@ def to_context(arr, ctx):
else:
raise RuntimeError('Invalid context', ctx)
def get_context(arr):
if arr.device.type == 'cpu':
def _get_context(type, index):
if type == 'cpu':
return TVMContext(TVMContext.STR2MASK['cpu'], 0)
else:
return TVMContext(
TVMContext.STR2MASK[arr.device.type], arr.device.index)
return TVMContext(TVMContext.STR2MASK[type], index)
def get_context(arr):
return _get_context(arr.device.type, arr.device.index)
_tvmtypes = {
th.float16: TVMType('float16'),
th.float32: TVMType('float32'),
th.float64: TVMType('float64'),
th.int8: TVMType('int8'),
th.uint8: TVMType('uint8'),
th.int16: TVMType('int16'),
th.int32: TVMType('int32'),
th.int64: TVMType('int64'),
}
def convert_to(src, dst):
'''
......@@ -101,24 +114,9 @@ def convert_to(src, dst):
def get_tvmtype(arr):
arr_dtype = arr.dtype
if arr_dtype in (th.float16, th.half):
return TVMType('float16')
elif arr_dtype in (th.float32, th.float):
return TVMType('float32')
elif arr_dtype in (th.float64, th.double):
return TVMType('float64')
elif arr_dtype in (th.int16, th.short):
return TVMType('int16')
elif arr_dtype in (th.int32, th.int):
return TVMType('int32')
elif arr_dtype in (th.int64, th.long):
return TVMType('int64')
elif arr_dtype == th.int8:
return TVMType('int8')
elif arr_dtype == th.uint8:
return TVMType('uint8')
else:
if arr_dtype not in _tvmtypes:
raise RuntimeError('Unsupported data type:', arr_dtype)
return _tvmtypes[arr_dtype]
def zerocopy_to_dlpack(arr):
"""Return a dlpack compatible array using zero copy."""
......@@ -130,7 +128,6 @@ def zerocopy_from_dlpack(dlpack_arr):
def zerocopy_to_numpy(arr):
"""Return a numpy array that shares the data."""
# TODO(minjie): zero copy
return arr.numpy()
def zerocopy_from_numpy(np_data):
......
......@@ -4,7 +4,7 @@ from __future__ import absolute_import
import numpy as np
from .base import ALL, is_all
from .frame import FrameRef
from .frame import FrameRef, Frame
from .graph import DGLGraph
from . import graph_index as gi
from . import backend as F
......@@ -31,14 +31,14 @@ class BatchedDGLGraph(DGLGraph):
batched_index = gi.disjoint_union([g._graph for g in graph_list])
# create batched node and edge frames
# NOTE: following code will materialize the columns of the input graphs.
batched_node_frame = FrameRef()
for gr in graph_list:
cols = {key : gr._node_frame[key] for key in node_attrs}
batched_node_frame.append(cols)
batched_edge_frame = FrameRef()
for gr in graph_list:
cols = {key : gr._edge_frame[key] for key in edge_attrs}
batched_edge_frame.append(cols)
cols = {key: F.pack([gr._node_frame[key] for gr in graph_list])
for key in node_attrs}
batched_node_frame = FrameRef(Frame(cols))
cols = {key: F.pack([gr._edge_frame[key] for gr in graph_list])
for key in edge_attrs}
batched_edge_frame = FrameRef(Frame(cols))
super(BatchedDGLGraph, self).__init__(
graph_data=batched_index,
node_frame=batched_node_frame,
......@@ -154,6 +154,14 @@ def unbatch(graph):
----------
graph : BatchedDGLGraph
The batched graph.
Notes
-----
Unbatching will partition each field tensor of the batched graph into
smaller partitions. This is usually wasteful.
For simpler tasks such as node/edge state aggregation by example,
try to use BatchedDGLGraph.readout().
"""
assert isinstance(graph, BatchedDGLGraph)
bsize = graph.batch_size
......
"""Columnar storage for DGLGraph."""
from __future__ import absolute_import
from collections import MutableMapping
from collections import MutableMapping, namedtuple
import numpy as np
from . import backend as F
......@@ -9,7 +9,8 @@ from .backend import Tensor
from .base import DGLError, dgl_warning
from . import utils
class Scheme(object):
class Scheme(namedtuple('Scheme', ['shape', 'dtype'])):
"""The column scheme.
Parameters
......@@ -19,22 +20,9 @@ class Scheme(object):
dtype : TVMType
The feature data type.
"""
def __init__(self, shape, dtype):
self.shape = shape
self.dtype = dtype
def __repr__(self):
return '{shape=%s, dtype=%s}' % (repr(self.shape), repr(self.dtype))
def __eq__(self, other):
return self.shape == other.shape and self.dtype == other.dtype
def __ne__(self, other):
return not self.__eq__(other)
pass
@staticmethod
def infer_scheme(tensor):
"""Infer the scheme of the given tensor."""
def infer_scheme(tensor):
return Scheme(tuple(F.shape(tensor)[1:]), F.get_tvmtype(tensor))
class Column(object):
......@@ -52,7 +40,7 @@ class Column(object):
"""
def __init__(self, data, scheme=None):
self.data = data
self.scheme = scheme if scheme else Scheme.infer_scheme(data)
self.scheme = scheme if scheme else infer_scheme(data)
def __len__(self):
"""The column length."""
......@@ -67,7 +55,7 @@ class Column(object):
Parameters
----------
idx : utils.Index
idx : slice or utils.Index
The index.
Returns
......@@ -75,6 +63,9 @@ class Column(object):
Tensor
The feature data
"""
if isinstance(idx, slice):
return self.data[idx]
else:
user_idx = idx.tousertensor(F.get_context(self.data))
return F.gather_row(self.data, user_idx)
......@@ -86,7 +77,7 @@ class Column(object):
Parameters
----------
idx : utils.Index
idx : utils.Index or slice
The index.
feats : Tensor
The new features.
......@@ -98,23 +89,34 @@ class Column(object):
Parameters
----------
idx : utils.Index
idx : utils.Index or slice
The index.
feats : Tensor
The new features.
inplace : bool
If true, use inplace write.
"""
feat_scheme = Scheme.infer_scheme(feats)
feat_scheme = infer_scheme(feats)
if feat_scheme != self.scheme:
raise DGLError("Cannot update column of scheme %s using feature of scheme %s."
% (feat_scheme, self.scheme))
user_idx = idx.tousertensor(F.get_context(self.data))
if isinstance(idx, utils.Index):
idx = idx.tousertensor(F.get_context(self.data))
if inplace:
# TODO(minjie): do not use [] operator directly
self.data[user_idx] = feats
self.data[idx] = feats
else:
self.data = F.scatter_row(self.data, user_idx, feats)
if isinstance(idx, slice):
# for contiguous indices pack is usually faster than scatter row
self.data = F.pack([
self.data[:idx.start],
feats,
self.data[idx.stop:],
])
else:
self.data = F.scatter_row(self.data, idx, feats)
def extend(self, feats, feat_scheme=None):
"""Extend the feature data.
......@@ -353,19 +355,23 @@ class FrameRef(MutableMapping):
frame : Frame, optional
The underlying frame. If not given, the reference will point to a
new empty frame.
index : iterable of int, optional
index : iterable, slice, or int, optional
The rows that are referenced in the underlying frame. If not given,
the whole frame is referenced. The index should be distinct (no
duplication is allowed).
Note that if a slice is given, the step must be None.
"""
def __init__(self, frame=None, index=None):
self._frame = frame if frame is not None else Frame()
if index is None:
# _index_data can be either a slice or an iterable
self._index_data = slice(0, self._frame.num_rows)
else:
# TODO(minjie): check no duplication
self._index_data = index
self._index = None
self._index_or_slice = None
@property
def schemes(self):
......@@ -387,10 +393,8 @@ class FrameRef(MutableMapping):
def num_rows(self):
"""Return the number of rows referred."""
if isinstance(self._index_data, slice):
# NOTE: we are assuming that the index is a slice ONLY IF
# index=None during construction.
# As such, start is always 0, and step is always 1.
return self._index_data.stop
# NOTE: we always assume that slice.step is None
return self._index_data.stop - self._index_data.start
else:
return len(self._index_data)
......@@ -417,11 +421,29 @@ class FrameRef(MutableMapping):
if self._index is None:
if self.is_contiguous():
self._index = utils.toindex(
F.arange(self._index_data.stop, dtype=F.int64))
F.arange(
self._index_data.start,
self._index_data.stop,
dtype=F.int64))
else:
self._index = utils.toindex(self._index_data)
return self._index
def index_or_slice(self):
"""Returns the index object or the slice
Returns
-------
utils.Index or slice
The index or slice
"""
if self._index_or_slice is None:
if self.is_contiguous():
self._index_or_slice = self._index_data
else:
self._index_or_slice = utils.toindex(self._index_data)
return self._index_or_slice
def __contains__(self, name):
"""Return whether the column name exists."""
return name in self._frame
......@@ -442,8 +464,8 @@ class FrameRef(MutableMapping):
"""Get data from the frame.
If the provided key is string, the corresponding column data will be returned.
If the provided key is an index, the corresponding rows will be selected. The
returned rows are saved in a lazy dictionary so only the real selection happens
If the provided key is an index or a slice, the corresponding rows will be selected.
The returned rows are saved in a lazy dictionary so only the real selection happens
when the explicit column name is provided.
Examples (using pytorch)
......@@ -457,7 +479,7 @@ class FrameRef(MutableMapping):
Parameters
----------
key : str or utils.Index
key : str or utils.Index or slice
The key.
Returns
......@@ -490,14 +512,14 @@ class FrameRef(MutableMapping):
if self.is_span_whole_column():
return col.data
else:
return col[self.index()]
return col[self.index_or_slice()]
def select_rows(self, query):
"""Return the rows given the query.
Parameters
----------
query : utils.Index
query : utils.Index or slice
The rows to be selected.
Returns
......@@ -505,8 +527,8 @@ class FrameRef(MutableMapping):
utils.LazyDict
The lazy dictionary from str to the selected data.
"""
rowids = self._getrowid(query)
return utils.LazyDict(lambda key: self._frame[key][rowids], keys=self.keys())
rows = self._getrows(query)
return utils.LazyDict(lambda key: self._frame[key][rows], keys=self.keys())
def __setitem__(self, key, val):
"""Update the data in the frame.
......@@ -561,15 +583,10 @@ class FrameRef(MutableMapping):
self._frame[name] = col
else:
if name not in self._frame:
feat_shape = F.shape(data)[1:]
feat_dtype = F.get_tvmtype(data)
ctx = F.get_context(data)
self._frame.add_column(name, Scheme(feat_shape, feat_dtype), ctx)
#raise DGLError('Cannot update column. Column "%s" does not exist.'
# ' Did you forget to init the column using `set_n_repr`'
# ' or `set_e_repr`?' % name)
self._frame.add_column(name, infer_scheme(data), ctx)
fcol = self._frame[name]
fcol.update(self.index(), data, inplace)
fcol.update(self.index_or_slice(), data, inplace)
def add_rows(self, num_rows):
"""Add blank rows.
......@@ -606,30 +623,28 @@ class FrameRef(MutableMapping):
Parameters
----------
query : utils.Index
query : utils.Index or slice
The rows to be updated.
data : dict-like
The row data.
inplace : bool
True if the update is performed inplacely.
"""
rowids = self._getrowid(query)
rows = self._getrows(query)
for key, col in data.items():
if key not in self:
# add new column
tmpref = FrameRef(self._frame, rowids)
tmpref = FrameRef(self._frame, rows)
tmpref.update_column(key, col, inplace)
#raise DGLError('Cannot update rows. Column "%s" does not exist.'
# ' Did you forget to init the column using `set_n_repr`'
# ' or `set_e_repr`?' % key)
else:
self._frame[key].update(rowids, col, inplace)
self._frame[key].update(rows, col, inplace)
def __delitem__(self, key):
"""Delete data in the frame.
If the provided key is a string, the corresponding column will be deleted.
If the provided key is an index object, the corresponding rows will be deleted.
If the provided key is an index object or a slice, the corresponding rows will
be deleted.
Please note that "deleted" rows are not really deleted, but simply removed
in the reference. As a result, if two FrameRefs point to the same Frame, deleting
......@@ -656,14 +671,17 @@ class FrameRef(MutableMapping):
Parameters
----------
query : utils.Index
query : utils.Index or slice
The rows to be deleted.
"""
if isinstance(query, slice):
query = range(query.start, query.stop)
else:
query = query.tolist()
if isinstance(self._index_data, slice):
self._index_data = list(range(self._index_data.start, self._index_data.stop))
arr = np.array(self._index_data, dtype=np.int32)
self._index_data = list(np.delete(arr, query))
self._index_data = range(self._index_data.start, self._index_data.stop)
self._index_data = list(np.delete(self._index_data, query))
self._clear_cache()
def append(self, other):
......@@ -682,8 +700,11 @@ class FrameRef(MutableMapping):
if span_whole:
self._index_data = slice(0, self._frame.num_rows)
elif contiguous:
if self._index_data.stop == old_nrows:
new_idx = slice(self._index_data.start, self._frame.num_rows)
else:
new_idx = list(range(self._index_data.start, self._index_data.stop))
new_idx += list(range(old_nrows, self._frame.num_rows))
new_idx.extend(range(old_nrows, self._frame.num_rows))
self._index_data = new_idx
self._clear_cache()
......@@ -695,26 +716,35 @@ class FrameRef(MutableMapping):
def is_contiguous(self):
"""Return whether this refers to a contiguous range of rows."""
# NOTE: this check could have false negatives and false positives
# (step other than 1)
# NOTE: this check could have false negatives
# NOTE: we always assume that slice.step is None
return isinstance(self._index_data, slice)
def is_span_whole_column(self):
"""Return whether this refers to all the rows."""
return self.is_contiguous() and self.num_rows == self._frame.num_rows
def _getrowid(self, query):
def _getrows(self, query):
"""Internal function to convert from the local row ids to the row ids of the frame."""
if self.is_contiguous():
start = self._index_data.start
if start == 0:
# shortcut for identical mapping
return query
elif isinstance(query, slice):
return slice(query.start + start, query.stop + start)
else:
query = query.tousertensor()
return utils.toindex(query + start)
else:
idxtensor = self.index().tousertensor()
return utils.toindex(F.gather_row(idxtensor, query.tousertensor()))
query = query.tousertensor()
return utils.toindex(F.gather_row(idxtensor, query))
def _clear_cache(self):
"""Internal function to clear the cached object."""
self._index_tensor = None
self._index = None
self._index_or_slice = None
def merge_frames(frames, indices, max_index, reduce_func):
"""Merge a list of frames.
......
......@@ -323,6 +323,23 @@ class DGLGraph(object):
else:
return eid.tousertensor()
def find_edges(self, eid):
"""Given the edge ids, return their source and destination node ids.
Parameters
----------
eid : list, tensor
The edge ids.
Returns
-------
tensor, tensor
The source and destination node IDs.
"""
eid = utils.toindex(u)
src, dst, _ = self._graph.find_edges(eid)
return src.tousertensor(), dst.tousertensor()
def in_edges(self, v):
"""Return the in edges of the node(s).
......
......@@ -36,7 +36,6 @@ def cpu(dev_id=0):
"""
return TVMContext(1, dev_id)
def gpu(dev_id=0):
"""Construct a CPU device
......
......@@ -236,6 +236,38 @@ def test_sharing():
f2_a1[0:2] = th.ones([2, D])
assert th.allclose(f2['a1'], f2_a1)
def test_slicing():
data = Frame(create_test_data(grad=True))
f1 = FrameRef(data, index=slice(1, 5))
f2 = FrameRef(data, index=slice(3, 8))
# test read
for k, v in f1.items():
assert th.allclose(data[k].data[1:5], v)
f2_a1 = f2['a1'].data
# test write
f1[Index(th.tensor([0, 1]))] = {
'a1': th.zeros([2, D]),
'a2': th.zeros([2, D]),
'a3': th.zeros([2, D]),
}
assert th.allclose(f2['a1'], f2_a1)
f1[Index(th.tensor([2, 3]))] = {
'a1': th.ones([2, D]),
'a2': th.ones([2, D]),
'a3': th.ones([2, D]),
}
f2_a1[0:2] = 1
assert th.allclose(f2['a1'], f2_a1)
f1[2:4] = {
'a1': th.zeros([2, D]),
'a2': th.zeros([2, D]),
'a3': th.zeros([2, D]),
}
f2_a1[0:2] = 0
assert th.allclose(f2['a1'], f2_a1)
if __name__ == '__main__':
test_create()
test_column1()
......@@ -246,3 +278,4 @@ if __name__ == '__main__':
test_row2()
test_row3()
test_sharing()
test_slicing()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment