Unverified Commit 6752bd45 authored by Sai Kandregula's avatar Sai Kandregula Committed by GitHub
Browse files

[Example] layers number fix in GAT example (#3872)



* gat layers fix

* minor bug
Co-authored-by: default avatardecoherencer <decoherencer@users.noreply.github.com>
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent f758db38
...@@ -31,20 +31,25 @@ class GAT(nn.Module): ...@@ -31,20 +31,25 @@ class GAT(nn.Module):
self.num_layers = num_layers self.num_layers = num_layers
self.gat_layers = nn.ModuleList() self.gat_layers = nn.ModuleList()
self.activation = activation self.activation = activation
if num_layers > 1:
# input projection (no residual) # input projection (no residual)
self.gat_layers.append(GATConv(
in_dim, num_hidden, heads[0],
feat_drop, attn_drop, negative_slope, False, self.activation))
# hidden layers
for l in range(1, num_layers):
# due to multi-head, the in_dim = num_hidden * num_heads
self.gat_layers.append(GATConv( self.gat_layers.append(GATConv(
num_hidden * heads[l-1], num_hidden, heads[l], in_dim, num_hidden, heads[0],
feat_drop, attn_drop, negative_slope, residual, self.activation)) feat_drop, attn_drop, negative_slope, False, self.activation))
# output projection # hidden layers
self.gat_layers.append(GATConv( for l in range(1, num_layers-1):
num_hidden * heads[-2], num_classes, heads[-1], # due to multi-head, the in_dim = num_hidden * num_heads
feat_drop, attn_drop, negative_slope, residual, None)) self.gat_layers.append(GATConv(
num_hidden * heads[l-1], num_hidden, heads[l],
feat_drop, attn_drop, negative_slope, residual, self.activation))
# output projection
self.gat_layers.append(GATConv(
num_hidden * heads[-2], num_classes, heads[-1],
feat_drop, attn_drop, negative_slope, residual, None))
else:
self.gat_layers.append(GATConv(
in_dim, num_classes, heads[0],
feat_drop, attn_drop, negative_slope, residual, None))
def forward(self, inputs): def forward(self, inputs):
h = inputs h = inputs
......
...@@ -79,7 +79,7 @@ def main(args): ...@@ -79,7 +79,7 @@ def main(args):
g = dgl.add_self_loop(g) g = dgl.add_self_loop(g)
n_edges = g.number_of_edges() n_edges = g.number_of_edges()
# create model # create model
heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads] heads = ([args.num_heads] * (args.num_layers-1)) + [args.num_out_heads]
model = GAT(g, model = GAT(g,
args.num_layers, args.num_layers,
num_feats, num_feats,
...@@ -153,7 +153,7 @@ if __name__ == '__main__': ...@@ -153,7 +153,7 @@ if __name__ == '__main__':
help="number of hidden attention heads") help="number of hidden attention heads")
parser.add_argument("--num-out-heads", type=int, default=1, parser.add_argument("--num-out-heads", type=int, default=1,
help="number of output attention heads") help="number of output attention heads")
parser.add_argument("--num-layers", type=int, default=1, parser.add_argument("--num-layers", type=int, default=2,
help="number of hidden layers") help="number of hidden layers")
parser.add_argument("--num-hidden", type=int, default=8, parser.add_argument("--num-hidden", type=int, default=8,
help="number of hidden units") help="number of hidden units")
......
...@@ -57,7 +57,7 @@ def main(args): ...@@ -57,7 +57,7 @@ def main(args):
n_classes = train_dataset.num_labels n_classes = train_dataset.num_labels
num_feats = g.ndata['feat'].shape[1] num_feats = g.ndata['feat'].shape[1]
g = g.int().to(device) g = g.int().to(device)
heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads] heads = ([args.num_heads] * (args.num_layers-1)) + [args.num_out_heads]
# define the model # define the model
model = GAT(g, model = GAT(g,
args.num_layers, args.num_layers,
...@@ -129,7 +129,7 @@ if __name__ == '__main__': ...@@ -129,7 +129,7 @@ if __name__ == '__main__':
help="number of hidden attention heads") help="number of hidden attention heads")
parser.add_argument("--num-out-heads", type=int, default=6, parser.add_argument("--num-out-heads", type=int, default=6,
help="number of output attention heads") help="number of output attention heads")
parser.add_argument("--num-layers", type=int, default=2, parser.add_argument("--num-layers", type=int, default=3,
help="number of hidden layers") help="number of hidden layers")
parser.add_argument("--num-hidden", type=int, default=256, parser.add_argument("--num-hidden", type=int, default=256,
help="number of hidden units") help="number of hidden units")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment