Unverified Commit f9805ef1 authored by Ereboas's avatar Ereboas Committed by GitHub
Browse files

[Example] SEAL+NGNN for ogbl-citation2 (#4772)



* Use black for formatting

* limit line width to 80 characters.

* Use a backslash instead of directly concatenating

* file structure adjustment.

* file structure adjustment(2)

* codes for citation2

* format slight adjustment

* adjust format in models.py

* now it runs normally for all datasets.
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent 4b1fb681
...@@ -33,6 +33,21 @@ python main.py --dataset ogbl-ppa --ngnn_type input --hidden_channels 48 --epoch ...@@ -33,6 +33,21 @@ python main.py --dataset ogbl-ppa --ngnn_type input --hidden_channels 48 --epoch
As training is very costly, we select the best model by evaluation on a subset of the validation edges and using a lower K for Hits@K. Then we do experiments on the full validation and test sets with the best model selected, and get the required metrics. As training is very costly, we select the best model by evaluation on a subset of the validation edges and using a lower K for Hits@K. Then we do experiments on the full validation and test sets with the best model selected, and get the required metrics.
### ogbl-citation2
#### performance
| | Test MRR | Validation MRR | #Parameters |
|:------------:|:-------------------:|:-----------------:|:------------:|
| SEAL | 0.8767 ± 0.0032 | 0.8757 ± 0.0031 | 260,802 |
| SEAL + NGNN | 0.8891 ± 0.0022 | 0.8879 ± 0.0022 | 1,134,402 |
#### Reproduction of performance
```{.bash}
python main.py --dataset ogbl-citation2 --ngnn_type all --hidden_channels 256 --epochs 15 --lr 2e-05 --batch_size 64 --num_workers 24 --train_percent 8 --val_percent 4 --num_ngnn_layers 2 --use_feature --use_edge_weight --dynamic_train --dynamic_val --dynamic_test --runs 10
```
For all datasets, if you specify `--dynamic_train`, the enclosing subgraphs of the training links will be extracted on the fly instead of preprocessing and saving to disk. Similarly for `--dynamic_val` and `--dynamic_test`. You can increase `--num_workers` to accelerate the dynamic subgraph extraction process. For all datasets, if you specify `--dynamic_train`, the enclosing subgraphs of the training links will be extracted on the fly instead of preprocessing and saving to disk. Similarly for `--dynamic_val` and `--dynamic_test`. You can increase `--num_workers` to accelerate the dynamic subgraph extraction process.
You can also specify the `val_percent` and `eval_hits_K` arguments in the above command to adjust the proportion of the validation dataset to use and the K to use for Hits@K. You can also specify the `val_percent` and `eval_hits_K` arguments in the above command to adjust the proportion of the validation dataset to use and the K to use for Hits@K.
......
...@@ -39,11 +39,6 @@ class SEALOGBLDataset(Dataset): ...@@ -39,11 +39,6 @@ class SEALOGBLDataset(Dataset):
self.directed = directed self.directed = directed
self.dynamic = dynamic self.dynamic = dynamic
if not self.dynamic:
self.g_list, tensor_dict = self.load_cached()
self.labels = tensor_dict["y"]
return
if "weights" in self.graph.edata: if "weights" in self.graph.edata:
self.edge_weights = self.graph.edata["weights"] self.edge_weights = self.graph.edata["weights"]
else: else:
...@@ -53,11 +48,16 @@ class SEALOGBLDataset(Dataset): ...@@ -53,11 +48,16 @@ class SEALOGBLDataset(Dataset):
else: else:
self.node_features = None self.node_features = None
if not self.dynamic:
self.g_list, tensor_dict = self.load_cached()
self.labels = tensor_dict["y"]
return
pos_edge, neg_edge = get_pos_neg_edges( pos_edge, neg_edge = get_pos_neg_edges(
split, self.split_edge, self.graph, self.percent split, self.split_edge, self.graph, self.percent
) )
self.links = torch.cat([pos_edge, neg_edge], 0).tolist() # [Np + Nn, 2] self.links = torch.cat([pos_edge, neg_edge], 0) # [Np + Nn, 2]
self.labels = [1] * len(pos_edge) + [0] * len(neg_edge) self.labels = np.array([1] * len(pos_edge) + [0] * len(neg_edge))
def __len__(self): def __len__(self):
return len(self.labels) return len(self.labels)
...@@ -69,7 +69,7 @@ class SEALOGBLDataset(Dataset): ...@@ -69,7 +69,7 @@ class SEALOGBLDataset(Dataset):
w = None if "w" not in g.edata else g.eata["w"] w = None if "w" not in g.edata else g.eata["w"]
return g, g.ndata["z"], x, w, y return g, g.ndata["z"], x, w, y
src, dst = self.links[idx] src, dst = self.links[idx][0].item(), self.links[idx][1].item()
y = self.labels[idx] y = self.labels[idx]
subg = k_hop_subgraph( subg = k_hop_subgraph(
src, dst, 1, self.graph, self.ratio_per_hop, self.directed src, dst, 1, self.graph, self.ratio_per_hop, self.directed
...@@ -132,8 +132,8 @@ class SEALOGBLDataset(Dataset): ...@@ -132,8 +132,8 @@ class SEALOGBLDataset(Dataset):
pos_edge, neg_edge = get_pos_neg_edges( pos_edge, neg_edge = get_pos_neg_edges(
self.split, self.split_edge, self.graph, self.percent self.split, self.split_edge, self.graph, self.percent
) )
self.links = torch.cat([pos_edge, neg_edge], 0).tolist() # [Np + Nn, 2] self.links = torch.cat([pos_edge, neg_edge], 0) # [Np + Nn, 2]
self.labels = [1] * len(pos_edge) + [0] * len(neg_edge) self.labels = np.array([1] * len(pos_edge) + [0] * len(neg_edge))
g_list, labels = self.process() g_list, labels = self.process()
save_graphs(path, g_list, labels) save_graphs(path, g_list, labels)
...@@ -285,6 +285,9 @@ if __name__ == "__main__": ...@@ -285,6 +285,9 @@ if __name__ == "__main__":
help="You can set this value from 'none', 'input', 'hidden' or 'all' " \ help="You can set this value from 'none', 'input', 'hidden' or 'all' " \
"to apply NGNN to different GNN layers.", "to apply NGNN to different GNN layers.",
) )
parser.add_argument(
"--num_ngnn_layers", type=int, default=1, choices=[1, 2]
)
# Subgraph extraction settings # Subgraph extraction settings
parser.add_argument("--ratio_per_hop", type=float, default=1.0) parser.add_argument("--ratio_per_hop", type=float, default=1.0)
parser.add_argument( parser.add_argument(
...@@ -310,8 +313,8 @@ if __name__ == "__main__": ...@@ -310,8 +313,8 @@ if __name__ == "__main__":
parser.add_argument("--runs", type=int, default=10) parser.add_argument("--runs", type=int, default=10)
parser.add_argument("--train_percent", type=float, default=1) parser.add_argument("--train_percent", type=float, default=1)
parser.add_argument("--val_percent", type=float, default=1) parser.add_argument("--val_percent", type=float, default=1)
parser.add_argument("--final_val_percent", type=float, default=1) parser.add_argument("--final_val_percent", type=float, default=100)
parser.add_argument("--test_percent", type=float, default=1) parser.add_argument("--test_percent", type=float, default=100)
parser.add_argument("--no_test", action="store_true") parser.add_argument("--no_test", action="store_true")
parser.add_argument( parser.add_argument(
"--dynamic_train", "--dynamic_train",
...@@ -504,13 +507,12 @@ if __name__ == "__main__": ...@@ -504,13 +507,12 @@ if __name__ == "__main__":
) )
if 0 < args.sortpool_k <= 1: # Transform percentile to number. if 0 < args.sortpool_k <= 1: # Transform percentile to number.
if args.dynamic_train: if args.dataset.startswith("ogbl-citation"):
_sampled_indices = range(1000) _sampled_indices = list(range(1000)) + list(
#_sampled_indices = np.random.choice( range(len(train_dataset) - 1000, len(train_dataset))
# len(train_dataset), 1000, replace=False )
# )
else: else:
_sampled_indices = range(len(train_dataset)) _sampled_indices = list(range(1000))
_num_nodes = sorted( _num_nodes = sorted(
[train_dataset[i][0].num_nodes() for i in _sampled_indices] [train_dataset[i][0].num_nodes() for i in _sampled_indices]
) )
...@@ -535,6 +537,7 @@ if __name__ == "__main__": ...@@ -535,6 +537,7 @@ if __name__ == "__main__":
else 0, else 0,
dropout=args.dropout, dropout=args.dropout,
ngnn_type=args.ngnn_type, ngnn_type=args.ngnn_type,
num_ngnn_layers=args.num_ngnn_layers,
).to(device) ).to(device)
parameters = list(model.parameters()) parameters = list(model.parameters())
optimizer = torch.optim.Adam(params=parameters, lr=args.lr) optimizer = torch.optim.Adam(params=parameters, lr=args.lr)
......
...@@ -7,11 +7,14 @@ from torch.nn import Conv1d, Embedding, Linear, MaxPool1d, ModuleList ...@@ -7,11 +7,14 @@ from torch.nn import Conv1d, Embedding, Linear, MaxPool1d, ModuleList
class NGNN_GCNConv(torch.nn.Module): class NGNN_GCNConv(torch.nn.Module):
def __init__(self, input_channels, hidden_channels, output_channels): def __init__(
self, input_channels, hidden_channels, output_channels, num_layers
):
super(NGNN_GCNConv, self).__init__() super(NGNN_GCNConv, self).__init__()
self.conv = GraphConv(input_channels, hidden_channels) self.conv = GraphConv(input_channels, hidden_channels)
self.fc = Linear(hidden_channels, hidden_channels) self.fc = Linear(hidden_channels, hidden_channels)
self.fc2 = Linear(hidden_channels, output_channels) self.fc2 = Linear(hidden_channels, output_channels)
self.num_layers = num_layers
def reset_parameters(self): def reset_parameters(self):
self.conv.reset_parameters() self.conv.reset_parameters()
...@@ -24,8 +27,9 @@ class NGNN_GCNConv(torch.nn.Module): ...@@ -24,8 +27,9 @@ class NGNN_GCNConv(torch.nn.Module):
def forward(self, g, x, edge_weight=None): def forward(self, g, x, edge_weight=None):
x = self.conv(g, x, edge_weight) x = self.conv(g, x, edge_weight)
# x = F.relu(x) if self.num_layers == 2:
# x = self.fc(x) x = F.relu(x)
x = self.fc(x)
x = F.relu(x) x = F.relu(x)
x = self.fc2(x) x = self.fc2(x)
return x return x
...@@ -44,6 +48,7 @@ class DGCNN(torch.nn.Module): ...@@ -44,6 +48,7 @@ class DGCNN(torch.nn.Module):
NGNN=NGNN_GCNConv, NGNN=NGNN_GCNConv,
dropout=0.0, dropout=0.0,
ngnn_type="all", ngnn_type="all",
num_ngnn_layers=1,
): ):
super(DGCNN, self).__init__() super(DGCNN, self).__init__()
...@@ -59,9 +64,15 @@ class DGCNN(torch.nn.Module): ...@@ -59,9 +64,15 @@ class DGCNN(torch.nn.Module):
self.convs = ModuleList() self.convs = ModuleList()
initial_channels = hidden_channels + self.feature_dim initial_channels = hidden_channels + self.feature_dim
self.num_ngnn_layers = num_ngnn_layers
if ngnn_type in ["input", "all"]: if ngnn_type in ["input", "all"]:
self.convs.append( self.convs.append(
NGNN(initial_channels, hidden_channels, hidden_channels) NGNN(
initial_channels,
hidden_channels,
hidden_channels,
self.num_ngnn_layers,
)
) )
else: else:
self.convs.append(GNN(initial_channels, hidden_channels)) self.convs.append(GNN(initial_channels, hidden_channels))
...@@ -69,14 +80,21 @@ class DGCNN(torch.nn.Module): ...@@ -69,14 +80,21 @@ class DGCNN(torch.nn.Module):
if ngnn_type in ["hidden", "all"]: if ngnn_type in ["hidden", "all"]:
for _ in range(0, num_layers - 1): for _ in range(0, num_layers - 1):
self.convs.append( self.convs.append(
NGNN(hidden_channels, hidden_channels, hidden_channels) NGNN(
hidden_channels,
hidden_channels,
hidden_channels,
self.num_ngnn_layers,
)
) )
else: else:
for _ in range(0, num_layers - 1): for _ in range(0, num_layers - 1):
self.convs.append(GNN(hidden_channels, hidden_channels)) self.convs.append(GNN(hidden_channels, hidden_channels))
if ngnn_type in ["output", "all"]: if ngnn_type in ["output", "all"]:
self.convs.append(NGNN(hidden_channels, hidden_channels, 1)) self.convs.append(
NGNN(hidden_channels, hidden_channels, 1, self.num_ngnn_layers)
)
else: else:
self.convs.append(GNN(hidden_channels, 1)) self.convs.append(GNN(hidden_channels, 1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment