Unverified Commit 0fb615b9 authored by Hao Xiong's avatar Hao Xiong Committed by GitHub
Browse files

[Example] Several updates on deepwalk (#1845)



* fix bug

* fix bugs

* add args

* add args

* fix bug

* fix a typo

* results

* results

* results
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent 33abd275
...@@ -30,7 +30,7 @@ For other datasets please pass the full path to the trainer through --data\_file ...@@ -30,7 +30,7 @@ For other datasets please pass the full path to the trainer through --data\_file
## How to run the code ## How to run the code
To run the code: To run the code:
``` ```
python3 deepwalk.py --data_file youtube --output_emb_file emb.txt --adam --mix --lr 0.2 --gpus 0 1 2 3 --batch_size 100 --negative 5 python3 deepwalk.py --data_file youtube --output_emb_file emb.txt --mix --lr 0.2 --gpus 0 1 2 3 --batch_size 100 --negative 5
``` ```
## How to save the embedding ## How to save the embedding
......
...@@ -10,7 +10,7 @@ import numpy as np ...@@ -10,7 +10,7 @@ import numpy as np
from reading_data import DeepwalkDataset from reading_data import DeepwalkDataset
from model import SkipGramModel from model import SkipGramModel
from utils import thread_wrapped_func, shuffle_walks from utils import thread_wrapped_func, shuffle_walks, sum_up_params
class DeepwalkTrainer: class DeepwalkTrainer:
def __init__(self, args): def __init__(self, args):
...@@ -26,8 +26,10 @@ class DeepwalkTrainer: ...@@ -26,8 +26,10 @@ class DeepwalkTrainer:
negative=args.negative, negative=args.negative,
gpus=args.gpus, gpus=args.gpus,
fast_neg=args.fast_neg, fast_neg=args.fast_neg,
ogbl_name=args.ogbl_name,
load_from_ogbl=args.load_from_ogbl,
) )
self.emb_size = len(self.dataset.net) self.emb_size = self.dataset.G.number_of_nodes()
self.emb_model = None self.emb_model = None
def init_device_emb(self): def init_device_emb(self):
...@@ -36,8 +38,6 @@ class DeepwalkTrainer: ...@@ -36,8 +38,6 @@ class DeepwalkTrainer:
""" """
choices = sum([self.args.only_gpu, self.args.only_cpu, self.args.mix]) choices = sum([self.args.only_gpu, self.args.only_cpu, self.args.mix])
assert choices == 1, "Must choose only *one* training mode in [only_cpu, only_gpu, mix]" assert choices == 1, "Must choose only *one* training mode in [only_cpu, only_gpu, mix]"
choices = sum([self.args.sgd, self.args.adam, self.args.avg_sgd])
assert choices == 1, "Must choose only *one* gradient descent strategy in [sgd, avg_sgd, adam]"
# initializing embedding on CPU # initializing embedding on CPU
self.emb_model = SkipGramModel( self.emb_model = SkipGramModel(
...@@ -53,10 +53,12 @@ class DeepwalkTrainer: ...@@ -53,10 +53,12 @@ class DeepwalkTrainer:
negative=self.args.negative, negative=self.args.negative,
lr=self.args.lr, lr=self.args.lr,
lap_norm=self.args.lap_norm, lap_norm=self.args.lap_norm,
adam=self.args.adam,
sgd=self.args.sgd,
avg_sgd=self.args.avg_sgd,
fast_neg=self.args.fast_neg, fast_neg=self.args.fast_neg,
record_loss=self.args.print_loss,
norm=self.args.norm,
use_context_weight=self.args.use_context_weight,
async_update=self.args.async_update,
num_threads=self.args.num_threads,
) )
torch.set_num_threads(self.args.num_threads) torch.set_num_threads(self.args.num_threads)
...@@ -67,7 +69,7 @@ class DeepwalkTrainer: ...@@ -67,7 +69,7 @@ class DeepwalkTrainer:
elif self.args.mix: elif self.args.mix:
print("Mix CPU with %d GPU" % len(self.args.gpus)) print("Mix CPU with %d GPU" % len(self.args.gpus))
if len(self.args.gpus) == 1: if len(self.args.gpus) == 1:
assert self.args.gpus[0] >= 0, 'mix CPU with GPU should have abaliable GPU' assert self.args.gpus[0] >= 0, 'mix CPU with GPU should have available GPU'
self.emb_model.set_device(self.args.gpus[0]) self.emb_model.set_device(self.args.gpus[0])
else: else:
print("Run in CPU process") print("Run in CPU process")
...@@ -86,11 +88,14 @@ class DeepwalkTrainer: ...@@ -86,11 +88,14 @@ class DeepwalkTrainer:
self.init_device_emb() self.init_device_emb()
self.emb_model.share_memory() self.emb_model.share_memory()
if self.args.count_params:
sum_up_params(self.emb_model)
start_all = time.time() start_all = time.time()
ps = [] ps = []
for i in range(len(self.args.gpus)): for i in range(len(self.args.gpus)):
p = mp.Process(target=self.fast_train_sp, args=(self.args.gpus[i],)) p = mp.Process(target=self.fast_train_sp, args=(i, self.args.gpus[i]))
ps.append(p) ps.append(p)
p.start() p.start()
...@@ -100,17 +105,22 @@ class DeepwalkTrainer: ...@@ -100,17 +105,22 @@ class DeepwalkTrainer:
print("Used time: %.2fs" % (time.time()-start_all)) print("Used time: %.2fs" % (time.time()-start_all))
if self.args.save_in_txt: if self.args.save_in_txt:
self.emb_model.save_embedding_txt(self.dataset, self.args.output_emb_file) self.emb_model.save_embedding_txt(self.dataset, self.args.output_emb_file)
elif self.args.save_in_pt:
self.emb_model.save_embedding_pt(self.dataset, self.args.output_emb_file)
else: else:
self.emb_model.save_embedding(self.dataset, self.args.output_emb_file) self.emb_model.save_embedding(self.dataset, self.args.output_emb_file)
@thread_wrapped_func @thread_wrapped_func
def fast_train_sp(self, gpu_id): def fast_train_sp(self, rank, gpu_id):
""" a subprocess for fast_train_mp """ """ a subprocess for fast_train_mp """
if self.args.mix: if self.args.mix:
self.emb_model.set_device(gpu_id) self.emb_model.set_device(gpu_id)
torch.set_num_threads(self.args.num_threads) torch.set_num_threads(self.args.num_threads)
if self.args.async_update:
self.emb_model.create_async_update()
sampler = self.dataset.create_sampler(gpu_id) sampler = self.dataset.create_sampler(rank)
dataloader = DataLoader( dataloader = DataLoader(
dataset=sampler.seeds, dataset=sampler.seeds,
...@@ -118,26 +128,19 @@ class DeepwalkTrainer: ...@@ -118,26 +128,19 @@ class DeepwalkTrainer:
collate_fn=sampler.sample, collate_fn=sampler.sample,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
num_workers=4, num_workers=self.args.num_sampler_threads,
) )
num_batches = len(dataloader) num_batches = len(dataloader)
print("num batchs: %d in subprocess [%d]" % (num_batches, gpu_id)) print("num batchs: %d in process [%d] GPU [%d]" % (num_batches, rank, gpu_id))
# number of positive node pairs in a sequence # number of positive node pairs in a sequence
num_pos = int(2 * self.args.walk_length * self.args.window_size\ num_pos = int(2 * self.args.walk_length * self.args.window_size\
- self.args.window_size * (self.args.window_size + 1)) - self.args.window_size * (self.args.window_size + 1))
start = time.time() start = time.time()
with torch.no_grad(): with torch.no_grad():
max_i = self.args.iterations * num_batches
for i, walks in enumerate(dataloader): for i, walks in enumerate(dataloader):
# decay learning rate for SGD
lr = self.args.lr * (max_i - i) / max_i
if lr < 0.00001:
lr = 0.00001
if self.args.fast_neg: if self.args.fast_neg:
self.emb_model.fast_learn(walks, lr) self.emb_model.fast_learn(walks)
else: else:
# do negative sampling # do negative sampling
bs = len(walks) bs = len(walks)
...@@ -145,14 +148,22 @@ class DeepwalkTrainer: ...@@ -145,14 +148,22 @@ class DeepwalkTrainer:
np.random.choice(self.dataset.neg_table, np.random.choice(self.dataset.neg_table,
bs * num_pos * self.args.negative, bs * num_pos * self.args.negative,
replace=True)) replace=True))
self.emb_model.fast_learn(walks, lr, neg_nodes=neg_nodes) self.emb_model.fast_learn(walks, neg_nodes=neg_nodes)
if i > 0 and i % self.args.print_interval == 0: if i > 0 and i % self.args.print_interval == 0:
print("Solver [%d] batch %d tt: %.2fs" % (gpu_id, i, time.time()-start)) if self.args.print_loss:
print("GPU-[%d] batch %d time: %.2fs loss: %.4f" \
% (gpu_id, i, time.time()-start, -sum(self.emb_model.loss)/self.args.print_interval))
self.emb_model.loss = []
else:
print("GPU-[%d] batch %d time: %.2fs" % (gpu_id, i, time.time()-start))
start = time.time() start = time.time()
if self.args.async_update:
self.emb_model.finish_async_update()
def fast_train(self): def fast_train(self):
""" fast train with dataloader """ """ fast train with dataloader with only gpu / only cpu"""
# the number of postive node pairs of a node sequence # the number of postive node pairs of a node sequence
num_pos = 2 * self.args.walk_length * self.args.window_size\ num_pos = 2 * self.args.walk_length * self.args.window_size\
- self.args.window_size * (self.args.window_size + 1) - self.args.window_size * (self.args.window_size + 1)
...@@ -160,6 +171,13 @@ class DeepwalkTrainer: ...@@ -160,6 +171,13 @@ class DeepwalkTrainer:
self.init_device_emb() self.init_device_emb()
if self.args.async_update:
self.emb_model.share_memory()
self.emb_model.create_async_update()
if self.args.count_params:
sum_up_params(self.emb_model)
sampler = self.dataset.create_sampler(0) sampler = self.dataset.create_sampler(0)
dataloader = DataLoader( dataloader = DataLoader(
...@@ -168,98 +186,127 @@ class DeepwalkTrainer: ...@@ -168,98 +186,127 @@ class DeepwalkTrainer:
collate_fn=sampler.sample, collate_fn=sampler.sample,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
num_workers=4, num_workers=self.args.num_sampler_threads,
) )
num_batches = len(dataloader) num_batches = len(dataloader)
print("num batchs: %d" % num_batches) print("num batchs: %d\n" % num_batches)
start_all = time.time() start_all = time.time()
start = time.time() start = time.time()
with torch.no_grad(): with torch.no_grad():
max_i = self.args.iterations * num_batches max_i = num_batches
for iteration in range(self.args.iterations): for i, walks in enumerate(dataloader):
print("\nIteration: " + str(iteration + 1)) if self.args.fast_neg:
self.emb_model.fast_learn(walks)
for i, walks in enumerate(dataloader): else:
# decay learning rate for SGD # do negative sampling
lr = self.args.lr * (max_i - i) / max_i bs = len(walks)
if lr < 0.00001: neg_nodes = torch.LongTensor(
lr = 0.00001 np.random.choice(self.dataset.neg_table,
bs * num_pos * self.args.negative,
replace=True))
self.emb_model.fast_learn(walks, neg_nodes=neg_nodes)
if self.args.fast_neg: if i > 0 and i % self.args.print_interval == 0:
self.emb_model.fast_learn(walks, lr) if self.args.print_loss:
print("Batch %d training time: %.2fs loss: %.4f" \
% (i, time.time()-start, -sum(self.emb_model.loss)/self.args.print_interval))
self.emb_model.loss = []
else: else:
# do negative sampling
bs = len(walks)
neg_nodes = torch.LongTensor(
np.random.choice(self.dataset.neg_table,
bs * num_pos * self.args.negative,
replace=True))
self.emb_model.fast_learn(walks, lr, neg_nodes=neg_nodes)
if i > 0 and i % self.args.print_interval == 0:
print("Batch %d, training time: %.2fs" % (i, time.time()-start)) print("Batch %d, training time: %.2fs" % (i, time.time()-start))
start = time.time() start = time.time()
if self.args.async_update:
self.emb_model.finish_async_update()
print("Training used time: %.2fs" % (time.time()-start_all)) print("Training used time: %.2fs" % (time.time()-start_all))
if self.args.save_in_txt: if self.args.save_in_txt:
self.emb_model.save_embedding_txt(self.dataset, self.args.output_emb_file) self.emb_model.save_embedding_txt(self.dataset, self.args.output_emb_file)
elif self.args.save_in_pt:
self.emb_model.save_embedding_pt(self.dataset, self.args.output_emb_file)
else: else:
self.emb_model.save_embedding(self.dataset, self.args.output_emb_file) self.emb_model.save_embedding(self.dataset, self.args.output_emb_file)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description="DeepWalk") parser = argparse.ArgumentParser(description="DeepWalk")
# input files
## personal datasets
parser.add_argument('--data_file', type=str, parser.add_argument('--data_file', type=str,
help="path of the txt network file, builtin dataset include youtube-net and blog-net") help="path of the txt network file, builtin dataset include youtube-net and blog-net")
## ogbl datasets
parser.add_argument('--ogbl_name', type=str,
help="name of ogbl dataset, e.g. ogbl-ddi")
parser.add_argument('--load_from_ogbl', default=False, action="store_true",
help="whether load dataset from ogbl")
# output files
parser.add_argument('--save_in_txt', default=False, action="store_true", parser.add_argument('--save_in_txt', default=False, action="store_true",
help='Whether save dat in txt format or npy') help='Whether save dat in txt format or npy')
parser.add_argument('--save_in_pt', default=False, action="store_true",
help='Whether save dat in pt format or npy')
parser.add_argument('--output_emb_file', type=str, default="emb.npy", parser.add_argument('--output_emb_file', type=str, default="emb.npy",
help='path of the output npy embedding file') help='path of the output npy embedding file')
parser.add_argument('--map_file', type=str, default="nodeid_to_index.pickle", parser.add_argument('--map_file', type=str, default="nodeid_to_index.pickle",
help='path of the mapping dict that maps node ids to embedding index') help='path of the mapping dict that maps node ids to embedding index')
parser.add_argument('--norm', default=False, action="store_true",
help="whether to do normalization over node embedding after training")
# model parameters
parser.add_argument('--dim', default=128, type=int, parser.add_argument('--dim', default=128, type=int,
help="embedding dimensions") help="embedding dimensions")
parser.add_argument('--window_size', default=5, type=int, parser.add_argument('--window_size', default=5, type=int,
help="context window size") help="context window size")
parser.add_argument('--use_context_weight', default=False, action="store_true",
help="whether to add weights over nodes in the context window")
parser.add_argument('--num_walks', default=10, type=int, parser.add_argument('--num_walks', default=10, type=int,
help="number of walks for each node") help="number of walks for each node")
parser.add_argument('--negative', default=5, type=int, parser.add_argument('--negative', default=1, type=int,
help="negative samples for each positve node pair") help="negative samples for each positve node pair")
parser.add_argument('--iterations', default=1, type=int, parser.add_argument('--batch_size', default=128, type=int,
help="iterations")
parser.add_argument('--batch_size', default=10, type=int,
help="number of node sequences in each batch") help="number of node sequences in each batch")
parser.add_argument('--print_interval', default=1000, type=int,
help="number of batches between printing")
parser.add_argument('--walk_length', default=80, type=int, parser.add_argument('--walk_length', default=80, type=int,
help="number of nodes in a sequence") help="number of nodes in a sequence")
parser.add_argument('--lr', default=0.2, type=float,
help="learning rate")
parser.add_argument('--neg_weight', default=1., type=float, parser.add_argument('--neg_weight', default=1., type=float,
help="negative weight") help="negative weight")
parser.add_argument('--lap_norm', default=0.01, type=float, parser.add_argument('--lap_norm', default=0.01, type=float,
help="weight of laplacian normalization") help="weight of laplacian normalization, recommend to set as 0.1 / windoe_size")
# training parameters
parser.add_argument('--print_interval', default=100, type=int,
help="number of batches between printing")
parser.add_argument('--print_loss', default=False, action="store_true",
help="whether print loss during training")
parser.add_argument('--lr', default=0.2, type=float,
help="learning rate")
# optimization settings
parser.add_argument('--mix', default=False, action="store_true", parser.add_argument('--mix', default=False, action="store_true",
help="mixed training with CPU and GPU") help="mixed training with CPU and GPU")
parser.add_argument('--gpus', type=int, default=[-1], nargs='+',
help='a list of active gpu ids, e.g. 0, used with --mix')
parser.add_argument('--only_cpu', default=False, action="store_true", parser.add_argument('--only_cpu', default=False, action="store_true",
help="training with CPU") help="training with CPU")
parser.add_argument('--only_gpu', default=False, action="store_true", parser.add_argument('--only_gpu', default=False, action="store_true",
help="training with GPU") help="training with GPU")
parser.add_argument('--fast_neg', default=True, action="store_true", parser.add_argument('--async_update', default=False, action="store_true",
help="mixed training asynchronously, not recommended")
parser.add_argument('--fast_neg', default=False, action="store_true",
help="do negative sampling inside a batch") help="do negative sampling inside a batch")
parser.add_argument('--adam', default=False, action="store_true", parser.add_argument('--num_threads', default=8, type=int,
help="use adam for embedding updation, recommended")
parser.add_argument('--sgd', default=False, action="store_true",
help="use sgd for embedding updation")
parser.add_argument('--avg_sgd', default=False, action="store_true",
help="average gradients of sgd for embedding updation")
parser.add_argument('--num_threads', default=2, type=int,
help="number of threads used for each CPU-core/GPU") help="number of threads used for each CPU-core/GPU")
parser.add_argument('--gpus', type=int, default=[-1], nargs='+', parser.add_argument('--num_sampler_threads', default=2, type=int,
help='a list of active gpu ids, e.g. 0') help="number of threads used for sampling")
parser.add_argument('--count_params', default=False, action="store_true",
help="count the params, exit once counting over")
args = parser.parse_args() args = parser.parse_args()
if args.async_update:
assert args.mix, "--async_update only with --mix"
start_time = time.time() start_time = time.time()
trainer = DeepwalkTrainer(args) trainer = DeepwalkTrainer(args)
trainer.train() trainer.train()
......
""" load dataset from ogb """
import argparse
from ogb.linkproppred import DglLinkPropPredDataset
def load_from_ogbl_with_name(name):
choices = ['ogbl-collab', 'ogbl-ddi', 'ogbl-ppa', 'ogbl-citation']
assert name in choices, "name must be selected from " + str(choices)
dataset = DglLinkPropPredDataset(name)
return dataset[0]
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--name', type=str,
choices=['ogbl-collab', 'ogbl-ddi', 'ogbl-ppa', 'ogbl-citation'],
default='ogbl-collab',
help="name of datasets by ogb")
args = parser.parse_args()
name = args.name
g = load_from_ogbl_with_name(name=name)
try:
w = g.edata['edge_weight']
weighted = True
except:
weighted = False
with open(name + "-net.txt", "w") as f:
for i in range(g.edges()[0].shape[0]):
if weighted:
f.write(str(g.edges()[0][i].item()) + " "\
+str(g.edges()[1][i].item()) + " "\
+str(g.edata['edge_weight'][i]) + "\n")
else:
f.write(str(g.edges()[0][i].item()) + " "\
+str(g.edges()[1][i].item()) + " "\
+"1\n")
\ No newline at end of file
...@@ -4,6 +4,10 @@ import torch.nn.functional as F ...@@ -4,6 +4,10 @@ import torch.nn.functional as F
from torch.nn import init from torch.nn import init
import random import random
import numpy as np import numpy as np
import torch.multiprocessing as mp
from torch.multiprocessing import Queue
from utils import thread_wrapped_func
def init_emb2pos_index(walk_length, window_size, batch_size): def init_emb2pos_index(walk_length, window_size, batch_size):
''' select embedding of positive nodes from a batch of node embeddings ''' select embedding of positive nodes from a batch of node embeddings
...@@ -72,25 +76,20 @@ def init_emb2neg_index(walk_length, window_size, negative, batch_size): ...@@ -72,25 +76,20 @@ def init_emb2neg_index(walk_length, window_size, negative, batch_size):
return index_emb_negu, index_emb_negv return index_emb_negu, index_emb_negv
def init_grad_avg(walk_length, window_size, batch_size): def init_weight(walk_length, window_size, batch_size):
'''select nodes' gradients from gradient matrix ''' init context weight '''
weight = []
Usage
-----
'''
grad_avg = []
for b in range(batch_size): for b in range(batch_size):
for i in range(walk_length): for i in range(walk_length):
if i < window_size: for j in range(i-window_size, i):
grad_avg.append(1. / float(i+window_size)) if j >= 0:
elif i >= walk_length - window_size: weight.append(1. - float(i - j - 1)/float(window_size))
grad_avg.append(1. / float(walk_length - i - 1 + window_size)) for j in range(i + 1, i + 1 + window_size):
else: if j < walk_length:
grad_avg.append(0.5 / window_size) weight.append(1. - float(j - i - 1)/float(window_size))
# [num_pos * batch_size] # [num_pos * batch_size]
return torch.Tensor(grad_avg).unsqueeze(1) return torch.Tensor(weight).unsqueeze(1)
def init_empty_grad(emb_dimension, walk_length, batch_size): def init_empty_grad(emb_dimension, walk_length, batch_size):
""" initialize gradient matrix """ """ initialize gradient matrix """
...@@ -111,6 +110,20 @@ def adam(grad, state_sum, nodes, lr, device, only_gpu): ...@@ -111,6 +110,20 @@ def adam(grad, state_sum, nodes, lr, device, only_gpu):
return grad return grad
@thread_wrapped_func
def async_update(num_threads, model, queue):
""" asynchronous embedding update """
torch.set_num_threads(num_threads)
while True:
(grad_u, grad_v, grad_v_neg, nodes, neg_nodes) = queue.get()
if grad_u is None:
return
with torch.no_grad():
model.u_embeddings.weight.data.index_add_(0, nodes.view(-1), grad_u)
model.v_embeddings.weight.data.index_add_(0, nodes.view(-1), grad_v)
if neg_nodes is not None:
model.v_embeddings.weight.data.index_add_(0, neg_nodes.view(-1), grad_v_neg)
class SkipGramModel(nn.Module): class SkipGramModel(nn.Module):
""" Negative sampling based skip-gram """ """ Negative sampling based skip-gram """
def __init__(self, def __init__(self,
...@@ -126,10 +139,12 @@ class SkipGramModel(nn.Module): ...@@ -126,10 +139,12 @@ class SkipGramModel(nn.Module):
negative, negative,
lr, lr,
lap_norm, lap_norm,
adam,
sgd,
avg_sgd,
fast_neg, fast_neg,
record_loss,
norm,
use_context_weight,
async_update,
num_threads,
): ):
""" initialize embedding on CPU """ initialize embedding on CPU
...@@ -147,10 +162,11 @@ class SkipGramModel(nn.Module): ...@@ -147,10 +162,11 @@ class SkipGramModel(nn.Module):
neg_weight float : negative weight neg_weight float : negative weight
lr float : initial learning rate lr float : initial learning rate
lap_norm float : weight of laplacian normalization lap_norm float : weight of laplacian normalization
adam bool : use adam for embedding updation
sgd bool : use sgd for embedding updation
avg_sgd bool : average gradients of sgd for embedding updation
fast_neg bool : do negative sampling inside a batch fast_neg bool : do negative sampling inside a batch
record_loss bool : print the loss during training
norm bool : do normalizatin on the embedding after training
use_context_weight : give different weights to the nodes in a context window
async_update : asynchronous training
""" """
super(SkipGramModel, self).__init__() super(SkipGramModel, self).__init__()
self.emb_size = emb_size self.emb_size = emb_size
...@@ -165,10 +181,12 @@ class SkipGramModel(nn.Module): ...@@ -165,10 +181,12 @@ class SkipGramModel(nn.Module):
self.negative = negative self.negative = negative
self.lr = lr self.lr = lr
self.lap_norm = lap_norm self.lap_norm = lap_norm
self.adam = adam
self.sgd = sgd
self.avg_sgd = avg_sgd
self.fast_neg = fast_neg self.fast_neg = fast_neg
self.record_loss = record_loss
self.norm = norm
self.use_context_weight = use_context_weight
self.async_update = async_update
self.num_threads = num_threads
# initialize the device as cpu # initialize the device as cpu
self.device = torch.device("cpu") self.device = torch.device("cpu")
...@@ -188,29 +206,30 @@ class SkipGramModel(nn.Module): ...@@ -188,29 +206,30 @@ class SkipGramModel(nn.Module):
self.lookup_table = torch.sigmoid(torch.arange(-6.01, 6.01, 0.01)) self.lookup_table = torch.sigmoid(torch.arange(-6.01, 6.01, 0.01))
self.lookup_table[0] = 0. self.lookup_table[0] = 0.
self.lookup_table[-1] = 1. self.lookup_table[-1] = 1.
if self.record_loss:
self.logsigmoid_table = torch.log(torch.sigmoid(torch.arange(-6.01, 6.01, 0.01)))
self.loss = []
# indexes to select positive/negative node pairs from batch_walks # indexes to select positive/negative node pairs from batch_walks
self.index_emb_posu, self.index_emb_posv = init_emb2pos_index( self.index_emb_posu, self.index_emb_posv = init_emb2pos_index(
self.walk_length, self.walk_length,
self.window_size, self.window_size,
self.batch_size) self.batch_size)
if self.fast_neg: self.index_emb_negu, self.index_emb_negv = init_emb2neg_index(
self.index_emb_negu, self.index_emb_negv = init_emb2neg_index( self.walk_length,
self.walk_length, self.window_size,
self.window_size, self.negative,
self.negative, self.batch_size)
self.batch_size)
# coefficients for averaging the gradients if self.use_context_weight:
if self.avg_sgd: self.context_weight = init_weight(
self.grad_avg = init_grad_avg(
self.walk_length, self.walk_length,
self.window_size, self.window_size,
self.batch_size) self.batch_size)
# adam # adam
if self.adam: self.state_sum_u = torch.zeros(self.emb_size)
self.state_sum_u = torch.zeros(self.emb_size) self.state_sum_v = torch.zeros(self.emb_size)
self.state_sum_v = torch.zeros(self.emb_size)
# gradients of nodes in batch_walks # gradients of nodes in batch_walks
self.grad_u, self.grad_v = init_empty_grad( self.grad_u, self.grad_v = init_empty_grad(
...@@ -218,28 +237,41 @@ class SkipGramModel(nn.Module): ...@@ -218,28 +237,41 @@ class SkipGramModel(nn.Module):
self.walk_length, self.walk_length,
self.batch_size) self.batch_size)
def create_async_update(self):
""" Set up the async update subprocess.
"""
self.async_q = Queue(1)
self.async_p = mp.Process(target=async_update, args=(self.num_threads, self, self.async_q))
self.async_p.start()
def finish_async_update(self):
""" Notify the async update subprocess to quit.
"""
self.async_q.put((None, None, None, None, None))
self.async_p.join()
def share_memory(self): def share_memory(self):
""" share the parameters across subprocesses """ """ share the parameters across subprocesses """
self.u_embeddings.weight.share_memory_() self.u_embeddings.weight.share_memory_()
self.v_embeddings.weight.share_memory_() self.v_embeddings.weight.share_memory_()
if self.adam: self.state_sum_u.share_memory_()
self.state_sum_u.share_memory_() self.state_sum_v.share_memory_()
self.state_sum_v.share_memory_()
def set_device(self, gpu_id): def set_device(self, gpu_id):
""" set gpu device """ """ set gpu device """
self.device = torch.device("cuda:%d" % gpu_id) self.device = torch.device("cuda:%d" % gpu_id)
print("The device is", self.device) print("The device is", self.device)
self.lookup_table = self.lookup_table.to(self.device) self.lookup_table = self.lookup_table.to(self.device)
if self.record_loss:
self.logsigmoid_table = self.logsigmoid_table.to(self.device)
self.index_emb_posu = self.index_emb_posu.to(self.device) self.index_emb_posu = self.index_emb_posu.to(self.device)
self.index_emb_posv = self.index_emb_posv.to(self.device) self.index_emb_posv = self.index_emb_posv.to(self.device)
if self.fast_neg: self.index_emb_negu = self.index_emb_negu.to(self.device)
self.index_emb_negu = self.index_emb_negu.to(self.device) self.index_emb_negv = self.index_emb_negv.to(self.device)
self.index_emb_negv = self.index_emb_negv.to(self.device)
self.grad_u = self.grad_u.to(self.device) self.grad_u = self.grad_u.to(self.device)
self.grad_v = self.grad_v.to(self.device) self.grad_v = self.grad_v.to(self.device)
if self.avg_sgd: if self.use_context_weight:
self.grad_avg = self.grad_avg.to(self.device) self.context_weight = self.context_weight.to(self.device)
def all_to_device(self, gpu_id): def all_to_device(self, gpu_id):
""" move all of the parameters to a single GPU """ """ move all of the parameters to a single GPU """
...@@ -247,16 +279,20 @@ class SkipGramModel(nn.Module): ...@@ -247,16 +279,20 @@ class SkipGramModel(nn.Module):
self.set_device(gpu_id) self.set_device(gpu_id)
self.u_embeddings = self.u_embeddings.cuda(gpu_id) self.u_embeddings = self.u_embeddings.cuda(gpu_id)
self.v_embeddings = self.v_embeddings.cuda(gpu_id) self.v_embeddings = self.v_embeddings.cuda(gpu_id)
if self.adam: self.state_sum_u = self.state_sum_u.to(self.device)
self.state_sum_u = self.state_sum_u.to(self.device) self.state_sum_v = self.state_sum_v.to(self.device)
self.state_sum_v = self.state_sum_v.to(self.device)
def fast_sigmoid(self, score): def fast_sigmoid(self, score):
""" do fast sigmoid by looking up in a pre-defined table """ """ do fast sigmoid by looking up in a pre-defined table """
idx = torch.floor((score + 6.01) / 0.01).long() idx = torch.floor((score + 6.01) / 0.01).long()
return self.lookup_table[idx] return self.lookup_table[idx]
def fast_learn(self, batch_walks, lr, neg_nodes=None): def fast_logsigmoid(self, score):
""" do fast logsigmoid by looking up in a pre-defined table """
idx = torch.floor((score + 6.01) / 0.01).long()
return self.logsigmoid_table[idx]
def fast_learn(self, batch_walks, neg_nodes=None):
""" Learn a batch of random walks in a fast way. It has the following features: """ Learn a batch of random walks in a fast way. It has the following features:
1. It calculating the gradients directly without the forward operation. 1. It calculating the gradients directly without the forward operation.
2. It does sigmoid by a looking up table. 2. It does sigmoid by a looking up table.
...@@ -281,8 +317,7 @@ class SkipGramModel(nn.Module): ...@@ -281,8 +317,7 @@ class SkipGramModel(nn.Module):
lr = 0.01 lr = 0.01
neg_nodes = None neg_nodes = None
""" """
if self.adam: lr = self.lr
lr = self.lr
# [batch_size, walk_length] # [batch_size, walk_length]
if isinstance(batch_walks, list): if isinstance(batch_walks, list):
...@@ -318,6 +353,8 @@ class SkipGramModel(nn.Module): ...@@ -318,6 +353,8 @@ class SkipGramModel(nn.Module):
pos_score = torch.clamp(pos_score, max=6, min=-6) pos_score = torch.clamp(pos_score, max=6, min=-6)
# [batch_size * num_pos, 1] # [batch_size * num_pos, 1]
score = (1 - self.fast_sigmoid(pos_score)).unsqueeze(1) score = (1 - self.fast_sigmoid(pos_score)).unsqueeze(1)
if self.record_loss:
self.loss.append(torch.mean(self.fast_logsigmoid(pos_score)).item())
# [batch_size * num_pos, dim] # [batch_size * num_pos, dim]
if self.lap_norm > 0: if self.lap_norm > 0:
...@@ -326,6 +363,18 @@ class SkipGramModel(nn.Module): ...@@ -326,6 +363,18 @@ class SkipGramModel(nn.Module):
else: else:
grad_u_pos = score * emb_pos_v grad_u_pos = score * emb_pos_v
grad_v_pos = score * emb_pos_u grad_v_pos = score * emb_pos_u
if self.use_context_weight:
if bs < self.batch_size:
context_weight = init_weight(
self.walk_length,
self.window_size,
bs).to(self.device)
else:
context_weight = self.context_weight
grad_u_pos *= context_weight
grad_v_pos *= context_weight
# [batch_size * walk_length, dim] # [batch_size * walk_length, dim]
if bs < self.batch_size: if bs < self.batch_size:
grad_u, grad_v = init_empty_grad( grad_u, grad_v = init_empty_grad(
...@@ -365,6 +414,8 @@ class SkipGramModel(nn.Module): ...@@ -365,6 +414,8 @@ class SkipGramModel(nn.Module):
neg_score = torch.clamp(neg_score, max=6, min=-6) neg_score = torch.clamp(neg_score, max=6, min=-6)
# [batch_size * walk_length * negative, 1] # [batch_size * walk_length * negative, 1]
score = - self.fast_sigmoid(neg_score).unsqueeze(1) score = - self.fast_sigmoid(neg_score).unsqueeze(1)
if self.record_loss:
self.loss.append(self.negative * self.neg_weight * torch.mean(self.fast_logsigmoid(-neg_score)).item())
grad_u_neg = self.neg_weight * score * emb_neg_v grad_u_neg = self.neg_weight * score * emb_neg_v
grad_v_neg = self.neg_weight * score * emb_neg_u grad_v_neg = self.neg_weight * score * emb_neg_u
...@@ -375,40 +426,35 @@ class SkipGramModel(nn.Module): ...@@ -375,40 +426,35 @@ class SkipGramModel(nn.Module):
## Update ## Update
nodes = nodes.view(-1) nodes = nodes.view(-1)
if self.avg_sgd:
# since the times that a node are performed backward propagation are different, # use adam optimizer
# we need to average the gradients by different weight. grad_u = adam(grad_u, self.state_sum_u, nodes, lr, self.device, self.only_gpu)
# e.g. for sequence [1, 2, 3, ...] with window_size = 5, we have positive node grad_v = adam(grad_v, self.state_sum_v, nodes, lr, self.device, self.only_gpu)
# pairs [(1,2), (1, 3), (1,4), ...]. To average the gradients for each node, we if neg_nodes is not None:
# perform weighting on the gradients of node pairs. grad_v_neg = adam(grad_v_neg, self.state_sum_v, neg_nodes, lr, self.device, self.only_gpu)
# The weights are: [1/5, 1/5, ..., 1/6, ..., 1/10, ..., 1/6, ..., 1/5].
if bs < self.batch_size:
grad_avg = init_grad_avg(
self.walk_length,
self.window_size,
bs).to(self.device)
else:
grad_avg = self.grad_avg
grad_u = grad_avg * grad_u * lr
grad_v = grad_avg * grad_v * lr
elif self.sgd:
grad_u = grad_u * lr
grad_v = grad_v * lr
elif self.adam:
# use adam optimizer
grad_u = adam(grad_u, self.state_sum_u, nodes, lr, self.device, self.only_gpu)
grad_v = adam(grad_v, self.state_sum_v, nodes, lr, self.device, self.only_gpu)
if self.mixed_train: if self.mixed_train:
grad_u = grad_u.cpu() grad_u = grad_u.cpu()
grad_v = grad_v.cpu() grad_v = grad_v.cpu()
if neg_nodes is not None: if neg_nodes is not None:
grad_v_neg = grad_v_neg.cpu() grad_v_neg = grad_v_neg.cpu()
else:
grad_v_neg = None
if self.async_update:
grad_u.share_memory_()
grad_v.share_memory_()
nodes.share_memory_()
if neg_nodes is not None:
neg_nodes.share_memory_()
grad_v_neg.share_memory_()
self.async_q.put((grad_u, grad_v, grad_v_neg, nodes, neg_nodes))
self.u_embeddings.weight.data.index_add_(0, nodes.view(-1), grad_u) if not self.async_update:
self.v_embeddings.weight.data.index_add_(0, nodes.view(-1), grad_v) self.u_embeddings.weight.data.index_add_(0, nodes.view(-1), grad_u)
if neg_nodes is not None: self.v_embeddings.weight.data.index_add_(0, nodes.view(-1), grad_v)
self.v_embeddings.weight.data.index_add_(0, neg_nodes.view(-1), lr * grad_v_neg) if neg_nodes is not None:
self.v_embeddings.weight.data.index_add_(0, neg_nodes.view(-1), grad_v_neg)
return return
def forward(self, pos_u, pos_v, neg_v): def forward(self, pos_u, pos_v, neg_v):
...@@ -429,7 +475,7 @@ class SkipGramModel(nn.Module): ...@@ -429,7 +475,7 @@ class SkipGramModel(nn.Module):
return torch.sum(score), torch.sum(neg_score) return torch.sum(score), torch.sum(neg_score)
def save_embedding(self, dataset, file_name): def save_embedding(self, dataset, file_name):
""" Write embedding to local file. """ Write embedding to local file. Only used when node ids are numbers.
Parameter Parameter
--------- ---------
...@@ -437,8 +483,41 @@ class SkipGramModel(nn.Module): ...@@ -437,8 +483,41 @@ class SkipGramModel(nn.Module):
file_name str : the file name file_name str : the file name
""" """
embedding = self.u_embeddings.weight.cpu().data.numpy() embedding = self.u_embeddings.weight.cpu().data.numpy()
if self.norm:
embedding /= np.sqrt(np.sum(embedding * embedding, 1)).reshape(-1, 1)
np.save(file_name, embedding) np.save(file_name, embedding)
def save_embedding_pt(self, dataset, file_name):
""" For ogb leaderboard.
"""
try:
max_node_id = max(dataset.node2id.keys())
if max_node_id + 1 != self.emb_size:
print("WARNING: The node ids are not serial.")
embedding = torch.zeros(max_node_id + 1, self.emb_dimension)
index = torch.LongTensor(list(map(lambda id: dataset.id2node[id], list(range(self.emb_size)))))
embedding.index_add_(0, index, self.u_embeddings.weight.cpu().data)
if self.norm:
embedding /= torch.sqrt(torch.sum(embedding.mul(embedding), 1) + 1e-6).unsqueeze(1)
torch.save(embedding, file_name)
except:
self.save_embedding_pt_dgl_graph(dataset, file_name)
def save_embedding_pt_dgl_graph(self, dataset, file_name):
""" For ogb leaderboard """
embedding = torch.zeros_like(self.u_embeddings.weight.cpu().data)
valid_seeds = torch.LongTensor(dataset.valid_seeds)
valid_embedding = self.u_embeddings.weight.cpu().data.index_select(0,
valid_seeds)
embedding.index_add_(0, valid_seeds, self.u_embeddings.weight.cpu().data)
if self.norm:
embedding /= torch.sqrt(torch.sum(embedding.mul(embedding), 1) + 1e-6).unsqueeze(1)
torch.save(embedding, file_name)
def save_embedding_txt(self, dataset, file_name): def save_embedding_txt(self, dataset, file_name):
""" Write embedding to local file. For future use. """ Write embedding to local file. For future use.
...@@ -448,8 +527,10 @@ class SkipGramModel(nn.Module): ...@@ -448,8 +527,10 @@ class SkipGramModel(nn.Module):
file_name str : the file name file_name str : the file name
""" """
embedding = self.u_embeddings.weight.cpu().data.numpy() embedding = self.u_embeddings.weight.cpu().data.numpy()
if self.norm:
embedding /= np.sqrt(np.sum(embedding * embedding, 1)).reshape(-1, 1)
with open(file_name, 'w') as f: with open(file_name, 'w') as f:
f.write('%d %d\n' % (self.emb_size, self.emb_dimension)) f.write('%d %d\n' % (self.emb_size, self.emb_dimension))
for wid in range(self.emb_size): for wid in range(self.emb_size):
e = ' '.join(map(lambda x: str(x), embedding[wid])) e = ' '.join(map(lambda x: str(x), embedding[wid]))
f.write('%s %s\n' % (str(dataset.id2node[wid]), e)) f.write('%s %s\n' % (str(dataset.id2node[wid]), e))
\ No newline at end of file
...@@ -8,8 +8,8 @@ from dgl.data.utils import download, _get_dgl_url, get_download_dir, extract_arc ...@@ -8,8 +8,8 @@ from dgl.data.utils import download, _get_dgl_url, get_download_dir, extract_arc
import random import random
import time import time
import dgl import dgl
from utils import shuffle_walks from utils import shuffle_walks
#np.random.seed(3141592653)
def ReadTxtNet(file_path="", undirected=True): def ReadTxtNet(file_path="", undirected=True):
""" Read the txt network file. """ Read the txt network file.
...@@ -41,10 +41,17 @@ def ReadTxtNet(file_path="", undirected=True): ...@@ -41,10 +41,17 @@ def ReadTxtNet(file_path="", undirected=True):
src = [] src = []
dst = [] dst = []
weight = []
net = {} net = {}
with open(file_path, "r") as f: with open(file_path, "r") as f:
for line in f.readlines(): for line in f.readlines():
n1, n2 = list(map(int, line.strip().split(" ")[:2])) tup = list(map(int, line.strip().split(" ")))
assert len(tup) in [2, 3], "The format of network file is unrecognizable."
if len(tup) == 3:
n1, n2, w = tup
elif len(tup) == 2:
n1, n2 = tup
w = 1
if n1 not in node2id: if n1 not in node2id:
node2id[n1] = cid node2id[n1] = cid
id2node[cid] = n1 id2node[cid] = n1
...@@ -57,30 +64,34 @@ def ReadTxtNet(file_path="", undirected=True): ...@@ -57,30 +64,34 @@ def ReadTxtNet(file_path="", undirected=True):
n1 = node2id[n1] n1 = node2id[n1]
n2 = node2id[n2] n2 = node2id[n2]
if n1 not in net: if n1 not in net:
net[n1] = {n2: 1} net[n1] = {n2: w}
src.append(n1) src.append(n1)
dst.append(n2) dst.append(n2)
weight.append(w)
elif n2 not in net[n1]: elif n2 not in net[n1]:
net[n1][n2] = 1 net[n1][n2] = w
src.append(n1) src.append(n1)
dst.append(n2) dst.append(n2)
weight.append(w)
if undirected: if undirected:
if n2 not in net: if n2 not in net:
net[n2] = {n1: 1} net[n2] = {n1: w}
src.append(n2) src.append(n2)
dst.append(n1) dst.append(n1)
weight.append(w)
elif n1 not in net[n2]: elif n1 not in net[n2]:
net[n2][n1] = 1 net[n2][n1] = w
src.append(n2) src.append(n2)
dst.append(n1) dst.append(n1)
weight.append(w)
print("node num: %d" % len(net)) print("node num: %d" % len(net))
print("edge num: %d" % len(src)) print("edge num: %d" % len(src))
assert max(net.keys()) == len(net) - 1, "error reading net, quit" assert max(net.keys()) == len(net) - 1, "error reading net, quit"
sm = sp.coo_matrix( sm = sp.coo_matrix(
(np.ones(len(src)), (src, dst)), (np.array(weight), (src, dst)),
dtype=np.float32) dtype=np.float32)
return net, node2id, id2node, sm return net, node2id, id2node, sm
...@@ -99,17 +110,31 @@ def net2graph(net_sm): ...@@ -99,17 +110,31 @@ def net2graph(net_sm):
print("Building DGLGraph in %.2fs" % t) print("Building DGLGraph in %.2fs" % t)
return G return G
def make_undirected(G):
G.readonly(False)
G.add_edges(G.edges()[1], G.edges()[0])
return G
def find_connected_nodes(G):
nodes = []
for n in G.nodes():
if G.out_degree(n) > 0:
nodes.append(n.item())
return nodes
class DeepwalkDataset: class DeepwalkDataset:
def __init__(self, def __init__(self,
net_file, net_file,
map_file, map_file,
walk_length=80, walk_length,
window_size=5, window_size,
num_walks=10, num_walks,
batch_size=32, batch_size,
negative=5, negative=5,
gpus=[0], gpus=[0],
fast_neg=True, fast_neg=True,
ogbl_name="",
load_from_ogbl=False,
): ):
""" This class has the following functions: """ This class has the following functions:
1. Transform the txt network file into DGL graph; 1. Transform the txt network file into DGL graph;
...@@ -134,51 +159,66 @@ class DeepwalkDataset: ...@@ -134,51 +159,66 @@ class DeepwalkDataset:
self.negative = negative self.negative = negative
self.num_procs = len(gpus) self.num_procs = len(gpus)
self.fast_neg = fast_neg self.fast_neg = fast_neg
self.net, self.node2id, self.id2node, self.sm = ReadTxtNet(net_file)
self.save_mapping(map_file) if load_from_ogbl:
self.G = net2graph(self.sm) assert len(gpus) == 1, "ogb.linkproppred is not compatible with multi-gpu training (CUDA error)."
from load_dataset import load_from_ogbl_with_name
self.G = load_from_ogbl_with_name(ogbl_name)
self.G = make_undirected(self.G)
else:
self.net, self.node2id, self.id2node, self.sm = ReadTxtNet(net_file)
self.save_mapping(map_file)
self.G = net2graph(self.sm)
self.num_nodes = self.G.number_of_nodes()
# random walk seeds # random walk seeds
start = time.time() start = time.time()
seeds = torch.cat([torch.LongTensor(self.G.nodes())] * num_walks) self.valid_seeds = find_connected_nodes(self.G)
self.seeds = torch.split(shuffle_walks(seeds), int(np.ceil(len(self.net) * self.num_walks / self.num_procs)), 0) if len(self.valid_seeds) != self.num_nodes:
print("WARNING: The node ids are not serial. Some nodes are invalid.")
seeds = torch.cat([torch.LongTensor(self.valid_seeds)] * num_walks)
self.seeds = torch.split(shuffle_walks(seeds),
int(np.ceil(len(self.valid_seeds) * self.num_walks / self.num_procs)),
0)
end = time.time() end = time.time()
t = end - start t = end - start
print("%d seeds in %.2fs" % (len(seeds), t)) print("%d seeds in %.2fs" % (len(seeds), t))
# negative table for true negative sampling # negative table for true negative sampling
if not fast_neg: if not fast_neg:
node_degree = np.array(list(map(lambda x: len(self.net[x]), self.net.keys()))) node_degree = np.array(list(map(lambda x: self.G.out_degree(x), self.valid_seeds)))
node_degree = np.power(node_degree, 0.75) node_degree = np.power(node_degree, 0.75)
node_degree /= np.sum(node_degree) node_degree /= np.sum(node_degree)
node_degree = np.array(node_degree * 1e8, dtype=np.int) node_degree = np.array(node_degree * 1e8, dtype=np.int)
self.neg_table = [] self.neg_table = []
for idx, node in enumerate(self.net.keys()):
for idx, node in enumerate(self.valid_seeds):
self.neg_table += [node] * node_degree[idx] self.neg_table += [node] * node_degree[idx]
self.neg_table_size = len(self.neg_table) self.neg_table_size = len(self.neg_table)
self.neg_table = np.array(self.neg_table, dtype=np.long) self.neg_table = np.array(self.neg_table, dtype=np.long)
del node_degree del node_degree
def create_sampler(self, gpu_id): def create_sampler(self, i):
""" Still in construction... """ create random walk sampler """
return DeepwalkSampler(self.G, self.seeds[i], self.walk_length)
Several mode:
1. do true negative sampling.
1.1 from random walk sequence
1.2 from node degree distribution
return the sampled node ids
2. do false negative sampling from random walk sequence
save GPU, faster
return the node indices in the sequences
"""
return DeepwalkSampler(self.G, self.seeds[gpu_id], self.walk_length)
def save_mapping(self, map_file): def save_mapping(self, map_file):
""" save the mapping dict that maps node IDs to embedding indices """
with open(map_file, "wb") as f: with open(map_file, "wb") as f:
pickle.dump(self.node2id, f) pickle.dump(self.node2id, f)
class DeepwalkSampler(object): class DeepwalkSampler(object):
def __init__(self, G, seeds, walk_length): def __init__(self, G, seeds, walk_length):
""" random walk sampler
Parameter
---------
G dgl.Graph : the input graph
seeds torch.LongTensor : starting nodes
walk_length int : walk length
"""
self.G = G self.G = G
self.seeds = seeds self.seeds = seeds
self.walk_length = walk_length self.walk_length = walk_length
......
...@@ -34,4 +34,32 @@ def thread_wrapped_func(func): ...@@ -34,4 +34,32 @@ def thread_wrapped_func(func):
def shuffle_walks(walks): def shuffle_walks(walks):
seeds = torch.randperm(walks.size()[0]) seeds = torch.randperm(walks.size()[0])
return walks[seeds] return walks[seeds]
\ No newline at end of file
def sum_up_params(model):
""" Count the model parameters """
n = []
n.append(model.u_embeddings.weight.cpu().data.numel() * 2)
n.append(model.lookup_table.cpu().numel())
n.append(model.index_emb_posu.cpu().numel() * 2)
n.append(model.grad_u.cpu().numel() * 2)
try:
n.append(model.index_emb_negu.cpu().numel() * 2)
except:
pass
try:
n.append(model.state_sum_u.cpu().numel() * 2)
except:
pass
try:
n.append(model.grad_avg.cpu().numel())
except:
pass
try:
n.append(model.context_weight.cpu().numel())
except:
pass
print("#params " + str(sum(n)))
exit()
\ No newline at end of file
...@@ -15,7 +15,7 @@ For evaluatation we follow the code mlp.py provided by ogb [here](https://github ...@@ -15,7 +15,7 @@ For evaluatation we follow the code mlp.py provided by ogb [here](https://github
## Used config ## Used config
ogbl-collab ogbl-collab
``` ```
python3 deepwalk.py --ogbl_name ogbl-collab --load_from_ogbl --save_in_pt --output_emb_file embedding.pt --num_walks 50 --window_size 20 --walk_length 40 --lr 0.1 --negative 1 --neg_weight 1 --lap_norm 0.005 --mix --adam --gpus 0 --num_threads 4 --print_interval 2000 --print_loss --batch_size 32 python3 deepwalk.py --ogbl_name ogbl-collab --load_from_ogbl --save_in_pt --output_emb_file embedding.pt --num_walks 50 --window_size 20 --walk_length 40 --lr 0.1 --negative 1 --neg_weight 1 --lap_norm 0.005 --mix --gpus 0 --num_threads 4 --print_interval 2000 --print_loss --batch_size 32
cd ./ogb/blob/master/examples/linkproppred/collab/ cd ./ogb/blob/master/examples/linkproppred/collab/
cp embedding_pt_file_path ./ cp embedding_pt_file_path ./
python3 mlp.py --device 0 --runs 10 --use_node_embedding python3 mlp.py --device 0 --runs 10 --use_node_embedding
...@@ -23,7 +23,7 @@ python3 mlp.py --device 0 --runs 10 --use_node_embedding ...@@ -23,7 +23,7 @@ python3 mlp.py --device 0 --runs 10 --use_node_embedding
ogbl-ddi ogbl-ddi
``` ```
python3 deepwalk.py --ogbl_name ogbl-ddi --load_from_ogbl --save_in_pt --output_emb_file ddi-embedding.pt --num_walks 50 --window_size 2 --walk_length 80 --lr 0.1 --negative 1 --neg_weight 1 --lap_norm 0.05 --only_gpu --adam --gpus 0 --num_threads 4 --print_interval 2000 --print_loss --batch_size 16 --use_context_weight python3 deepwalk.py --ogbl_name ogbl-ddi --load_from_ogbl --save_in_pt --output_emb_file ddi-embedding.pt --num_walks 50 --window_size 2 --walk_length 80 --lr 0.1 --negative 1 --neg_weight 1 --lap_norm 0.05 --only_gpu --gpus 0 --num_threads 4 --print_interval 2000 --print_loss --batch_size 16 --use_context_weight
cd ./ogb/blob/master/examples/linkproppred/ddi/ cd ./ogb/blob/master/examples/linkproppred/ddi/
cp embedding_pt_file_path ./ cp embedding_pt_file_path ./
python3 mlp.py --device 0 --runs 10 --epochs 100 python3 mlp.py --device 0 --runs 10 --epochs 100
...@@ -31,14 +31,14 @@ python3 mlp.py --device 0 --runs 10 --epochs 100 ...@@ -31,14 +31,14 @@ python3 mlp.py --device 0 --runs 10 --epochs 100
ogbl-ppa ogbl-ppa
``` ```
python3 deepwalk.py --ogbl_name ogbl-ppa --load_from_ogbl --save_in_pt --output_emb_file ppa-embedding.pt --negative 1 --neg_weight 1 --batch_size 64 --print_interval 2000 --print_loss --window_size 2 --num_walks 30 --walk_length 80 --lr 0.1 --lap_norm 0.02 --adam --mix --gpus 0 --use_context_weight --num_threads 4 python3 deepwalk.py --ogbl_name ogbl-ppa --load_from_ogbl --save_in_pt --output_emb_file ppa-embedding.pt --negative 1 --neg_weight 1 --batch_size 64 --print_interval 2000 --print_loss --window_size 1 --num_walks 30 --walk_length 80 --lr 0.1 --lap_norm 0.02 --mix --gpus 0 --num_threads 4
cp embedding_pt_file_path ./ cp embedding_pt_file_path ./
python3 mlp.py --device 2 --runs 10 python3 mlp.py --device 2 --runs 10
``` ```
ogbl-citation ogbl-citation
``` ```
python3 deepwalk.py --ogbl_name ogbl-citation --load_from_ogbl --save_in_pt --output_emb_file embedding.pt --window_size 2 --num_walks 10 --negative 1 --neg_weight 1 --walk_length 80 --batch_size 128 --print_loss --print_interval 1000 --mix --adam --gpus 0 --use_context_weight --num_threads 4 --lap_norm 0.05 --lr 0.1 python3 deepwalk.py --ogbl_name ogbl-citation --load_from_ogbl --save_in_pt --output_emb_file embedding.pt --window_size 2 --num_walks 10 --negative 1 --neg_weight 1 --walk_length 80 --batch_size 128 --print_loss --print_interval 1000 --mix --gpus 0 --use_context_weight --num_threads 4 --lap_norm 0.05 --lr 0.1
cp embedding_pt_file_path ./ cp embedding_pt_file_path ./
python3 mlp.py --device 2 --runs 10 --use_node_embedding python3 mlp.py --device 2 --runs 10 --use_node_embedding
``` ```
...@@ -64,45 +64,45 @@ ogbl-collab ...@@ -64,45 +64,45 @@ ogbl-collab
<br>obgl-ddi <br>obgl-ddi
<br>#params: 1444840(model) + 99073(mlp) = 1543913 <br>#params: 1444840(model) + 99073(mlp) = 1543913
<br>&emsp;Hits@10 <br>Hits@10
<br>&emsp;Highest Train: 36.09 ± 2.47 <br>&emsp;Highest Train: 33.91 ± 2.01
<br>&emsp;Highest Valid: 32.83 ± 2.30 <br>&emsp;Highest Valid: 30.96 ± 1.89
<br>&emsp;&emsp;Final Train: 36.06 ± 2.45 <br>&emsp;&emsp;Final Train: 33.90 ± 2.00
<br>&emsp;&emsp;Final Test: 11.76 ± 3.91 <br>&emsp;&emsp;Final Test: 15.16 ± 4.28
<br>&emsp;Hits@20 <br>Hits@20
<br>&emsp;Highest Train: 45.59 ± 2.45 <br>&emsp;Highest Train: 44.64 ± 1.71
<br>&emsp;Highest Valid: 42.00 ± 2.36 <br>&emsp;Highest Valid: 41.32 ± 1.69
<br>&emsp;&emsp;Final Train: 45.56 ± 2.50 <br>&emsp;&emsp;Final Train: 44.62 ± 1.69
<br>&emsp;&emsp;Final Test: 22.46 ± 2.90 <br>&emsp;&emsp;Final Test: 26.42 ± 6.10
<br>&emsp;Hits@30 <br>Hits@30
<br>&emsp;Highest Train: 51.58 ± 2.41 <br>&emsp;Highest Train: 51.01 ± 1.72
<br>&emsp;Highest Valid: 47.82 ± 2.19 <br>&emsp;Highest Valid: 47.64 ± 1.71
<br>&emsp;&emsp;Final Train: 51.58 ± 2.42 <br>&emsp;&emsp;Final Train: 50.99 ± 1.72
<br>&emsp;&emsp;Final Test: 30.17 ± 3.39 <br>&emsp;&emsp;Final Test: 33.56 ± 3.95
<br>ogbl-ppa <br>ogbl-ppa
<br>#params: 150024820(model) + 113921(mlp) = 150138741 <br>#params: 150024820(model) + 113921(mlp) = 150138741
<br>Hits@10 <br>Hits@10
<br>&emsp;Highest Train: 3.58 ± 0.90 <br>&emsp;Highest Train: 4.78 ± 0.73
<br>&emsp;Highest Valid: 2.88 ± 0.76 <br>&emsp;Highest Valid: 4.30 ± 0.68
<br>&emsp;&emsp;Final Train: 3.58 ± 0.90 <br>&emsp;&emsp;Final Train: 4.77 ± 0.73
<br>&emsp;&emsp;Final Test: 1.45 ± 0.65 <br>&emsp;&emsp;Final Test: 2.67 ± 0.42
<br>&emsp;Hits@50 <br>Hits@50
<br>&emsp;Highest Train: 18.21 ± 2.29 <br>&emsp;Highest Train: 18.82 ± 1.07
<br>&emsp;Highest Valid: 15.75 ± 2.10 <br>&emsp;Highest Valid: 17.26 ± 1.01
<br>&emsp;&emsp;Final Train: 18.21 ± 2.29 <br>&emsp;&emsp;Final Train: 18.82 ± 1.07
<br>&emsp;&emsp;Final Test: 11.70 ± 0.97 <br>&emsp;&emsp;Final Test: 17.34 ± 2.09
<br>&emsp;Hits@100 <br>Hits@100
<br>&emsp;Highest Train: 31.16 ± 2.23 <br>&emsp;Highest Train: 31.29 ± 2.11
<br>&emsp;Highest Valid: 27.52 ± 2.07 <br>&emsp;Highest Valid: 28.97 ± 1.92
<br>&emsp;&emsp;Final Train: 31.16 ± 2.23 <br>&emsp;&emsp;Final Train: 31.28 ± 2.12
<br>&emsp;&emsp;Final Test: 23.02 ± 1.63 <br>&emsp;&emsp;Final Test: 28.88 ± 1.53
<br>ogbl-citation <br>ogbl-citation
<br>#params: 757811178(model) + 131841(mlp) = 757943019 <br>#params: 757811178(model) + 131841(mlp) = 757943019
<br>MRR <br>MRR
<br>&emsp;Highest Train: 0.8797 ± 0.0007 <br>&emsp;Highest Train: 0.8994 ± 0.0004
<br>&emsp;Highest Valid: 0.8139 ± 0.0005 <br>&emsp;Highest Valid: 0.8271 ± 0.0003
<br>&emsp;&emsp;Final Train: 0.8792 ± 0.0008 <br>&emsp;&emsp;Final Train: 0.8991 ± 0.0007
<br>&emsp;&emsp;Final Test: 0.8148 ± 0.0004 <br>&emsp;&emsp;Final Test: 0.8284 ± 0.0005
...@@ -38,8 +38,6 @@ class DeepwalkTrainer: ...@@ -38,8 +38,6 @@ class DeepwalkTrainer:
""" """
choices = sum([self.args.only_gpu, self.args.only_cpu, self.args.mix]) choices = sum([self.args.only_gpu, self.args.only_cpu, self.args.mix])
assert choices == 1, "Must choose only *one* training mode in [only_cpu, only_gpu, mix]" assert choices == 1, "Must choose only *one* training mode in [only_cpu, only_gpu, mix]"
choices = sum([self.args.sgd, self.args.adam, self.args.avg_sgd])
assert choices == 1, "Must choose only *one* gradient descent strategy in [sgd, avg_sgd, adam]"
# initializing embedding on CPU # initializing embedding on CPU
self.emb_model = SkipGramModel( self.emb_model = SkipGramModel(
...@@ -55,13 +53,12 @@ class DeepwalkTrainer: ...@@ -55,13 +53,12 @@ class DeepwalkTrainer:
negative=self.args.negative, negative=self.args.negative,
lr=self.args.lr, lr=self.args.lr,
lap_norm=self.args.lap_norm, lap_norm=self.args.lap_norm,
adam=self.args.adam,
sgd=self.args.sgd,
avg_sgd=self.args.avg_sgd,
fast_neg=self.args.fast_neg, fast_neg=self.args.fast_neg,
record_loss=self.args.print_loss, record_loss=self.args.print_loss,
norm=self.args.norm, norm=self.args.norm,
use_context_weight=self.args.use_context_weight, use_context_weight=self.args.use_context_weight,
async_update=self.args.async_update,
num_threads=self.args.num_threads,
) )
torch.set_num_threads(self.args.num_threads) torch.set_num_threads(self.args.num_threads)
...@@ -72,8 +69,8 @@ class DeepwalkTrainer: ...@@ -72,8 +69,8 @@ class DeepwalkTrainer:
elif self.args.mix: elif self.args.mix:
print("Mix CPU with %d GPU" % len(self.args.gpus)) print("Mix CPU with %d GPU" % len(self.args.gpus))
if len(self.args.gpus) == 1: if len(self.args.gpus) == 1:
assert self.args.gpus[0] >= 0, 'mix CPU with GPU should have abaliable GPU' assert self.args.gpus[0] >= 0, 'mix CPU with GPU should have available GPU'
#self.emb_model.set_device(self.args.gpus[0]) self.emb_model.set_device(self.args.gpus[0])
else: else:
print("Run in CPU process") print("Run in CPU process")
self.args.gpus = [torch.device('cpu')] self.args.gpus = [torch.device('cpu')]
...@@ -98,7 +95,7 @@ class DeepwalkTrainer: ...@@ -98,7 +95,7 @@ class DeepwalkTrainer:
ps = [] ps = []
for i in range(len(self.args.gpus)): for i in range(len(self.args.gpus)):
p = mp.Process(target=self.fast_train_sp, args=(self.args.gpus[i],)) p = mp.Process(target=self.fast_train_sp, args=(i, self.args.gpus[i]))
ps.append(p) ps.append(p)
p.start() p.start()
...@@ -114,13 +111,16 @@ class DeepwalkTrainer: ...@@ -114,13 +111,16 @@ class DeepwalkTrainer:
self.emb_model.save_embedding(self.dataset, self.args.output_emb_file) self.emb_model.save_embedding(self.dataset, self.args.output_emb_file)
@thread_wrapped_func @thread_wrapped_func
def fast_train_sp(self, gpu_id): def fast_train_sp(self, rank, gpu_id):
""" a subprocess for fast_train_mp """ """ a subprocess for fast_train_mp """
if self.args.mix: if self.args.mix:
self.emb_model.set_device(gpu_id) self.emb_model.set_device(gpu_id)
torch.set_num_threads(self.args.num_threads) torch.set_num_threads(self.args.num_threads)
if self.args.async_update:
self.emb_model.create_async_update()
sampler = self.dataset.create_sampler(gpu_id) sampler = self.dataset.create_sampler(rank)
dataloader = DataLoader( dataloader = DataLoader(
dataset=sampler.seeds, dataset=sampler.seeds,
...@@ -128,26 +128,19 @@ class DeepwalkTrainer: ...@@ -128,26 +128,19 @@ class DeepwalkTrainer:
collate_fn=sampler.sample, collate_fn=sampler.sample,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
num_workers=4, num_workers=self.args.num_sampler_threads,
) )
num_batches = len(dataloader) num_batches = len(dataloader)
print("num batchs: %d in subprocess [%d]" % (num_batches, gpu_id)) print("num batchs: %d in process [%d] GPU [%d]" % (num_batches, rank, gpu_id))
# number of positive node pairs in a sequence # number of positive node pairs in a sequence
num_pos = int(2 * self.args.walk_length * self.args.window_size\ num_pos = int(2 * self.args.walk_length * self.args.window_size\
- self.args.window_size * (self.args.window_size + 1)) - self.args.window_size * (self.args.window_size + 1))
start = time.time() start = time.time()
with torch.no_grad(): with torch.no_grad():
max_i = self.args.iterations * num_batches
for i, walks in enumerate(dataloader): for i, walks in enumerate(dataloader):
# decay learning rate for SGD
lr = self.args.lr * (max_i - i) / max_i
if lr < 0.00001:
lr = 0.00001
if self.args.fast_neg: if self.args.fast_neg:
self.emb_model.fast_learn(walks, lr) self.emb_model.fast_learn(walks)
else: else:
# do negative sampling # do negative sampling
bs = len(walks) bs = len(walks)
...@@ -155,19 +148,22 @@ class DeepwalkTrainer: ...@@ -155,19 +148,22 @@ class DeepwalkTrainer:
np.random.choice(self.dataset.neg_table, np.random.choice(self.dataset.neg_table,
bs * num_pos * self.args.negative, bs * num_pos * self.args.negative,
replace=True)) replace=True))
self.emb_model.fast_learn(walks, lr, neg_nodes=neg_nodes) self.emb_model.fast_learn(walks, neg_nodes=neg_nodes)
if i > 0 and i % self.args.print_interval == 0: if i > 0 and i % self.args.print_interval == 0:
if self.args.print_loss: if self.args.print_loss:
print("Solver [%d] batch %d tt: %.2fs loss: %.4f" \ print("GPU-[%d] batch %d time: %.2fs loss: %.4f" \
% (gpu_id, i, time.time()-start, -sum(self.emb_model.loss)/self.args.print_interval)) % (gpu_id, i, time.time()-start, -sum(self.emb_model.loss)/self.args.print_interval))
self.emb_model.loss = [] self.emb_model.loss = []
else: else:
print("Solver [%d] batch %d tt: %.2fs" % (gpu_id, i, time.time()-start)) print("GPU-[%d] batch %d time: %.2fs" % (gpu_id, i, time.time()-start))
start = time.time() start = time.time()
if self.args.async_update:
self.emb_model.finish_async_update()
def fast_train(self): def fast_train(self):
""" fast train with dataloader """ """ fast train with dataloader with only gpu / only cpu"""
# the number of postive node pairs of a node sequence # the number of postive node pairs of a node sequence
num_pos = 2 * self.args.walk_length * self.args.window_size\ num_pos = 2 * self.args.walk_length * self.args.window_size\
- self.args.window_size * (self.args.window_size + 1) - self.args.window_size * (self.args.window_size + 1)
...@@ -175,6 +171,10 @@ class DeepwalkTrainer: ...@@ -175,6 +171,10 @@ class DeepwalkTrainer:
self.init_device_emb() self.init_device_emb()
if self.args.async_update:
self.emb_model.share_memory()
self.emb_model.create_async_update()
if self.args.count_params: if self.args.count_params:
sum_up_params(self.emb_model) sum_up_params(self.emb_model)
...@@ -186,44 +186,39 @@ class DeepwalkTrainer: ...@@ -186,44 +186,39 @@ class DeepwalkTrainer:
collate_fn=sampler.sample, collate_fn=sampler.sample,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
num_workers=4, num_workers=self.args.num_sampler_threads,
) )
num_batches = len(dataloader) num_batches = len(dataloader)
print("num batchs: %d" % num_batches) print("num batchs: %d\n" % num_batches)
start_all = time.time() start_all = time.time()
start = time.time() start = time.time()
with torch.no_grad(): with torch.no_grad():
max_i = self.args.iterations * num_batches max_i = num_batches
for iteration in range(self.args.iterations): for i, walks in enumerate(dataloader):
print("\nIteration: " + str(iteration + 1)) if self.args.fast_neg:
self.emb_model.fast_learn(walks)
for i, walks in enumerate(dataloader): else:
# decay learning rate for SGD # do negative sampling
lr = self.args.lr * (max_i - i) / max_i bs = len(walks)
if lr < 0.00001: neg_nodes = torch.LongTensor(
lr = 0.00001 np.random.choice(self.dataset.neg_table,
bs * num_pos * self.args.negative,
replace=True))
self.emb_model.fast_learn(walks, neg_nodes=neg_nodes)
if self.args.fast_neg: if i > 0 and i % self.args.print_interval == 0:
self.emb_model.fast_learn(walks, lr) if self.args.print_loss:
print("Batch %d training time: %.2fs loss: %.4f" \
% (i, time.time()-start, -sum(self.emb_model.loss)/self.args.print_interval))
self.emb_model.loss = []
else: else:
# do negative sampling print("Batch %d, training time: %.2fs" % (i, time.time()-start))
bs = len(walks) start = time.time()
neg_nodes = torch.LongTensor(
np.random.choice(self.dataset.neg_table,
bs * num_pos * self.args.negative,
replace=True))
self.emb_model.fast_learn(walks, lr, neg_nodes=neg_nodes)
if i > 0 and i % self.args.print_interval == 0: if self.args.async_update:
if self.args.print_loss: self.emb_model.finish_async_update()
print("Batch %d training time: %.2fs loss: %.4f" \
% (i, time.time()-start, -sum(self.emb_model.loss)/self.args.print_interval))
self.emb_model.loss = []
else:
print("Batch %d, training time: %.2fs" % (i, time.time()-start))
start = time.time()
print("Training used time: %.2fs" % (time.time()-start_all)) print("Training used time: %.2fs" % (time.time()-start_all))
if self.args.save_in_txt: if self.args.save_in_txt:
...@@ -266,20 +261,18 @@ if __name__ == '__main__': ...@@ -266,20 +261,18 @@ if __name__ == '__main__':
help="whether to add weights over nodes in the context window") help="whether to add weights over nodes in the context window")
parser.add_argument('--num_walks', default=10, type=int, parser.add_argument('--num_walks', default=10, type=int,
help="number of walks for each node") help="number of walks for each node")
parser.add_argument('--negative', default=5, type=int, parser.add_argument('--negative', default=1, type=int,
help="negative samples for each positve node pair") help="negative samples for each positve node pair")
parser.add_argument('--batch_size', default=10, type=int, parser.add_argument('--batch_size', default=128, type=int,
help="number of node sequences in each batch") help="number of node sequences in each batch")
parser.add_argument('--walk_length', default=80, type=int, parser.add_argument('--walk_length', default=80, type=int,
help="number of nodes in a sequence") help="number of nodes in a sequence")
parser.add_argument('--neg_weight', default=1., type=float, parser.add_argument('--neg_weight', default=1., type=float,
help="negative weight") help="negative weight")
parser.add_argument('--lap_norm', default=0.01, type=float, parser.add_argument('--lap_norm', default=0.01, type=float,
help="weight of laplacian normalization") help="weight of laplacian normalization, recommend to set as 0.1 / windoe_size")
# training parameters # training parameters
parser.add_argument('--iterations', default=1, type=int,
help="iterations")
parser.add_argument('--print_interval', default=100, type=int, parser.add_argument('--print_interval', default=100, type=int,
help="number of batches between printing") help="number of batches between printing")
parser.add_argument('--print_loss', default=False, action="store_true", parser.add_argument('--print_loss', default=False, action="store_true",
...@@ -296,23 +289,24 @@ if __name__ == '__main__': ...@@ -296,23 +289,24 @@ if __name__ == '__main__':
help="training with CPU") help="training with CPU")
parser.add_argument('--only_gpu', default=False, action="store_true", parser.add_argument('--only_gpu', default=False, action="store_true",
help="training with GPU") help="training with GPU")
parser.add_argument('--async_update', default=False, action="store_true",
help="mixed training asynchronously, not recommended")
parser.add_argument('--adam', default=False, action="store_true",
help="use adam for embedding updation, recommended")
parser.add_argument('--sgd', default=False, action="store_true",
help="use sgd for embedding updation")
parser.add_argument('--avg_sgd', default=False, action="store_true",
help="average gradients of sgd for embedding updation")
parser.add_argument('--fast_neg', default=False, action="store_true", parser.add_argument('--fast_neg', default=False, action="store_true",
help="do negative sampling inside a batch") help="do negative sampling inside a batch")
parser.add_argument('--num_threads', default=2, type=int, parser.add_argument('--num_threads', default=8, type=int,
help="number of threads used for each CPU-core/GPU") help="number of threads used for each CPU-core/GPU")
parser.add_argument('--num_sampler_threads', default=2, type=int,
help="number of threads used for sampling")
parser.add_argument('--count_params', default=False, action="store_true", parser.add_argument('--count_params', default=False, action="store_true",
help="count the params, then exit") help="count the params, exit once counting over")
args = parser.parse_args() args = parser.parse_args()
if args.async_update:
assert args.mix, "--async_update only with --mix"
start_time = time.time() start_time = time.time()
trainer = DeepwalkTrainer(args) trainer = DeepwalkTrainer(args)
trainer.train() trainer.train()
......
...@@ -4,6 +4,10 @@ import torch.nn.functional as F ...@@ -4,6 +4,10 @@ import torch.nn.functional as F
from torch.nn import init from torch.nn import init
import random import random
import numpy as np import numpy as np
import torch.multiprocessing as mp
from torch.multiprocessing import Queue
from utils import thread_wrapped_func
def init_emb2pos_index(walk_length, window_size, batch_size): def init_emb2pos_index(walk_length, window_size, batch_size):
''' select embedding of positive nodes from a batch of node embeddings ''' select embedding of positive nodes from a batch of node embeddings
...@@ -72,25 +76,8 @@ def init_emb2neg_index(walk_length, window_size, negative, batch_size): ...@@ -72,25 +76,8 @@ def init_emb2neg_index(walk_length, window_size, negative, batch_size):
return index_emb_negu, index_emb_negv return index_emb_negu, index_emb_negv
def init_grad_avg(walk_length, window_size, batch_size):
''' averaging graidents by specific weights
'''
grad_avg = []
for b in range(batch_size):
for i in range(walk_length):
if i < window_size:
grad_avg.append(1. / float(i+window_size))
elif i >= walk_length - window_size:
grad_avg.append(1. / float(walk_length - i - 1 + window_size))
else:
grad_avg.append(0.5 / window_size)
# [num_pos * batch_size]
return torch.Tensor(grad_avg).unsqueeze(1)
def init_weight(walk_length, window_size, batch_size): def init_weight(walk_length, window_size, batch_size):
''' select nodes' gradients from gradient matrix ''' init context weight '''
'''
weight = [] weight = []
for b in range(batch_size): for b in range(batch_size):
for i in range(walk_length): for i in range(walk_length):
...@@ -123,6 +110,20 @@ def adam(grad, state_sum, nodes, lr, device, only_gpu): ...@@ -123,6 +110,20 @@ def adam(grad, state_sum, nodes, lr, device, only_gpu):
return grad return grad
@thread_wrapped_func
def async_update(num_threads, model, queue):
""" asynchronous embedding update """
torch.set_num_threads(num_threads)
while True:
(grad_u, grad_v, grad_v_neg, nodes, neg_nodes) = queue.get()
if grad_u is None:
return
with torch.no_grad():
model.u_embeddings.weight.data.index_add_(0, nodes.view(-1), grad_u)
model.v_embeddings.weight.data.index_add_(0, nodes.view(-1), grad_v)
if neg_nodes is not None:
model.v_embeddings.weight.data.index_add_(0, neg_nodes.view(-1), grad_v_neg)
class SkipGramModel(nn.Module): class SkipGramModel(nn.Module):
""" Negative sampling based skip-gram """ """ Negative sampling based skip-gram """
def __init__(self, def __init__(self,
...@@ -138,13 +139,12 @@ class SkipGramModel(nn.Module): ...@@ -138,13 +139,12 @@ class SkipGramModel(nn.Module):
negative, negative,
lr, lr,
lap_norm, lap_norm,
adam,
sgd,
avg_sgd,
fast_neg, fast_neg,
record_loss, record_loss,
norm, norm,
use_context_weight, use_context_weight,
async_update,
num_threads,
): ):
""" initialize embedding on CPU """ initialize embedding on CPU
...@@ -162,10 +162,11 @@ class SkipGramModel(nn.Module): ...@@ -162,10 +162,11 @@ class SkipGramModel(nn.Module):
neg_weight float : negative weight neg_weight float : negative weight
lr float : initial learning rate lr float : initial learning rate
lap_norm float : weight of laplacian normalization lap_norm float : weight of laplacian normalization
adam bool : use adam for embedding updation
sgd bool : use sgd for embedding updation
avg_sgd bool : average gradients of sgd for embedding updation
fast_neg bool : do negative sampling inside a batch fast_neg bool : do negative sampling inside a batch
record_loss bool : print the loss during training
norm bool : do normalizatin on the embedding after training
use_context_weight : give different weights to the nodes in a context window
async_update : asynchronous training
""" """
super(SkipGramModel, self).__init__() super(SkipGramModel, self).__init__()
self.emb_size = emb_size self.emb_size = emb_size
...@@ -180,13 +181,12 @@ class SkipGramModel(nn.Module): ...@@ -180,13 +181,12 @@ class SkipGramModel(nn.Module):
self.negative = negative self.negative = negative
self.lr = lr self.lr = lr
self.lap_norm = lap_norm self.lap_norm = lap_norm
self.adam = adam
self.sgd = sgd
self.avg_sgd = avg_sgd
self.fast_neg = fast_neg self.fast_neg = fast_neg
self.record_loss = record_loss self.record_loss = record_loss
self.norm = norm self.norm = norm
self.use_context_weight = use_context_weight self.use_context_weight = use_context_weight
self.async_update = async_update
self.num_threads = num_threads
# initialize the device as cpu # initialize the device as cpu
self.device = torch.device("cpu") self.device = torch.device("cpu")
...@@ -215,12 +215,11 @@ class SkipGramModel(nn.Module): ...@@ -215,12 +215,11 @@ class SkipGramModel(nn.Module):
self.walk_length, self.walk_length,
self.window_size, self.window_size,
self.batch_size) self.batch_size)
if self.fast_neg: self.index_emb_negu, self.index_emb_negv = init_emb2neg_index(
self.index_emb_negu, self.index_emb_negv = init_emb2neg_index( self.walk_length,
self.walk_length, self.window_size,
self.window_size, self.negative,
self.negative, self.batch_size)
self.batch_size)
if self.use_context_weight: if self.use_context_weight:
self.context_weight = init_weight( self.context_weight = init_weight(
...@@ -228,16 +227,9 @@ class SkipGramModel(nn.Module): ...@@ -228,16 +227,9 @@ class SkipGramModel(nn.Module):
self.window_size, self.window_size,
self.batch_size) self.batch_size)
# coefficients for averaging the gradients
if self.avg_sgd:
self.grad_avg = init_grad_avg(
self.walk_length,
self.window_size,
self.batch_size)
# adam # adam
if self.adam: self.state_sum_u = torch.zeros(self.emb_size)
self.state_sum_u = torch.zeros(self.emb_size) self.state_sum_v = torch.zeros(self.emb_size)
self.state_sum_v = torch.zeros(self.emb_size)
# gradients of nodes in batch_walks # gradients of nodes in batch_walks
self.grad_u, self.grad_v = init_empty_grad( self.grad_u, self.grad_v = init_empty_grad(
...@@ -245,13 +237,25 @@ class SkipGramModel(nn.Module): ...@@ -245,13 +237,25 @@ class SkipGramModel(nn.Module):
self.walk_length, self.walk_length,
self.batch_size) self.batch_size)
def create_async_update(self):
""" Set up the async update subprocess.
"""
self.async_q = Queue(1)
self.async_p = mp.Process(target=async_update, args=(self.num_threads, self, self.async_q))
self.async_p.start()
def finish_async_update(self):
""" Notify the async update subprocess to quit.
"""
self.async_q.put((None, None, None, None, None))
self.async_p.join()
def share_memory(self): def share_memory(self):
""" share the parameters across subprocesses """ """ share the parameters across subprocesses """
self.u_embeddings.weight.share_memory_() self.u_embeddings.weight.share_memory_()
self.v_embeddings.weight.share_memory_() self.v_embeddings.weight.share_memory_()
if self.adam: self.state_sum_u.share_memory_()
self.state_sum_u.share_memory_() self.state_sum_v.share_memory_()
self.state_sum_v.share_memory_()
def set_device(self, gpu_id): def set_device(self, gpu_id):
""" set gpu device """ """ set gpu device """
...@@ -262,13 +266,10 @@ class SkipGramModel(nn.Module): ...@@ -262,13 +266,10 @@ class SkipGramModel(nn.Module):
self.logsigmoid_table = self.logsigmoid_table.to(self.device) self.logsigmoid_table = self.logsigmoid_table.to(self.device)
self.index_emb_posu = self.index_emb_posu.to(self.device) self.index_emb_posu = self.index_emb_posu.to(self.device)
self.index_emb_posv = self.index_emb_posv.to(self.device) self.index_emb_posv = self.index_emb_posv.to(self.device)
if self.fast_neg: self.index_emb_negu = self.index_emb_negu.to(self.device)
self.index_emb_negu = self.index_emb_negu.to(self.device) self.index_emb_negv = self.index_emb_negv.to(self.device)
self.index_emb_negv = self.index_emb_negv.to(self.device)
self.grad_u = self.grad_u.to(self.device) self.grad_u = self.grad_u.to(self.device)
self.grad_v = self.grad_v.to(self.device) self.grad_v = self.grad_v.to(self.device)
if self.avg_sgd:
self.grad_avg = self.grad_avg.to(self.device)
if self.use_context_weight: if self.use_context_weight:
self.context_weight = self.context_weight.to(self.device) self.context_weight = self.context_weight.to(self.device)
...@@ -278,9 +279,8 @@ class SkipGramModel(nn.Module): ...@@ -278,9 +279,8 @@ class SkipGramModel(nn.Module):
self.set_device(gpu_id) self.set_device(gpu_id)
self.u_embeddings = self.u_embeddings.cuda(gpu_id) self.u_embeddings = self.u_embeddings.cuda(gpu_id)
self.v_embeddings = self.v_embeddings.cuda(gpu_id) self.v_embeddings = self.v_embeddings.cuda(gpu_id)
if self.adam: self.state_sum_u = self.state_sum_u.to(self.device)
self.state_sum_u = self.state_sum_u.to(self.device) self.state_sum_v = self.state_sum_v.to(self.device)
self.state_sum_v = self.state_sum_v.to(self.device)
def fast_sigmoid(self, score): def fast_sigmoid(self, score):
""" do fast sigmoid by looking up in a pre-defined table """ """ do fast sigmoid by looking up in a pre-defined table """
...@@ -292,7 +292,7 @@ class SkipGramModel(nn.Module): ...@@ -292,7 +292,7 @@ class SkipGramModel(nn.Module):
idx = torch.floor((score + 6.01) / 0.01).long() idx = torch.floor((score + 6.01) / 0.01).long()
return self.logsigmoid_table[idx] return self.logsigmoid_table[idx]
def fast_learn(self, batch_walks, lr, neg_nodes=None): def fast_learn(self, batch_walks, neg_nodes=None):
""" Learn a batch of random walks in a fast way. It has the following features: """ Learn a batch of random walks in a fast way. It has the following features:
1. It calculating the gradients directly without the forward operation. 1. It calculating the gradients directly without the forward operation.
2. It does sigmoid by a looking up table. 2. It does sigmoid by a looking up table.
...@@ -317,8 +317,7 @@ class SkipGramModel(nn.Module): ...@@ -317,8 +317,7 @@ class SkipGramModel(nn.Module):
lr = 0.01 lr = 0.01
neg_nodes = None neg_nodes = None
""" """
if self.adam: lr = self.lr
lr = self.lr
# [batch_size, walk_length] # [batch_size, walk_length]
if isinstance(batch_walks, list): if isinstance(batch_walks, list):
...@@ -427,40 +426,35 @@ class SkipGramModel(nn.Module): ...@@ -427,40 +426,35 @@ class SkipGramModel(nn.Module):
## Update ## Update
nodes = nodes.view(-1) nodes = nodes.view(-1)
if self.avg_sgd:
# since the times that a node are performed backward propagation are different, # use adam optimizer
# we need to average the gradients by different weight. grad_u = adam(grad_u, self.state_sum_u, nodes, lr, self.device, self.only_gpu)
# e.g. for sequence [1, 2, 3, ...] with window_size = 5, we have positive node grad_v = adam(grad_v, self.state_sum_v, nodes, lr, self.device, self.only_gpu)
# pairs [(1,2), (1, 3), (1,4), ...]. To average the gradients for each node, we if neg_nodes is not None:
# perform weighting on the gradients of node pairs. grad_v_neg = adam(grad_v_neg, self.state_sum_v, neg_nodes, lr, self.device, self.only_gpu)
# The weights are: [1/5, 1/5, ..., 1/6, ..., 1/10, ..., 1/6, ..., 1/5].
if bs < self.batch_size:
grad_avg = init_grad_avg(
self.walk_length,
self.window_size,
bs).to(self.device)
else:
grad_avg = self.grad_avg
grad_u = grad_avg * grad_u * lr
grad_v = grad_avg * grad_v * lr
elif self.sgd:
grad_u = grad_u * lr
grad_v = grad_v * lr
elif self.adam:
# use adam optimizer
grad_u = adam(grad_u, self.state_sum_u, nodes, lr, self.device, self.only_gpu)
grad_v = adam(grad_v, self.state_sum_v, nodes, lr, self.device, self.only_gpu)
if self.mixed_train: if self.mixed_train:
grad_u = grad_u.cpu() grad_u = grad_u.cpu()
grad_v = grad_v.cpu() grad_v = grad_v.cpu()
if neg_nodes is not None: if neg_nodes is not None:
grad_v_neg = grad_v_neg.cpu() grad_v_neg = grad_v_neg.cpu()
else:
grad_v_neg = None
if self.async_update:
grad_u.share_memory_()
grad_v.share_memory_()
nodes.share_memory_()
if neg_nodes is not None:
neg_nodes.share_memory_()
grad_v_neg.share_memory_()
self.async_q.put((grad_u, grad_v, grad_v_neg, nodes, neg_nodes))
self.u_embeddings.weight.data.index_add_(0, nodes.view(-1), grad_u) if not self.async_update:
self.v_embeddings.weight.data.index_add_(0, nodes.view(-1), grad_v) self.u_embeddings.weight.data.index_add_(0, nodes.view(-1), grad_u)
if neg_nodes is not None: self.v_embeddings.weight.data.index_add_(0, nodes.view(-1), grad_v)
self.v_embeddings.weight.data.index_add_(0, neg_nodes.view(-1), lr * grad_v_neg) if neg_nodes is not None:
self.v_embeddings.weight.data.index_add_(0, neg_nodes.view(-1), grad_v_neg)
return return
def forward(self, pos_u, pos_v, neg_v): def forward(self, pos_u, pos_v, neg_v):
...@@ -512,8 +506,7 @@ class SkipGramModel(nn.Module): ...@@ -512,8 +506,7 @@ class SkipGramModel(nn.Module):
self.save_embedding_pt_dgl_graph(dataset, file_name) self.save_embedding_pt_dgl_graph(dataset, file_name)
def save_embedding_pt_dgl_graph(self, dataset, file_name): def save_embedding_pt_dgl_graph(self, dataset, file_name):
""" For ogb leaderboard. """ For ogb leaderboard """
"""
embedding = torch.zeros_like(self.u_embeddings.weight.cpu().data) embedding = torch.zeros_like(self.u_embeddings.weight.cpu().data)
valid_seeds = torch.LongTensor(dataset.valid_seeds) valid_seeds = torch.LongTensor(dataset.valid_seeds)
valid_embedding = self.u_embeddings.weight.cpu().data.index_select(0, valid_embedding = self.u_embeddings.weight.cpu().data.index_select(0,
...@@ -540,4 +533,4 @@ class SkipGramModel(nn.Module): ...@@ -540,4 +533,4 @@ class SkipGramModel(nn.Module):
f.write('%d %d\n' % (self.emb_size, self.emb_dimension)) f.write('%d %d\n' % (self.emb_size, self.emb_dimension))
for wid in range(self.emb_size): for wid in range(self.emb_size):
e = ' '.join(map(lambda x: str(x), embedding[wid])) e = ' '.join(map(lambda x: str(x), embedding[wid]))
f.write('%s %s\n' % (str(dataset.id2node[wid]), e)) f.write('%s %s\n' % (str(dataset.id2node[wid]), e))
\ No newline at end of file
...@@ -161,7 +161,7 @@ class DeepwalkDataset: ...@@ -161,7 +161,7 @@ class DeepwalkDataset:
self.fast_neg = fast_neg self.fast_neg = fast_neg
if load_from_ogbl: if load_from_ogbl:
assert len(gpus) == 1, "ogb.linkproppred is not compatible with multi-gpu training." assert len(gpus) == 1, "ogb.linkproppred is not compatible with multi-gpu training (CUDA error)."
from load_dataset import load_from_ogbl_with_name from load_dataset import load_from_ogbl_with_name
self.G = load_from_ogbl_with_name(ogbl_name) self.G = load_from_ogbl_with_name(ogbl_name)
self.G = make_undirected(self.G) self.G = make_undirected(self.G)
...@@ -200,26 +200,25 @@ class DeepwalkDataset: ...@@ -200,26 +200,25 @@ class DeepwalkDataset:
self.neg_table = np.array(self.neg_table, dtype=np.long) self.neg_table = np.array(self.neg_table, dtype=np.long)
del node_degree del node_degree
def create_sampler(self, gpu_id): def create_sampler(self, i):
""" Still in construction... """ create random walk sampler """
return DeepwalkSampler(self.G, self.seeds[i], self.walk_length)
Several mode:
1. do true negative sampling.
1.1 from random walk sequence
1.2 from node degree distribution
return the sampled node ids
2. do false negative sampling from random walk sequence
save GPU, faster
return the node indices in the sequences
"""
return DeepwalkSampler(self.G, self.seeds[gpu_id], self.walk_length)
def save_mapping(self, map_file): def save_mapping(self, map_file):
""" save the mapping dict that maps node IDs to embedding indices """
with open(map_file, "wb") as f: with open(map_file, "wb") as f:
pickle.dump(self.node2id, f) pickle.dump(self.node2id, f)
class DeepwalkSampler(object): class DeepwalkSampler(object):
def __init__(self, G, seeds, walk_length): def __init__(self, G, seeds, walk_length):
""" random walk sampler
Parameter
---------
G dgl.Graph : the input graph
seeds torch.LongTensor : starting nodes
walk_length int : walk length
"""
self.G = G self.G = G
self.seeds = seeds self.seeds = seeds
self.walk_length = walk_length self.walk_length = walk_length
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment