"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "365a938884dfcd33b2c89b814d69a08acb97de0f"
Unverified Commit e3bac70b authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

Spmv partial (#43)

* partial spmv impl and test

* some fix for update edge
parent ee241699
......@@ -23,6 +23,9 @@ sparse_tensor = th.sparse.FloatTensor
sum = th.sum
max = th.max
def astype(a, ty):
return a.type(ty)
def asnumpy(a):
return a.cpu().numpy()
......@@ -50,16 +53,14 @@ def broadcast_to(x, to_array):
return x + th.zeros_like(to_array)
nonzero = th.nonzero
def eq_scalar(x, val):
return th.eq(x, float(val))
squeeze = th.squeeze
unsqueeze = th.unsqueeze
reshape = th.reshape
zeros = th.zeros
ones = th.ones
spmm = th.spmm
sort = th.sort
arange = th.arange
def to_context(x, ctx):
if ctx is None:
......
......@@ -436,24 +436,32 @@ class DGLGraph(DiGraph):
def _nonbatch_sendto(self, u, v, message_func):
f_msg = _get_message_func(message_func)
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
for uu, vv in utils.edge_iter(u, v):
ret = f_msg(_get_repr(self.nodes[uu]),
_get_repr(self.edges[uu, vv]))
self.edges[uu, vv][__MSG__] = ret
def _batch_sendto(self, u, v, message_func):
f_msg = _get_message_func(message_func)
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
u = utils.convert_to_id_tensor(u)
v = utils.convert_to_id_tensor(v)
eid = self.cached_graph.get_edge_id(u, v)
self.msg_graph.add_edges(u, v)
if len(u) != len(v) and len(u) == 1:
u = F.broadcast_to(u, v)
# call UDF
src_reprs = self.get_n_repr(u)
edge_reprs = self.get_e_repr_by_id(eid)
msgs = message_func(src_reprs, edge_reprs)
self.msg_graph.add_edges(u, v)
# call UDF
src_reprs = self.get_n_repr(u)
edge_reprs = self.get_e_repr()
msgs = message_func(src_reprs, edge_reprs)
else:
u = utils.convert_to_id_tensor(u)
v = utils.convert_to_id_tensor(v)
u, v = utils.edge_broadcasting(u, v)
eid = self.cached_graph.get_edge_id(u, v)
self.msg_graph.add_edges(u, v)
# call UDF
src_reprs = self.get_n_repr(u)
edge_reprs = self.get_e_repr_by_id(eid)
msgs = message_func(src_reprs, edge_reprs)
if isinstance(msgs, dict):
self._msg_frame.append(msgs)
else:
......@@ -490,6 +498,8 @@ class DGLGraph(DiGraph):
self._nonbatch_update_edge(u, v, edge_func)
def _nonbatch_update_edge(self, u, v, edge_func):
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
for uu, vv in utils.edge_iter(u, v):
ret = edge_func(_get_repr(self.nodes[uu]),
_get_repr(self.nodes[vv]),
......@@ -497,19 +507,25 @@ class DGLGraph(DiGraph):
_set_repr(self.edges[uu, vv], ret)
def _batch_update_edge(self, u, v, edge_func):
u = utils.convert_to_id_tensor(u)
v = utils.convert_to_id_tensor(v)
eid = self.cached_graph.get_edge_id(u, v)
if len(u) != len(v) and len(u) == 1:
u = F.broadcast_to(u, v)
elif len(u) != len(v) and len(v) == 1:
v = F.broadcast_to(v, u)
# call the UDF
src_reprs = self.get_n_repr(u)
dst_reprs = self.get_n_repr(v)
edge_reprs = self.get_e_repr_by_id(eid)
new_edge_reprs = edge_func(src_reprs, dst_reprs, edge_reprs)
self.set_e_repr_by_id(new_edge_reprs, eid)
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
# call the UDF
src_reprs = self.get_n_repr(u)
dst_reprs = self.get_n_repr(v)
edge_reprs = self.get_e_repr()
new_edge_reprs = edge_func(src_reprs, dst_reprs, edge_reprs)
self.set_e_repr(new_edge_reprs)
else:
u = utils.convert_to_id_tensor(u)
v = utils.convert_to_id_tensor(v)
u, v = utils.edge_broadcasting(u, v)
eid = self.cached_graph.get_edge_id(u, v)
# call the UDF
src_reprs = self.get_n_repr(u)
dst_reprs = self.get_n_repr(v)
edge_reprs = self.get_e_repr_by_id(eid)
new_edge_reprs = edge_func(src_reprs, dst_reprs, edge_reprs)
self.set_e_repr_by_id(new_edge_reprs, eid)
def recv(self,
u,
......@@ -566,6 +582,8 @@ class DGLGraph(DiGraph):
def _nonbatch_recv(self, u, reduce_func, update_func):
f_reduce = _get_reduce_func(reduce_func)
f_update = update_func
if is_all(u):
u = list(range(0, self.number_of_nodes()))
for i, uu in enumerate(utils.node_iter(u)):
# reduce phase
msgs_batch = [self.edges[vv, uu].pop(__MSG__)
......@@ -702,6 +720,8 @@ class DGLGraph(DiGraph):
message_func,
reduce_func,
update_func):
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
self._nonbatch_sendto(u, v, message_func)
dst = set()
for uu, vv in utils.edge_iter(u, v):
......@@ -714,26 +734,39 @@ class DGLGraph(DiGraph):
message_func,
reduce_func,
update_func):
if message_func == 'from_src' and reduce_func == 'sum' \
and is_all(u) and is_all(v):
# TODO(minjie): SPMV is only supported for updating all nodes right now.
adjmat = self.cached_graph.adjmat(self.context)
if is_all(u) and is_all(v):
self.update_all(message_func, reduce_func, update_func, True)
elif message_func == 'from_src' and reduce_func == 'sum':
# TODO(minjie): check the validity of edges u->v
u = utils.convert_to_id_tensor(u)
v = utils.convert_to_id_tensor(v)
# TODO(minjie): broadcasting is optional for many-one input.
u, v = utils.edge_broadcasting(u, v)
# relabel destination nodes.
new2old, old2new = utils.build_relabel_map(v)
# TODO(minjie): should not directly use []
new_v = old2new[v]
# create adj mat
idx = F.pack([F.unsqueeze(new_v, 0), F.unsqueeze(u, 0)])
dat = F.ones((len(u),))
n = self.number_of_nodes()
m = len(new2old)
adjmat = F.sparse_tensor(idx, dat, [m, n])
adjmat = F.to_context(adjmat, self.context)
# TODO(minjie): use lazy dict for reduced_msgs
reduced_msgs = {}
for key in self._node_frame.schemes:
col = self._node_frame[key]
reduced_msgs[key] = F.spmm(adjmat, col)
node_repr = self.get_n_repr()
if len(reduced_msgs) == 1 and __REPR__ in reduced_msgs:
reduced_msgs = reduced_msgs[__REPR__]
self.set_n_repr(update_func(node_repr, reduced_msgs))
node_repr = self.get_n_repr(new2old)
new_node_repr = update_func(node_repr, reduced_msgs)
self.set_n_repr(new_node_repr, new2old)
else:
if is_all(u) and is_all(v):
self._batch_sendto(u, v, message_func)
self._batch_recv(v, reduce_func, update_func)
else:
self._batch_sendto(u, v, message_func)
unique_v = F.unique(v)
self._batch_recv(unique_v, reduce_func, update_func)
self._batch_sendto(u, v, message_func)
unique_v = F.unique(v)
self._batch_recv(unique_v, reduce_func, update_func)
def update_to(self,
v,
......@@ -845,11 +878,24 @@ class DGLGraph(DiGraph):
assert reduce_func is not None
assert update_func is not None
if batchable:
self._batch_update_by_edge(ALL, ALL,
message_func, reduce_func, update_func)
if message_func == 'from_src' and reduce_func == 'sum':
# TODO(minjie): use lazy dict for reduced_msgs
adjmat = self.cached_graph.adjmat(self.context)
reduced_msgs = {}
for key in self._node_frame.schemes:
col = self._node_frame[key]
reduced_msgs[key] = F.spmm(adjmat, col)
if len(reduced_msgs) == 1 and __REPR__ in reduced_msgs:
reduced_msgs = reduced_msgs[__REPR__]
node_repr = self.get_n_repr()
self.set_n_repr(update_func(node_repr, reduced_msgs))
else:
self._batch_sendto(ALL, ALL, message_func)
self._batch_recv(ALL, reduce_func, update_func)
else:
u = [uu for uu, _ in self.edges]
v = [vv for _, vv in self.edges]
u, v = zip(*self.edges)
u = list(u)
v = list(v)
self._nonbatch_sendto(u, v, message_func)
self._nonbatch_recv(list(self.nodes()), reduce_func, update_func)
......
......@@ -6,17 +6,21 @@ import dgl.backend as F
from dgl.backend import Tensor, SparseTensor
def is_id_tensor(u):
"""Return whether the input is a supported id tensor."""
return isinstance(u, Tensor) and F.isinteger(u) and len(F.shape(u)) == 1
def is_id_container(u):
"""Return whether the input is a supported id container."""
return isinstance(u, list)
def node_iter(n):
"""Return an iterator that loops over the given nodes."""
n = convert_to_id_container(n)
for nn in n:
yield nn
def edge_iter(u, v):
"""Return an iterator that loops over the given edges."""
u = convert_to_id_container(u)
v = convert_to_id_container(v)
if len(u) == len(v):
......@@ -35,6 +39,7 @@ def edge_iter(u, v):
raise ValueError('Error edges:', u, v)
def convert_to_id_container(x):
"""Convert the input to id container."""
if is_id_container(x):
return x
elif is_id_tensor(x):
......@@ -47,6 +52,7 @@ def convert_to_id_container(x):
return None
def convert_to_id_tensor(x, ctx=None):
"""Convert the input to id tensor."""
if is_id_container(x):
ret = F.tensor(x, dtype=F.int64)
elif is_id_tensor(x):
......@@ -81,3 +87,38 @@ class LazyDict(Mapping):
def __len__(self):
return len(self._keys)
def build_relabel_map(x):
"""Relabel the input ids to continuous ids that starts from zero.
Parameters
----------
x : int, tensor or container
The input ids.
Returns
-------
new_to_old : tensor
The mapping from new id to old id.
old_to_new : tensor
The mapping from old id to new id. It is a vector of length MAX(x).
One can use advanced indexing to convert an old id tensor to a
new id tensor: new_id = old_to_new[old_id]
"""
x = convert_to_id_tensor(x)
unique_x, _ = F.sort(F.unique(x))
map_len = int(F.max(unique_x)) + 1
old_to_new = F.zeros(map_len, dtype=F.int64)
# TODO(minjie): should not directly use []
old_to_new[unique_x] = F.astype(F.arange(len(unique_x)), F.int64)
return unique_x, old_to_new
def edge_broadcasting(u, v):
"""Convert one-many and many-one edges to many-many."""
if len(u) != len(v) and len(u) == 1:
u = F.broadcast_to(u, v)
elif len(u) != len(v) and len(v) == 1:
v = F.broadcast_to(v, u)
else:
assert len(u) == len(v)
return u, v
......@@ -34,16 +34,22 @@ def generate_graph():
def test_spmv_specialize():
g = generate_graph()
g.register_message_func('from_src', batchable=True)
g.register_reduce_func('sum', batchable=True)
g.register_update_func(update_func, batchable=True)
# update all
v1 = g.get_n_repr()
g.update_all()
g.update_all('from_src', 'sum', update_func, batchable=True)
v2 = g.get_n_repr()
g.set_n_repr(v1)
g.register_message_func(message_func, batchable=True)
g.register_reduce_func(reduce_func, batchable=True)
g.update_all()
g.update_all(message_func, reduce_func, update_func, batchable=True)
v3 = g.get_n_repr()
check_eq(v2, v3)
# partial update
u = th.tensor([0, 0, 0, 3, 4, 9])
v = th.tensor([1, 2, 3, 9, 9, 0])
v1 = g.get_n_repr()
g.update_by_edge(u, v, 'from_src', 'sum', update_func, batchable=True)
v2 = g.get_n_repr()
g.set_n_repr(v1)
g.update_by_edge(u, v, message_func, reduce_func, update_func, batchable=True)
v3 = g.get_n_repr()
check_eq(v2, v3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment