Unverified Commit 7c788f53 authored by Ereboas's avatar Ereboas Committed by GitHub
Browse files

[Example] Update SEAL+NGNN for ogbl-citation2 (#4776)



* Use black for formatting

* limit line width to 80 characters.

* Use a backslash instead of directly concatenating

* file structure adjustment.

* file structure adjustment(2)

* codes for citation2

* format slight adjustment

* adjust format in models.py

* now it runs normally for all datasets.

* add comments; adjust code order.

* adjust indenting.
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent b2d38ca8
...@@ -48,17 +48,16 @@ class SEALOGBLDataset(Dataset): ...@@ -48,17 +48,16 @@ class SEALOGBLDataset(Dataset):
else: else:
self.node_features = None self.node_features = None
if not self.dynamic:
self.g_list, tensor_dict = self.load_cached()
self.labels = tensor_dict["y"]
return
pos_edge, neg_edge = get_pos_neg_edges( pos_edge, neg_edge = get_pos_neg_edges(
split, self.split_edge, self.graph, self.percent self.split, self.split_edge, self.graph, self.percent
) )
self.links = torch.cat([pos_edge, neg_edge], 0) # [Np + Nn, 2] self.links = torch.cat([pos_edge, neg_edge], 0) # [Np + Nn, 2]
self.labels = np.array([1] * len(pos_edge) + [0] * len(neg_edge)) self.labels = np.array([1] * len(pos_edge) + [0] * len(neg_edge))
if not self.dynamic:
self.g_list, tensor_dict = self.load_cached()
self.labels = tensor_dict["y"]
def __len__(self): def __len__(self):
return len(self.labels) return len(self.labels)
...@@ -129,12 +128,6 @@ class SEALOGBLDataset(Dataset): ...@@ -129,12 +128,6 @@ class SEALOGBLDataset(Dataset):
if not os.path.exists(self.root): if not os.path.exists(self.root):
os.makedirs(self.root) os.makedirs(self.root)
pos_edge, neg_edge = get_pos_neg_edges(
self.split, self.split_edge, self.graph, self.percent
)
self.links = torch.cat([pos_edge, neg_edge], 0) # [Np + Nn, 2]
self.labels = np.array([1] * len(pos_edge) + [0] * len(neg_edge))
g_list, labels = self.process() g_list, labels = self.process()
save_graphs(path, g_list, labels) save_graphs(path, g_list, labels)
return g_list, labels return g_list, labels
...@@ -508,6 +501,12 @@ if __name__ == "__main__": ...@@ -508,6 +501,12 @@ if __name__ == "__main__":
if 0 < args.sortpool_k <= 1: # Transform percentile to number. if 0 < args.sortpool_k <= 1: # Transform percentile to number.
if args.dataset.startswith("ogbl-citation"): if args.dataset.startswith("ogbl-citation"):
# For this dataset, subgraphs extracted around positive edges are
# rather larger than negative edges. Thus we sample from 1000
# positive and 1000 negative edges to estimate the k (number of
# nodes to hold for each graph) used in SortPooling.
# You can certainly set k manually, instead of estimating from
# a percentage of sampled subgraphs.
_sampled_indices = list(range(1000)) + list( _sampled_indices = list(range(1000)) + list(
range(len(train_dataset) - 1000, len(train_dataset)) range(len(train_dataset) - 1000, len(train_dataset))
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment