Unverified Commit 524e656d authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Bugfix] enable SPMV for multi-dimension node feature (#178)

* Support multi-dimension node feature in SPMV

* fix as requested
parent b1eeb934
......@@ -32,15 +32,6 @@ class MessageFunction(BuiltinFunction):
raise NotImplementedError
def _is_spmv_supported_node_feat(g, field):
"""Return whether the node feature shape supports SPMV optimization.
Only scalar and vector features are supported currently.
"""
feat = g.get_n_repr()[field]
shape = F.shape(feat)
return len(shape) == 1 or len(shape) == 2
def _is_spmv_supported_edge_feat(g, field):
"""Return whether the edge feature shape supports SPMV optimization.
......@@ -59,8 +50,7 @@ class SrcMulEdgeMessageFunction(MessageFunction):
self.out_field = out_field
def is_spmv_supported(self, g):
return _is_spmv_supported_node_feat(g, self.src_field) \
and _is_spmv_supported_edge_feat(g, self.edge_field)
return _is_spmv_supported_edge_feat(g, self.edge_field)
def __call__(self, edges):
src_data = edges.src[self.src_field]
......@@ -87,7 +77,7 @@ class CopySrcMessageFunction(MessageFunction):
self.out_field = out_field
def is_spmv_supported(self, g):
return _is_spmv_supported_node_feat(g, self.src_field)
return True
def __call__(self, edges):
return {self.out_field : edges.src[self.src_field]}
......@@ -106,7 +96,7 @@ class CopyEdgeMessageFunction(MessageFunction):
self.out_field = out_field
def is_spmv_supported(self, g):
# TODO: support this with g-spmv
# TODO: support this with e2v spmv
return False
# return _is_spmv_supported_edge_feat(g, self.edge_field)
......
from __future__ import absolute_import
from abc import abstractmethod
import functools
import operator
from ...base import DGLError
from ... import backend as F
......@@ -250,6 +252,16 @@ class SPMVExecutor(Executor):
B = F.unsqueeze(B, 1)
C = F.spmm(spA, B)
C = F.squeeze(C, 1)
elif F.ndim(B) > 2:
# Flatten the dim 1:~
B_shape = F.shape(B)
feat_shape = B_shape[1:]
tmp_B_shape = (B_shape[0],
functools.reduce(operator.mul, feat_shape, 1))
B = F.reshape(B, tmp_B_shape)
C = F.spmm(spA, B)
C_shape = (F.shape(C)[0],) + feat_shape
C = F.reshape(C, C_shape)
else:
C = F.spmm(spA, B)
self.ret.data = C
......@@ -301,6 +313,16 @@ class SPMVWithDataExecutor(Executor):
B = F.unsqueeze(B, 1)
C = F.spmm(spA, B)
C = F.squeeze(C, 1)
elif F.ndim(B) > 2:
# Flatten the dim 1:~
B_shape = F.shape(B)
feat_shape = B_shape[1:]
tmp_B_shape = (B_shape[0],
functools.reduce(operator.mul, feat_shape, 1))
B = F.reshape(B, tmp_B_shape)
C = F.spmm(spA, B)
C_shape = (F.shape(C)[0],) + feat_shape
C = F.reshape(C, C_shape)
else:
C = F.spmm(spA, B)
self.ret.data = C
......
import torch as th
import numpy as np
import scipy.sparse as sp
import dgl
import dgl.function as fn
import utils as U
......@@ -513,6 +515,60 @@ def test_pull_multi_fallback():
nodes = [0, 1, 2, 9]
_pull_nodes(nodes)
def test_spmv_3d_feat():
def src_mul_edge_udf(edges):
return {'sum': edges.src['h'] * edges.data['h'].unsqueeze(1).unsqueeze(1)}
def sum_udf(nodes):
return {'h': nodes.mailbox['sum'].sum(1)}
n = 100
p = 0.1
a = sp.random(n, n, p, data_rvs=lambda n: np.ones(n))
g = dgl.DGLGraph(a)
m = g.number_of_edges()
# test#1: v2v with adj data
h = th.randn((n, 5, 5))
e = th.randn((m,))
g.ndata['h'] = h
g.edata['h'] = e
g.update_all(message_func=fn.src_mul_edge('h', 'h', 'sum'), reduce_func=fn.sum('sum', 'h')) # 1
ans = g.ndata['h']
g.ndata['h'] = h
g.edata['h'] = e
g.update_all(message_func=src_mul_edge_udf, reduce_func=fn.sum('sum', 'h')) # 2
assert U.allclose(g.ndata['h'], ans)
g.ndata['h'] = h
g.edata['h'] = e
g.update_all(message_func=src_mul_edge_udf, reduce_func=sum_udf) # 3
assert U.allclose(g.ndata['h'], ans)
# test#2: e2v
def src_mul_edge_udf(edges):
return {'sum': edges.src['h'] * edges.data['h']}
h = th.randn((n, 5, 5))
e = th.randn((m, 5, 5))
g.ndata['h'] = h
g.edata['h'] = e
g.update_all(message_func=fn.src_mul_edge('h', 'h', 'sum'), reduce_func=fn.sum('sum', 'h')) # 1
ans = g.ndata['h']
g.ndata['h'] = h
g.edata['h'] = e
g.update_all(message_func=src_mul_edge_udf, reduce_func=fn.sum('sum', 'h')) # 2
assert U.allclose(g.ndata['h'], ans)
g.ndata['h'] = h
g.edata['h'] = e
g.update_all(message_func=src_mul_edge_udf, reduce_func=sum_udf) # 3
assert U.allclose(g.ndata['h'], ans)
if __name__ == '__main__':
test_v2v_update_all()
test_v2v_snr()
......@@ -524,3 +580,4 @@ if __name__ == '__main__':
test_e2v_recv_multi_fn()
test_update_all_multi_fallback()
test_pull_multi_fallback()
test_spmv_3d_feat()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment