"src/git@developer.sourcefind.cn:one/TransferBench.git" did not exist on "9f68d14dc0f977489cfb22c202eec6d45870ff0c"
Unverified Commit ec2e24be authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Feature] Slice dstnode features in NN modules (#1838)

* slice dstdata from srcdata within nn module

* a bunch of fixes

* add comment

* fix gcmc layer

* repr for blocks

* fix

* fix context

* fix

* do not copy internal columns

* docstring
parent 0e896a92
...@@ -26,10 +26,6 @@ def get_graph(network_data, vocab): ...@@ -26,10 +26,6 @@ def get_graph(network_data, vocab):
keys describing the edge types, values representing edges keys describing the edge types, values representing edges
vocab: a dict vocab: a dict
mapping node IDs to node indices mapping node IDs to node indices
<<<<<<< HEAD
=======
>>>>>>> c334b40e1f8a30bd5619814f34a469b18774fba7
Output Output
------ ------
DGLHeteroGraph DGLHeteroGraph
......
...@@ -69,6 +69,8 @@ class GCMCGraphConv(nn.Module): ...@@ -69,6 +69,8 @@ class GCMCGraphConv(nn.Module):
The output feature The output feature
""" """
with graph.local_scope(): with graph.local_scope():
if isinstance(feat, tuple):
feat, _ = feat # dst feature not used
cj = graph.srcdata['cj'] cj = graph.srcdata['cj']
ci = graph.dstdata['ci'] ci = graph.dstdata['ci']
if self.device is not None: if self.device is not None:
......
...@@ -41,14 +41,7 @@ class SAGE(nn.Module): ...@@ -41,14 +41,7 @@ class SAGE(nn.Module):
def forward(self, blocks, x): def forward(self, blocks, x):
h = x h = x
for l, (layer, block) in enumerate(zip(self.layers, blocks)): for l, (layer, block) in enumerate(zip(self.layers, blocks)):
# We need to first copy the representation of nodes on the RHS from the h = layer(block, h)
# appropriate nodes on the LHS.
# Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst
# would be (num_nodes_RHS, D)
h_dst = h[:block.number_of_dst_nodes()]
# Then we compute the updated representation on the RHS.
# The shape of h now becomes (num_nodes_RHS, D)
h = layer(block, (h, h_dst))
if l != len(self.layers) - 1: if l != len(self.layers) - 1:
h = self.activation(h) h = self.activation(h)
h = self.dropout(h) h = self.dropout(h)
...@@ -85,8 +78,7 @@ class SAGE(nn.Module): ...@@ -85,8 +78,7 @@ class SAGE(nn.Module):
block = blocks[0] block = blocks[0]
h = x[input_nodes].to(device) h = x[input_nodes].to(device)
h_dst = h[:block.number_of_dst_nodes()] h = layer(block, h)
h = layer(block, (h, h_dst))
if l != len(self.layers) - 1: if l != len(self.layers) - 1:
h = self.activation(h) h = self.activation(h)
h = self.dropout(h) h = self.dropout(h)
......
...@@ -41,14 +41,7 @@ class SAGE(nn.Module): ...@@ -41,14 +41,7 @@ class SAGE(nn.Module):
def forward(self, blocks, x): def forward(self, blocks, x):
h = x h = x
for l, (layer, block) in enumerate(zip(self.layers, blocks)): for l, (layer, block) in enumerate(zip(self.layers, blocks)):
# We need to first copy the representation of nodes on the RHS from the h = layer(block, h)
# appropriate nodes on the LHS.
# Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst
# would be (num_nodes_RHS, D)
h_dst = h[:block.number_of_dst_nodes()]
# Then we compute the updated representation on the RHS.
# The shape of h now becomes (num_nodes_RHS, D)
h = layer(block, (h, h_dst))
if l != len(self.layers) - 1: if l != len(self.layers) - 1:
h = self.activation(h) h = self.activation(h)
h = self.dropout(h) h = self.dropout(h)
...@@ -86,8 +79,7 @@ class SAGE(nn.Module): ...@@ -86,8 +79,7 @@ class SAGE(nn.Module):
block = blocks[0] block = blocks[0]
h = x[input_nodes].to(device) h = x[input_nodes].to(device)
h_dst = h[:block.number_of_dst_nodes()] h = layer(block, h)
h = layer(block, (h, h_dst))
if l != len(self.layers) - 1: if l != len(self.layers) - 1:
h = self.activation(h) h = self.activation(h)
h = self.dropout(h) h = self.dropout(h)
......
...@@ -121,14 +121,7 @@ class SAGE(nn.Module): ...@@ -121,14 +121,7 @@ class SAGE(nn.Module):
def forward(self, blocks, x): def forward(self, blocks, x):
h = x h = x
for l, (layer, block) in enumerate(zip(self.layers, blocks)): for l, (layer, block) in enumerate(zip(self.layers, blocks)):
# We need to first copy the representation of nodes on the RHS from the h = layer(block, h)
# appropriate nodes on the LHS.
# Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst
# would be (num_nodes_RHS, D)
h_dst = h[:block.number_of_dst_nodes()]
# Then we compute the updated representation on the RHS.
# The shape of h now becomes (num_nodes_RHS, D)
h = layer(block, (h, h_dst))
if l != len(self.layers) - 1: if l != len(self.layers) - 1:
h = self.activation(h) h = self.activation(h)
h = self.dropout(h) h = self.dropout(h)
...@@ -159,8 +152,7 @@ class SAGE(nn.Module): ...@@ -159,8 +152,7 @@ class SAGE(nn.Module):
input_nodes = block.srcdata[dgl.NID] input_nodes = block.srcdata[dgl.NID]
h = x[input_nodes].to(device) h = x[input_nodes].to(device)
h_dst = h[:block.number_of_dst_nodes()] h = layer(block, h)
h = layer(block, (h, h_dst))
if l != len(self.layers) - 1: if l != len(self.layers) - 1:
h = self.activation(h) h = self.activation(h)
h = self.dropout(h) h = self.dropout(h)
......
...@@ -5,7 +5,7 @@ import torch as th ...@@ -5,7 +5,7 @@ import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl import dgl
import dgl.nn.pytorch as dglnn import dgl.nn as dglnn
import tqdm import tqdm
class RelGraphConvLayer(nn.Module): class RelGraphConvLayer(nn.Module):
...@@ -101,15 +101,9 @@ class RelGraphConvLayer(nn.Module): ...@@ -101,15 +101,9 @@ class RelGraphConvLayer(nn.Module):
for i, w in enumerate(th.split(weight, 1, dim=0))} for i, w in enumerate(th.split(weight, 1, dim=0))}
else: else:
wdict = {} wdict = {}
inputs_src, inputs_dst = dglnn.expand_as_pair(inputs, g)
hs = self.conv(g, inputs, mod_kwargs=wdict) hs = self.conv(g, inputs, mod_kwargs=wdict)
if isinstance(inputs, tuple):
# minibatch training
inputs_dst = inputs[1]
else:
# full graph training
inputs_dst = inputs
def _apply(ntype, h): def _apply(ntype, h):
if self.self_loop: if self.self_loop:
h = h + th.matmul(inputs_dst[ntype], self.loop_weight) h = h + th.matmul(inputs_dst[ntype], self.loop_weight)
...@@ -211,8 +205,7 @@ class EntityClassify(nn.Module): ...@@ -211,8 +205,7 @@ class EntityClassify(nn.Module):
else: else:
# minibatch training # minibatch training
for layer, block in zip(self.layers, blocks): for layer, block in zip(self.layers, blocks):
h_dst = {k: v[:block.number_of_dst_nodes(k)] for k, v in h.items()} h = layer(block, h)
h = layer(block, (h, h_dst))
return h return h
def inference(self, g, batch_size, device, num_workers, x=None): def inference(self, g, batch_size, device, num_workers, x=None):
...@@ -247,8 +240,7 @@ class EntityClassify(nn.Module): ...@@ -247,8 +240,7 @@ class EntityClassify(nn.Module):
block = blocks[0] block = blocks[0]
h = {k: x[k][input_nodes[k]].to(device) for k in input_nodes.keys()} h = {k: x[k][input_nodes[k]].to(device) for k in input_nodes.keys()}
h_dst = {k: v[:block.number_of_dst_nodes(k)] for k, v in h.items()} h = layer(block, h)
h = layer(block, (h, h_dst))
for k in h.keys(): for k in h.keys():
y[k][output_nodes[k]] = h[k].cpu() y[k][output_nodes[k]] = h[k].cpu()
......
...@@ -16,6 +16,12 @@ NID = '_ID' ...@@ -16,6 +16,12 @@ NID = '_ID'
ETYPE = '_TYPE' ETYPE = '_TYPE'
EID = '_ID' EID = '_ID'
_INTERNAL_COLUMNS = {NTYPE, NID, ETYPE, EID}
def is_internal_column(name):
"""Return true if the column name is reversed by DGL."""
return name in _INTERNAL_COLUMNS
def is_all(arg): def is_all(arg):
"""Return true if the argument is a special symbol for all nodes or edges.""" """Return true if the argument is a special symbol for all nodes or edges."""
return isinstance(arg, str) and arg == ALL return isinstance(arg, str) and arg == ALL
......
...@@ -39,6 +39,9 @@ class DGLBaseGraph(object): ...@@ -39,6 +39,9 @@ class DGLBaseGraph(object):
graph : graph index, optional graph : graph index, optional
Data to initialize graph. Data to initialize graph.
""" """
is_block = False # for compatibility with DGLHeteroGraph
def __init__(self, graph): def __init__(self, graph):
self._graph = graph self._graph = graph
......
...@@ -191,6 +191,8 @@ class DGLHeteroGraph(object): ...@@ -191,6 +191,8 @@ class DGLHeteroGraph(object):
Otherwise, ``edge_frames[i]`` stores the edge features Otherwise, ``edge_frames[i]`` stores the edge features
of edge type i. (default: None) of edge type i. (default: None)
""" """
is_block = False
# pylint: disable=unused-argument # pylint: disable=unused-argument
def __init__(self, def __init__(self,
gidx, gidx,
...@@ -280,7 +282,11 @@ class DGLHeteroGraph(object): ...@@ -280,7 +282,11 @@ class DGLHeteroGraph(object):
self._msg_frames.append(frame) self._msg_frames.append(frame)
def __getstate__(self): def __getstate__(self):
return self._graph, self._ntypes, self._etypes, self._node_frames, self._edge_frames if self.is_block:
ntypes = (self.srctypes, self.dsttypes)
else:
ntypes = self._ntypes
return self._graph, ntypes, self._etypes, self._node_frames, self._edge_frames
def __setstate__(self, state): def __setstate__(self, state):
# Compatibility check # Compatibility check
...@@ -323,7 +329,7 @@ class DGLHeteroGraph(object): ...@@ -323,7 +329,7 @@ class DGLHeteroGraph(object):
for i in range(len(self.ntypes))} for i in range(len(self.ntypes))}
nedge_dict = {self.canonical_etypes[i] : self._graph.number_of_edges(i) nedge_dict = {self.canonical_etypes[i] : self._graph.number_of_edges(i)
for i in range(len(self.etypes))} for i in range(len(self.etypes))}
meta = str(self.metagraph.edges()) meta = str(self.metagraph.edges(keys=True))
return ret.format(node=nnode_dict, edge=nedge_dict, meta=meta) return ret.format(node=nnode_dict, edge=nedge_dict, meta=meta)
################################################################# #################################################################
...@@ -1095,7 +1101,7 @@ class DGLHeteroGraph(object): ...@@ -1095,7 +1101,7 @@ class DGLHeteroGraph(object):
new_etypes = [etype] new_etypes = [etype]
new_eframes = [self._edge_frames[etid]] new_eframes = [self._edge_frames[etid]]
return DGLHeteroGraph(new_g, new_ntypes, new_etypes, new_nframes, new_eframes) return self.__class__(new_g, new_ntypes, new_etypes, new_nframes, new_eframes)
else: else:
flat = self._graph.flatten_relations(etypes) flat = self._graph.flatten_relations(etypes)
new_g = flat.graph new_g = flat.graph
...@@ -1117,7 +1123,7 @@ class DGLHeteroGraph(object): ...@@ -1117,7 +1123,7 @@ class DGLHeteroGraph(object):
new_eframes = [combine_frames(self._edge_frames, etids)] new_eframes = [combine_frames(self._edge_frames, etids)]
# create new heterograph # create new heterograph
new_hg = DGLHeteroGraph(new_g, new_ntypes, new_etypes, new_nframes, new_eframes) new_hg = self.__class__(new_g, new_ntypes, new_etypes, new_nframes, new_eframes)
src = new_ntypes[0] src = new_ntypes[0]
dst = new_ntypes[1] if new_g.number_of_ntypes() == 2 else src dst = new_ntypes[1] if new_g.number_of_ntypes() == 2 else src
...@@ -2040,7 +2046,7 @@ class DGLHeteroGraph(object): ...@@ -2040,7 +2046,7 @@ class DGLHeteroGraph(object):
num_rows=len(induced_edges_of_etype))) num_rows=len(induced_edges_of_etype)))
for i, induced_edges_of_etype in enumerate(induced_edges)] for i, induced_edges_of_etype in enumerate(induced_edges)]
hsg = DGLHeteroGraph(sgi.graph, self._ntypes, self._etypes, node_frames, edge_frames) hsg = self.__class__(sgi.graph, self._ntypes, self._etypes, node_frames, edge_frames)
hsg.is_subgraph = True hsg.is_subgraph = True
for ntype, induced_nid in zip(self.ntypes, induced_nodes): for ntype, induced_nid in zip(self.ntypes, induced_nodes):
hsg.nodes[ntype].data[NID] = induced_nid.tousertensor() hsg.nodes[ntype].data[NID] = induced_nid.tousertensor()
...@@ -2360,7 +2366,7 @@ class DGLHeteroGraph(object): ...@@ -2360,7 +2366,7 @@ class DGLHeteroGraph(object):
# num_nodes_per_type doesn't need to be int32 # num_nodes_per_type doesn't need to be int32
hgidx = heterograph_index.create_heterograph_from_relations( hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, rel_graphs, utils.toindex(num_nodes_per_type, "int64")) metagraph, rel_graphs, utils.toindex(num_nodes_per_type, "int64"))
hg = DGLHeteroGraph(hgidx, ntypes, induced_etypes, hg = self.__class__(hgidx, ntypes, induced_etypes,
node_frames, edge_frames) node_frames, edge_frames)
return hg return hg
...@@ -2437,7 +2443,7 @@ class DGLHeteroGraph(object): ...@@ -2437,7 +2443,7 @@ class DGLHeteroGraph(object):
# num_nodes_per_type should be int64 # num_nodes_per_type should be int64
hgidx = heterograph_index.create_heterograph_from_relations( hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, rel_graphs, utils.toindex(num_nodes_per_induced_type, "int64")) metagraph, rel_graphs, utils.toindex(num_nodes_per_induced_type, "int64"))
hg = DGLHeteroGraph(hgidx, induced_ntypes, induced_etypes, node_frames, edge_frames) hg = self.__class__(hgidx, induced_ntypes, induced_etypes, node_frames, edge_frames)
return hg return hg
def adjacency_matrix(self, transpose=None, ctx=F.cpu(), scipy_fmt=None, etype=None): def adjacency_matrix(self, transpose=None, ctx=F.cpu(), scipy_fmt=None, etype=None):
...@@ -4324,7 +4330,7 @@ class DGLHeteroGraph(object): ...@@ -4324,7 +4330,7 @@ class DGLHeteroGraph(object):
# TODO(minjie): replace the following line with the commented one to enable GPU graph. # TODO(minjie): replace the following line with the commented one to enable GPU graph.
new_gidx = self._graph new_gidx = self._graph
#new_gidx = self._graph.copy_to(utils.to_dgl_context(ctx)) #new_gidx = self._graph.copy_to(utils.to_dgl_context(ctx))
return DGLHeteroGraph(new_gidx, self.ntypes, self.etypes, return self.__class__(new_gidx, self.ntypes, self.etypes,
new_nframes, new_eframes) new_nframes, new_eframes)
def local_var(self): def local_var(self):
...@@ -4647,7 +4653,7 @@ class DGLHeteroGraph(object): ...@@ -4647,7 +4653,7 @@ class DGLHeteroGraph(object):
restrict_format restrict_format
request_format request_format
""" """
return DGLHeteroGraph(self._graph.to_format(restrict_format), self.ntypes, self.etypes, return self.__class__(self._graph.to_format(restrict_format), self.ntypes, self.etypes,
self._node_frames, self._node_frames,
self._edge_frames) self._edge_frames)
...@@ -4672,7 +4678,7 @@ class DGLHeteroGraph(object): ...@@ -4672,7 +4678,7 @@ class DGLHeteroGraph(object):
int int
idtype idtype
""" """
return DGLHeteroGraph(self._graph.asbits(64), self.ntypes, self.etypes, return self.__class__(self._graph.asbits(64), self.ntypes, self.etypes,
self._node_frames, self._node_frames,
self._edge_frames) self._edge_frames)
...@@ -4697,7 +4703,7 @@ class DGLHeteroGraph(object): ...@@ -4697,7 +4703,7 @@ class DGLHeteroGraph(object):
long long
idtype idtype
""" """
return DGLHeteroGraph(self._graph.asbits(32), self.ntypes, self.etypes, return self.__class__(self._graph.asbits(32), self.ntypes, self.etypes,
self._node_frames, self._node_frames,
self._edge_frames) self._edge_frames)
...@@ -5040,4 +5046,34 @@ def check_idtype_dict(graph_dtype, tensor_dict): ...@@ -5040,4 +5046,34 @@ def check_idtype_dict(graph_dtype, tensor_dict):
for _, v in tensor_dict.items(): for _, v in tensor_dict.items():
check_same_dtype(graph_dtype, v) check_same_dtype(graph_dtype, v)
class DGLBlock(DGLHeteroGraph):
"""Subclass that signifies the graph is a block created from
:func:`dgl.to_block`.
"""
# (BarclayII) I'm making a subclass because I don't want to make another version of
# serialization that contains the is_block flag.
is_block = True
def __repr__(self):
if len(self.srctypes) == 1 and len(self.dsttypes) == 1 and len(self.etypes) == 1:
ret = 'Block(num_src_nodes={srcnode}, num_dst_nodes={dstnode}, num_edges={edge})'
return ret.format(
srcnode=self.number_of_src_nodes(),
dstnode=self.number_of_dst_nodes(),
edge=self.number_of_edges())
else:
ret = ('Block(num_src_nodes={srcnode},\n'
' num_dst_nodes={dstnode},\n'
' num_edges={edge},\n'
' metagraph={meta})')
nsrcnode_dict = {ntype : self.number_of_src_nodes(ntype)
for ntype in self.srctypes}
ndstnode_dict = {ntype : self.number_of_dst_nodes(ntype)
for ntype in self.dsttypes}
nedge_dict = {etype : self.number_of_edges(etype)
for etype in self.canonical_etypes}
meta = str(self.metagraph.edges(keys=True))
return ret.format(
srcnode=nsrcnode_dict, dstnode=ndstnode_dict, edge=nedge_dict, meta=meta)
_init_api("dgl.heterograph") _init_api("dgl.heterograph")
...@@ -3,6 +3,7 @@ import importlib ...@@ -3,6 +3,7 @@ import importlib
import sys import sys
import os import os
from ..backend import backend_name from ..backend import backend_name
from ..utils import expand_as_pair
def _load_backend(mod_name): def _load_backend(mod_name):
mod = importlib.import_module('.%s' % mod_name, __name__) mod = importlib.import_module('.%s' % mod_name, __name__)
......
...@@ -60,10 +60,10 @@ class AGNNConv(nn.Block): ...@@ -60,10 +60,10 @@ class AGNNConv(nn.Block):
should be the same as input shape. should be the same as input shape.
""" """
with graph.local_scope(): with graph.local_scope():
feat_src, feat_dst = expand_as_pair(feat) feat_src, feat_dst = expand_as_pair(feat, graph)
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.srcdata['norm_h'] = normalize(feat_src, p=2, axis=-1) graph.srcdata['norm_h'] = normalize(feat_src, p=2, axis=-1)
if isinstance(feat, tuple): if isinstance(feat, tuple) or graph.is_block:
graph.dstdata['norm_h'] = normalize(feat_dst, p=2, axis=-1) graph.dstdata['norm_h'] = normalize(feat_dst, p=2, axis=-1)
# compute cosine distance # compute cosine distance
graph.apply_edges(fn.u_dot_v('norm_h', 'norm_h', 'cos')) graph.apply_edges(fn.u_dot_v('norm_h', 'norm_h', 'cos'))
......
...@@ -71,7 +71,7 @@ class EdgeConv(nn.Block): ...@@ -71,7 +71,7 @@ class EdgeConv(nn.Block):
New node features. New node features.
""" """
with g.local_scope(): with g.local_scope():
h_src, h_dst = expand_as_pair(h) h_src, h_dst = expand_as_pair(h, g)
g.srcdata['x'] = h_src g.srcdata['x'] = h_src
g.dstdata['x'] = h_dst g.dstdata['x'] = h_dst
if not self.batch_norm: if not self.batch_norm:
......
...@@ -129,6 +129,8 @@ class GATConv(nn.Block): ...@@ -129,6 +129,8 @@ class GATConv(nn.Block):
h_src = h_dst = self.feat_drop(feat) h_src = h_dst = self.feat_drop(feat)
feat_src = feat_dst = self.fc(h_src).reshape( feat_src = feat_dst = self.fc(h_src).reshape(
-1, self._num_heads, self._out_feats) -1, self._num_heads, self._out_feats)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
# NOTE: GAT paper uses "first concatenation then linear projection" # NOTE: GAT paper uses "first concatenation then linear projection"
# to compute attention scores, while ours is "first projection then # to compute attention scores, while ours is "first projection then
# addition", the two approaches are mathematically equivalent: # addition", the two approaches are mathematically equivalent:
......
...@@ -75,7 +75,7 @@ class GINConv(nn.Block): ...@@ -75,7 +75,7 @@ class GINConv(nn.Block):
as input dimensionality. as input dimensionality.
""" """
with graph.local_scope(): with graph.local_scope():
feat_src, feat_dst = expand_as_pair(feat) feat_src, feat_dst = expand_as_pair(feat, graph)
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_u('h', 'm'), self._reducer('m', 'neigh')) graph.update_all(fn.copy_u('h', 'm'), self._reducer('m', 'neigh'))
rst = (1 + self.eps.data(feat_dst.context)) * feat_dst + graph.dstdata['neigh'] rst = (1 + self.eps.data(feat_dst.context)) * feat_dst + graph.dstdata['neigh']
......
...@@ -115,7 +115,7 @@ class GMMConv(nn.Block): ...@@ -115,7 +115,7 @@ class GMMConv(nn.Block):
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is the output feature size. is the output feature size.
""" """
feat_src, feat_dst = expand_as_pair(feat) feat_src, feat_dst = expand_as_pair(feat, graph)
with graph.local_scope(): with graph.local_scope():
graph.srcdata['h'] = self.fc(feat_src).reshape( graph.srcdata['h'] = self.fc(feat_src).reshape(
-1, self._n_kernels, self._out_feats) -1, self._n_kernels, self._out_feats)
......
...@@ -128,7 +128,7 @@ class GraphConv(gluon.Block): ...@@ -128,7 +128,7 @@ class GraphConv(gluon.Block):
The output feature The output feature
""" """
with graph.local_scope(): with graph.local_scope():
feat_src, feat_dst = expand_as_pair(feat) feat_src, feat_dst = expand_as_pair(feat, graph)
if self._norm == 'both': if self._norm == 'both':
degs = graph.out_degrees().as_in_context(feat_src.context).astype('float32') degs = graph.out_degrees().as_in_context(feat_src.context).astype('float32')
......
...@@ -100,7 +100,7 @@ class NNConv(nn.Block): ...@@ -100,7 +100,7 @@ class NNConv(nn.Block):
is the output feature size. is the output feature size.
""" """
with graph.local_scope(): with graph.local_scope():
feat_src, feat_dst = expand_as_pair(feat) feat_src, feat_dst = expand_as_pair(feat, graph)
# (n, d_in, 1) # (n, d_in, 1)
graph.srcdata['h'] = feat_src.expand_dims(-1) graph.srcdata['h'] = feat_src.expand_dims(-1)
......
...@@ -104,6 +104,8 @@ class SAGEConv(nn.Block): ...@@ -104,6 +104,8 @@ class SAGEConv(nn.Block):
feat_dst = self.feat_drop(feat[1]) feat_dst = self.feat_drop(feat[1])
else: else:
feat_src = feat_dst = self.feat_drop(feat) feat_src = feat_dst = self.feat_drop(feat)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
h_self = feat_dst h_self = feat_dst
......
...@@ -59,10 +59,10 @@ class AGNNConv(nn.Module): ...@@ -59,10 +59,10 @@ class AGNNConv(nn.Module):
should be the same as input shape. should be the same as input shape.
""" """
with graph.local_scope(): with graph.local_scope():
feat_src, feat_dst = expand_as_pair(feat) feat_src, feat_dst = expand_as_pair(feat, graph)
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.srcdata['norm_h'] = F.normalize(feat_src, p=2, dim=-1) graph.srcdata['norm_h'] = F.normalize(feat_src, p=2, dim=-1)
if isinstance(feat, tuple): if isinstance(feat, tuple) or graph.is_block:
graph.dstdata['norm_h'] = F.normalize(feat_dst, p=2, dim=-1) graph.dstdata['norm_h'] = F.normalize(feat_dst, p=2, dim=-1)
# compute cosine distance # compute cosine distance
graph.apply_edges(fn.u_dot_v('norm_h', 'norm_h', 'cos')) graph.apply_edges(fn.u_dot_v('norm_h', 'norm_h', 'cos'))
......
...@@ -70,6 +70,8 @@ class DotGatConv(nn.Module): ...@@ -70,6 +70,8 @@ class DotGatConv(nn.Module):
else: else:
h_src = feat h_src = feat
feat_src = feat_dst = self.fc(h_src) feat_src = feat_dst = self.fc(h_src)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
# Assign features to nodes # Assign features to nodes
graph.srcdata.update({'ft': feat_src}) graph.srcdata.update({'ft': feat_src})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment