Unverified Commit 61f007c4 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[NN] Add low memory support for nn.RelGraphConv (#1631)



* Add low memory support for relgraphconv
y

* lint

* Add tf support

* Fix lint

* Fix mxnet

* Fix mxnet

* Fix

* minor fix

* Fix
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
parent 04522a76
...@@ -56,6 +56,9 @@ class RelGraphConv(gluon.Block): ...@@ -56,6 +56,9 @@ class RelGraphConv(gluon.Block):
Activation function. Default: None Activation function. Default: None
self_loop : bool, optional self_loop : bool, optional
True to include self loop message. Default: False True to include self loop message. Default: False
low_mem : bool, optional
Use low-memory implementation. MXNet currently does not support this.
Default: False.
dropout : float, optional dropout : float, optional
Dropout rate. Default: 0.0 Dropout rate. Default: 0.0
""" """
...@@ -68,6 +71,7 @@ class RelGraphConv(gluon.Block): ...@@ -68,6 +71,7 @@ class RelGraphConv(gluon.Block):
bias=True, bias=True,
activation=None, activation=None,
self_loop=False, self_loop=False,
low_mem=False,
dropout=0.0): dropout=0.0):
super(RelGraphConv, self).__init__() super(RelGraphConv, self).__init__()
self.in_feat = in_feat self.in_feat = in_feat
...@@ -81,6 +85,8 @@ class RelGraphConv(gluon.Block): ...@@ -81,6 +85,8 @@ class RelGraphConv(gluon.Block):
self.activation = activation self.activation = activation
self.self_loop = self_loop self.self_loop = self_loop
assert low_mem is False, 'MXNet currently does not support low-memory implementation.'
if regularizer == "basis": if regularizer == "basis":
# add basis weights # add basis weights
self.weight = self.params.get( self.weight = self.params.get(
......
...@@ -53,6 +53,10 @@ class RelGraphConv(nn.Module): ...@@ -53,6 +53,10 @@ class RelGraphConv(nn.Module):
Activation function. Default: None Activation function. Default: None
self_loop : bool, optional self_loop : bool, optional
True to include self loop message. Default: False True to include self loop message. Default: False
low_mem : bool, optional
True to use low memory implementation of relation message passing function. Default: False
This option trade speed with memory consumption, and will slowdown the forward/backward.
Turn it on when you encounter OOM problem during training or evaluation.
dropout : float, optional dropout : float, optional
Dropout rate. Default: 0.0 Dropout rate. Default: 0.0
""" """
...@@ -65,6 +69,7 @@ class RelGraphConv(nn.Module): ...@@ -65,6 +69,7 @@ class RelGraphConv(nn.Module):
bias=True, bias=True,
activation=None, activation=None,
self_loop=False, self_loop=False,
low_mem=False,
dropout=0.0): dropout=0.0):
super(RelGraphConv, self).__init__() super(RelGraphConv, self).__init__()
self.in_feat = in_feat self.in_feat = in_feat
...@@ -77,6 +82,7 @@ class RelGraphConv(nn.Module): ...@@ -77,6 +82,7 @@ class RelGraphConv(nn.Module):
self.bias = bias self.bias = bias
self.activation = activation self.activation = activation
self.self_loop = self_loop self.self_loop = self_loop
self.low_mem = low_mem
if regularizer == "basis": if regularizer == "basis":
# add basis weights # add basis weights
...@@ -133,7 +139,22 @@ class RelGraphConv(nn.Module): ...@@ -133,7 +139,22 @@ class RelGraphConv(nn.Module):
else: else:
weight = self.weight weight = self.weight
msg = utils.bmm_maybe_select(edges.src['h'], weight, edges.data['type']) # calculate msg @ W_r before put msg into edge
# if src is th.int64 we expect it is an index select
if edges.src['h'].dtype != th.int64 and self.low_mem:
etypes = th.unique(edges.data['type'])
msg = th.empty((edges.src['h'].shape[0], self.out_feat),
device=edges.src['h'].device)
for etype in etypes:
loc = edges.data['type'] == etype
w = weight[etype]
src = edges.src['h'][loc]
sub_msg = th.matmul(src, w)
msg[loc] = sub_msg
else:
# put W_r into edges then do msg @ W_r
msg = utils.bmm_maybe_select(edges.src['h'], weight, edges.data['type'])
if 'norm' in edges.data: if 'norm' in edges.data:
msg = msg * edges.data['norm'] msg = msg * edges.data['norm']
return {'msg': msg} return {'msg': msg}
...@@ -142,10 +163,24 @@ class RelGraphConv(nn.Module): ...@@ -142,10 +163,24 @@ class RelGraphConv(nn.Module):
"""Message function for block-diagonal-decomposition regularizer""" """Message function for block-diagonal-decomposition regularizer"""
if edges.src['h'].dtype == th.int64 and len(edges.src['h'].shape) == 1: if edges.src['h'].dtype == th.int64 and len(edges.src['h'].shape) == 1:
raise TypeError('Block decomposition does not allow integer ID feature.') raise TypeError('Block decomposition does not allow integer ID feature.')
weight = self.weight.index_select(0, edges.data['type']).view(
-1, self.submat_in, self.submat_out) # calculate msg @ W_r before put msg into edge
node = edges.src['h'].view(-1, 1, self.submat_in) if self.low_mem:
msg = th.bmm(node, weight).view(-1, self.out_feat) etypes = th.unique(edges.data['type'])
msg = th.empty((edges.src['h'].shape[0], self.out_feat),
device=edges.src['h'].device)
for etype in etypes:
loc = edges.data['type'] == etype
w = self.weight[etype].view(self.num_bases, self.submat_in, self.submat_out)
src = edges.src['h'][loc].view(-1, self.num_bases, self.submat_in)
sub_msg = th.einsum('abc,bcd->abd', src, w)
sub_msg = sub_msg.reshape(-1, self.out_feat)
msg[loc] = sub_msg
else:
weight = self.weight.index_select(0, edges.data['type']).view(
-1, self.submat_in, self.submat_out)
node = edges.src['h'].view(-1, 1, self.submat_in)
msg = th.bmm(node, weight).view(-1, self.out_feat)
if 'norm' in edges.data: if 'norm' in edges.data:
msg = msg * edges.data['norm'] msg = msg * edges.data['norm']
return {'msg': msg} return {'msg': msg}
......
...@@ -53,6 +53,10 @@ class RelGraphConv(layers.Layer): ...@@ -53,6 +53,10 @@ class RelGraphConv(layers.Layer):
Activation function. Default: None Activation function. Default: None
self_loop : bool, optional self_loop : bool, optional
True to include self loop message. Default: False True to include self loop message. Default: False
low_mem : bool, optional
True to use low memory implementation of relation message passing function. Default: False
This option trade speed with memory consumption, and will slowdown the forward/backward.
Turn it on when you encounter OOM problem during training or evaluation.
dropout : float, optional dropout : float, optional
Dropout rate. Default: 0.0 Dropout rate. Default: 0.0
""" """
...@@ -66,6 +70,7 @@ class RelGraphConv(layers.Layer): ...@@ -66,6 +70,7 @@ class RelGraphConv(layers.Layer):
bias=True, bias=True,
activation=None, activation=None,
self_loop=False, self_loop=False,
low_mem=False,
dropout=0.0): dropout=0.0):
super(RelGraphConv, self).__init__() super(RelGraphConv, self).__init__()
self.in_feat = in_feat self.in_feat = in_feat
...@@ -78,6 +83,7 @@ class RelGraphConv(layers.Layer): ...@@ -78,6 +83,7 @@ class RelGraphConv(layers.Layer):
self.bias = bias self.bias = bias
self.activation = activation self.activation = activation
self.self_loop = self_loop self.self_loop = self_loop
self.low_mem = low_mem
xinit = tf.keras.initializers.glorot_uniform() xinit = tf.keras.initializers.glorot_uniform()
zeroinit = tf.keras.initializers.zeros() zeroinit = tf.keras.initializers.zeros()
...@@ -134,8 +140,22 @@ class RelGraphConv(layers.Layer): ...@@ -134,8 +140,22 @@ class RelGraphConv(layers.Layer):
else: else:
weight = self.weight weight = self.weight
msg = utils.bmm_maybe_select( # calculate msg @ W_r before put msg into edge
edges.src['h'], weight, edges.data['type']) # if src is th.int64 we expect it is an index select
if edges.src['h'].dtype != tf.int64 and self.low_mem:
etypes, _ = tf.unique(edges.data['type'])
msg = tf.zeros([edges.src['h'].shape[0], self.out_feat])
idx = tf.range(edges.src['h'].shape[0])
for etype in etypes:
loc = (edges.data['type'] == etype)
w = weight[etype]
src = tf.boolean_mask(edges.src['h'], loc)
sub_msg = tf.matmul(src, w)
indices = tf.reshape(tf.boolean_mask(idx, loc), (-1, 1))
msg = tf.tensor_scatter_nd_update(msg, indices, sub_msg)
else:
msg = utils.bmm_maybe_select(
edges.src['h'], weight, edges.data['type'])
if 'norm' in edges.data: if 'norm' in edges.data:
msg = msg * edges.data['norm'] msg = msg * edges.data['norm']
return {'msg': msg} return {'msg': msg}
...@@ -146,10 +166,28 @@ class RelGraphConv(layers.Layer): ...@@ -146,10 +166,28 @@ class RelGraphConv(layers.Layer):
len(edges.src['h'].shape) == 1): len(edges.src['h'].shape) == 1):
raise TypeError( raise TypeError(
'Block decomposition does not allow integer ID feature.') 'Block decomposition does not allow integer ID feature.')
weight = tf.reshape(tf.gather(
self.weight, edges.data['type']), (-1, self.submat_in, self.submat_out)) # calculate msg @ W_r before put msg into edge
node = tf.reshape(edges.src['h'], (-1, 1, self.submat_in)) # if src is th.int64 we expect it is an index select
msg = tf.reshape(tf.matmul(node, weight), (-1, self.out_feat)) if self.low_mem:
etypes, _ = tf.unique(edges.data['type'])
msg = tf.zeros([edges.src['h'].shape[0], self.out_feat])
idx = tf.range(edges.src['h'].shape[0])
for etype in etypes:
loc = (edges.data['type'] == etype)
w = tf.reshape(self.weight[etype],
(self.num_bases, self.submat_in, self.submat_out))
src = tf.reshape(tf.boolean_mask(edges.src['h'], loc),
(-1, self.num_bases, self.submat_in))
sub_msg = tf.einsum('abc,bcd->abd', src, w)
sub_msg = tf.reshape(sub_msg, (-1, self.out_feat))
indices = tf.reshape(tf.boolean_mask(idx, loc), (-1, 1))
msg = tf.tensor_scatter_nd_update(msg, indices, sub_msg)
else:
weight = tf.reshape(tf.gather(
self.weight, edges.data['type']), (-1, self.submat_in, self.submat_out))
node = tf.reshape(edges.src['h'], (-1, 1, self.submat_in))
msg = tf.reshape(tf.matmul(node, weight), (-1, self.out_feat))
if 'norm' in edges.data: if 'norm' in edges.data:
msg = msg * edges.data['norm'] msg = msg * edges.data['norm']
return {'msg': msg} return {'msg': msg}
......
...@@ -373,38 +373,66 @@ def test_rgcn(): ...@@ -373,38 +373,66 @@ def test_rgcn():
O = 8 O = 8
rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx) rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
rgc_basis_low.weight = rgc_basis.weight
rgc_basis_low.w_comp = rgc_basis.w_comp
h = th.randn((100, I)).to(ctx) h = th.randn((100, I)).to(ctx)
r = th.tensor(etype).to(ctx) r = th.tensor(etype).to(ctx)
h_new = rgc_basis(g, h, r) h_new = rgc_basis(g, h, r)
h_new_low = rgc_basis_low(g, h, r)
assert list(h_new.shape) == [100, O] assert list(h_new.shape) == [100, O]
assert list(h_new_low.shape) == [100, O]
assert F.allclose(h_new, h_new_low)
rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx) rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True).to(ctx)
rgc_bdd_low.weight = rgc_bdd.weight
h = th.randn((100, I)).to(ctx) h = th.randn((100, I)).to(ctx)
r = th.tensor(etype).to(ctx) r = th.tensor(etype).to(ctx)
h_new = rgc_bdd(g, h, r) h_new = rgc_bdd(g, h, r)
h_new_low = rgc_bdd_low(g, h, r)
assert list(h_new.shape) == [100, O] assert list(h_new.shape) == [100, O]
assert list(h_new_low.shape) == [100, O]
assert F.allclose(h_new, h_new_low)
# with norm # with norm
norm = th.zeros((g.number_of_edges(), 1)).to(ctx) norm = th.zeros((g.number_of_edges(), 1)).to(ctx)
rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx) rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
rgc_basis_low.weight = rgc_basis.weight
rgc_basis_low.w_comp = rgc_basis.w_comp
h = th.randn((100, I)).to(ctx) h = th.randn((100, I)).to(ctx)
r = th.tensor(etype).to(ctx) r = th.tensor(etype).to(ctx)
h_new = rgc_basis(g, h, r, norm) h_new = rgc_basis(g, h, r, norm)
h_new_low = rgc_basis_low(g, h, r, norm)
assert list(h_new.shape) == [100, O] assert list(h_new.shape) == [100, O]
assert list(h_new_low.shape) == [100, O]
assert F.allclose(h_new, h_new_low)
rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx) rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True).to(ctx)
rgc_bdd_low.weight = rgc_bdd.weight
h = th.randn((100, I)).to(ctx) h = th.randn((100, I)).to(ctx)
r = th.tensor(etype).to(ctx) r = th.tensor(etype).to(ctx)
h_new = rgc_bdd(g, h, r, norm) h_new = rgc_bdd(g, h, r, norm)
h_new_low = rgc_bdd_low(g, h, r, norm)
assert list(h_new.shape) == [100, O] assert list(h_new.shape) == [100, O]
assert list(h_new_low.shape) == [100, O]
assert F.allclose(h_new, h_new_low)
# id input # id input
rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx) rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
rgc_basis_low.weight = rgc_basis.weight
rgc_basis_low.w_comp = rgc_basis.w_comp
h = th.randint(0, I, (100,)).to(ctx) h = th.randint(0, I, (100,)).to(ctx)
r = th.tensor(etype).to(ctx) r = th.tensor(etype).to(ctx)
h_new = rgc_basis(g, h, r) h_new = rgc_basis(g, h, r)
h_new_low = rgc_basis_low(g, h, r)
assert list(h_new.shape) == [100, O] assert list(h_new.shape) == [100, O]
assert list(h_new_low.shape) == [100, O]
assert F.allclose(h_new, h_new_low)
def test_gat_conv(): def test_gat_conv():
ctx = F.ctx() ctx = F.ctx()
......
...@@ -272,38 +272,66 @@ def test_rgcn(): ...@@ -272,38 +272,66 @@ def test_rgcn():
O = 8 O = 8
rgc_basis = nn.RelGraphConv(I, O, R, "basis", B) rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True)
rgc_basis_low.weight = rgc_basis.weight
rgc_basis_low.w_comp = rgc_basis.w_comp
h = tf.random.normal((100, I)) h = tf.random.normal((100, I))
r = tf.constant(etype) r = tf.constant(etype)
h_new = rgc_basis(g, h, r) h_new = rgc_basis(g, h, r)
h_new_low = rgc_basis_low(g, h, r)
assert list(h_new.shape) == [100, O] assert list(h_new.shape) == [100, O]
assert list(h_new_low.shape) == [100, O]
assert F.allclose(h_new, h_new_low)
rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B) rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True)
rgc_bdd_low.weight = rgc_bdd.weight
h = tf.random.normal((100, I)) h = tf.random.normal((100, I))
r = tf.constant(etype) r = tf.constant(etype)
h_new = rgc_bdd(g, h, r) h_new = rgc_bdd(g, h, r)
h_new_low = rgc_bdd_low(g, h, r)
assert list(h_new.shape) == [100, O] assert list(h_new.shape) == [100, O]
assert list(h_new_low.shape) == [100, O]
assert F.allclose(h_new, h_new_low)
# with norm # with norm
norm = tf.zeros((g.number_of_edges(), 1)) norm = tf.zeros((g.number_of_edges(), 1))
rgc_basis = nn.RelGraphConv(I, O, R, "basis", B) rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True)
rgc_basis_low.weight = rgc_basis.weight
rgc_basis_low.w_comp = rgc_basis.w_comp
h = tf.random.normal((100, I)) h = tf.random.normal((100, I))
r = tf.constant(etype) r = tf.constant(etype)
h_new = rgc_basis(g, h, r, norm) h_new = rgc_basis(g, h, r, norm)
h_new_low = rgc_basis_low(g, h, r, norm)
assert list(h_new.shape) == [100, O] assert list(h_new.shape) == [100, O]
assert list(h_new_low.shape) == [100, O]
assert F.allclose(h_new, h_new_low)
rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B) rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True)
rgc_bdd_low.weight = rgc_bdd.weight
h = tf.random.normal((100, I)) h = tf.random.normal((100, I))
r = tf.constant(etype) r = tf.constant(etype)
h_new = rgc_bdd(g, h, r, norm) h_new = rgc_bdd(g, h, r, norm)
h_new_low = rgc_bdd_low(g, h, r, norm)
assert list(h_new.shape) == [100, O] assert list(h_new.shape) == [100, O]
assert list(h_new_low.shape) == [100, O]
assert F.allclose(h_new, h_new_low)
# id input # id input
rgc_basis = nn.RelGraphConv(I, O, R, "basis", B) rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True)
rgc_basis_low.weight = rgc_basis.weight
rgc_basis_low.w_comp = rgc_basis.w_comp
h = tf.constant(np.random.randint(0, I, (100,))) h = tf.constant(np.random.randint(0, I, (100,)))
r = tf.constant(etype) r = tf.constant(etype)
h_new = rgc_basis(g, h, r) h_new = rgc_basis(g, h, r)
h_new_low = rgc_basis_low(g, h, r)
assert list(h_new.shape) == [100, O] assert list(h_new.shape) == [100, O]
assert list(h_new_low.shape) == [100, O]
assert F.allclose(h_new, h_new_low)
def test_gat_conv(): def test_gat_conv():
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment