Unverified Commit e3bac70b authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

Spmv partial (#43)

* partial spmv impl and test

* some fix for update edge
parent ee241699
...@@ -23,6 +23,9 @@ sparse_tensor = th.sparse.FloatTensor ...@@ -23,6 +23,9 @@ sparse_tensor = th.sparse.FloatTensor
sum = th.sum sum = th.sum
max = th.max max = th.max
def astype(a, ty):
return a.type(ty)
def asnumpy(a): def asnumpy(a):
return a.cpu().numpy() return a.cpu().numpy()
...@@ -50,16 +53,14 @@ def broadcast_to(x, to_array): ...@@ -50,16 +53,14 @@ def broadcast_to(x, to_array):
return x + th.zeros_like(to_array) return x + th.zeros_like(to_array)
nonzero = th.nonzero nonzero = th.nonzero
def eq_scalar(x, val):
return th.eq(x, float(val))
squeeze = th.squeeze squeeze = th.squeeze
unsqueeze = th.unsqueeze unsqueeze = th.unsqueeze
reshape = th.reshape reshape = th.reshape
zeros = th.zeros
ones = th.ones ones = th.ones
spmm = th.spmm spmm = th.spmm
sort = th.sort sort = th.sort
arange = th.arange
def to_context(x, ctx): def to_context(x, ctx):
if ctx is None: if ctx is None:
......
...@@ -436,24 +436,32 @@ class DGLGraph(DiGraph): ...@@ -436,24 +436,32 @@ class DGLGraph(DiGraph):
def _nonbatch_sendto(self, u, v, message_func): def _nonbatch_sendto(self, u, v, message_func):
f_msg = _get_message_func(message_func) f_msg = _get_message_func(message_func)
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
for uu, vv in utils.edge_iter(u, v): for uu, vv in utils.edge_iter(u, v):
ret = f_msg(_get_repr(self.nodes[uu]), ret = f_msg(_get_repr(self.nodes[uu]),
_get_repr(self.edges[uu, vv])) _get_repr(self.edges[uu, vv]))
self.edges[uu, vv][__MSG__] = ret self.edges[uu, vv][__MSG__] = ret
def _batch_sendto(self, u, v, message_func): def _batch_sendto(self, u, v, message_func):
f_msg = _get_message_func(message_func)
if is_all(u) and is_all(v): if is_all(u) and is_all(v):
u, v = self.cached_graph.edges() u, v = self.cached_graph.edges()
u = utils.convert_to_id_tensor(u) self.msg_graph.add_edges(u, v)
v = utils.convert_to_id_tensor(v) # call UDF
eid = self.cached_graph.get_edge_id(u, v) src_reprs = self.get_n_repr(u)
self.msg_graph.add_edges(u, v) edge_reprs = self.get_e_repr()
if len(u) != len(v) and len(u) == 1: msgs = message_func(src_reprs, edge_reprs)
u = F.broadcast_to(u, v) else:
# call UDF u = utils.convert_to_id_tensor(u)
src_reprs = self.get_n_repr(u) v = utils.convert_to_id_tensor(v)
edge_reprs = self.get_e_repr_by_id(eid) u, v = utils.edge_broadcasting(u, v)
msgs = message_func(src_reprs, edge_reprs) eid = self.cached_graph.get_edge_id(u, v)
self.msg_graph.add_edges(u, v)
# call UDF
src_reprs = self.get_n_repr(u)
edge_reprs = self.get_e_repr_by_id(eid)
msgs = message_func(src_reprs, edge_reprs)
if isinstance(msgs, dict): if isinstance(msgs, dict):
self._msg_frame.append(msgs) self._msg_frame.append(msgs)
else: else:
...@@ -490,6 +498,8 @@ class DGLGraph(DiGraph): ...@@ -490,6 +498,8 @@ class DGLGraph(DiGraph):
self._nonbatch_update_edge(u, v, edge_func) self._nonbatch_update_edge(u, v, edge_func)
def _nonbatch_update_edge(self, u, v, edge_func): def _nonbatch_update_edge(self, u, v, edge_func):
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
for uu, vv in utils.edge_iter(u, v): for uu, vv in utils.edge_iter(u, v):
ret = edge_func(_get_repr(self.nodes[uu]), ret = edge_func(_get_repr(self.nodes[uu]),
_get_repr(self.nodes[vv]), _get_repr(self.nodes[vv]),
...@@ -497,19 +507,25 @@ class DGLGraph(DiGraph): ...@@ -497,19 +507,25 @@ class DGLGraph(DiGraph):
_set_repr(self.edges[uu, vv], ret) _set_repr(self.edges[uu, vv], ret)
def _batch_update_edge(self, u, v, edge_func): def _batch_update_edge(self, u, v, edge_func):
u = utils.convert_to_id_tensor(u) if is_all(u) and is_all(v):
v = utils.convert_to_id_tensor(v) u, v = self.cached_graph.edges()
eid = self.cached_graph.get_edge_id(u, v) # call the UDF
if len(u) != len(v) and len(u) == 1: src_reprs = self.get_n_repr(u)
u = F.broadcast_to(u, v) dst_reprs = self.get_n_repr(v)
elif len(u) != len(v) and len(v) == 1: edge_reprs = self.get_e_repr()
v = F.broadcast_to(v, u) new_edge_reprs = edge_func(src_reprs, dst_reprs, edge_reprs)
# call the UDF self.set_e_repr(new_edge_reprs)
src_reprs = self.get_n_repr(u) else:
dst_reprs = self.get_n_repr(v) u = utils.convert_to_id_tensor(u)
edge_reprs = self.get_e_repr_by_id(eid) v = utils.convert_to_id_tensor(v)
new_edge_reprs = edge_func(src_reprs, dst_reprs, edge_reprs) u, v = utils.edge_broadcasting(u, v)
self.set_e_repr_by_id(new_edge_reprs, eid) eid = self.cached_graph.get_edge_id(u, v)
# call the UDF
src_reprs = self.get_n_repr(u)
dst_reprs = self.get_n_repr(v)
edge_reprs = self.get_e_repr_by_id(eid)
new_edge_reprs = edge_func(src_reprs, dst_reprs, edge_reprs)
self.set_e_repr_by_id(new_edge_reprs, eid)
def recv(self, def recv(self,
u, u,
...@@ -566,6 +582,8 @@ class DGLGraph(DiGraph): ...@@ -566,6 +582,8 @@ class DGLGraph(DiGraph):
def _nonbatch_recv(self, u, reduce_func, update_func): def _nonbatch_recv(self, u, reduce_func, update_func):
f_reduce = _get_reduce_func(reduce_func) f_reduce = _get_reduce_func(reduce_func)
f_update = update_func f_update = update_func
if is_all(u):
u = list(range(0, self.number_of_nodes()))
for i, uu in enumerate(utils.node_iter(u)): for i, uu in enumerate(utils.node_iter(u)):
# reduce phase # reduce phase
msgs_batch = [self.edges[vv, uu].pop(__MSG__) msgs_batch = [self.edges[vv, uu].pop(__MSG__)
...@@ -702,6 +720,8 @@ class DGLGraph(DiGraph): ...@@ -702,6 +720,8 @@ class DGLGraph(DiGraph):
message_func, message_func,
reduce_func, reduce_func,
update_func): update_func):
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
self._nonbatch_sendto(u, v, message_func) self._nonbatch_sendto(u, v, message_func)
dst = set() dst = set()
for uu, vv in utils.edge_iter(u, v): for uu, vv in utils.edge_iter(u, v):
...@@ -714,26 +734,39 @@ class DGLGraph(DiGraph): ...@@ -714,26 +734,39 @@ class DGLGraph(DiGraph):
message_func, message_func,
reduce_func, reduce_func,
update_func): update_func):
if message_func == 'from_src' and reduce_func == 'sum' \ if is_all(u) and is_all(v):
and is_all(u) and is_all(v): self.update_all(message_func, reduce_func, update_func, True)
# TODO(minjie): SPMV is only supported for updating all nodes right now. elif message_func == 'from_src' and reduce_func == 'sum':
adjmat = self.cached_graph.adjmat(self.context) # TODO(minjie): check the validity of edges u->v
u = utils.convert_to_id_tensor(u)
v = utils.convert_to_id_tensor(v)
# TODO(minjie): broadcasting is optional for many-one input.
u, v = utils.edge_broadcasting(u, v)
# relabel destination nodes.
new2old, old2new = utils.build_relabel_map(v)
# TODO(minjie): should not directly use []
new_v = old2new[v]
# create adj mat
idx = F.pack([F.unsqueeze(new_v, 0), F.unsqueeze(u, 0)])
dat = F.ones((len(u),))
n = self.number_of_nodes()
m = len(new2old)
adjmat = F.sparse_tensor(idx, dat, [m, n])
adjmat = F.to_context(adjmat, self.context)
# TODO(minjie): use lazy dict for reduced_msgs
reduced_msgs = {} reduced_msgs = {}
for key in self._node_frame.schemes: for key in self._node_frame.schemes:
col = self._node_frame[key] col = self._node_frame[key]
reduced_msgs[key] = F.spmm(adjmat, col) reduced_msgs[key] = F.spmm(adjmat, col)
node_repr = self.get_n_repr()
if len(reduced_msgs) == 1 and __REPR__ in reduced_msgs: if len(reduced_msgs) == 1 and __REPR__ in reduced_msgs:
reduced_msgs = reduced_msgs[__REPR__] reduced_msgs = reduced_msgs[__REPR__]
self.set_n_repr(update_func(node_repr, reduced_msgs)) node_repr = self.get_n_repr(new2old)
new_node_repr = update_func(node_repr, reduced_msgs)
self.set_n_repr(new_node_repr, new2old)
else: else:
if is_all(u) and is_all(v): self._batch_sendto(u, v, message_func)
self._batch_sendto(u, v, message_func) unique_v = F.unique(v)
self._batch_recv(v, reduce_func, update_func) self._batch_recv(unique_v, reduce_func, update_func)
else:
self._batch_sendto(u, v, message_func)
unique_v = F.unique(v)
self._batch_recv(unique_v, reduce_func, update_func)
def update_to(self, def update_to(self,
v, v,
...@@ -845,11 +878,24 @@ class DGLGraph(DiGraph): ...@@ -845,11 +878,24 @@ class DGLGraph(DiGraph):
assert reduce_func is not None assert reduce_func is not None
assert update_func is not None assert update_func is not None
if batchable: if batchable:
self._batch_update_by_edge(ALL, ALL, if message_func == 'from_src' and reduce_func == 'sum':
message_func, reduce_func, update_func) # TODO(minjie): use lazy dict for reduced_msgs
adjmat = self.cached_graph.adjmat(self.context)
reduced_msgs = {}
for key in self._node_frame.schemes:
col = self._node_frame[key]
reduced_msgs[key] = F.spmm(adjmat, col)
if len(reduced_msgs) == 1 and __REPR__ in reduced_msgs:
reduced_msgs = reduced_msgs[__REPR__]
node_repr = self.get_n_repr()
self.set_n_repr(update_func(node_repr, reduced_msgs))
else:
self._batch_sendto(ALL, ALL, message_func)
self._batch_recv(ALL, reduce_func, update_func)
else: else:
u = [uu for uu, _ in self.edges] u, v = zip(*self.edges)
v = [vv for _, vv in self.edges] u = list(u)
v = list(v)
self._nonbatch_sendto(u, v, message_func) self._nonbatch_sendto(u, v, message_func)
self._nonbatch_recv(list(self.nodes()), reduce_func, update_func) self._nonbatch_recv(list(self.nodes()), reduce_func, update_func)
......
...@@ -6,17 +6,21 @@ import dgl.backend as F ...@@ -6,17 +6,21 @@ import dgl.backend as F
from dgl.backend import Tensor, SparseTensor from dgl.backend import Tensor, SparseTensor
def is_id_tensor(u): def is_id_tensor(u):
"""Return whether the input is a supported id tensor."""
return isinstance(u, Tensor) and F.isinteger(u) and len(F.shape(u)) == 1 return isinstance(u, Tensor) and F.isinteger(u) and len(F.shape(u)) == 1
def is_id_container(u): def is_id_container(u):
"""Return whether the input is a supported id container."""
return isinstance(u, list) return isinstance(u, list)
def node_iter(n): def node_iter(n):
"""Return an iterator that loops over the given nodes."""
n = convert_to_id_container(n) n = convert_to_id_container(n)
for nn in n: for nn in n:
yield nn yield nn
def edge_iter(u, v): def edge_iter(u, v):
"""Return an iterator that loops over the given edges."""
u = convert_to_id_container(u) u = convert_to_id_container(u)
v = convert_to_id_container(v) v = convert_to_id_container(v)
if len(u) == len(v): if len(u) == len(v):
...@@ -35,6 +39,7 @@ def edge_iter(u, v): ...@@ -35,6 +39,7 @@ def edge_iter(u, v):
raise ValueError('Error edges:', u, v) raise ValueError('Error edges:', u, v)
def convert_to_id_container(x): def convert_to_id_container(x):
"""Convert the input to id container."""
if is_id_container(x): if is_id_container(x):
return x return x
elif is_id_tensor(x): elif is_id_tensor(x):
...@@ -47,6 +52,7 @@ def convert_to_id_container(x): ...@@ -47,6 +52,7 @@ def convert_to_id_container(x):
return None return None
def convert_to_id_tensor(x, ctx=None): def convert_to_id_tensor(x, ctx=None):
"""Convert the input to id tensor."""
if is_id_container(x): if is_id_container(x):
ret = F.tensor(x, dtype=F.int64) ret = F.tensor(x, dtype=F.int64)
elif is_id_tensor(x): elif is_id_tensor(x):
...@@ -81,3 +87,38 @@ class LazyDict(Mapping): ...@@ -81,3 +87,38 @@ class LazyDict(Mapping):
def __len__(self): def __len__(self):
return len(self._keys) return len(self._keys)
def build_relabel_map(x):
"""Relabel the input ids to continuous ids that starts from zero.
Parameters
----------
x : int, tensor or container
The input ids.
Returns
-------
new_to_old : tensor
The mapping from new id to old id.
old_to_new : tensor
The mapping from old id to new id. It is a vector of length MAX(x).
One can use advanced indexing to convert an old id tensor to a
new id tensor: new_id = old_to_new[old_id]
"""
x = convert_to_id_tensor(x)
unique_x, _ = F.sort(F.unique(x))
map_len = int(F.max(unique_x)) + 1
old_to_new = F.zeros(map_len, dtype=F.int64)
# TODO(minjie): should not directly use []
old_to_new[unique_x] = F.astype(F.arange(len(unique_x)), F.int64)
return unique_x, old_to_new
def edge_broadcasting(u, v):
"""Convert one-many and many-one edges to many-many."""
if len(u) != len(v) and len(u) == 1:
u = F.broadcast_to(u, v)
elif len(u) != len(v) and len(v) == 1:
v = F.broadcast_to(v, u)
else:
assert len(u) == len(v)
return u, v
...@@ -34,16 +34,22 @@ def generate_graph(): ...@@ -34,16 +34,22 @@ def generate_graph():
def test_spmv_specialize(): def test_spmv_specialize():
g = generate_graph() g = generate_graph()
g.register_message_func('from_src', batchable=True) # update all
g.register_reduce_func('sum', batchable=True)
g.register_update_func(update_func, batchable=True)
v1 = g.get_n_repr() v1 = g.get_n_repr()
g.update_all() g.update_all('from_src', 'sum', update_func, batchable=True)
v2 = g.get_n_repr() v2 = g.get_n_repr()
g.set_n_repr(v1) g.set_n_repr(v1)
g.register_message_func(message_func, batchable=True) g.update_all(message_func, reduce_func, update_func, batchable=True)
g.register_reduce_func(reduce_func, batchable=True) v3 = g.get_n_repr()
g.update_all() check_eq(v2, v3)
# partial update
u = th.tensor([0, 0, 0, 3, 4, 9])
v = th.tensor([1, 2, 3, 9, 9, 0])
v1 = g.get_n_repr()
g.update_by_edge(u, v, 'from_src', 'sum', update_func, batchable=True)
v2 = g.get_n_repr()
g.set_n_repr(v1)
g.update_by_edge(u, v, message_func, reduce_func, update_func, batchable=True)
v3 = g.get_n_repr() v3 = g.get_n_repr()
check_eq(v2, v3) check_eq(v2, v3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment