Unverified Commit 2f43cdb3 authored by WangYQ's avatar WangYQ Committed by GitHub
Browse files

deal with situation where num_layers equals 1 (#3066)

parent 9664cdff
......@@ -17,10 +17,13 @@ class SAGE(nn.Module):
self.n_hidden = n_hidden
self.n_classes = n_classes
self.layers = nn.ModuleList()
self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
for i in range(1, n_layers - 1):
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean'))
if n_layers > 1:
self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
for i in range(1, n_layers - 1):
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean'))
else:
self.layers.append(dglnn.SAGEConv(in_feats, n_classes, 'mean'))
self.dropout = nn.Dropout(dropout)
self.activation = activation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment