Unverified Commit ec2e24be authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Feature] Slice dstnode features in NN modules (#1838)

* slice dstdata from srcdata within nn module

* a bunch of fixes

* add comment

* fix gcmc layer

* repr for blocks

* fix

* fix context

* fix

* do not copy internal columns

* docstring
parent 0e896a92
......@@ -26,10 +26,6 @@ def get_graph(network_data, vocab):
keys describing the edge types, values representing edges
vocab: a dict
mapping node IDs to node indices
<<<<<<< HEAD
=======
>>>>>>> c334b40e1f8a30bd5619814f34a469b18774fba7
Output
------
DGLHeteroGraph
......
......@@ -69,6 +69,8 @@ class GCMCGraphConv(nn.Module):
The output feature
"""
with graph.local_scope():
if isinstance(feat, tuple):
feat, _ = feat # dst feature not used
cj = graph.srcdata['cj']
ci = graph.dstdata['ci']
if self.device is not None:
......
......@@ -41,14 +41,7 @@ class SAGE(nn.Module):
def forward(self, blocks, x):
h = x
for l, (layer, block) in enumerate(zip(self.layers, blocks)):
# We need to first copy the representation of nodes on the RHS from the
# appropriate nodes on the LHS.
# Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst
# would be (num_nodes_RHS, D)
h_dst = h[:block.number_of_dst_nodes()]
# Then we compute the updated representation on the RHS.
# The shape of h now becomes (num_nodes_RHS, D)
h = layer(block, (h, h_dst))
h = layer(block, h)
if l != len(self.layers) - 1:
h = self.activation(h)
h = self.dropout(h)
......@@ -85,8 +78,7 @@ class SAGE(nn.Module):
block = blocks[0]
h = x[input_nodes].to(device)
h_dst = h[:block.number_of_dst_nodes()]
h = layer(block, (h, h_dst))
h = layer(block, h)
if l != len(self.layers) - 1:
h = self.activation(h)
h = self.dropout(h)
......
......@@ -41,14 +41,7 @@ class SAGE(nn.Module):
def forward(self, blocks, x):
h = x
for l, (layer, block) in enumerate(zip(self.layers, blocks)):
# We need to first copy the representation of nodes on the RHS from the
# appropriate nodes on the LHS.
# Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst
# would be (num_nodes_RHS, D)
h_dst = h[:block.number_of_dst_nodes()]
# Then we compute the updated representation on the RHS.
# The shape of h now becomes (num_nodes_RHS, D)
h = layer(block, (h, h_dst))
h = layer(block, h)
if l != len(self.layers) - 1:
h = self.activation(h)
h = self.dropout(h)
......@@ -86,8 +79,7 @@ class SAGE(nn.Module):
block = blocks[0]
h = x[input_nodes].to(device)
h_dst = h[:block.number_of_dst_nodes()]
h = layer(block, (h, h_dst))
h = layer(block, h)
if l != len(self.layers) - 1:
h = self.activation(h)
h = self.dropout(h)
......
......@@ -121,14 +121,7 @@ class SAGE(nn.Module):
def forward(self, blocks, x):
h = x
for l, (layer, block) in enumerate(zip(self.layers, blocks)):
# We need to first copy the representation of nodes on the RHS from the
# appropriate nodes on the LHS.
# Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst
# would be (num_nodes_RHS, D)
h_dst = h[:block.number_of_dst_nodes()]
# Then we compute the updated representation on the RHS.
# The shape of h now becomes (num_nodes_RHS, D)
h = layer(block, (h, h_dst))
h = layer(block, h)
if l != len(self.layers) - 1:
h = self.activation(h)
h = self.dropout(h)
......@@ -159,8 +152,7 @@ class SAGE(nn.Module):
input_nodes = block.srcdata[dgl.NID]
h = x[input_nodes].to(device)
h_dst = h[:block.number_of_dst_nodes()]
h = layer(block, (h, h_dst))
h = layer(block, h)
if l != len(self.layers) - 1:
h = self.activation(h)
h = self.dropout(h)
......
......@@ -5,7 +5,7 @@ import torch as th
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.nn.pytorch as dglnn
import dgl.nn as dglnn
import tqdm
class RelGraphConvLayer(nn.Module):
......@@ -101,15 +101,9 @@ class RelGraphConvLayer(nn.Module):
for i, w in enumerate(th.split(weight, 1, dim=0))}
else:
wdict = {}
inputs_src, inputs_dst = dglnn.expand_as_pair(inputs, g)
hs = self.conv(g, inputs, mod_kwargs=wdict)
if isinstance(inputs, tuple):
# minibatch training
inputs_dst = inputs[1]
else:
# full graph training
inputs_dst = inputs
def _apply(ntype, h):
if self.self_loop:
h = h + th.matmul(inputs_dst[ntype], self.loop_weight)
......@@ -211,8 +205,7 @@ class EntityClassify(nn.Module):
else:
# minibatch training
for layer, block in zip(self.layers, blocks):
h_dst = {k: v[:block.number_of_dst_nodes(k)] for k, v in h.items()}
h = layer(block, (h, h_dst))
h = layer(block, h)
return h
def inference(self, g, batch_size, device, num_workers, x=None):
......@@ -247,8 +240,7 @@ class EntityClassify(nn.Module):
block = blocks[0]
h = {k: x[k][input_nodes[k]].to(device) for k in input_nodes.keys()}
h_dst = {k: v[:block.number_of_dst_nodes(k)] for k, v in h.items()}
h = layer(block, (h, h_dst))
h = layer(block, h)
for k in h.keys():
y[k][output_nodes[k]] = h[k].cpu()
......
......@@ -16,6 +16,12 @@ NID = '_ID'
ETYPE = '_TYPE'
EID = '_ID'
_INTERNAL_COLUMNS = {NTYPE, NID, ETYPE, EID}
def is_internal_column(name):
"""Return true if the column name is reversed by DGL."""
return name in _INTERNAL_COLUMNS
def is_all(arg):
"""Return true if the argument is a special symbol for all nodes or edges."""
return isinstance(arg, str) and arg == ALL
......
......@@ -39,6 +39,9 @@ class DGLBaseGraph(object):
graph : graph index, optional
Data to initialize graph.
"""
is_block = False # for compatibility with DGLHeteroGraph
def __init__(self, graph):
self._graph = graph
......
......@@ -191,6 +191,8 @@ class DGLHeteroGraph(object):
Otherwise, ``edge_frames[i]`` stores the edge features
of edge type i. (default: None)
"""
is_block = False
# pylint: disable=unused-argument
def __init__(self,
gidx,
......@@ -280,7 +282,11 @@ class DGLHeteroGraph(object):
self._msg_frames.append(frame)
def __getstate__(self):
return self._graph, self._ntypes, self._etypes, self._node_frames, self._edge_frames
if self.is_block:
ntypes = (self.srctypes, self.dsttypes)
else:
ntypes = self._ntypes
return self._graph, ntypes, self._etypes, self._node_frames, self._edge_frames
def __setstate__(self, state):
# Compatibility check
......@@ -323,7 +329,7 @@ class DGLHeteroGraph(object):
for i in range(len(self.ntypes))}
nedge_dict = {self.canonical_etypes[i] : self._graph.number_of_edges(i)
for i in range(len(self.etypes))}
meta = str(self.metagraph.edges())
meta = str(self.metagraph.edges(keys=True))
return ret.format(node=nnode_dict, edge=nedge_dict, meta=meta)
#################################################################
......@@ -1095,7 +1101,7 @@ class DGLHeteroGraph(object):
new_etypes = [etype]
new_eframes = [self._edge_frames[etid]]
return DGLHeteroGraph(new_g, new_ntypes, new_etypes, new_nframes, new_eframes)
return self.__class__(new_g, new_ntypes, new_etypes, new_nframes, new_eframes)
else:
flat = self._graph.flatten_relations(etypes)
new_g = flat.graph
......@@ -1117,7 +1123,7 @@ class DGLHeteroGraph(object):
new_eframes = [combine_frames(self._edge_frames, etids)]
# create new heterograph
new_hg = DGLHeteroGraph(new_g, new_ntypes, new_etypes, new_nframes, new_eframes)
new_hg = self.__class__(new_g, new_ntypes, new_etypes, new_nframes, new_eframes)
src = new_ntypes[0]
dst = new_ntypes[1] if new_g.number_of_ntypes() == 2 else src
......@@ -2040,7 +2046,7 @@ class DGLHeteroGraph(object):
num_rows=len(induced_edges_of_etype)))
for i, induced_edges_of_etype in enumerate(induced_edges)]
hsg = DGLHeteroGraph(sgi.graph, self._ntypes, self._etypes, node_frames, edge_frames)
hsg = self.__class__(sgi.graph, self._ntypes, self._etypes, node_frames, edge_frames)
hsg.is_subgraph = True
for ntype, induced_nid in zip(self.ntypes, induced_nodes):
hsg.nodes[ntype].data[NID] = induced_nid.tousertensor()
......@@ -2360,7 +2366,7 @@ class DGLHeteroGraph(object):
# num_nodes_per_type doesn't need to be int32
hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, rel_graphs, utils.toindex(num_nodes_per_type, "int64"))
hg = DGLHeteroGraph(hgidx, ntypes, induced_etypes,
hg = self.__class__(hgidx, ntypes, induced_etypes,
node_frames, edge_frames)
return hg
......@@ -2437,7 +2443,7 @@ class DGLHeteroGraph(object):
# num_nodes_per_type should be int64
hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, rel_graphs, utils.toindex(num_nodes_per_induced_type, "int64"))
hg = DGLHeteroGraph(hgidx, induced_ntypes, induced_etypes, node_frames, edge_frames)
hg = self.__class__(hgidx, induced_ntypes, induced_etypes, node_frames, edge_frames)
return hg
def adjacency_matrix(self, transpose=None, ctx=F.cpu(), scipy_fmt=None, etype=None):
......@@ -4324,7 +4330,7 @@ class DGLHeteroGraph(object):
# TODO(minjie): replace the following line with the commented one to enable GPU graph.
new_gidx = self._graph
#new_gidx = self._graph.copy_to(utils.to_dgl_context(ctx))
return DGLHeteroGraph(new_gidx, self.ntypes, self.etypes,
return self.__class__(new_gidx, self.ntypes, self.etypes,
new_nframes, new_eframes)
def local_var(self):
......@@ -4647,7 +4653,7 @@ class DGLHeteroGraph(object):
restrict_format
request_format
"""
return DGLHeteroGraph(self._graph.to_format(restrict_format), self.ntypes, self.etypes,
return self.__class__(self._graph.to_format(restrict_format), self.ntypes, self.etypes,
self._node_frames,
self._edge_frames)
......@@ -4672,7 +4678,7 @@ class DGLHeteroGraph(object):
int
idtype
"""
return DGLHeteroGraph(self._graph.asbits(64), self.ntypes, self.etypes,
return self.__class__(self._graph.asbits(64), self.ntypes, self.etypes,
self._node_frames,
self._edge_frames)
......@@ -4697,7 +4703,7 @@ class DGLHeteroGraph(object):
long
idtype
"""
return DGLHeteroGraph(self._graph.asbits(32), self.ntypes, self.etypes,
return self.__class__(self._graph.asbits(32), self.ntypes, self.etypes,
self._node_frames,
self._edge_frames)
......@@ -5040,4 +5046,34 @@ def check_idtype_dict(graph_dtype, tensor_dict):
for _, v in tensor_dict.items():
check_same_dtype(graph_dtype, v)
class DGLBlock(DGLHeteroGraph):
"""Subclass that signifies the graph is a block created from
:func:`dgl.to_block`.
"""
# (BarclayII) I'm making a subclass because I don't want to make another version of
# serialization that contains the is_block flag.
is_block = True
def __repr__(self):
if len(self.srctypes) == 1 and len(self.dsttypes) == 1 and len(self.etypes) == 1:
ret = 'Block(num_src_nodes={srcnode}, num_dst_nodes={dstnode}, num_edges={edge})'
return ret.format(
srcnode=self.number_of_src_nodes(),
dstnode=self.number_of_dst_nodes(),
edge=self.number_of_edges())
else:
ret = ('Block(num_src_nodes={srcnode},\n'
' num_dst_nodes={dstnode},\n'
' num_edges={edge},\n'
' metagraph={meta})')
nsrcnode_dict = {ntype : self.number_of_src_nodes(ntype)
for ntype in self.srctypes}
ndstnode_dict = {ntype : self.number_of_dst_nodes(ntype)
for ntype in self.dsttypes}
nedge_dict = {etype : self.number_of_edges(etype)
for etype in self.canonical_etypes}
meta = str(self.metagraph.edges(keys=True))
return ret.format(
srcnode=nsrcnode_dict, dstnode=ndstnode_dict, edge=nedge_dict, meta=meta)
_init_api("dgl.heterograph")
......@@ -3,6 +3,7 @@ import importlib
import sys
import os
from ..backend import backend_name
from ..utils import expand_as_pair
def _load_backend(mod_name):
mod = importlib.import_module('.%s' % mod_name, __name__)
......
......@@ -60,10 +60,10 @@ class AGNNConv(nn.Block):
should be the same as input shape.
"""
with graph.local_scope():
feat_src, feat_dst = expand_as_pair(feat)
feat_src, feat_dst = expand_as_pair(feat, graph)
graph.srcdata['h'] = feat_src
graph.srcdata['norm_h'] = normalize(feat_src, p=2, axis=-1)
if isinstance(feat, tuple):
if isinstance(feat, tuple) or graph.is_block:
graph.dstdata['norm_h'] = normalize(feat_dst, p=2, axis=-1)
# compute cosine distance
graph.apply_edges(fn.u_dot_v('norm_h', 'norm_h', 'cos'))
......
......@@ -71,7 +71,7 @@ class EdgeConv(nn.Block):
New node features.
"""
with g.local_scope():
h_src, h_dst = expand_as_pair(h)
h_src, h_dst = expand_as_pair(h, g)
g.srcdata['x'] = h_src
g.dstdata['x'] = h_dst
if not self.batch_norm:
......
......@@ -129,6 +129,8 @@ class GATConv(nn.Block):
h_src = h_dst = self.feat_drop(feat)
feat_src = feat_dst = self.fc(h_src).reshape(
-1, self._num_heads, self._out_feats)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
# NOTE: GAT paper uses "first concatenation then linear projection"
# to compute attention scores, while ours is "first projection then
# addition", the two approaches are mathematically equivalent:
......
......@@ -75,7 +75,7 @@ class GINConv(nn.Block):
as input dimensionality.
"""
with graph.local_scope():
feat_src, feat_dst = expand_as_pair(feat)
feat_src, feat_dst = expand_as_pair(feat, graph)
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_u('h', 'm'), self._reducer('m', 'neigh'))
rst = (1 + self.eps.data(feat_dst.context)) * feat_dst + graph.dstdata['neigh']
......
......@@ -115,7 +115,7 @@ class GMMConv(nn.Block):
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is the output feature size.
"""
feat_src, feat_dst = expand_as_pair(feat)
feat_src, feat_dst = expand_as_pair(feat, graph)
with graph.local_scope():
graph.srcdata['h'] = self.fc(feat_src).reshape(
-1, self._n_kernels, self._out_feats)
......
......@@ -128,7 +128,7 @@ class GraphConv(gluon.Block):
The output feature
"""
with graph.local_scope():
feat_src, feat_dst = expand_as_pair(feat)
feat_src, feat_dst = expand_as_pair(feat, graph)
if self._norm == 'both':
degs = graph.out_degrees().as_in_context(feat_src.context).astype('float32')
......
......@@ -100,7 +100,7 @@ class NNConv(nn.Block):
is the output feature size.
"""
with graph.local_scope():
feat_src, feat_dst = expand_as_pair(feat)
feat_src, feat_dst = expand_as_pair(feat, graph)
# (n, d_in, 1)
graph.srcdata['h'] = feat_src.expand_dims(-1)
......
......@@ -104,6 +104,8 @@ class SAGEConv(nn.Block):
feat_dst = self.feat_drop(feat[1])
else:
feat_src = feat_dst = self.feat_drop(feat)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
h_self = feat_dst
......
......@@ -59,10 +59,10 @@ class AGNNConv(nn.Module):
should be the same as input shape.
"""
with graph.local_scope():
feat_src, feat_dst = expand_as_pair(feat)
feat_src, feat_dst = expand_as_pair(feat, graph)
graph.srcdata['h'] = feat_src
graph.srcdata['norm_h'] = F.normalize(feat_src, p=2, dim=-1)
if isinstance(feat, tuple):
if isinstance(feat, tuple) or graph.is_block:
graph.dstdata['norm_h'] = F.normalize(feat_dst, p=2, dim=-1)
# compute cosine distance
graph.apply_edges(fn.u_dot_v('norm_h', 'norm_h', 'cos'))
......
......@@ -70,6 +70,8 @@ class DotGatConv(nn.Module):
else:
h_src = feat
feat_src = feat_dst = self.fc(h_src)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
# Assign features to nodes
graph.srcdata.update({'ft': feat_src})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment