Unverified Commit 524e656d authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Bugfix] enable SPMV for multi-dimension node feature (#178)

* Support multi-dimension node feature in SPMV

* fix as requested
parent b1eeb934
...@@ -32,15 +32,6 @@ class MessageFunction(BuiltinFunction): ...@@ -32,15 +32,6 @@ class MessageFunction(BuiltinFunction):
raise NotImplementedError raise NotImplementedError
def _is_spmv_supported_node_feat(g, field):
"""Return whether the node feature shape supports SPMV optimization.
Only scalar and vector features are supported currently.
"""
feat = g.get_n_repr()[field]
shape = F.shape(feat)
return len(shape) == 1 or len(shape) == 2
def _is_spmv_supported_edge_feat(g, field): def _is_spmv_supported_edge_feat(g, field):
"""Return whether the edge feature shape supports SPMV optimization. """Return whether the edge feature shape supports SPMV optimization.
...@@ -59,8 +50,7 @@ class SrcMulEdgeMessageFunction(MessageFunction): ...@@ -59,8 +50,7 @@ class SrcMulEdgeMessageFunction(MessageFunction):
self.out_field = out_field self.out_field = out_field
def is_spmv_supported(self, g): def is_spmv_supported(self, g):
return _is_spmv_supported_node_feat(g, self.src_field) \ return _is_spmv_supported_edge_feat(g, self.edge_field)
and _is_spmv_supported_edge_feat(g, self.edge_field)
def __call__(self, edges): def __call__(self, edges):
src_data = edges.src[self.src_field] src_data = edges.src[self.src_field]
...@@ -87,7 +77,7 @@ class CopySrcMessageFunction(MessageFunction): ...@@ -87,7 +77,7 @@ class CopySrcMessageFunction(MessageFunction):
self.out_field = out_field self.out_field = out_field
def is_spmv_supported(self, g): def is_spmv_supported(self, g):
return _is_spmv_supported_node_feat(g, self.src_field) return True
def __call__(self, edges): def __call__(self, edges):
return {self.out_field : edges.src[self.src_field]} return {self.out_field : edges.src[self.src_field]}
...@@ -106,7 +96,7 @@ class CopyEdgeMessageFunction(MessageFunction): ...@@ -106,7 +96,7 @@ class CopyEdgeMessageFunction(MessageFunction):
self.out_field = out_field self.out_field = out_field
def is_spmv_supported(self, g): def is_spmv_supported(self, g):
# TODO: support this with g-spmv # TODO: support this with e2v spmv
return False return False
# return _is_spmv_supported_edge_feat(g, self.edge_field) # return _is_spmv_supported_edge_feat(g, self.edge_field)
......
from __future__ import absolute_import from __future__ import absolute_import
from abc import abstractmethod from abc import abstractmethod
import functools
import operator
from ...base import DGLError from ...base import DGLError
from ... import backend as F from ... import backend as F
...@@ -250,6 +252,16 @@ class SPMVExecutor(Executor): ...@@ -250,6 +252,16 @@ class SPMVExecutor(Executor):
B = F.unsqueeze(B, 1) B = F.unsqueeze(B, 1)
C = F.spmm(spA, B) C = F.spmm(spA, B)
C = F.squeeze(C, 1) C = F.squeeze(C, 1)
elif F.ndim(B) > 2:
# Flatten the dim 1:~
B_shape = F.shape(B)
feat_shape = B_shape[1:]
tmp_B_shape = (B_shape[0],
functools.reduce(operator.mul, feat_shape, 1))
B = F.reshape(B, tmp_B_shape)
C = F.spmm(spA, B)
C_shape = (F.shape(C)[0],) + feat_shape
C = F.reshape(C, C_shape)
else: else:
C = F.spmm(spA, B) C = F.spmm(spA, B)
self.ret.data = C self.ret.data = C
...@@ -301,6 +313,16 @@ class SPMVWithDataExecutor(Executor): ...@@ -301,6 +313,16 @@ class SPMVWithDataExecutor(Executor):
B = F.unsqueeze(B, 1) B = F.unsqueeze(B, 1)
C = F.spmm(spA, B) C = F.spmm(spA, B)
C = F.squeeze(C, 1) C = F.squeeze(C, 1)
elif F.ndim(B) > 2:
# Flatten the dim 1:~
B_shape = F.shape(B)
feat_shape = B_shape[1:]
tmp_B_shape = (B_shape[0],
functools.reduce(operator.mul, feat_shape, 1))
B = F.reshape(B, tmp_B_shape)
C = F.spmm(spA, B)
C_shape = (F.shape(C)[0],) + feat_shape
C = F.reshape(C, C_shape)
else: else:
C = F.spmm(spA, B) C = F.spmm(spA, B)
self.ret.data = C self.ret.data = C
......
import torch as th import torch as th
import numpy as np
import scipy.sparse as sp
import dgl import dgl
import dgl.function as fn import dgl.function as fn
import utils as U import utils as U
...@@ -513,6 +515,60 @@ def test_pull_multi_fallback(): ...@@ -513,6 +515,60 @@ def test_pull_multi_fallback():
nodes = [0, 1, 2, 9] nodes = [0, 1, 2, 9]
_pull_nodes(nodes) _pull_nodes(nodes)
def test_spmv_3d_feat():
def src_mul_edge_udf(edges):
return {'sum': edges.src['h'] * edges.data['h'].unsqueeze(1).unsqueeze(1)}
def sum_udf(nodes):
return {'h': nodes.mailbox['sum'].sum(1)}
n = 100
p = 0.1
a = sp.random(n, n, p, data_rvs=lambda n: np.ones(n))
g = dgl.DGLGraph(a)
m = g.number_of_edges()
# test#1: v2v with adj data
h = th.randn((n, 5, 5))
e = th.randn((m,))
g.ndata['h'] = h
g.edata['h'] = e
g.update_all(message_func=fn.src_mul_edge('h', 'h', 'sum'), reduce_func=fn.sum('sum', 'h')) # 1
ans = g.ndata['h']
g.ndata['h'] = h
g.edata['h'] = e
g.update_all(message_func=src_mul_edge_udf, reduce_func=fn.sum('sum', 'h')) # 2
assert U.allclose(g.ndata['h'], ans)
g.ndata['h'] = h
g.edata['h'] = e
g.update_all(message_func=src_mul_edge_udf, reduce_func=sum_udf) # 3
assert U.allclose(g.ndata['h'], ans)
# test#2: e2v
def src_mul_edge_udf(edges):
return {'sum': edges.src['h'] * edges.data['h']}
h = th.randn((n, 5, 5))
e = th.randn((m, 5, 5))
g.ndata['h'] = h
g.edata['h'] = e
g.update_all(message_func=fn.src_mul_edge('h', 'h', 'sum'), reduce_func=fn.sum('sum', 'h')) # 1
ans = g.ndata['h']
g.ndata['h'] = h
g.edata['h'] = e
g.update_all(message_func=src_mul_edge_udf, reduce_func=fn.sum('sum', 'h')) # 2
assert U.allclose(g.ndata['h'], ans)
g.ndata['h'] = h
g.edata['h'] = e
g.update_all(message_func=src_mul_edge_udf, reduce_func=sum_udf) # 3
assert U.allclose(g.ndata['h'], ans)
if __name__ == '__main__': if __name__ == '__main__':
test_v2v_update_all() test_v2v_update_all()
test_v2v_snr() test_v2v_snr()
...@@ -524,3 +580,4 @@ if __name__ == '__main__': ...@@ -524,3 +580,4 @@ if __name__ == '__main__':
test_e2v_recv_multi_fn() test_e2v_recv_multi_fn()
test_update_all_multi_fallback() test_update_all_multi_fallback()
test_pull_multi_fallback() test_pull_multi_fallback()
test_spmv_3d_feat()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment