Unverified Commit 7241a9c0 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Backend] backend interface (#109)

* backend interface

* small fix

* more comments to the data type dict

* WIP

* convert_to and narrow

* WIP

* pytorch and numpy backend; WIP on mxnet backend

* mxnet backend

* narrow

* Fix all usages

* fix for mx

* fix for mx

* fix mx

* fix mx

* fix mx

* fix mx

* fix mx

* fix mx

* fix mx

* revert jenkins

* add sparse_matrix api

* sparse matrix api

* some fixme

* Fix as requested
parent b420a5b5
...@@ -6,7 +6,6 @@ from functools import wraps ...@@ -6,7 +6,6 @@ from functools import wraps
import numpy as np import numpy as np
from . import backend as F from . import backend as F
from .backend import Tensor, SparseTensor
from . import ndarray as nd from . import ndarray as nd
class Index(object): class Index(object):
...@@ -19,7 +18,7 @@ class Index(object): ...@@ -19,7 +18,7 @@ class Index(object):
def _dispatch(self, data): def _dispatch(self, data):
"""Store data based on its type.""" """Store data based on its type."""
if isinstance(data, Tensor): if F.is_tensor(data):
if not (F.dtype(data) == F.int64): if not (F.dtype(data) == F.int64):
raise ValueError('Index data must be an int64 vector, but got: %s' % str(data)) raise ValueError('Index data must be an int64 vector, but got: %s' % str(data))
if len(F.shape(data)) > 1: if len(F.shape(data)) > 1:
...@@ -28,7 +27,7 @@ class Index(object): ...@@ -28,7 +27,7 @@ class Index(object):
# a tensor of one int # a tensor of one int
self._dispatch(int(data)) self._dispatch(int(data))
else: else:
self._user_tensor_data[F.get_context(data)] = data self._user_tensor_data[F.context(data)] = data
elif isinstance(data, nd.NDArray): elif isinstance(data, nd.NDArray):
if not (data.dtype == 'int64' and len(data.shape) == 1): if not (data.dtype == 'int64' and len(data.shape) == 1):
raise ValueError('Index data must be 1D int64 vector, but got: %s' % str(data)) raise ValueError('Index data must be 1D int64 vector, but got: %s' % str(data))
...@@ -41,7 +40,7 @@ class Index(object): ...@@ -41,7 +40,7 @@ class Index(object):
self._list_data = np.array(data).astype(np.int64) self._list_data = np.array(data).astype(np.int64)
except: except:
raise ValueError('Error index data: %s' % str(data)) raise ValueError('Error index data: %s' % str(data))
self._user_tensor_data[nd.cpu()] = F.zerocopy_from_numpy(self._list_data) self._user_tensor_data[F.cpu()] = F.zerocopy_from_numpy(self._list_data)
def tolist(self): def tolist(self):
"""Convert to a python-list compatible object.""" """Convert to a python-list compatible object."""
...@@ -56,15 +55,15 @@ class Index(object): ...@@ -56,15 +55,15 @@ class Index(object):
def tousertensor(self, ctx=None): def tousertensor(self, ctx=None):
"""Convert to user tensor (defined in `backend`).""" """Convert to user tensor (defined in `backend`)."""
if ctx is None: if ctx is None:
ctx = nd.cpu() ctx = F.cpu()
if len(self._user_tensor_data) == 0: if len(self._user_tensor_data) == 0:
# zero copy from dgl tensor # zero copy from dgl tensor
dl = self._dgl_tensor_data.to_dlpack() dl = self._dgl_tensor_data.to_dlpack()
self._user_tensor_data[nd.cpu()] = F.zerocopy_from_dlpack(dl) self._user_tensor_data[F.cpu()] = F.zerocopy_from_dlpack(dl)
if ctx not in self._user_tensor_data: if ctx not in self._user_tensor_data:
# copy from cpu to another device # copy from cpu to another device
data = next(iter(self._user_tensor_data.values())) data = next(iter(self._user_tensor_data.values()))
self._user_tensor_data[ctx] = F.to_context(data, ctx) self._user_tensor_data[ctx] = F.copy_to(data, ctx)
return self._user_tensor_data[ctx] return self._user_tensor_data[ctx]
def todgltensor(self): def todgltensor(self):
...@@ -147,9 +146,9 @@ def edge_broadcasting(u, v): ...@@ -147,9 +146,9 @@ def edge_broadcasting(u, v):
The dst id(s) after broadcasting The dst id(s) after broadcasting
""" """
if len(u) != len(v) and len(u) == 1: if len(u) != len(v) and len(u) == 1:
u = toindex(F.broadcast_to(u.tousertensor(), v.tousertensor())) u = toindex(F.full_1d(len(v), u[0]))
elif len(u) != len(v) and len(v) == 1: elif len(u) != len(v) and len(v) == 1:
v = toindex(F.broadcast_to(v.tousertensor(), u.tousertensor())) v = toindex(F.full_1d(len(u), v[0]))
else: else:
assert len(u) == len(v) assert len(u) == len(v)
return u, v return u, v
...@@ -240,11 +239,10 @@ def build_relabel_map(x): ...@@ -240,11 +239,10 @@ def build_relabel_map(x):
new id tensor: new_id = old_to_new[old_id] new id tensor: new_id = old_to_new[old_id]
""" """
x = x.tousertensor() x = x.tousertensor()
unique_x, _ = F.sort(F.unique(x)) unique_x, _ = F.sort_1d(F.unique(x))
map_len = int(F.max(unique_x)) + 1 map_len = int(F.max(unique_x, dim=0)) + 1
old_to_new = F.zeros(map_len, dtype=F.int64) old_to_new = F.zeros(map_len, dtype=F.int64)
# TODO(minjie): should not directly use [] F.scatter_row_inplace(old_to_new, unique_x, F.arange(0, len(unique_x)))
old_to_new[unique_x] = F.astype(F.arange(len(unique_x)), F.int64)
return unique_x, old_to_new return unique_x, old_to_new
def build_relabel_dict(x): def build_relabel_dict(x):
...@@ -323,17 +321,6 @@ def cached_member(func): ...@@ -323,17 +321,6 @@ def cached_member(func):
def is_dict_like(obj): def is_dict_like(obj):
return isinstance(obj, Mapping) return isinstance(obj, Mapping)
def pack2(a, b):
if a is None:
return b
elif b is None:
return a
else:
if isinstance(a, dict):
return {k: F.pack([a[k], b[k]]) for k in a}
else:
return F.pack([a, b])
def reorder(dict_like, index): def reorder(dict_like, index):
"""Reorder each column in the dict according to the index. """Reorder each column in the dict according to the index.
...@@ -346,7 +333,7 @@ def reorder(dict_like, index): ...@@ -346,7 +333,7 @@ def reorder(dict_like, index):
""" """
new_dict = {} new_dict = {}
for key, val in dict_like.items(): for key, val in dict_like.items():
idx_ctx = index.tousertensor(F.get_context(val)) idx_ctx = index.tousertensor(F.context(val))
new_dict[key] = F.gather_row(val, idx_ctx) new_dict[key] = F.gather_row(val, idx_ctx)
return new_dict return new_dict
......
...@@ -45,7 +45,7 @@ class NodeView(object): ...@@ -45,7 +45,7 @@ class NodeView(object):
def __call__(self): def __call__(self):
"""Return the nodes.""" """Return the nodes."""
return F.arange(0, len(self), dtype=F.int64) return F.arange(0, len(self))
class NodeDataView(MutableMapping): class NodeDataView(MutableMapping):
__slot__ = ['_graph', '_nodes'] __slot__ = ['_graph', '_nodes']
......
...@@ -5,21 +5,8 @@ import numpy as np ...@@ -5,21 +5,8 @@ import numpy as np
import scipy as sp import scipy as sp
from dgl.graph import GraphIndex, create_graph_index from dgl.graph import GraphIndex, create_graph_index
from dgl.graph_index import map_to_subgraph_nid from dgl.graph_index import map_to_subgraph_nid
import dgl.backend as F
from dgl import utils from dgl import utils
def generate_graph():
g = create_graph_index()
g.add_nodes(10) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
ig = create_graph_index(g.to_networkx(), readonly=True)
return g, ig
def generate_rand_graph(n): def generate_rand_graph(n):
arr = (sp.sparse.random(n, n, density=0.1, format='coo') != 0).astype(np.int64) arr = (sp.sparse.random(n, n, density=0.1, format='coo') != 0).astype(np.int64)
g = create_graph_index(arr) g = create_graph_index(arr)
...@@ -27,9 +14,8 @@ def generate_rand_graph(n): ...@@ -27,9 +14,8 @@ def generate_rand_graph(n):
return g, ig return g, ig
def check_graph_equal(g1, g2): def check_graph_equal(g1, g2):
ctx = F.get_context(mx.nd.array([1])) adj1 = g1.adjacency_matrix().get(mx.cpu()) != 0
adj1 = g1.adjacency_matrix().get(ctx) != 0 adj2 = g2.adjacency_matrix().get(mx.cpu()) != 0
adj2 = g2.adjacency_matrix().get(ctx) != 0
assert mx.nd.sum(adj1 - adj2).asnumpy() == 0 assert mx.nd.sum(adj1 - adj2).asnumpy() == 0
def test_graph_gen(): def test_graph_gen():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment