Unverified Commit e2a28a6c authored by Kay Liu's avatar Kay Liu Committed by GitHub
Browse files

[Doc] Revised Two Issues in Message Passing Tutorial (#2983)



* [Doc] modify the dimension of weight to Numpy broadcasting rule

* [Doc] modify the user defined reduce function
Co-authored-by: default avatarzhjwy9343 <6593865@qq.com>
parent 849cbec6
......@@ -279,9 +279,9 @@ class Model(nn.Module):
self.conv2 = WeightedSAGEConv(h_feats, num_classes)
def forward(self, g, in_feat):
h = self.conv1(g, in_feat, torch.ones(g.num_edges()).to(g.device))
h = self.conv1(g, in_feat, torch.ones(g.num_edges(), 1).to(g.device))
h = F.relu(h)
h = self.conv2(g, h, torch.ones(g.num_edges()).to(g.device))
h = self.conv2(g, h, torch.ones(g.num_edges(), 1).to(g.device))
return h
model = Model(g.ndata['feat'].shape[1], 16, dataset.num_classes)
......@@ -310,12 +310,12 @@ def u_mul_e_udf(edges):
######################################################################
# You can also write your own reduce function. For example, the following
# is equivalent to the builtin ``fn.sum('m', 'h')`` function that sums up
# is equivalent to the builtin ``fn.mean('m', 'h_N')`` function that averages
# the incoming messages:
#
def sum_udf(nodes):
return {'h': nodes.mailbox['m'].sum(1)}
def mean_udf(nodes):
return {'h_N': nodes.mailbox['m'].mean(1)}
######################################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment