"docs/source/api/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "37bd092587430540a9b3eae2ee86c75b9a05e614"
Unverified Commit c8ec6ee5 authored by Hao Xiong's avatar Hao Xiong Committed by GitHub
Browse files

[Example] Experimental results of deepwalk on ogb datasets (#1729)

* ogb-deepwalk

* ?

* ?

* remove pyg
parent b1e83b6e
...@@ -3,6 +3,11 @@ To load ogb dataset, you need to run the following command, which will output a ...@@ -3,6 +3,11 @@ To load ogb dataset, you need to run the following command, which will output a
``` ```
python3 load_dataset.py --name ogbl-collab python3 load_dataset.py --name ogbl-collab
``` ```
Or you can run the code directly with:
```
python3 deepwalk --ogbl_name xxx --load_from_ogbl
```
However, ogb.linkproppred might not be compatible with mixed training with multi-gpu. If you want to do mixed training, please use no more than 1 gpu by the command above.
## Evaluation ## Evaluation
For evaluatation we follow the code mlp.py provided by ogb [here](https://github.com/snap-stanford/ogb/blob/master/examples/linkproppred/collab/mlp.py). For evaluatation we follow the code mlp.py provided by ogb [here](https://github.com/snap-stanford/ogb/blob/master/examples/linkproppred/collab/mlp.py).
...@@ -10,7 +15,7 @@ For evaluatation we follow the code mlp.py provided by ogb [here](https://github ...@@ -10,7 +15,7 @@ For evaluatation we follow the code mlp.py provided by ogb [here](https://github
## Used config ## Used config
ogbl-collab ogbl-collab
``` ```
python3 deepwalk.py --data_file ogbl-collab-net.txt --save_in_pt --output_emb_file embedding.pt --num_walks 50 --window_size 20 --walk_length 40 --lr 0.1 --negative 1 --neg_weight 1 --lap_norm 0.005 --mix --adam --gpus 0 1 2 3 --num_threads 4 --print_interval 2000 --print_loss --batch_size 32 python3 deepwalk.py --ogbl_name ogbl-collab --load_from_ogbl --save_in_pt --output_emb_file embedding.pt --num_walks 50 --window_size 20 --walk_length 40 --lr 0.1 --negative 1 --neg_weight 1 --lap_norm 0.005 --mix --adam --gpus 0 --num_threads 4 --print_interval 2000 --print_loss --batch_size 32
cd ./ogb/blob/master/examples/linkproppred/collab/ cd ./ogb/blob/master/examples/linkproppred/collab/
cp embedding_pt_file_path ./ cp embedding_pt_file_path ./
python3 mlp.py --device 0 --runs 10 --use_node_embedding python3 mlp.py --device 0 --runs 10 --use_node_embedding
...@@ -18,28 +23,29 @@ python3 mlp.py --device 0 --runs 10 --use_node_embedding ...@@ -18,28 +23,29 @@ python3 mlp.py --device 0 --runs 10 --use_node_embedding
ogbl-ddi ogbl-ddi
``` ```
python3 deepwalk.py --data_file ogbl-ddi-net.txt --save_in_pt --output_emb_file ddi-embedding.pt --num_walks 50 --window_size 2 --walk_length 80 --lr 0.1 --negative 1 --neg_weight 1 --lap_norm 0.05 --only_gpu --adam --gpus 0 --num_threads 4 --print_interval 2000 --print_loss --batch_size 16 --use_context_weight python3 deepwalk.py --ogbl_name ogbl-ddi --load_from_ogbl --save_in_pt --output_emb_file ddi-embedding.pt --num_walks 50 --window_size 2 --walk_length 80 --lr 0.1 --negative 1 --neg_weight 1 --lap_norm 0.05 --only_gpu --adam --gpus 0 --num_threads 4 --print_interval 2000 --print_loss --batch_size 16 --use_context_weight
cd ./ogb/blob/master/examples/linkproppred/ddi/ cd ./ogb/blob/master/examples/linkproppred/ddi/
cp embedding_pt_file_path ./ cp embedding_pt_file_path ./
python3 mlp.py --device 0 --runs 5 python3 mlp.py --device 0 --runs 10 --epochs 100
``` ```
ogbl-ppa ogbl-ppa
``` ```
python3 deepwalk.py --data_file ogbl-ppa-net.txt --save_in_pt --output_emb_file ppa-embedding.pt --negative 1 --neg_weight 1 --batch_size 64 --print_interval 2000 --print_loss --window_size 2 --num_walks 30 --walk_length 80 --lr 0.1 --lap_norm 0.02 --adam --mix --gpus 0 1 --use_context_weight --num_threads 4 python3 deepwalk.py --ogbl_name ogbl-ppa --load_from_ogbl --save_in_pt --output_emb_file ppa-embedding.pt --negative 1 --neg_weight 1 --batch_size 64 --print_interval 2000 --print_loss --window_size 2 --num_walks 30 --walk_length 80 --lr 0.1 --lap_norm 0.02 --adam --mix --gpus 0 --use_context_weight --num_threads 4
cp embedding_pt_file_path ./ cp embedding_pt_file_path ./
python3 mlp.py --device 2 --runs 10 python3 mlp.py --device 2 --runs 10
``` ```
ogbl-citation ogbl-citation
``` ```
python3 deepwalk.py --data_file ogbl-citation-net.txt --save_in_pt --output_emb_file embedding.pt --window_size 2 --num_walks 10 --negative 1 --neg_weight 1 --walk_length 80 --batch_size 128 --print_loss --print_interval 1000 --mix --adam --gpus 0 1 2 3 --use_context_weight --num_threads 4 --lap_norm 0.05 --lr 0.1 python3 deepwalk.py --ogbl_name ogbl-citation --load_from_ogbl --save_in_pt --output_emb_file embedding.pt --window_size 2 --num_walks 10 --negative 1 --neg_weight 1 --walk_length 80 --batch_size 128 --print_loss --print_interval 1000 --mix --adam --gpus 0 --use_context_weight --num_threads 4 --lap_norm 0.05 --lr 0.1
cp embedding_pt_file_path ./ cp embedding_pt_file_path ./
python3 mlp.py --device 2 --runs 5 --use_node_embedding python3 mlp.py --device 2 --runs 10 --use_node_embedding
``` ```
## Score ## Result
ogbl-collab ogbl-collab
<br>#params: 61258346(model) + 131841(mlp) = 61390187
<br>Hits@10 <br>Hits@10
<br>&emsp;Highest Train: 74.83 ± 4.79 <br>&emsp;Highest Train: 74.83 ± 4.79
<br>&emsp;Highest Valid: 40.03 ± 2.98 <br>&emsp;Highest Valid: 40.03 ± 2.98
...@@ -57,24 +63,26 @@ ogbl-collab ...@@ -57,24 +63,26 @@ ogbl-collab
<br>&emsp;&emsp;Final Test: 56.88 ± 0.37 <br>&emsp;&emsp;Final Test: 56.88 ± 0.37
<br>obgl-ddi <br>obgl-ddi
<br>Hits@10 <br>#params: 1444840(model) + 99073(mlp) = 1543913
<br>&emsp;Highest Train: 35.05 ± 3.68 <br>&emsp;Hits@10
<br>&emsp;Highest Valid: 31.72 ± 3.52 <br>&emsp;Highest Train: 36.09 ± 2.47
<br>&emsp;&emsp;Final Train: 35.05 ± 3.68 <br>&emsp;Highest Valid: 32.83 ± 2.30
<br>&emsp;&emsp;Final Test: 12.68 ± 3.19 <br>&emsp;&emsp;Final Train: 36.06 ± 2.45
<br>Hits@20 <br>&emsp;&emsp;Final Test: 11.76 ± 3.91
<br>&emsp;Highest Train: 44.85 ± 1.26 <br>&emsp;Hits@20
<br>&emsp;Highest Valid: 41.20 ± 1.41 <br>&emsp;Highest Train: 45.59 ± 2.45
<br>&emsp;&emsp;Final Train: 44.85 ± 1.26 <br>&emsp;Highest Valid: 42.00 ± 2.36
<br>&emsp;&emsp;Final Test: 21.69 ± 3.14 <br>&emsp;&emsp;Final Train: 45.56 ± 2.50
<br>Hits@30 <br>&emsp;&emsp;Final Test: 22.46 ± 2.90
<br>&emsp;Highest Train: 52.28 ± 1.21 <br>&emsp;Hits@30
<br>&emsp;Highest Valid: 48.49 ± 1.09 <br>&emsp;Highest Train: 51.58 ± 2.41
<br>&emsp;&emsp;Final Train: 52.28 ± 1.21 <br>&emsp;Highest Valid: 47.82 ± 2.19
<br>&emsp;&emsp;Final Test: 29.13 ± 3.46 <br>&emsp;&emsp;Final Train: 51.58 ± 2.42
<br>&emsp;&emsp;Final Test: 30.17 ± 3.39
<br>ogbl-ppa <br>ogbl-ppa
<br>#params: 150024820(model) + 113921(mlp) = 150138741
<br>Hits@10 <br>Hits@10
<br>&emsp;Highest Train: 3.58 ± 0.90 <br>&emsp;Highest Train: 3.58 ± 0.90
<br>&emsp;Highest Valid: 2.88 ± 0.76 <br>&emsp;Highest Valid: 2.88 ± 0.76
...@@ -92,8 +100,9 @@ ogbl-collab ...@@ -92,8 +100,9 @@ ogbl-collab
<br>&emsp;&emsp;Final Test: 23.02 ± 1.63 <br>&emsp;&emsp;Final Test: 23.02 ± 1.63
<br>ogbl-citation <br>ogbl-citation
<br>#params: 757811178(model) + 131841(mlp) = 757943019
<br>MRR <br>MRR
<br>&emsp;Highest Train: 0.8796 ± 0.0007 <br>&emsp;Highest Train: 0.8797 ± 0.0007
<br>&emsp;Highest Valid: 0.8141 ± 0.0007 <br>&emsp;Highest Valid: 0.8139 ± 0.0005
<br>&emsp;&emsp;Final Train: 0.8793 ± 0.0008 <br>&emsp;&emsp;Final Train: 0.8792 ± 0.0008
<br>&emsp;&emsp;Final Test: 0.8159 ± 0.0006 <br>&emsp;&emsp;Final Test: 0.8148 ± 0.0004
...@@ -10,7 +10,7 @@ import numpy as np ...@@ -10,7 +10,7 @@ import numpy as np
from reading_data import DeepwalkDataset from reading_data import DeepwalkDataset
from model import SkipGramModel from model import SkipGramModel
from utils import thread_wrapped_func, shuffle_walks from utils import thread_wrapped_func, shuffle_walks, sum_up_params
class DeepwalkTrainer: class DeepwalkTrainer:
def __init__(self, args): def __init__(self, args):
...@@ -26,8 +26,10 @@ class DeepwalkTrainer: ...@@ -26,8 +26,10 @@ class DeepwalkTrainer:
negative=args.negative, negative=args.negative,
gpus=args.gpus, gpus=args.gpus,
fast_neg=args.fast_neg, fast_neg=args.fast_neg,
ogbl_name=args.ogbl_name,
load_from_ogbl=args.load_from_ogbl,
) )
self.emb_size = len(self.dataset.net) self.emb_size = self.dataset.G.number_of_nodes()
self.emb_model = None self.emb_model = None
def init_device_emb(self): def init_device_emb(self):
...@@ -71,7 +73,7 @@ class DeepwalkTrainer: ...@@ -71,7 +73,7 @@ class DeepwalkTrainer:
print("Mix CPU with %d GPU" % len(self.args.gpus)) print("Mix CPU with %d GPU" % len(self.args.gpus))
if len(self.args.gpus) == 1: if len(self.args.gpus) == 1:
assert self.args.gpus[0] >= 0, 'mix CPU with GPU should have abaliable GPU' assert self.args.gpus[0] >= 0, 'mix CPU with GPU should have abaliable GPU'
self.emb_model.set_device(self.args.gpus[0]) #self.emb_model.set_device(self.args.gpus[0])
else: else:
print("Run in CPU process") print("Run in CPU process")
self.args.gpus = [torch.device('cpu')] self.args.gpus = [torch.device('cpu')]
...@@ -89,6 +91,9 @@ class DeepwalkTrainer: ...@@ -89,6 +91,9 @@ class DeepwalkTrainer:
self.init_device_emb() self.init_device_emb()
self.emb_model.share_memory() self.emb_model.share_memory()
if self.args.count_params:
sum_up_params(self.emb_model)
start_all = time.time() start_all = time.time()
ps = [] ps = []
...@@ -170,6 +175,9 @@ class DeepwalkTrainer: ...@@ -170,6 +175,9 @@ class DeepwalkTrainer:
self.init_device_emb() self.init_device_emb()
if self.args.count_params:
sum_up_params(self.emb_model)
sampler = self.dataset.create_sampler(0) sampler = self.dataset.create_sampler(0)
dataloader = DataLoader( dataloader = DataLoader(
...@@ -227,8 +235,17 @@ class DeepwalkTrainer: ...@@ -227,8 +235,17 @@ class DeepwalkTrainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description="DeepWalk") parser = argparse.ArgumentParser(description="DeepWalk")
# input files
## personal datasets
parser.add_argument('--data_file', type=str, parser.add_argument('--data_file', type=str,
help="path of the txt network file, builtin dataset include youtube-net and blog-net") help="path of the txt network file, builtin dataset include youtube-net and blog-net")
## ogbl datasets
parser.add_argument('--ogbl_name', type=str,
help="name of ogbl dataset, e.g. ogbl-ddi")
parser.add_argument('--load_from_ogbl', default=False, action="store_true",
help="whether load dataset from ogbl")
# output files
parser.add_argument('--save_in_txt', default=False, action="store_true", parser.add_argument('--save_in_txt', default=False, action="store_true",
help='Whether save dat in txt format or npy') help='Whether save dat in txt format or npy')
parser.add_argument('--save_in_pt', default=False, action="store_true", parser.add_argument('--save_in_pt', default=False, action="store_true",
...@@ -237,52 +254,63 @@ if __name__ == '__main__': ...@@ -237,52 +254,63 @@ if __name__ == '__main__':
help='path of the output npy embedding file') help='path of the output npy embedding file')
parser.add_argument('--map_file', type=str, default="nodeid_to_index.pickle", parser.add_argument('--map_file', type=str, default="nodeid_to_index.pickle",
help='path of the mapping dict that maps node ids to embedding index') help='path of the mapping dict that maps node ids to embedding index')
parser.add_argument('--norm', default=False, action="store_true",
help="whether to do normalization over node embedding after training")
# model parameters
parser.add_argument('--dim', default=128, type=int, parser.add_argument('--dim', default=128, type=int,
help="embedding dimensions") help="embedding dimensions")
parser.add_argument('--window_size', default=5, type=int, parser.add_argument('--window_size', default=5, type=int,
help="context window size") help="context window size")
parser.add_argument('--use_context_weight', default=False, action="store_true",
help="whether to add weights over nodes in the context window")
parser.add_argument('--num_walks', default=10, type=int, parser.add_argument('--num_walks', default=10, type=int,
help="number of walks for each node") help="number of walks for each node")
parser.add_argument('--negative', default=5, type=int, parser.add_argument('--negative', default=5, type=int,
help="negative samples for each positve node pair") help="negative samples for each positve node pair")
parser.add_argument('--iterations', default=1, type=int,
help="iterations")
parser.add_argument('--batch_size', default=10, type=int, parser.add_argument('--batch_size', default=10, type=int,
help="number of node sequences in each batch") help="number of node sequences in each batch")
parser.add_argument('--print_interval', default=100, type=int,
help="number of batches between printing")
parser.add_argument('--print_loss', default=False, action="store_true",
help="whether print loss during training")
parser.add_argument('--walk_length', default=80, type=int, parser.add_argument('--walk_length', default=80, type=int,
help="number of nodes in a sequence") help="number of nodes in a sequence")
parser.add_argument('--lr', default=0.2, type=float,
help="learning rate")
parser.add_argument('--neg_weight', default=1., type=float, parser.add_argument('--neg_weight', default=1., type=float,
help="negative weight") help="negative weight")
parser.add_argument('--lap_norm', default=0.01, type=float, parser.add_argument('--lap_norm', default=0.01, type=float,
help="weight of laplacian normalization") help="weight of laplacian normalization")
# training parameters
parser.add_argument('--iterations', default=1, type=int,
help="iterations")
parser.add_argument('--print_interval', default=100, type=int,
help="number of batches between printing")
parser.add_argument('--print_loss', default=False, action="store_true",
help="whether print loss during training")
parser.add_argument('--lr', default=0.2, type=float,
help="learning rate")
# optimization settings
parser.add_argument('--mix', default=False, action="store_true", parser.add_argument('--mix', default=False, action="store_true",
help="mixed training with CPU and GPU") help="mixed training with CPU and GPU")
parser.add_argument('--gpus', type=int, default=[-1], nargs='+',
help='a list of active gpu ids, e.g. 0, used with --mix')
parser.add_argument('--only_cpu', default=False, action="store_true", parser.add_argument('--only_cpu', default=False, action="store_true",
help="training with CPU") help="training with CPU")
parser.add_argument('--only_gpu', default=False, action="store_true", parser.add_argument('--only_gpu', default=False, action="store_true",
help="training with GPU") help="training with GPU")
parser.add_argument('--fast_neg', default=True, action="store_true",
help="do negative sampling inside a batch")
parser.add_argument('--adam', default=False, action="store_true", parser.add_argument('--adam', default=False, action="store_true",
help="use adam for embedding updation, recommended") help="use adam for embedding updation, recommended")
parser.add_argument('--sgd', default=False, action="store_true", parser.add_argument('--sgd', default=False, action="store_true",
help="use sgd for embedding updation") help="use sgd for embedding updation")
parser.add_argument('--avg_sgd', default=False, action="store_true", parser.add_argument('--avg_sgd', default=False, action="store_true",
help="average gradients of sgd for embedding updation") help="average gradients of sgd for embedding updation")
parser.add_argument('--norm', default=False, action="store_true", parser.add_argument('--fast_neg', default=False, action="store_true",
help="whether to do normalization over node embedding after training") help="do negative sampling inside a batch")
parser.add_argument('--use_context_weight', default=False, action="store_true",
help="whether to add weights over nodes in the context window")
parser.add_argument('--num_threads', default=2, type=int, parser.add_argument('--num_threads', default=2, type=int,
help="number of threads used for each CPU-core/GPU") help="number of threads used for each CPU-core/GPU")
parser.add_argument('--gpus', type=int, default=[-1], nargs='+',
help='a list of active gpu ids, e.g. 0') parser.add_argument('--count_params', default=False, action="store_true",
help="count the params, then exit")
args = parser.parse_args() args = parser.parse_args()
start_time = time.time() start_time = time.time()
......
""" load dataset from ogb """ """ load dataset from ogb """
from ogb.linkproppred import PygLinkPropPredDataset
import argparse import argparse
from ogb.linkproppred import DglLinkPropPredDataset
parser = argparse.ArgumentParser() def load_from_ogbl_with_name(name):
parser.add_argument('--name', type=str, choices = ['ogbl-collab', 'ogbl-ddi', 'ogbl-ppa', 'ogbl-citation']
choices=['ogbl-collab', 'ogbl-ddi', 'ogbl-ppa', 'ogbl-citation'], assert name in choices, "name must be selected from " + str(choices)
default='ogbl-collab', dataset = DglLinkPropPredDataset(name)
help="name of datasets by ogb") return dataset[0]
args = parser.parse_args()
name = args.name if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--name', type=str,
choices=['ogbl-collab', 'ogbl-ddi', 'ogbl-ppa', 'ogbl-citation'],
default='ogbl-collab',
help="name of datasets by ogb")
args = parser.parse_args()
dataset = PygLinkPropPredDataset(name=name) name = args.name
data = dataset[0] g = load_from_ogbl_with_name(name=name)
try: try:
weighted = data.edge_weight w = g.edata['edge_weight']
weighted = True weighted = True
except: except:
weighted = False weighted = False
with open(name + "-net.txt", "w") as f: with open(name + "-net.txt", "w") as f:
for i in range(data.edge_index.shape[1]): for i in range(g.edges()[0].shape[0]):
if weighted: if weighted:
f.write(str(data.edge_index[0][i].item()) + " "\ f.write(str(g.edges()[0][i].item()) + " "\
+str(data.edge_index[1][i].item()) + " "\ +str(g.edges()[1][i].item()) + " "\
+str(data.edge_weight[i].item()) + "\n") +str(g.edata['edge_weight'][i]) + "\n")
else: else:
f.write(str(data.edge_index[0][i].item()) + " "\ f.write(str(g.edges()[0][i].item()) + " "\
+str(data.edge_index[1][i].item()) + " "\ +str(g.edges()[1][i].item()) + " "\
+"1\n") +"1\n")
\ No newline at end of file
...@@ -496,13 +496,29 @@ class SkipGramModel(nn.Module): ...@@ -496,13 +496,29 @@ class SkipGramModel(nn.Module):
def save_embedding_pt(self, dataset, file_name): def save_embedding_pt(self, dataset, file_name):
""" For ogb leaderboard. """ For ogb leaderboard.
""" """
max_node_id = max(dataset.node2id.keys()) try:
if max_node_id + 1 != self.emb_size: max_node_id = max(dataset.node2id.keys())
print("WARNING: The node ids are not serial.") if max_node_id + 1 != self.emb_size:
print("WARNING: The node ids are not serial.")
embedding = torch.zeros(max_node_id + 1, self.emb_dimension)
index = torch.LongTensor(list(map(lambda id: dataset.id2node[id], list(range(self.emb_size))))) embedding = torch.zeros(max_node_id + 1, self.emb_dimension)
embedding.index_add_(0, index, self.u_embeddings.weight.cpu().data) index = torch.LongTensor(list(map(lambda id: dataset.id2node[id], list(range(self.emb_size)))))
embedding.index_add_(0, index, self.u_embeddings.weight.cpu().data)
if self.norm:
embedding /= torch.sqrt(torch.sum(embedding.mul(embedding), 1) + 1e-6).unsqueeze(1)
torch.save(embedding, file_name)
except:
self.save_embedding_pt_dgl_graph(dataset, file_name)
def save_embedding_pt_dgl_graph(self, dataset, file_name):
""" For ogb leaderboard.
"""
embedding = torch.zeros_like(self.u_embeddings.weight.cpu().data)
valid_seeds = torch.LongTensor(dataset.valid_seeds)
valid_embedding = self.u_embeddings.weight.cpu().data.index_select(0,
valid_seeds)
embedding.index_add_(0, valid_seeds, self.u_embeddings.weight.cpu().data)
if self.norm: if self.norm:
embedding /= torch.sqrt(torch.sum(embedding.mul(embedding), 1) + 1e-6).unsqueeze(1) embedding /= torch.sqrt(torch.sum(embedding.mul(embedding), 1) + 1e-6).unsqueeze(1)
......
...@@ -8,8 +8,8 @@ from dgl.data.utils import download, _get_dgl_url, get_download_dir, extract_arc ...@@ -8,8 +8,8 @@ from dgl.data.utils import download, _get_dgl_url, get_download_dir, extract_arc
import random import random
import time import time
import dgl import dgl
from utils import shuffle_walks from utils import shuffle_walks
#np.random.seed(3141592653)
def ReadTxtNet(file_path="", undirected=True): def ReadTxtNet(file_path="", undirected=True):
""" Read the txt network file. """ Read the txt network file.
...@@ -110,17 +110,31 @@ def net2graph(net_sm): ...@@ -110,17 +110,31 @@ def net2graph(net_sm):
print("Building DGLGraph in %.2fs" % t) print("Building DGLGraph in %.2fs" % t)
return G return G
def make_undirected(G):
G.readonly(False)
G.add_edges(G.edges()[1], G.edges()[0])
return G
def find_connected_nodes(G):
nodes = []
for n in G.nodes():
if G.out_degree(n) > 0:
nodes.append(n.item())
return nodes
class DeepwalkDataset: class DeepwalkDataset:
def __init__(self, def __init__(self,
net_file, net_file,
map_file, map_file,
walk_length=80, walk_length,
window_size=5, window_size,
num_walks=10, num_walks,
batch_size=32, batch_size,
negative=5, negative=5,
gpus=[0], gpus=[0],
fast_neg=True, fast_neg=True,
ogbl_name="",
load_from_ogbl=False,
): ):
""" This class has the following functions: """ This class has the following functions:
1. Transform the txt network file into DGL graph; 1. Transform the txt network file into DGL graph;
...@@ -145,26 +159,42 @@ class DeepwalkDataset: ...@@ -145,26 +159,42 @@ class DeepwalkDataset:
self.negative = negative self.negative = negative
self.num_procs = len(gpus) self.num_procs = len(gpus)
self.fast_neg = fast_neg self.fast_neg = fast_neg
self.net, self.node2id, self.id2node, self.sm = ReadTxtNet(net_file)
self.save_mapping(map_file) if load_from_ogbl:
self.G = net2graph(self.sm) assert len(gpus) == 1, "ogb.linkproppred is not compatible with multi-gpu training."
from load_dataset import load_from_ogbl_with_name
self.G = load_from_ogbl_with_name(ogbl_name)
self.G = make_undirected(self.G)
else:
self.net, self.node2id, self.id2node, self.sm = ReadTxtNet(net_file)
self.save_mapping(map_file)
self.G = net2graph(self.sm)
self.num_nodes = self.G.number_of_nodes()
# random walk seeds # random walk seeds
start = time.time() start = time.time()
seeds = torch.cat([torch.LongTensor(self.G.nodes())] * num_walks) self.valid_seeds = find_connected_nodes(self.G)
self.seeds = torch.split(shuffle_walks(seeds), int(np.ceil(len(self.net) * self.num_walks / self.num_procs)), 0) if len(self.valid_seeds) != self.num_nodes:
print("WARNING: The node ids are not serial. Some nodes are invalid.")
seeds = torch.cat([torch.LongTensor(self.valid_seeds)] * num_walks)
self.seeds = torch.split(shuffle_walks(seeds),
int(np.ceil(len(self.valid_seeds) * self.num_walks / self.num_procs)),
0)
end = time.time() end = time.time()
t = end - start t = end - start
print("%d seeds in %.2fs" % (len(seeds), t)) print("%d seeds in %.2fs" % (len(seeds), t))
# negative table for true negative sampling # negative table for true negative sampling
if not fast_neg: if not fast_neg:
node_degree = np.array(list(map(lambda x: len(self.net[x]), self.net.keys()))) node_degree = np.array(list(map(lambda x: self.G.out_degree(x), self.valid_seeds)))
node_degree = np.power(node_degree, 0.75) node_degree = np.power(node_degree, 0.75)
node_degree /= np.sum(node_degree) node_degree /= np.sum(node_degree)
node_degree = np.array(node_degree * 1e8, dtype=np.int) node_degree = np.array(node_degree * 1e8, dtype=np.int)
self.neg_table = [] self.neg_table = []
for idx, node in enumerate(self.net.keys()):
for idx, node in enumerate(self.valid_seeds):
self.neg_table += [node] * node_degree[idx] self.neg_table += [node] * node_degree[idx]
self.neg_table_size = len(self.neg_table) self.neg_table_size = len(self.neg_table)
self.neg_table = np.array(self.neg_table, dtype=np.long) self.neg_table = np.array(self.neg_table, dtype=np.long)
......
...@@ -34,4 +34,32 @@ def thread_wrapped_func(func): ...@@ -34,4 +34,32 @@ def thread_wrapped_func(func):
def shuffle_walks(walks): def shuffle_walks(walks):
seeds = torch.randperm(walks.size()[0]) seeds = torch.randperm(walks.size()[0])
return walks[seeds] return walks[seeds]
\ No newline at end of file
def sum_up_params(model):
""" Count the model parameters """
n = []
n.append(model.u_embeddings.weight.cpu().data.numel() * 2)
n.append(model.lookup_table.cpu().numel())
n.append(model.index_emb_posu.cpu().numel() * 2)
n.append(model.grad_u.cpu().numel() * 2)
try:
n.append(model.index_emb_negu.cpu().numel() * 2)
except:
pass
try:
n.append(model.state_sum_u.cpu().numel() * 2)
except:
pass
try:
n.append(model.grad_avg.cpu().numel())
except:
pass
try:
n.append(model.context_weight.cpu().numel())
except:
pass
print("#params " + str(sum(n)))
exit()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment