Unverified Commit 0fb615b9 authored by Hao Xiong's avatar Hao Xiong Committed by GitHub
Browse files

[Example] Several updates on deepwalk (#1845)



* fix bug

* fix bugs

* add args

* add args

* fix bug

* fix a typo

* results

* results

* results
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent 33abd275
......@@ -30,7 +30,7 @@ For other datasets please pass the full path to the trainer through --data\_file
## How to run the code
To run the code:
```
python3 deepwalk.py --data_file youtube --output_emb_file emb.txt --adam --mix --lr 0.2 --gpus 0 1 2 3 --batch_size 100 --negative 5
python3 deepwalk.py --data_file youtube --output_emb_file emb.txt --mix --lr 0.2 --gpus 0 1 2 3 --batch_size 100 --negative 5
```
## How to save the embedding
......
......@@ -10,7 +10,7 @@ import numpy as np
from reading_data import DeepwalkDataset
from model import SkipGramModel
from utils import thread_wrapped_func, shuffle_walks
from utils import thread_wrapped_func, shuffle_walks, sum_up_params
class DeepwalkTrainer:
def __init__(self, args):
......@@ -26,8 +26,10 @@ class DeepwalkTrainer:
negative=args.negative,
gpus=args.gpus,
fast_neg=args.fast_neg,
ogbl_name=args.ogbl_name,
load_from_ogbl=args.load_from_ogbl,
)
self.emb_size = len(self.dataset.net)
self.emb_size = self.dataset.G.number_of_nodes()
self.emb_model = None
def init_device_emb(self):
......@@ -36,8 +38,6 @@ class DeepwalkTrainer:
"""
choices = sum([self.args.only_gpu, self.args.only_cpu, self.args.mix])
assert choices == 1, "Must choose only *one* training mode in [only_cpu, only_gpu, mix]"
choices = sum([self.args.sgd, self.args.adam, self.args.avg_sgd])
assert choices == 1, "Must choose only *one* gradient descent strategy in [sgd, avg_sgd, adam]"
# initializing embedding on CPU
self.emb_model = SkipGramModel(
......@@ -53,10 +53,12 @@ class DeepwalkTrainer:
negative=self.args.negative,
lr=self.args.lr,
lap_norm=self.args.lap_norm,
adam=self.args.adam,
sgd=self.args.sgd,
avg_sgd=self.args.avg_sgd,
fast_neg=self.args.fast_neg,
record_loss=self.args.print_loss,
norm=self.args.norm,
use_context_weight=self.args.use_context_weight,
async_update=self.args.async_update,
num_threads=self.args.num_threads,
)
torch.set_num_threads(self.args.num_threads)
......@@ -67,7 +69,7 @@ class DeepwalkTrainer:
elif self.args.mix:
print("Mix CPU with %d GPU" % len(self.args.gpus))
if len(self.args.gpus) == 1:
assert self.args.gpus[0] >= 0, 'mix CPU with GPU should have abaliable GPU'
assert self.args.gpus[0] >= 0, 'mix CPU with GPU should have available GPU'
self.emb_model.set_device(self.args.gpus[0])
else:
print("Run in CPU process")
......@@ -86,11 +88,14 @@ class DeepwalkTrainer:
self.init_device_emb()
self.emb_model.share_memory()
if self.args.count_params:
sum_up_params(self.emb_model)
start_all = time.time()
ps = []
for i in range(len(self.args.gpus)):
p = mp.Process(target=self.fast_train_sp, args=(self.args.gpus[i],))
p = mp.Process(target=self.fast_train_sp, args=(i, self.args.gpus[i]))
ps.append(p)
p.start()
......@@ -100,17 +105,22 @@ class DeepwalkTrainer:
print("Used time: %.2fs" % (time.time()-start_all))
if self.args.save_in_txt:
self.emb_model.save_embedding_txt(self.dataset, self.args.output_emb_file)
elif self.args.save_in_pt:
self.emb_model.save_embedding_pt(self.dataset, self.args.output_emb_file)
else:
self.emb_model.save_embedding(self.dataset, self.args.output_emb_file)
@thread_wrapped_func
def fast_train_sp(self, gpu_id):
def fast_train_sp(self, rank, gpu_id):
""" a subprocess for fast_train_mp """
if self.args.mix:
self.emb_model.set_device(gpu_id)
torch.set_num_threads(self.args.num_threads)
if self.args.async_update:
self.emb_model.create_async_update()
sampler = self.dataset.create_sampler(gpu_id)
sampler = self.dataset.create_sampler(rank)
dataloader = DataLoader(
dataset=sampler.seeds,
......@@ -118,26 +128,19 @@ class DeepwalkTrainer:
collate_fn=sampler.sample,
shuffle=False,
drop_last=False,
num_workers=4,
num_workers=self.args.num_sampler_threads,
)
num_batches = len(dataloader)
print("num batchs: %d in subprocess [%d]" % (num_batches, gpu_id))
print("num batchs: %d in process [%d] GPU [%d]" % (num_batches, rank, gpu_id))
# number of positive node pairs in a sequence
num_pos = int(2 * self.args.walk_length * self.args.window_size\
- self.args.window_size * (self.args.window_size + 1))
start = time.time()
with torch.no_grad():
max_i = self.args.iterations * num_batches
for i, walks in enumerate(dataloader):
# decay learning rate for SGD
lr = self.args.lr * (max_i - i) / max_i
if lr < 0.00001:
lr = 0.00001
if self.args.fast_neg:
self.emb_model.fast_learn(walks, lr)
self.emb_model.fast_learn(walks)
else:
# do negative sampling
bs = len(walks)
......@@ -145,14 +148,22 @@ class DeepwalkTrainer:
np.random.choice(self.dataset.neg_table,
bs * num_pos * self.args.negative,
replace=True))
self.emb_model.fast_learn(walks, lr, neg_nodes=neg_nodes)
self.emb_model.fast_learn(walks, neg_nodes=neg_nodes)
if i > 0 and i % self.args.print_interval == 0:
print("Solver [%d] batch %d tt: %.2fs" % (gpu_id, i, time.time()-start))
if self.args.print_loss:
print("GPU-[%d] batch %d time: %.2fs loss: %.4f" \
% (gpu_id, i, time.time()-start, -sum(self.emb_model.loss)/self.args.print_interval))
self.emb_model.loss = []
else:
print("GPU-[%d] batch %d time: %.2fs" % (gpu_id, i, time.time()-start))
start = time.time()
if self.args.async_update:
self.emb_model.finish_async_update()
def fast_train(self):
""" fast train with dataloader """
""" fast train with dataloader with only gpu / only cpu"""
# the number of postive node pairs of a node sequence
num_pos = 2 * self.args.walk_length * self.args.window_size\
- self.args.window_size * (self.args.window_size + 1)
......@@ -160,6 +171,13 @@ class DeepwalkTrainer:
self.init_device_emb()
if self.args.async_update:
self.emb_model.share_memory()
self.emb_model.create_async_update()
if self.args.count_params:
sum_up_params(self.emb_model)
sampler = self.dataset.create_sampler(0)
dataloader = DataLoader(
......@@ -168,27 +186,19 @@ class DeepwalkTrainer:
collate_fn=sampler.sample,
shuffle=False,
drop_last=False,
num_workers=4,
num_workers=self.args.num_sampler_threads,
)
num_batches = len(dataloader)
print("num batchs: %d" % num_batches)
print("num batchs: %d\n" % num_batches)
start_all = time.time()
start = time.time()
with torch.no_grad():
max_i = self.args.iterations * num_batches
for iteration in range(self.args.iterations):
print("\nIteration: " + str(iteration + 1))
max_i = num_batches
for i, walks in enumerate(dataloader):
# decay learning rate for SGD
lr = self.args.lr * (max_i - i) / max_i
if lr < 0.00001:
lr = 0.00001
if self.args.fast_neg:
self.emb_model.fast_learn(walks, lr)
self.emb_model.fast_learn(walks)
else:
# do negative sampling
bs = len(walks)
......@@ -196,70 +206,107 @@ class DeepwalkTrainer:
np.random.choice(self.dataset.neg_table,
bs * num_pos * self.args.negative,
replace=True))
self.emb_model.fast_learn(walks, lr, neg_nodes=neg_nodes)
self.emb_model.fast_learn(walks, neg_nodes=neg_nodes)
if i > 0 and i % self.args.print_interval == 0:
if self.args.print_loss:
print("Batch %d training time: %.2fs loss: %.4f" \
% (i, time.time()-start, -sum(self.emb_model.loss)/self.args.print_interval))
self.emb_model.loss = []
else:
print("Batch %d, training time: %.2fs" % (i, time.time()-start))
start = time.time()
if self.args.async_update:
self.emb_model.finish_async_update()
print("Training used time: %.2fs" % (time.time()-start_all))
if self.args.save_in_txt:
self.emb_model.save_embedding_txt(self.dataset, self.args.output_emb_file)
elif self.args.save_in_pt:
self.emb_model.save_embedding_pt(self.dataset, self.args.output_emb_file)
else:
self.emb_model.save_embedding(self.dataset, self.args.output_emb_file)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="DeepWalk")
# input files
## personal datasets
parser.add_argument('--data_file', type=str,
help="path of the txt network file, builtin dataset include youtube-net and blog-net")
## ogbl datasets
parser.add_argument('--ogbl_name', type=str,
help="name of ogbl dataset, e.g. ogbl-ddi")
parser.add_argument('--load_from_ogbl', default=False, action="store_true",
help="whether load dataset from ogbl")
# output files
parser.add_argument('--save_in_txt', default=False, action="store_true",
help='Whether save dat in txt format or npy')
parser.add_argument('--save_in_pt', default=False, action="store_true",
help='Whether save dat in pt format or npy')
parser.add_argument('--output_emb_file', type=str, default="emb.npy",
help='path of the output npy embedding file')
parser.add_argument('--map_file', type=str, default="nodeid_to_index.pickle",
help='path of the mapping dict that maps node ids to embedding index')
parser.add_argument('--norm', default=False, action="store_true",
help="whether to do normalization over node embedding after training")
# model parameters
parser.add_argument('--dim', default=128, type=int,
help="embedding dimensions")
parser.add_argument('--window_size', default=5, type=int,
help="context window size")
parser.add_argument('--use_context_weight', default=False, action="store_true",
help="whether to add weights over nodes in the context window")
parser.add_argument('--num_walks', default=10, type=int,
help="number of walks for each node")
parser.add_argument('--negative', default=5, type=int,
parser.add_argument('--negative', default=1, type=int,
help="negative samples for each positve node pair")
parser.add_argument('--iterations', default=1, type=int,
help="iterations")
parser.add_argument('--batch_size', default=10, type=int,
parser.add_argument('--batch_size', default=128, type=int,
help="number of node sequences in each batch")
parser.add_argument('--print_interval', default=1000, type=int,
help="number of batches between printing")
parser.add_argument('--walk_length', default=80, type=int,
help="number of nodes in a sequence")
parser.add_argument('--lr', default=0.2, type=float,
help="learning rate")
parser.add_argument('--neg_weight', default=1., type=float,
help="negative weight")
parser.add_argument('--lap_norm', default=0.01, type=float,
help="weight of laplacian normalization")
help="weight of laplacian normalization, recommend to set as 0.1 / windoe_size")
# training parameters
parser.add_argument('--print_interval', default=100, type=int,
help="number of batches between printing")
parser.add_argument('--print_loss', default=False, action="store_true",
help="whether print loss during training")
parser.add_argument('--lr', default=0.2, type=float,
help="learning rate")
# optimization settings
parser.add_argument('--mix', default=False, action="store_true",
help="mixed training with CPU and GPU")
parser.add_argument('--gpus', type=int, default=[-1], nargs='+',
help='a list of active gpu ids, e.g. 0, used with --mix')
parser.add_argument('--only_cpu', default=False, action="store_true",
help="training with CPU")
parser.add_argument('--only_gpu', default=False, action="store_true",
help="training with GPU")
parser.add_argument('--fast_neg', default=True, action="store_true",
parser.add_argument('--async_update', default=False, action="store_true",
help="mixed training asynchronously, not recommended")
parser.add_argument('--fast_neg', default=False, action="store_true",
help="do negative sampling inside a batch")
parser.add_argument('--adam', default=False, action="store_true",
help="use adam for embedding updation, recommended")
parser.add_argument('--sgd', default=False, action="store_true",
help="use sgd for embedding updation")
parser.add_argument('--avg_sgd', default=False, action="store_true",
help="average gradients of sgd for embedding updation")
parser.add_argument('--num_threads', default=2, type=int,
parser.add_argument('--num_threads', default=8, type=int,
help="number of threads used for each CPU-core/GPU")
parser.add_argument('--gpus', type=int, default=[-1], nargs='+',
help='a list of active gpu ids, e.g. 0')
parser.add_argument('--num_sampler_threads', default=2, type=int,
help="number of threads used for sampling")
parser.add_argument('--count_params', default=False, action="store_true",
help="count the params, exit once counting over")
args = parser.parse_args()
if args.async_update:
assert args.mix, "--async_update only with --mix"
start_time = time.time()
trainer = DeepwalkTrainer(args)
trainer.train()
......
""" load dataset from ogb """
import argparse
from ogb.linkproppred import DglLinkPropPredDataset
def load_from_ogbl_with_name(name):
choices = ['ogbl-collab', 'ogbl-ddi', 'ogbl-ppa', 'ogbl-citation']
assert name in choices, "name must be selected from " + str(choices)
dataset = DglLinkPropPredDataset(name)
return dataset[0]
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--name', type=str,
choices=['ogbl-collab', 'ogbl-ddi', 'ogbl-ppa', 'ogbl-citation'],
default='ogbl-collab',
help="name of datasets by ogb")
args = parser.parse_args()
name = args.name
g = load_from_ogbl_with_name(name=name)
try:
w = g.edata['edge_weight']
weighted = True
except:
weighted = False
with open(name + "-net.txt", "w") as f:
for i in range(g.edges()[0].shape[0]):
if weighted:
f.write(str(g.edges()[0][i].item()) + " "\
+str(g.edges()[1][i].item()) + " "\
+str(g.edata['edge_weight'][i]) + "\n")
else:
f.write(str(g.edges()[0][i].item()) + " "\
+str(g.edges()[1][i].item()) + " "\
+"1\n")
\ No newline at end of file
......@@ -4,6 +4,10 @@ import torch.nn.functional as F
from torch.nn import init
import random
import numpy as np
import torch.multiprocessing as mp
from torch.multiprocessing import Queue
from utils import thread_wrapped_func
def init_emb2pos_index(walk_length, window_size, batch_size):
''' select embedding of positive nodes from a batch of node embeddings
......@@ -72,25 +76,20 @@ def init_emb2neg_index(walk_length, window_size, negative, batch_size):
return index_emb_negu, index_emb_negv
def init_grad_avg(walk_length, window_size, batch_size):
'''select nodes' gradients from gradient matrix
Usage
-----
'''
grad_avg = []
def init_weight(walk_length, window_size, batch_size):
''' init context weight '''
weight = []
for b in range(batch_size):
for i in range(walk_length):
if i < window_size:
grad_avg.append(1. / float(i+window_size))
elif i >= walk_length - window_size:
grad_avg.append(1. / float(walk_length - i - 1 + window_size))
else:
grad_avg.append(0.5 / window_size)
for j in range(i-window_size, i):
if j >= 0:
weight.append(1. - float(i - j - 1)/float(window_size))
for j in range(i + 1, i + 1 + window_size):
if j < walk_length:
weight.append(1. - float(j - i - 1)/float(window_size))
# [num_pos * batch_size]
return torch.Tensor(grad_avg).unsqueeze(1)
return torch.Tensor(weight).unsqueeze(1)
def init_empty_grad(emb_dimension, walk_length, batch_size):
""" initialize gradient matrix """
......@@ -111,6 +110,20 @@ def adam(grad, state_sum, nodes, lr, device, only_gpu):
return grad
@thread_wrapped_func
def async_update(num_threads, model, queue):
""" asynchronous embedding update """
torch.set_num_threads(num_threads)
while True:
(grad_u, grad_v, grad_v_neg, nodes, neg_nodes) = queue.get()
if grad_u is None:
return
with torch.no_grad():
model.u_embeddings.weight.data.index_add_(0, nodes.view(-1), grad_u)
model.v_embeddings.weight.data.index_add_(0, nodes.view(-1), grad_v)
if neg_nodes is not None:
model.v_embeddings.weight.data.index_add_(0, neg_nodes.view(-1), grad_v_neg)
class SkipGramModel(nn.Module):
""" Negative sampling based skip-gram """
def __init__(self,
......@@ -126,10 +139,12 @@ class SkipGramModel(nn.Module):
negative,
lr,
lap_norm,
adam,
sgd,
avg_sgd,
fast_neg,
record_loss,
norm,
use_context_weight,
async_update,
num_threads,
):
""" initialize embedding on CPU
......@@ -147,10 +162,11 @@ class SkipGramModel(nn.Module):
neg_weight float : negative weight
lr float : initial learning rate
lap_norm float : weight of laplacian normalization
adam bool : use adam for embedding updation
sgd bool : use sgd for embedding updation
avg_sgd bool : average gradients of sgd for embedding updation
fast_neg bool : do negative sampling inside a batch
record_loss bool : print the loss during training
norm bool : do normalizatin on the embedding after training
use_context_weight : give different weights to the nodes in a context window
async_update : asynchronous training
"""
super(SkipGramModel, self).__init__()
self.emb_size = emb_size
......@@ -165,10 +181,12 @@ class SkipGramModel(nn.Module):
self.negative = negative
self.lr = lr
self.lap_norm = lap_norm
self.adam = adam
self.sgd = sgd
self.avg_sgd = avg_sgd
self.fast_neg = fast_neg
self.record_loss = record_loss
self.norm = norm
self.use_context_weight = use_context_weight
self.async_update = async_update
self.num_threads = num_threads
# initialize the device as cpu
self.device = torch.device("cpu")
......@@ -188,27 +206,28 @@ class SkipGramModel(nn.Module):
self.lookup_table = torch.sigmoid(torch.arange(-6.01, 6.01, 0.01))
self.lookup_table[0] = 0.
self.lookup_table[-1] = 1.
if self.record_loss:
self.logsigmoid_table = torch.log(torch.sigmoid(torch.arange(-6.01, 6.01, 0.01)))
self.loss = []
# indexes to select positive/negative node pairs from batch_walks
self.index_emb_posu, self.index_emb_posv = init_emb2pos_index(
self.walk_length,
self.window_size,
self.batch_size)
if self.fast_neg:
self.index_emb_negu, self.index_emb_negv = init_emb2neg_index(
self.walk_length,
self.window_size,
self.negative,
self.batch_size)
# coefficients for averaging the gradients
if self.avg_sgd:
self.grad_avg = init_grad_avg(
if self.use_context_weight:
self.context_weight = init_weight(
self.walk_length,
self.window_size,
self.batch_size)
# adam
if self.adam:
self.state_sum_u = torch.zeros(self.emb_size)
self.state_sum_v = torch.zeros(self.emb_size)
......@@ -218,11 +237,23 @@ class SkipGramModel(nn.Module):
self.walk_length,
self.batch_size)
def create_async_update(self):
""" Set up the async update subprocess.
"""
self.async_q = Queue(1)
self.async_p = mp.Process(target=async_update, args=(self.num_threads, self, self.async_q))
self.async_p.start()
def finish_async_update(self):
""" Notify the async update subprocess to quit.
"""
self.async_q.put((None, None, None, None, None))
self.async_p.join()
def share_memory(self):
""" share the parameters across subprocesses """
self.u_embeddings.weight.share_memory_()
self.v_embeddings.weight.share_memory_()
if self.adam:
self.state_sum_u.share_memory_()
self.state_sum_v.share_memory_()
......@@ -231,15 +262,16 @@ class SkipGramModel(nn.Module):
self.device = torch.device("cuda:%d" % gpu_id)
print("The device is", self.device)
self.lookup_table = self.lookup_table.to(self.device)
if self.record_loss:
self.logsigmoid_table = self.logsigmoid_table.to(self.device)
self.index_emb_posu = self.index_emb_posu.to(self.device)
self.index_emb_posv = self.index_emb_posv.to(self.device)
if self.fast_neg:
self.index_emb_negu = self.index_emb_negu.to(self.device)
self.index_emb_negv = self.index_emb_negv.to(self.device)
self.grad_u = self.grad_u.to(self.device)
self.grad_v = self.grad_v.to(self.device)
if self.avg_sgd:
self.grad_avg = self.grad_avg.to(self.device)
if self.use_context_weight:
self.context_weight = self.context_weight.to(self.device)
def all_to_device(self, gpu_id):
""" move all of the parameters to a single GPU """
......@@ -247,7 +279,6 @@ class SkipGramModel(nn.Module):
self.set_device(gpu_id)
self.u_embeddings = self.u_embeddings.cuda(gpu_id)
self.v_embeddings = self.v_embeddings.cuda(gpu_id)
if self.adam:
self.state_sum_u = self.state_sum_u.to(self.device)
self.state_sum_v = self.state_sum_v.to(self.device)
......@@ -256,7 +287,12 @@ class SkipGramModel(nn.Module):
idx = torch.floor((score + 6.01) / 0.01).long()
return self.lookup_table[idx]
def fast_learn(self, batch_walks, lr, neg_nodes=None):
def fast_logsigmoid(self, score):
""" do fast logsigmoid by looking up in a pre-defined table """
idx = torch.floor((score + 6.01) / 0.01).long()
return self.logsigmoid_table[idx]
def fast_learn(self, batch_walks, neg_nodes=None):
""" Learn a batch of random walks in a fast way. It has the following features:
1. It calculating the gradients directly without the forward operation.
2. It does sigmoid by a looking up table.
......@@ -281,7 +317,6 @@ class SkipGramModel(nn.Module):
lr = 0.01
neg_nodes = None
"""
if self.adam:
lr = self.lr
# [batch_size, walk_length]
......@@ -318,6 +353,8 @@ class SkipGramModel(nn.Module):
pos_score = torch.clamp(pos_score, max=6, min=-6)
# [batch_size * num_pos, 1]
score = (1 - self.fast_sigmoid(pos_score)).unsqueeze(1)
if self.record_loss:
self.loss.append(torch.mean(self.fast_logsigmoid(pos_score)).item())
# [batch_size * num_pos, dim]
if self.lap_norm > 0:
......@@ -326,6 +363,18 @@ class SkipGramModel(nn.Module):
else:
grad_u_pos = score * emb_pos_v
grad_v_pos = score * emb_pos_u
if self.use_context_weight:
if bs < self.batch_size:
context_weight = init_weight(
self.walk_length,
self.window_size,
bs).to(self.device)
else:
context_weight = self.context_weight
grad_u_pos *= context_weight
grad_v_pos *= context_weight
# [batch_size * walk_length, dim]
if bs < self.batch_size:
grad_u, grad_v = init_empty_grad(
......@@ -365,6 +414,8 @@ class SkipGramModel(nn.Module):
neg_score = torch.clamp(neg_score, max=6, min=-6)
# [batch_size * walk_length * negative, 1]
score = - self.fast_sigmoid(neg_score).unsqueeze(1)
if self.record_loss:
self.loss.append(self.negative * self.neg_weight * torch.mean(self.fast_logsigmoid(-neg_score)).item())
grad_u_neg = self.neg_weight * score * emb_neg_v
grad_v_neg = self.neg_weight * score * emb_neg_u
......@@ -375,40 +426,35 @@ class SkipGramModel(nn.Module):
## Update
nodes = nodes.view(-1)
if self.avg_sgd:
# since the times that a node are performed backward propagation are different,
# we need to average the gradients by different weight.
# e.g. for sequence [1, 2, 3, ...] with window_size = 5, we have positive node
# pairs [(1,2), (1, 3), (1,4), ...]. To average the gradients for each node, we
# perform weighting on the gradients of node pairs.
# The weights are: [1/5, 1/5, ..., 1/6, ..., 1/10, ..., 1/6, ..., 1/5].
if bs < self.batch_size:
grad_avg = init_grad_avg(
self.walk_length,
self.window_size,
bs).to(self.device)
else:
grad_avg = self.grad_avg
grad_u = grad_avg * grad_u * lr
grad_v = grad_avg * grad_v * lr
elif self.sgd:
grad_u = grad_u * lr
grad_v = grad_v * lr
elif self.adam:
# use adam optimizer
grad_u = adam(grad_u, self.state_sum_u, nodes, lr, self.device, self.only_gpu)
grad_v = adam(grad_v, self.state_sum_v, nodes, lr, self.device, self.only_gpu)
if neg_nodes is not None:
grad_v_neg = adam(grad_v_neg, self.state_sum_v, neg_nodes, lr, self.device, self.only_gpu)
if self.mixed_train:
grad_u = grad_u.cpu()
grad_v = grad_v.cpu()
if neg_nodes is not None:
grad_v_neg = grad_v_neg.cpu()
else:
grad_v_neg = None
if self.async_update:
grad_u.share_memory_()
grad_v.share_memory_()
nodes.share_memory_()
if neg_nodes is not None:
neg_nodes.share_memory_()
grad_v_neg.share_memory_()
self.async_q.put((grad_u, grad_v, grad_v_neg, nodes, neg_nodes))
if not self.async_update:
self.u_embeddings.weight.data.index_add_(0, nodes.view(-1), grad_u)
self.v_embeddings.weight.data.index_add_(0, nodes.view(-1), grad_v)
if neg_nodes is not None:
self.v_embeddings.weight.data.index_add_(0, neg_nodes.view(-1), lr * grad_v_neg)
self.v_embeddings.weight.data.index_add_(0, neg_nodes.view(-1), grad_v_neg)
return
def forward(self, pos_u, pos_v, neg_v):
......@@ -429,7 +475,7 @@ class SkipGramModel(nn.Module):
return torch.sum(score), torch.sum(neg_score)
def save_embedding(self, dataset, file_name):
""" Write embedding to local file.
""" Write embedding to local file. Only used when node ids are numbers.
Parameter
---------
......@@ -437,8 +483,41 @@ class SkipGramModel(nn.Module):
file_name str : the file name
"""
embedding = self.u_embeddings.weight.cpu().data.numpy()
if self.norm:
embedding /= np.sqrt(np.sum(embedding * embedding, 1)).reshape(-1, 1)
np.save(file_name, embedding)
def save_embedding_pt(self, dataset, file_name):
""" For ogb leaderboard.
"""
try:
max_node_id = max(dataset.node2id.keys())
if max_node_id + 1 != self.emb_size:
print("WARNING: The node ids are not serial.")
embedding = torch.zeros(max_node_id + 1, self.emb_dimension)
index = torch.LongTensor(list(map(lambda id: dataset.id2node[id], list(range(self.emb_size)))))
embedding.index_add_(0, index, self.u_embeddings.weight.cpu().data)
if self.norm:
embedding /= torch.sqrt(torch.sum(embedding.mul(embedding), 1) + 1e-6).unsqueeze(1)
torch.save(embedding, file_name)
except:
self.save_embedding_pt_dgl_graph(dataset, file_name)
def save_embedding_pt_dgl_graph(self, dataset, file_name):
""" For ogb leaderboard """
embedding = torch.zeros_like(self.u_embeddings.weight.cpu().data)
valid_seeds = torch.LongTensor(dataset.valid_seeds)
valid_embedding = self.u_embeddings.weight.cpu().data.index_select(0,
valid_seeds)
embedding.index_add_(0, valid_seeds, self.u_embeddings.weight.cpu().data)
if self.norm:
embedding /= torch.sqrt(torch.sum(embedding.mul(embedding), 1) + 1e-6).unsqueeze(1)
torch.save(embedding, file_name)
def save_embedding_txt(self, dataset, file_name):
""" Write embedding to local file. For future use.
......@@ -448,6 +527,8 @@ class SkipGramModel(nn.Module):
file_name str : the file name
"""
embedding = self.u_embeddings.weight.cpu().data.numpy()
if self.norm:
embedding /= np.sqrt(np.sum(embedding * embedding, 1)).reshape(-1, 1)
with open(file_name, 'w') as f:
f.write('%d %d\n' % (self.emb_size, self.emb_dimension))
for wid in range(self.emb_size):
......
......@@ -8,8 +8,8 @@ from dgl.data.utils import download, _get_dgl_url, get_download_dir, extract_arc
import random
import time
import dgl
from utils import shuffle_walks
#np.random.seed(3141592653)
def ReadTxtNet(file_path="", undirected=True):
""" Read the txt network file.
......@@ -41,10 +41,17 @@ def ReadTxtNet(file_path="", undirected=True):
src = []
dst = []
weight = []
net = {}
with open(file_path, "r") as f:
for line in f.readlines():
n1, n2 = list(map(int, line.strip().split(" ")[:2]))
tup = list(map(int, line.strip().split(" ")))
assert len(tup) in [2, 3], "The format of network file is unrecognizable."
if len(tup) == 3:
n1, n2, w = tup
elif len(tup) == 2:
n1, n2 = tup
w = 1
if n1 not in node2id:
node2id[n1] = cid
id2node[cid] = n1
......@@ -57,30 +64,34 @@ def ReadTxtNet(file_path="", undirected=True):
n1 = node2id[n1]
n2 = node2id[n2]
if n1 not in net:
net[n1] = {n2: 1}
net[n1] = {n2: w}
src.append(n1)
dst.append(n2)
weight.append(w)
elif n2 not in net[n1]:
net[n1][n2] = 1
net[n1][n2] = w
src.append(n1)
dst.append(n2)
weight.append(w)
if undirected:
if n2 not in net:
net[n2] = {n1: 1}
net[n2] = {n1: w}
src.append(n2)
dst.append(n1)
weight.append(w)
elif n1 not in net[n2]:
net[n2][n1] = 1
net[n2][n1] = w
src.append(n2)
dst.append(n1)
weight.append(w)
print("node num: %d" % len(net))
print("edge num: %d" % len(src))
assert max(net.keys()) == len(net) - 1, "error reading net, quit"
sm = sp.coo_matrix(
(np.ones(len(src)), (src, dst)),
(np.array(weight), (src, dst)),
dtype=np.float32)
return net, node2id, id2node, sm
......@@ -99,17 +110,31 @@ def net2graph(net_sm):
print("Building DGLGraph in %.2fs" % t)
return G
def make_undirected(G):
G.readonly(False)
G.add_edges(G.edges()[1], G.edges()[0])
return G
def find_connected_nodes(G):
nodes = []
for n in G.nodes():
if G.out_degree(n) > 0:
nodes.append(n.item())
return nodes
class DeepwalkDataset:
def __init__(self,
net_file,
map_file,
walk_length=80,
window_size=5,
num_walks=10,
batch_size=32,
walk_length,
window_size,
num_walks,
batch_size,
negative=5,
gpus=[0],
fast_neg=True,
ogbl_name="",
load_from_ogbl=False,
):
""" This class has the following functions:
1. Transform the txt network file into DGL graph;
......@@ -134,51 +159,66 @@ class DeepwalkDataset:
self.negative = negative
self.num_procs = len(gpus)
self.fast_neg = fast_neg
if load_from_ogbl:
assert len(gpus) == 1, "ogb.linkproppred is not compatible with multi-gpu training (CUDA error)."
from load_dataset import load_from_ogbl_with_name
self.G = load_from_ogbl_with_name(ogbl_name)
self.G = make_undirected(self.G)
else:
self.net, self.node2id, self.id2node, self.sm = ReadTxtNet(net_file)
self.save_mapping(map_file)
self.G = net2graph(self.sm)
self.num_nodes = self.G.number_of_nodes()
# random walk seeds
start = time.time()
seeds = torch.cat([torch.LongTensor(self.G.nodes())] * num_walks)
self.seeds = torch.split(shuffle_walks(seeds), int(np.ceil(len(self.net) * self.num_walks / self.num_procs)), 0)
self.valid_seeds = find_connected_nodes(self.G)
if len(self.valid_seeds) != self.num_nodes:
print("WARNING: The node ids are not serial. Some nodes are invalid.")
seeds = torch.cat([torch.LongTensor(self.valid_seeds)] * num_walks)
self.seeds = torch.split(shuffle_walks(seeds),
int(np.ceil(len(self.valid_seeds) * self.num_walks / self.num_procs)),
0)
end = time.time()
t = end - start
print("%d seeds in %.2fs" % (len(seeds), t))
# negative table for true negative sampling
if not fast_neg:
node_degree = np.array(list(map(lambda x: len(self.net[x]), self.net.keys())))
node_degree = np.array(list(map(lambda x: self.G.out_degree(x), self.valid_seeds)))
node_degree = np.power(node_degree, 0.75)
node_degree /= np.sum(node_degree)
node_degree = np.array(node_degree * 1e8, dtype=np.int)
self.neg_table = []
for idx, node in enumerate(self.net.keys()):
for idx, node in enumerate(self.valid_seeds):
self.neg_table += [node] * node_degree[idx]
self.neg_table_size = len(self.neg_table)
self.neg_table = np.array(self.neg_table, dtype=np.long)
del node_degree
def create_sampler(self, gpu_id):
""" Still in construction...
Several mode:
1. do true negative sampling.
1.1 from random walk sequence
1.2 from node degree distribution
return the sampled node ids
2. do false negative sampling from random walk sequence
save GPU, faster
return the node indices in the sequences
"""
return DeepwalkSampler(self.G, self.seeds[gpu_id], self.walk_length)
def create_sampler(self, i):
""" create random walk sampler """
return DeepwalkSampler(self.G, self.seeds[i], self.walk_length)
def save_mapping(self, map_file):
""" save the mapping dict that maps node IDs to embedding indices """
with open(map_file, "wb") as f:
pickle.dump(self.node2id, f)
class DeepwalkSampler(object):
def __init__(self, G, seeds, walk_length):
""" random walk sampler
Parameter
---------
G dgl.Graph : the input graph
seeds torch.LongTensor : starting nodes
walk_length int : walk length
"""
self.G = G
self.seeds = seeds
self.walk_length = walk_length
......
......@@ -35,3 +35,31 @@ def thread_wrapped_func(func):
def shuffle_walks(walks):
seeds = torch.randperm(walks.size()[0])
return walks[seeds]
def sum_up_params(model):
""" Count the model parameters """
n = []
n.append(model.u_embeddings.weight.cpu().data.numel() * 2)
n.append(model.lookup_table.cpu().numel())
n.append(model.index_emb_posu.cpu().numel() * 2)
n.append(model.grad_u.cpu().numel() * 2)
try:
n.append(model.index_emb_negu.cpu().numel() * 2)
except:
pass
try:
n.append(model.state_sum_u.cpu().numel() * 2)
except:
pass
try:
n.append(model.grad_avg.cpu().numel())
except:
pass
try:
n.append(model.context_weight.cpu().numel())
except:
pass
print("#params " + str(sum(n)))
exit()
\ No newline at end of file
......@@ -15,7 +15,7 @@ For evaluatation we follow the code mlp.py provided by ogb [here](https://github
## Used config
ogbl-collab
```
python3 deepwalk.py --ogbl_name ogbl-collab --load_from_ogbl --save_in_pt --output_emb_file embedding.pt --num_walks 50 --window_size 20 --walk_length 40 --lr 0.1 --negative 1 --neg_weight 1 --lap_norm 0.005 --mix --adam --gpus 0 --num_threads 4 --print_interval 2000 --print_loss --batch_size 32
python3 deepwalk.py --ogbl_name ogbl-collab --load_from_ogbl --save_in_pt --output_emb_file embedding.pt --num_walks 50 --window_size 20 --walk_length 40 --lr 0.1 --negative 1 --neg_weight 1 --lap_norm 0.005 --mix --gpus 0 --num_threads 4 --print_interval 2000 --print_loss --batch_size 32
cd ./ogb/blob/master/examples/linkproppred/collab/
cp embedding_pt_file_path ./
python3 mlp.py --device 0 --runs 10 --use_node_embedding
......@@ -23,7 +23,7 @@ python3 mlp.py --device 0 --runs 10 --use_node_embedding
ogbl-ddi
```
python3 deepwalk.py --ogbl_name ogbl-ddi --load_from_ogbl --save_in_pt --output_emb_file ddi-embedding.pt --num_walks 50 --window_size 2 --walk_length 80 --lr 0.1 --negative 1 --neg_weight 1 --lap_norm 0.05 --only_gpu --adam --gpus 0 --num_threads 4 --print_interval 2000 --print_loss --batch_size 16 --use_context_weight
python3 deepwalk.py --ogbl_name ogbl-ddi --load_from_ogbl --save_in_pt --output_emb_file ddi-embedding.pt --num_walks 50 --window_size 2 --walk_length 80 --lr 0.1 --negative 1 --neg_weight 1 --lap_norm 0.05 --only_gpu --gpus 0 --num_threads 4 --print_interval 2000 --print_loss --batch_size 16 --use_context_weight
cd ./ogb/blob/master/examples/linkproppred/ddi/
cp embedding_pt_file_path ./
python3 mlp.py --device 0 --runs 10 --epochs 100
......@@ -31,14 +31,14 @@ python3 mlp.py --device 0 --runs 10 --epochs 100
ogbl-ppa
```
python3 deepwalk.py --ogbl_name ogbl-ppa --load_from_ogbl --save_in_pt --output_emb_file ppa-embedding.pt --negative 1 --neg_weight 1 --batch_size 64 --print_interval 2000 --print_loss --window_size 2 --num_walks 30 --walk_length 80 --lr 0.1 --lap_norm 0.02 --adam --mix --gpus 0 --use_context_weight --num_threads 4
python3 deepwalk.py --ogbl_name ogbl-ppa --load_from_ogbl --save_in_pt --output_emb_file ppa-embedding.pt --negative 1 --neg_weight 1 --batch_size 64 --print_interval 2000 --print_loss --window_size 1 --num_walks 30 --walk_length 80 --lr 0.1 --lap_norm 0.02 --mix --gpus 0 --num_threads 4
cp embedding_pt_file_path ./
python3 mlp.py --device 2 --runs 10
```
ogbl-citation
```
python3 deepwalk.py --ogbl_name ogbl-citation --load_from_ogbl --save_in_pt --output_emb_file embedding.pt --window_size 2 --num_walks 10 --negative 1 --neg_weight 1 --walk_length 80 --batch_size 128 --print_loss --print_interval 1000 --mix --adam --gpus 0 --use_context_weight --num_threads 4 --lap_norm 0.05 --lr 0.1
python3 deepwalk.py --ogbl_name ogbl-citation --load_from_ogbl --save_in_pt --output_emb_file embedding.pt --window_size 2 --num_walks 10 --negative 1 --neg_weight 1 --walk_length 80 --batch_size 128 --print_loss --print_interval 1000 --mix --gpus 0 --use_context_weight --num_threads 4 --lap_norm 0.05 --lr 0.1
cp embedding_pt_file_path ./
python3 mlp.py --device 2 --runs 10 --use_node_embedding
```
......@@ -64,45 +64,45 @@ ogbl-collab
<br>obgl-ddi
<br>#params: 1444840(model) + 99073(mlp) = 1543913
<br>&emsp;Hits@10
<br>&emsp;Highest Train: 36.09 ± 2.47
<br>&emsp;Highest Valid: 32.83 ± 2.30
<br>&emsp;&emsp;Final Train: 36.06 ± 2.45
<br>&emsp;&emsp;Final Test: 11.76 ± 3.91
<br>&emsp;Hits@20
<br>&emsp;Highest Train: 45.59 ± 2.45
<br>&emsp;Highest Valid: 42.00 ± 2.36
<br>&emsp;&emsp;Final Train: 45.56 ± 2.50
<br>&emsp;&emsp;Final Test: 22.46 ± 2.90
<br>&emsp;Hits@30
<br>&emsp;Highest Train: 51.58 ± 2.41
<br>&emsp;Highest Valid: 47.82 ± 2.19
<br>&emsp;&emsp;Final Train: 51.58 ± 2.42
<br>&emsp;&emsp;Final Test: 30.17 ± 3.39
<br>Hits@10
<br>&emsp;Highest Train: 33.91 ± 2.01
<br>&emsp;Highest Valid: 30.96 ± 1.89
<br>&emsp;&emsp;Final Train: 33.90 ± 2.00
<br>&emsp;&emsp;Final Test: 15.16 ± 4.28
<br>Hits@20
<br>&emsp;Highest Train: 44.64 ± 1.71
<br>&emsp;Highest Valid: 41.32 ± 1.69
<br>&emsp;&emsp;Final Train: 44.62 ± 1.69
<br>&emsp;&emsp;Final Test: 26.42 ± 6.10
<br>Hits@30
<br>&emsp;Highest Train: 51.01 ± 1.72
<br>&emsp;Highest Valid: 47.64 ± 1.71
<br>&emsp;&emsp;Final Train: 50.99 ± 1.72
<br>&emsp;&emsp;Final Test: 33.56 ± 3.95
<br>ogbl-ppa
<br>#params: 150024820(model) + 113921(mlp) = 150138741
<br>Hits@10
<br>&emsp;Highest Train: 3.58 ± 0.90
<br>&emsp;Highest Valid: 2.88 ± 0.76
<br>&emsp;&emsp;Final Train: 3.58 ± 0.90
<br>&emsp;&emsp;Final Test: 1.45 ± 0.65
<br>&emsp;Hits@50
<br>&emsp;Highest Train: 18.21 ± 2.29
<br>&emsp;Highest Valid: 15.75 ± 2.10
<br>&emsp;&emsp;Final Train: 18.21 ± 2.29
<br>&emsp;&emsp;Final Test: 11.70 ± 0.97
<br>&emsp;Hits@100
<br>&emsp;Highest Train: 31.16 ± 2.23
<br>&emsp;Highest Valid: 27.52 ± 2.07
<br>&emsp;&emsp;Final Train: 31.16 ± 2.23
<br>&emsp;&emsp;Final Test: 23.02 ± 1.63
<br>&emsp;Highest Train: 4.78 ± 0.73
<br>&emsp;Highest Valid: 4.30 ± 0.68
<br>&emsp;&emsp;Final Train: 4.77 ± 0.73
<br>&emsp;&emsp;Final Test: 2.67 ± 0.42
<br>Hits@50
<br>&emsp;Highest Train: 18.82 ± 1.07
<br>&emsp;Highest Valid: 17.26 ± 1.01
<br>&emsp;&emsp;Final Train: 18.82 ± 1.07
<br>&emsp;&emsp;Final Test: 17.34 ± 2.09
<br>Hits@100
<br>&emsp;Highest Train: 31.29 ± 2.11
<br>&emsp;Highest Valid: 28.97 ± 1.92
<br>&emsp;&emsp;Final Train: 31.28 ± 2.12
<br>&emsp;&emsp;Final Test: 28.88 ± 1.53
<br>ogbl-citation
<br>#params: 757811178(model) + 131841(mlp) = 757943019
<br>MRR
<br>&emsp;Highest Train: 0.8797 ± 0.0007
<br>&emsp;Highest Valid: 0.8139 ± 0.0005
<br>&emsp;&emsp;Final Train: 0.8792 ± 0.0008
<br>&emsp;&emsp;Final Test: 0.8148 ± 0.0004
<br>&emsp;Highest Train: 0.8994 ± 0.0004
<br>&emsp;Highest Valid: 0.8271 ± 0.0003
<br>&emsp;&emsp;Final Train: 0.8991 ± 0.0007
<br>&emsp;&emsp;Final Test: 0.8284 ± 0.0005
......@@ -38,8 +38,6 @@ class DeepwalkTrainer:
"""
choices = sum([self.args.only_gpu, self.args.only_cpu, self.args.mix])
assert choices == 1, "Must choose only *one* training mode in [only_cpu, only_gpu, mix]"
choices = sum([self.args.sgd, self.args.adam, self.args.avg_sgd])
assert choices == 1, "Must choose only *one* gradient descent strategy in [sgd, avg_sgd, adam]"
# initializing embedding on CPU
self.emb_model = SkipGramModel(
......@@ -55,13 +53,12 @@ class DeepwalkTrainer:
negative=self.args.negative,
lr=self.args.lr,
lap_norm=self.args.lap_norm,
adam=self.args.adam,
sgd=self.args.sgd,
avg_sgd=self.args.avg_sgd,
fast_neg=self.args.fast_neg,
record_loss=self.args.print_loss,
norm=self.args.norm,
use_context_weight=self.args.use_context_weight,
async_update=self.args.async_update,
num_threads=self.args.num_threads,
)
torch.set_num_threads(self.args.num_threads)
......@@ -72,8 +69,8 @@ class DeepwalkTrainer:
elif self.args.mix:
print("Mix CPU with %d GPU" % len(self.args.gpus))
if len(self.args.gpus) == 1:
assert self.args.gpus[0] >= 0, 'mix CPU with GPU should have abaliable GPU'
#self.emb_model.set_device(self.args.gpus[0])
assert self.args.gpus[0] >= 0, 'mix CPU with GPU should have available GPU'
self.emb_model.set_device(self.args.gpus[0])
else:
print("Run in CPU process")
self.args.gpus = [torch.device('cpu')]
......@@ -98,7 +95,7 @@ class DeepwalkTrainer:
ps = []
for i in range(len(self.args.gpus)):
p = mp.Process(target=self.fast_train_sp, args=(self.args.gpus[i],))
p = mp.Process(target=self.fast_train_sp, args=(i, self.args.gpus[i]))
ps.append(p)
p.start()
......@@ -114,13 +111,16 @@ class DeepwalkTrainer:
self.emb_model.save_embedding(self.dataset, self.args.output_emb_file)
@thread_wrapped_func
def fast_train_sp(self, gpu_id):
def fast_train_sp(self, rank, gpu_id):
""" a subprocess for fast_train_mp """
if self.args.mix:
self.emb_model.set_device(gpu_id)
torch.set_num_threads(self.args.num_threads)
if self.args.async_update:
self.emb_model.create_async_update()
sampler = self.dataset.create_sampler(gpu_id)
sampler = self.dataset.create_sampler(rank)
dataloader = DataLoader(
dataset=sampler.seeds,
......@@ -128,26 +128,19 @@ class DeepwalkTrainer:
collate_fn=sampler.sample,
shuffle=False,
drop_last=False,
num_workers=4,
num_workers=self.args.num_sampler_threads,
)
num_batches = len(dataloader)
print("num batchs: %d in subprocess [%d]" % (num_batches, gpu_id))
print("num batchs: %d in process [%d] GPU [%d]" % (num_batches, rank, gpu_id))
# number of positive node pairs in a sequence
num_pos = int(2 * self.args.walk_length * self.args.window_size\
- self.args.window_size * (self.args.window_size + 1))
start = time.time()
with torch.no_grad():
max_i = self.args.iterations * num_batches
for i, walks in enumerate(dataloader):
# decay learning rate for SGD
lr = self.args.lr * (max_i - i) / max_i
if lr < 0.00001:
lr = 0.00001
if self.args.fast_neg:
self.emb_model.fast_learn(walks, lr)
self.emb_model.fast_learn(walks)
else:
# do negative sampling
bs = len(walks)
......@@ -155,19 +148,22 @@ class DeepwalkTrainer:
np.random.choice(self.dataset.neg_table,
bs * num_pos * self.args.negative,
replace=True))
self.emb_model.fast_learn(walks, lr, neg_nodes=neg_nodes)
self.emb_model.fast_learn(walks, neg_nodes=neg_nodes)
if i > 0 and i % self.args.print_interval == 0:
if self.args.print_loss:
print("Solver [%d] batch %d tt: %.2fs loss: %.4f" \
print("GPU-[%d] batch %d time: %.2fs loss: %.4f" \
% (gpu_id, i, time.time()-start, -sum(self.emb_model.loss)/self.args.print_interval))
self.emb_model.loss = []
else:
print("Solver [%d] batch %d tt: %.2fs" % (gpu_id, i, time.time()-start))
print("GPU-[%d] batch %d time: %.2fs" % (gpu_id, i, time.time()-start))
start = time.time()
if self.args.async_update:
self.emb_model.finish_async_update()
def fast_train(self):
""" fast train with dataloader """
""" fast train with dataloader with only gpu / only cpu"""
# the number of postive node pairs of a node sequence
num_pos = 2 * self.args.walk_length * self.args.window_size\
- self.args.window_size * (self.args.window_size + 1)
......@@ -175,6 +171,10 @@ class DeepwalkTrainer:
self.init_device_emb()
if self.args.async_update:
self.emb_model.share_memory()
self.emb_model.create_async_update()
if self.args.count_params:
sum_up_params(self.emb_model)
......@@ -186,27 +186,19 @@ class DeepwalkTrainer:
collate_fn=sampler.sample,
shuffle=False,
drop_last=False,
num_workers=4,
num_workers=self.args.num_sampler_threads,
)
num_batches = len(dataloader)
print("num batchs: %d" % num_batches)
print("num batchs: %d\n" % num_batches)
start_all = time.time()
start = time.time()
with torch.no_grad():
max_i = self.args.iterations * num_batches
for iteration in range(self.args.iterations):
print("\nIteration: " + str(iteration + 1))
max_i = num_batches
for i, walks in enumerate(dataloader):
# decay learning rate for SGD
lr = self.args.lr * (max_i - i) / max_i
if lr < 0.00001:
lr = 0.00001
if self.args.fast_neg:
self.emb_model.fast_learn(walks, lr)
self.emb_model.fast_learn(walks)
else:
# do negative sampling
bs = len(walks)
......@@ -214,7 +206,7 @@ class DeepwalkTrainer:
np.random.choice(self.dataset.neg_table,
bs * num_pos * self.args.negative,
replace=True))
self.emb_model.fast_learn(walks, lr, neg_nodes=neg_nodes)
self.emb_model.fast_learn(walks, neg_nodes=neg_nodes)
if i > 0 and i % self.args.print_interval == 0:
if self.args.print_loss:
......@@ -225,6 +217,9 @@ class DeepwalkTrainer:
print("Batch %d, training time: %.2fs" % (i, time.time()-start))
start = time.time()
if self.args.async_update:
self.emb_model.finish_async_update()
print("Training used time: %.2fs" % (time.time()-start_all))
if self.args.save_in_txt:
self.emb_model.save_embedding_txt(self.dataset, self.args.output_emb_file)
......@@ -266,20 +261,18 @@ if __name__ == '__main__':
help="whether to add weights over nodes in the context window")
parser.add_argument('--num_walks', default=10, type=int,
help="number of walks for each node")
parser.add_argument('--negative', default=5, type=int,
parser.add_argument('--negative', default=1, type=int,
help="negative samples for each positve node pair")
parser.add_argument('--batch_size', default=10, type=int,
parser.add_argument('--batch_size', default=128, type=int,
help="number of node sequences in each batch")
parser.add_argument('--walk_length', default=80, type=int,
help="number of nodes in a sequence")
parser.add_argument('--neg_weight', default=1., type=float,
help="negative weight")
parser.add_argument('--lap_norm', default=0.01, type=float,
help="weight of laplacian normalization")
help="weight of laplacian normalization, recommend to set as 0.1 / windoe_size")
# training parameters
parser.add_argument('--iterations', default=1, type=int,
help="iterations")
parser.add_argument('--print_interval', default=100, type=int,
help="number of batches between printing")
parser.add_argument('--print_loss', default=False, action="store_true",
......@@ -296,23 +289,24 @@ if __name__ == '__main__':
help="training with CPU")
parser.add_argument('--only_gpu', default=False, action="store_true",
help="training with GPU")
parser.add_argument('--async_update', default=False, action="store_true",
help="mixed training asynchronously, not recommended")
parser.add_argument('--adam', default=False, action="store_true",
help="use adam for embedding updation, recommended")
parser.add_argument('--sgd', default=False, action="store_true",
help="use sgd for embedding updation")
parser.add_argument('--avg_sgd', default=False, action="store_true",
help="average gradients of sgd for embedding updation")
parser.add_argument('--fast_neg', default=False, action="store_true",
help="do negative sampling inside a batch")
parser.add_argument('--num_threads', default=2, type=int,
parser.add_argument('--num_threads', default=8, type=int,
help="number of threads used for each CPU-core/GPU")
parser.add_argument('--num_sampler_threads', default=2, type=int,
help="number of threads used for sampling")
parser.add_argument('--count_params', default=False, action="store_true",
help="count the params, then exit")
help="count the params, exit once counting over")
args = parser.parse_args()
if args.async_update:
assert args.mix, "--async_update only with --mix"
start_time = time.time()
trainer = DeepwalkTrainer(args)
trainer.train()
......
......@@ -4,6 +4,10 @@ import torch.nn.functional as F
from torch.nn import init
import random
import numpy as np
import torch.multiprocessing as mp
from torch.multiprocessing import Queue
from utils import thread_wrapped_func
def init_emb2pos_index(walk_length, window_size, batch_size):
''' select embedding of positive nodes from a batch of node embeddings
......@@ -72,25 +76,8 @@ def init_emb2neg_index(walk_length, window_size, negative, batch_size):
return index_emb_negu, index_emb_negv
def init_grad_avg(walk_length, window_size, batch_size):
''' averaging graidents by specific weights
'''
grad_avg = []
for b in range(batch_size):
for i in range(walk_length):
if i < window_size:
grad_avg.append(1. / float(i+window_size))
elif i >= walk_length - window_size:
grad_avg.append(1. / float(walk_length - i - 1 + window_size))
else:
grad_avg.append(0.5 / window_size)
# [num_pos * batch_size]
return torch.Tensor(grad_avg).unsqueeze(1)
def init_weight(walk_length, window_size, batch_size):
''' select nodes' gradients from gradient matrix
'''
''' init context weight '''
weight = []
for b in range(batch_size):
for i in range(walk_length):
......@@ -123,6 +110,20 @@ def adam(grad, state_sum, nodes, lr, device, only_gpu):
return grad
@thread_wrapped_func
def async_update(num_threads, model, queue):
""" asynchronous embedding update """
torch.set_num_threads(num_threads)
while True:
(grad_u, grad_v, grad_v_neg, nodes, neg_nodes) = queue.get()
if grad_u is None:
return
with torch.no_grad():
model.u_embeddings.weight.data.index_add_(0, nodes.view(-1), grad_u)
model.v_embeddings.weight.data.index_add_(0, nodes.view(-1), grad_v)
if neg_nodes is not None:
model.v_embeddings.weight.data.index_add_(0, neg_nodes.view(-1), grad_v_neg)
class SkipGramModel(nn.Module):
""" Negative sampling based skip-gram """
def __init__(self,
......@@ -138,13 +139,12 @@ class SkipGramModel(nn.Module):
negative,
lr,
lap_norm,
adam,
sgd,
avg_sgd,
fast_neg,
record_loss,
norm,
use_context_weight,
async_update,
num_threads,
):
""" initialize embedding on CPU
......@@ -162,10 +162,11 @@ class SkipGramModel(nn.Module):
neg_weight float : negative weight
lr float : initial learning rate
lap_norm float : weight of laplacian normalization
adam bool : use adam for embedding updation
sgd bool : use sgd for embedding updation
avg_sgd bool : average gradients of sgd for embedding updation
fast_neg bool : do negative sampling inside a batch
record_loss bool : print the loss during training
norm bool : do normalizatin on the embedding after training
use_context_weight : give different weights to the nodes in a context window
async_update : asynchronous training
"""
super(SkipGramModel, self).__init__()
self.emb_size = emb_size
......@@ -180,13 +181,12 @@ class SkipGramModel(nn.Module):
self.negative = negative
self.lr = lr
self.lap_norm = lap_norm
self.adam = adam
self.sgd = sgd
self.avg_sgd = avg_sgd
self.fast_neg = fast_neg
self.record_loss = record_loss
self.norm = norm
self.use_context_weight = use_context_weight
self.async_update = async_update
self.num_threads = num_threads
# initialize the device as cpu
self.device = torch.device("cpu")
......@@ -215,7 +215,6 @@ class SkipGramModel(nn.Module):
self.walk_length,
self.window_size,
self.batch_size)
if self.fast_neg:
self.index_emb_negu, self.index_emb_negv = init_emb2neg_index(
self.walk_length,
self.window_size,
......@@ -228,14 +227,7 @@ class SkipGramModel(nn.Module):
self.window_size,
self.batch_size)
# coefficients for averaging the gradients
if self.avg_sgd:
self.grad_avg = init_grad_avg(
self.walk_length,
self.window_size,
self.batch_size)
# adam
if self.adam:
self.state_sum_u = torch.zeros(self.emb_size)
self.state_sum_v = torch.zeros(self.emb_size)
......@@ -245,11 +237,23 @@ class SkipGramModel(nn.Module):
self.walk_length,
self.batch_size)
def create_async_update(self):
""" Set up the async update subprocess.
"""
self.async_q = Queue(1)
self.async_p = mp.Process(target=async_update, args=(self.num_threads, self, self.async_q))
self.async_p.start()
def finish_async_update(self):
""" Notify the async update subprocess to quit.
"""
self.async_q.put((None, None, None, None, None))
self.async_p.join()
def share_memory(self):
""" share the parameters across subprocesses """
self.u_embeddings.weight.share_memory_()
self.v_embeddings.weight.share_memory_()
if self.adam:
self.state_sum_u.share_memory_()
self.state_sum_v.share_memory_()
......@@ -262,13 +266,10 @@ class SkipGramModel(nn.Module):
self.logsigmoid_table = self.logsigmoid_table.to(self.device)
self.index_emb_posu = self.index_emb_posu.to(self.device)
self.index_emb_posv = self.index_emb_posv.to(self.device)
if self.fast_neg:
self.index_emb_negu = self.index_emb_negu.to(self.device)
self.index_emb_negv = self.index_emb_negv.to(self.device)
self.grad_u = self.grad_u.to(self.device)
self.grad_v = self.grad_v.to(self.device)
if self.avg_sgd:
self.grad_avg = self.grad_avg.to(self.device)
if self.use_context_weight:
self.context_weight = self.context_weight.to(self.device)
......@@ -278,7 +279,6 @@ class SkipGramModel(nn.Module):
self.set_device(gpu_id)
self.u_embeddings = self.u_embeddings.cuda(gpu_id)
self.v_embeddings = self.v_embeddings.cuda(gpu_id)
if self.adam:
self.state_sum_u = self.state_sum_u.to(self.device)
self.state_sum_v = self.state_sum_v.to(self.device)
......@@ -292,7 +292,7 @@ class SkipGramModel(nn.Module):
idx = torch.floor((score + 6.01) / 0.01).long()
return self.logsigmoid_table[idx]
def fast_learn(self, batch_walks, lr, neg_nodes=None):
def fast_learn(self, batch_walks, neg_nodes=None):
""" Learn a batch of random walks in a fast way. It has the following features:
1. It calculating the gradients directly without the forward operation.
2. It does sigmoid by a looking up table.
......@@ -317,7 +317,6 @@ class SkipGramModel(nn.Module):
lr = 0.01
neg_nodes = None
"""
if self.adam:
lr = self.lr
# [batch_size, walk_length]
......@@ -427,40 +426,35 @@ class SkipGramModel(nn.Module):
## Update
nodes = nodes.view(-1)
if self.avg_sgd:
# since the times that a node are performed backward propagation are different,
# we need to average the gradients by different weight.
# e.g. for sequence [1, 2, 3, ...] with window_size = 5, we have positive node
# pairs [(1,2), (1, 3), (1,4), ...]. To average the gradients for each node, we
# perform weighting on the gradients of node pairs.
# The weights are: [1/5, 1/5, ..., 1/6, ..., 1/10, ..., 1/6, ..., 1/5].
if bs < self.batch_size:
grad_avg = init_grad_avg(
self.walk_length,
self.window_size,
bs).to(self.device)
else:
grad_avg = self.grad_avg
grad_u = grad_avg * grad_u * lr
grad_v = grad_avg * grad_v * lr
elif self.sgd:
grad_u = grad_u * lr
grad_v = grad_v * lr
elif self.adam:
# use adam optimizer
grad_u = adam(grad_u, self.state_sum_u, nodes, lr, self.device, self.only_gpu)
grad_v = adam(grad_v, self.state_sum_v, nodes, lr, self.device, self.only_gpu)
if neg_nodes is not None:
grad_v_neg = adam(grad_v_neg, self.state_sum_v, neg_nodes, lr, self.device, self.only_gpu)
if self.mixed_train:
grad_u = grad_u.cpu()
grad_v = grad_v.cpu()
if neg_nodes is not None:
grad_v_neg = grad_v_neg.cpu()
else:
grad_v_neg = None
if self.async_update:
grad_u.share_memory_()
grad_v.share_memory_()
nodes.share_memory_()
if neg_nodes is not None:
neg_nodes.share_memory_()
grad_v_neg.share_memory_()
self.async_q.put((grad_u, grad_v, grad_v_neg, nodes, neg_nodes))
if not self.async_update:
self.u_embeddings.weight.data.index_add_(0, nodes.view(-1), grad_u)
self.v_embeddings.weight.data.index_add_(0, nodes.view(-1), grad_v)
if neg_nodes is not None:
self.v_embeddings.weight.data.index_add_(0, neg_nodes.view(-1), lr * grad_v_neg)
self.v_embeddings.weight.data.index_add_(0, neg_nodes.view(-1), grad_v_neg)
return
def forward(self, pos_u, pos_v, neg_v):
......@@ -512,8 +506,7 @@ class SkipGramModel(nn.Module):
self.save_embedding_pt_dgl_graph(dataset, file_name)
def save_embedding_pt_dgl_graph(self, dataset, file_name):
""" For ogb leaderboard.
"""
""" For ogb leaderboard """
embedding = torch.zeros_like(self.u_embeddings.weight.cpu().data)
valid_seeds = torch.LongTensor(dataset.valid_seeds)
valid_embedding = self.u_embeddings.weight.cpu().data.index_select(0,
......
......@@ -161,7 +161,7 @@ class DeepwalkDataset:
self.fast_neg = fast_neg
if load_from_ogbl:
assert len(gpus) == 1, "ogb.linkproppred is not compatible with multi-gpu training."
assert len(gpus) == 1, "ogb.linkproppred is not compatible with multi-gpu training (CUDA error)."
from load_dataset import load_from_ogbl_with_name
self.G = load_from_ogbl_with_name(ogbl_name)
self.G = make_undirected(self.G)
......@@ -200,26 +200,25 @@ class DeepwalkDataset:
self.neg_table = np.array(self.neg_table, dtype=np.long)
del node_degree
def create_sampler(self, gpu_id):
""" Still in construction...
Several mode:
1. do true negative sampling.
1.1 from random walk sequence
1.2 from node degree distribution
return the sampled node ids
2. do false negative sampling from random walk sequence
save GPU, faster
return the node indices in the sequences
"""
return DeepwalkSampler(self.G, self.seeds[gpu_id], self.walk_length)
def create_sampler(self, i):
""" create random walk sampler """
return DeepwalkSampler(self.G, self.seeds[i], self.walk_length)
def save_mapping(self, map_file):
""" save the mapping dict that maps node IDs to embedding indices """
with open(map_file, "wb") as f:
pickle.dump(self.node2id, f)
class DeepwalkSampler(object):
def __init__(self, G, seeds, walk_length):
""" random walk sampler
Parameter
---------
G dgl.Graph : the input graph
seeds torch.LongTensor : starting nodes
walk_length int : walk length
"""
self.G = G
self.seeds = seeds
self.walk_length = walk_length
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment