Unverified Commit 168794cd authored by Lingfan Yu's avatar Lingfan Yu Committed by GitHub
Browse files

remove text using code line numbers (#217)

parent 7156c716
...@@ -125,6 +125,7 @@ base. This tutorial shows how to implement R-GCN with DGL. ...@@ -125,6 +125,7 @@ base. This tutorial shows how to implement R-GCN with DGL.
# Each relation type is associated with a different weight. Therefore, # Each relation type is associated with a different weight. Therefore,
# the full weight matrix has three dimensions: relation, input_feature, # the full weight matrix has three dimensions: relation, input_feature,
# output_feature. # output_feature.
#
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -149,10 +150,11 @@ class RGCNLayer(nn.Module): ...@@ -149,10 +150,11 @@ class RGCNLayer(nn.Module):
if self.num_bases <= 0 or self.num_bases > self.num_rels: if self.num_bases <= 0 or self.num_bases > self.num_rels:
self.num_bases = self.num_rels self.num_bases = self.num_rels
# add weights # weight bases in equation (3)
self.weight = nn.Parameter(torch.Tensor(self.num_bases, self.in_feat, self.weight = nn.Parameter(torch.Tensor(self.num_bases, self.in_feat,
self.out_feat)) self.out_feat))
if self.num_bases < self.num_rels: if self.num_bases < self.num_rels:
# linear combination coefficients in equation (3)
self.w_comp = nn.Parameter(torch.Tensor(self.num_rels, self.num_bases)) self.w_comp = nn.Parameter(torch.Tensor(self.num_rels, self.num_bases))
# add bias # add bias
...@@ -171,7 +173,7 @@ class RGCNLayer(nn.Module): ...@@ -171,7 +173,7 @@ class RGCNLayer(nn.Module):
def forward(self, g): def forward(self, g):
if self.num_bases < self.num_rels: if self.num_bases < self.num_rels:
# generate all weights from basis (equation (3)) # generate all weights from bases (equation (3))
weight = self.weight.view(self.in_feat, self.num_bases, self.out_feat) weight = self.weight.view(self.in_feat, self.num_bases, self.out_feat)
weight = torch.matmul(self.w_comp, weight).view(self.num_rels, weight = torch.matmul(self.w_comp, weight).view(self.num_rels,
self.in_feat, self.out_feat) self.in_feat, self.out_feat)
...@@ -204,19 +206,6 @@ class RGCNLayer(nn.Module): ...@@ -204,19 +206,6 @@ class RGCNLayer(nn.Module):
############################################################################### ###############################################################################
# As mentioned above, R-GCN uses decomposition of reduce parameter size
# (equation (3)). So line 18-19 defines the weight bases (:math:`V_b^{(l)}`),
# and line 20-21 defines the linear combination coefficients
# (:math:`a_{rb}^{(l)}`). The forward function of R-GCN layer is similar to
# GCN, except that at the beginning of forward phase, weights for each
# relation type is generated from bases (line 41-44).
#
# The message function for R-GCN replicates weights onto edges and then
# generates messages (line 55-59). But for the first layer, since the node
# feature is the node id, the transformation from node feature to messages
# can be computed more efficiently by performing an embedding lookup (line
# 49-53).
#
# Define full R-GCN model # Define full R-GCN model
# ~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment