Unverified Commit 6ae440db authored by Linfang He's avatar Linfang He Committed by GitHub
Browse files

[Model] GATNE-T (#1470)



* Add an example for GATNE-T

* Update README.md

* Add links for datasets

* Update README.md with running results

* Update README.md

* Add comments

* Update main.py

* Remove node type name `user`
Co-authored-by: default avatarDa Zheng <zhengda1936@gmail.com>
parent 2190c39d
Representation Learning for Attributed Multiplex Heterogeneous Network (GANTE)
============
- Paper link: [https://arxiv.org/abs/1905.01669](https://arxiv.org/abs/1905.01669)
- Author's code repo: [https://github.com/THUDM/GATNE](https://github.com/THUDM/GATNE). Note that only GATNE-T is implemented here.
Requirements
------------
- requirements
``bash
pip install requirements
``
Datasets
--------
* [example](https://s3.us-west-2.amazonaws.com/dgl-data/dataset/recsys/GATNE/example.zip)
* [amazon](https://s3.us-west-2.amazonaws.com/dgl-data/dataset/recsys/GATNE/amazon.zip)
* [youtube](https://s3.us-west-2.amazonaws.com/dgl-data/dataset/recsys/GATNE/youtube.zip)
* [twitter](https://s3.us-west-2.amazonaws.com/dgl-data/dataset/recsys/GATNE/twitter.zip)
Training
--------
Run with following (available dataset: "example", "youtube", "amazon")
```bash
python src/main.py --input data/example
```
To run on "twitter" dataset, use
```bash
python src/main.py --input data/twitter --eval-type 1
```
Results
-------
All the results match the [official code](https://github.com/THUDM/GATNE/blob/master/src/main_pytorch.py) with the same hyper parameter values, including twiiter dataset (auc, pr, f1 is 76.29, 76.17, 69.34, respectively).
| | auc | pr | f1 |
| ------ | ---- | --- | ----- |
| amazon | 96.88 | 96.31 | 92.12 |
| youtube | 82.29 | 80.35 | 74.63 |
| twitter | 72.40 | 74.40 | 65.89 |
| example | 94.65 | 94.57 | 89.99 |
tqdm
numpy
sklearn
networkx
gensim
requests
--pre dgl-cu101
\ No newline at end of file
python src/main.py --input data/example
\ No newline at end of file
from collections import defaultdict
import math
import os
import sys
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm
from numpy import random
from torch.nn.parameter import Parameter
import dgl
import dgl.function as fn
from utils import *
def get_graph(network_data, vocab):
''' Build graph, treat all nodes as the same type
Parameters
----------
network_data: a dict
keys describing the edge types, values representing edges
vocab: a dict
mapping node IDs to node indices
<<<<<<< HEAD
=======
>>>>>>> c334b40e1f8a30bd5619814f34a469b18774fba7
Output
------
DGLHeteroGraph
a heterogenous graph, with one node type and different edge types
'''
graphs = []
num_nodes = len(vocab)
for edge_type in network_data:
tmp_data = network_data[edge_type]
edges = []
for edge in tmp_data:
edges.append((vocab[edge[0]], vocab[edge[1]]))
edges.append((vocab[edge[1]], vocab[edge[0]]))
g = dgl.graph(edges, etype=edge_type, num_nodes=num_nodes)
graphs.append(g)
graph = dgl.hetero_from_relations(graphs)
return graph
class NeighborSampler(object):
def __init__(self, g, num_fanouts):
self.g = g
self.num_fanouts = num_fanouts
def sample(self, pairs):
heads, tails, types = zip(*pairs)
seeds, head_invmap = torch.unique(torch.LongTensor(heads), return_inverse=True)
blocks = []
for fanout in reversed(self.num_fanouts):
sampled_graph = dgl.sampling.sample_neighbors(self.g, seeds, fanout)
sampled_block = dgl.to_block(sampled_graph, seeds)
seeds = sampled_block.srcdata[dgl.NID]
blocks.insert(0, sampled_block)
return blocks, torch.LongTensor(head_invmap), torch.LongTensor(tails), torch.LongTensor(types)
class DGLGATNE(nn.Module):
def __init__(self, num_nodes, embedding_size, embedding_u_size, edge_types, edge_type_count, dim_a):
super(DGLGATNE, self).__init__()
self.num_nodes = num_nodes
self.embedding_size = embedding_size
self.embedding_u_size = embedding_u_size
self.edge_types = edge_types
self.edge_type_count = edge_type_count
self.dim_a = dim_a
self.node_embeddings = Parameter(torch.FloatTensor(num_nodes, embedding_size))
self.node_type_embeddings = Parameter(torch.FloatTensor(num_nodes, edge_type_count, embedding_u_size))
self.trans_weights = Parameter(torch.FloatTensor(edge_type_count, embedding_u_size, embedding_size))
self.trans_weights_s1 = Parameter(torch.FloatTensor(edge_type_count, embedding_u_size, dim_a))
self.trans_weights_s2 = Parameter(torch.FloatTensor(edge_type_count, dim_a, 1))
self.reset_parameters()
def reset_parameters(self):
self.node_embeddings.data.uniform_(-1.0, 1.0)
self.node_type_embeddings.data.uniform_(-1.0, 1.0)
self.trans_weights.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
self.trans_weights_s1.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
self.trans_weights_s2.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
# embs: [batch_size, embedding_size]
def forward(self, block):
input_nodes = block.srcdata[dgl.NID]
output_nodes = block.dstdata[dgl.NID]
batch_size = block.number_of_dst_nodes()
node_embed = self.node_embeddings
node_type_embed = []
with block.local_scope():
for i in range(self.edge_type_count):
edge_type = self.edge_types[i]
block.srcdata[edge_type] = self.node_type_embeddings[input_nodes, i]
block.dstdata[edge_type] = self.node_type_embeddings[output_nodes, i]
block.update_all(fn.copy_u(edge_type, 'm'), fn.sum('m', edge_type), etype=edge_type)
node_type_embed.append(block.dstdata[edge_type])
node_type_embed = torch.stack(node_type_embed, 1)
tmp_node_type_embed = node_type_embed.unsqueeze(2).view(-1, 1, self.embedding_u_size)
trans_w = self.trans_weights.unsqueeze(0).repeat(batch_size, 1, 1, 1).view(
-1, self.embedding_u_size, self.embedding_size
)
trans_w_s1 = self.trans_weights_s1.unsqueeze(0).repeat(batch_size, 1, 1, 1).view(
-1, self.embedding_u_size, self.dim_a
)
trans_w_s2 = self.trans_weights_s2.unsqueeze(0).repeat(batch_size, 1, 1, 1).view(-1, self.dim_a, 1)
attention = F.softmax(
torch.matmul(
torch.tanh(torch.matmul(tmp_node_type_embed, trans_w_s1)), trans_w_s2
).squeeze(2).view(-1, self.edge_type_count),
dim=1,
).unsqueeze(1).repeat(1, self.edge_type_count, 1)
node_type_embed = torch.matmul(attention, node_type_embed).view(-1, 1, self.embedding_u_size)
node_embed = node_embed[output_nodes].unsqueeze(1).repeat(1, self.edge_type_count, 1) + \
torch.matmul(node_type_embed, trans_w).view(-1, self.edge_type_count, self.embedding_size)
last_node_embed = F.normalize(node_embed, dim=2)
return last_node_embed # [batch_size, edge_type_count, embedding_size]
class NSLoss(nn.Module):
def __init__(self, num_nodes, num_sampled, embedding_size):
super(NSLoss, self).__init__()
self.num_nodes = num_nodes
self.num_sampled = num_sampled
self.embedding_size = embedding_size
self.weights = Parameter(torch.FloatTensor(num_nodes, embedding_size))
# [ (log(i+2) - log(i+1)) / log(num_nodes + 1)]
self.sample_weights = F.normalize(
torch.Tensor(
[
(math.log(k + 2) - math.log(k + 1)) / math.log(num_nodes + 1)
for k in range(num_nodes)
]
),
dim=0,
)
self.reset_parameters()
def reset_parameters(self):
self.weights.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
def forward(self, input, embs, label):
n = input.shape[0]
log_target = torch.log(
torch.sigmoid(torch.sum(torch.mul(embs, self.weights[label]), 1))
)
negs = torch.multinomial(
self.sample_weights, self.num_sampled * n, replacement=True
).view(n, self.num_sampled)
noise = torch.neg(self.weights[negs])
sum_log_sampled = torch.sum(
torch.log(torch.sigmoid(torch.bmm(noise, embs.unsqueeze(2)))), 1
).squeeze()
loss = log_target + sum_log_sampled
return -loss.sum() / n
def train_model(network_data):
index2word, vocab, type_nodes = generate_vocab(network_data)
edge_types = list(network_data.keys())
num_nodes = len(index2word)
edge_type_count = len(edge_types)
epochs = args.epoch
batch_size = args.batch_size
embedding_size = args.dimensions
embedding_u_size = args.edge_dim
u_num = edge_type_count
num_sampled = args.negative_samples
dim_a = args.att_dim
att_head = 1
neighbor_samples = args.neighbor_samples
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
g = get_graph(network_data, vocab)
all_walks = []
for i in range(edge_type_count):
nodes = torch.LongTensor(type_nodes[i] * args.num_walks)
traces, types = dgl.sampling.random_walk(g, nodes, metapath=[edge_types[i]]*(neighbor_samples-1))
all_walks.append(traces)
train_pairs = generate_pairs(all_walks, args.window_size)
neighbor_sampler = NeighborSampler(g, [neighbor_samples])
train_dataloader = torch.utils.data.DataLoader(
train_pairs, batch_size=batch_size, collate_fn=neighbor_sampler.sample, shuffle=True
)
model = DGLGATNE(num_nodes, embedding_size, embedding_u_size, edge_types, edge_type_count, dim_a)
nsloss = NSLoss(num_nodes, num_sampled, embedding_size)
model.to(device)
nsloss.to(device)
optimizer = torch.optim.Adam([{"params": model.parameters()}, {"params": nsloss.parameters()}], lr=1e-4)
best_score = 0
patience = 0
for epoch in range(epochs):
model.train()
random.shuffle(train_pairs)
data_iter = tqdm.tqdm(
train_dataloader,
desc="epoch %d" % (epoch),
total=(len(train_pairs) + (batch_size - 1)) // batch_size,
bar_format="{l_bar}{r_bar}",
)
avg_loss = 0.0
for i, (block, head_invmap, tails, block_types) in enumerate(data_iter):
optimizer.zero_grad()
# embs: [batch_size, edge_type_count, embedding_size]
block_types = block_types.to(device)
embs = model(block[0].to(device))[head_invmap]
embs = embs.gather(1, block_types.view(-1, 1, 1).expand(embs.shape[0], 1, embs.shape[2]))[:, 0]
loss = nsloss(block[0].dstdata[dgl.NID][head_invmap].to(device), embs, tails.to(device))
loss.backward()
optimizer.step()
avg_loss += loss.item()
if i % 5000 == 0:
post_fix = {
"epoch": epoch,
"iter": i,
"avg_loss": avg_loss / (i + 1),
"loss": loss.item(),
}
data_iter.write(str(post_fix))
model.eval()
# {'1': {}, '2': {}}
final_model = dict(zip(edge_types, [dict() for _ in range(edge_type_count)]))
for i in range(num_nodes):
train_inputs = torch.tensor([i for _ in range(edge_type_count)]).unsqueeze(1).to(device) # [i, i]
train_types = torch.tensor(list(range(edge_type_count))).unsqueeze(1).to(device) # [0, 1]
pairs = torch.cat((train_inputs, train_inputs, train_types), dim=1) # (2, 3)
train_blocks, train_invmap, fake_tails, train_types = neighbor_sampler.sample(pairs)
node_emb = model(train_blocks[0].to(device))[train_invmap]
node_emb = node_emb.gather(1, train_types.to(device).view(-1, 1, 1
).expand(node_emb.shape[0], 1, node_emb.shape[2])
)[:, 0]
for j in range(edge_type_count):
final_model[edge_types[j]][index2word[i]] = (
node_emb[j].cpu().detach().numpy()
)
valid_aucs, valid_f1s, valid_prs = [], [], []
test_aucs, test_f1s, test_prs = [], [], []
for i in range(edge_type_count):
if args.eval_type == "all" or edge_types[i] in args.eval_type.split(","):
tmp_auc, tmp_f1, tmp_pr = evaluate(
final_model[edge_types[i]],
valid_true_data_by_edge[edge_types[i]],
valid_false_data_by_edge[edge_types[i]],
)
valid_aucs.append(tmp_auc)
valid_f1s.append(tmp_f1)
valid_prs.append(tmp_pr)
tmp_auc, tmp_f1, tmp_pr = evaluate(
final_model[edge_types[i]],
testing_true_data_by_edge[edge_types[i]],
testing_false_data_by_edge[edge_types[i]],
)
test_aucs.append(tmp_auc)
test_f1s.append(tmp_f1)
test_prs.append(tmp_pr)
print("valid auc:", np.mean(valid_aucs))
print("valid pr:", np.mean(valid_prs))
print("valid f1:", np.mean(valid_f1s))
average_auc = np.mean(test_aucs)
average_f1 = np.mean(test_f1s)
average_pr = np.mean(test_prs)
cur_score = np.mean(valid_aucs)
if cur_score > best_score:
best_score = cur_score
patience = 0
else:
patience += 1
if patience > args.patience:
print("Early Stopping")
break
return average_auc, average_f1, average_pr
if __name__ == "__main__":
args = parse_args()
file_name = args.input
print(args)
training_data_by_type = load_training_data(file_name + "/train.txt")
valid_true_data_by_edge, valid_false_data_by_edge = load_testing_data(
file_name + "/valid.txt"
)
testing_true_data_by_edge, testing_false_data_by_edge = load_testing_data(
file_name + "/test.txt"
)
start = time.time()
average_auc, average_f1, average_pr = train_model(training_data_by_type)
end = time.time()
print("Overall ROC-AUC:", average_auc)
print("Overall PR-AUC", average_pr)
print("Overall F1:", average_f1)
print("Training Time", end-start)
import argparse
from collections import defaultdict
import networkx as nx
import numpy as np
from gensim.models.keyedvectors import Vocab
from six import iteritems
from sklearn.metrics import (auc, f1_score, precision_recall_curve,
roc_auc_score)
import torch
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--input', type=str, default='data/amazon',
help='Input dataset path')
parser.add_argument('--features', type=str, default=None,
help='Input node features')
parser.add_argument('--epoch', type=int, default=100,
help='Number of epoch. Default is 100.')
parser.add_argument('--batch-size', type=int, default=64,
help='Number of batch_size. Default is 64.')
parser.add_argument('--eval-type', type=str, default='all',
help='The edge type(s) for evaluation.')
parser.add_argument('--schema', type=str, default=None,
help='The metapath schema (e.g., U-I-U,I-U-I).')
parser.add_argument('--dimensions', type=int, default=200,
help='Number of dimensions. Default is 200.')
parser.add_argument('--edge-dim', type=int, default=10,
help='Number of edge embedding dimensions. Default is 10.')
parser.add_argument('--att-dim', type=int, default=20,
help='Number of attention dimensions. Default is 20.')
parser.add_argument('--walk-length', type=int, default=10,
help='Length of walk per source. Default is 10.')
parser.add_argument('--num-walks', type=int, default=20,
help='Number of walks per source. Default is 20.')
parser.add_argument('--window-size', type=int, default=5,
help='Context size for optimization. Default is 5.')
parser.add_argument('--negative-samples', type=int, default=5,
help='Negative samples for optimization. Default is 5.')
parser.add_argument('--neighbor-samples', type=int, default=10,
help='Neighbor samples for aggregation. Default is 10.')
parser.add_argument('--patience', type=int, default=5,
help='Early stopping patience. Default is 5.')
return parser.parse_args()
# for each line, the data is [edge_type, node, node]
def load_training_data(f_name):
print('We are loading data from:', f_name)
edge_data_by_type = dict()
all_nodes = list()
with open(f_name, 'r') as f:
for line in f:
words = line[:-1].split(' ') # line[-1] == '\n'
if words[0] not in edge_data_by_type:
edge_data_by_type[words[0]] = list()
x, y = words[1], words[2]
edge_data_by_type[words[0]].append((x, y))
all_nodes.append(x)
all_nodes.append(y)
all_nodes = list(set(all_nodes))
print('Total training nodes: ' + str(len(all_nodes)))
return edge_data_by_type
# for each line, the data is [edge_type, node, node, true_or_false]
def load_testing_data(f_name):
print('We are loading data from:', f_name)
true_edge_data_by_type = dict()
false_edge_data_by_type = dict()
all_edges = list()
all_nodes = list()
with open(f_name, 'r') as f:
for line in f:
words = line[:-1].split(' ')
x, y = words[1], words[2]
if int(words[3]) == 1:
if words[0] not in true_edge_data_by_type:
true_edge_data_by_type[words[0]] = list()
true_edge_data_by_type[words[0]].append((x, y))
else:
if words[0] not in false_edge_data_by_type:
false_edge_data_by_type[words[0]] = list()
false_edge_data_by_type[words[0]].append((x, y))
all_nodes.append(x)
all_nodes.append(y)
all_nodes = list(set(all_nodes))
return true_edge_data_by_type, false_edge_data_by_type
def load_node_type(f_name):
print('We are loading node type from:', f_name)
node_type = {}
with open(f_name, 'r') as f:
for line in f:
items = line.strip().split()
node_type[items[0]] = items[1]
return node_type
def generate_pairs(all_walks, window_size):
# for each node, choose the first neighbor and second neighbor of it to form pairs
pairs = []
skip_window = window_size // 2
for layer_id, walks in enumerate(all_walks):
for walk in walks:
for i in range(len(walk)):
for j in range(1, skip_window + 1):
if i - j >= 0:
pairs.append((walk[i], walk[i-j], layer_id))
if i + j < len(walk):
pairs.append((walk[i], walk[i+j], layer_id))
return pairs
def generate_vocab(network_data):
nodes, index2word = [], []
for edge_type in network_data:
node1, node2 = zip(*network_data[edge_type])
index2word = index2word + list(node1) + list(node2)
index2word = list(set(index2word))
vocab = {}
i = 0
for word in index2word:
vocab[word] = i
i = i + 1
for edge_type in network_data:
node1, node2 = zip(*network_data[edge_type])
tmp_nodes = list(set(list(node1) + list(node2)))
tmp_nodes = [vocab[word] for word in tmp_nodes]
nodes.append(tmp_nodes)
return index2word, vocab, nodes
def get_score(local_model, node1, node2):
try:
vector1 = local_model[node1]
vector2 = local_model[node2]
return np.dot(vector1, vector2) / (np.linalg.norm(vector1) * np.linalg.norm(vector2))
except Exception as e:
pass
def evaluate(model, true_edges, false_edges):
true_list = list()
prediction_list = list()
true_num = 0
for edge in true_edges:
tmp_score = get_score(model, str(edge[0]), str(edge[1]))
if tmp_score is not None:
true_list.append(1)
prediction_list.append(tmp_score)
true_num += 1
for edge in false_edges:
tmp_score = get_score(model, str(edge[0]), str(edge[1]))
if tmp_score is not None:
true_list.append(0)
prediction_list.append(tmp_score)
sorted_pred = prediction_list[:]
sorted_pred.sort()
threshold = sorted_pred[-true_num]
y_pred = np.zeros(len(prediction_list), dtype=np.int32)
for i in range(len(prediction_list)):
if prediction_list[i] >= threshold:
y_pred[i] = 1
y_true = np.array(true_list)
y_scores = np.array(prediction_list)
ps, rs, _ = precision_recall_curve(y_true, y_scores)
return roc_auc_score(y_true, y_scores), f1_score(y_true, y_pred), auc(rs, ps)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment