Unverified Commit f5bba284 authored by WangYQ's avatar WangYQ Committed by GitHub
Browse files

[Example] Add EGES example (#3756)



* add eges example

* remove csv files and add data link

* Update README.md

* Update main.py

* Update model.py

* Update sampler.py

* Update utils.py

* Update model.py
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent a0bf5daa
# DGL & Pytorch implementation of Enhanced Graph Embedding with Side information (EGES)
## Version
dgl==0.6.1, torch==1.9.0
## Paper
Billion-scale Commodity Embedding for E-commerce Recommendation in Alibaba:
https://arxiv.org/pdf/1803.02349.pdf
https://arxiv.org/abs/1803.02349
## How to run
Create folder named `data`. Download two csv files from [here](https://github.com/Wang-Yu-Qing/dgl_data/tree/master/eges_data) into the `data` folder.
Run command: `python main.py` with default configuration, and the following message will shown up:
```
Using backend: pytorch
Num skus: 33344, num brands: 3662, num shops: 4785, num cates: 79
Epoch 00000 | Step 00000 | Step Loss 0.9117 | Epoch Avg Loss: 0.9117
Epoch 00000 | Step 00100 | Step Loss 0.8736 | Epoch Avg Loss: 0.8801
Epoch 00000 | Step 00200 | Step Loss 0.8975 | Epoch Avg Loss: 0.8785
Evaluate link prediction AUC: 0.6864
Epoch 00001 | Step 00000 | Step Loss 0.8695 | Epoch Avg Loss: 0.8695
Epoch 00001 | Step 00100 | Step Loss 0.8290 | Epoch Avg Loss: 0.8643
Epoch 00001 | Step 00200 | Step Loss 0.8012 | Epoch Avg Loss: 0.8604
Evaluate link prediction AUC: 0.6875
...
Epoch 00029 | Step 00000 | Step Loss 0.7095 | Epoch Avg Loss: 0.7095
Epoch 00029 | Step 00100 | Step Loss 0.7248 | Epoch Avg Loss: 0.7139
Epoch 00029 | Step 00200 | Step Loss 0.7123 | Epoch Avg Loss: 0.7134
Evaluate link prediction AUC: 0.7084
```
The AUC of link-prediction task on test graph is computed after each epoch is done.
## Reference
https://github.com/nonva/eges
https://github.com/wangzhegeek/EGES.git
import dgl
import torch as th
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn import metrics
import utils
from model import EGES
from sampler import Sampler
def train(args, train_g, sku_info, num_skus, num_brands, num_shops, num_cates):
sampler = Sampler(
train_g,
args.walk_length,
args.num_walks,
args.window_size,
args.num_negative
)
# for each node in the graph, we sample pos and neg
# pairs for it, and feed these sampled pairs into the model.
# (nodes in the graph are of course batched before sampling)
dataloader = DataLoader(
th.arange(train_g.num_nodes()),
# this is the batch_size of input nodes
batch_size=args.batch_size,
shuffle=True,
collate_fn=lambda x: sampler.sample(x, sku_info)
)
model = EGES(args.dim, num_skus, num_brands, num_shops, num_cates)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
for epoch in range(args.epochs):
epoch_total_loss = 0
for step, (srcs, dsts, labels) in enumerate(dataloader):
# the batch size of output pairs is unfixed
# TODO: shuffle the triples?
srcs_embeds, dsts_embeds = model(srcs, dsts)
loss = model.loss(srcs_embeds, dsts_embeds, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_total_loss += loss.item()
if step % args.log_every == 0:
print('Epoch {:05d} | Step {:05d} | Step Loss {:.4f} | Epoch Avg Loss: {:.4f}'.format(
epoch, step, loss.item(), epoch_total_loss / (step + 1)))
eval(model, test_g, sku_info)
return model
def eval(model, test_graph, sku_info):
preds, labels = [], []
for edge in test_graph:
src = th.tensor(sku_info[edge.src.numpy()[0]]).view(1, 4)
dst = th.tensor(sku_info[edge.dst.numpy()[0]]).view(1, 4)
# (1, dim)
src = model.query_node_embed(src)
dst = model.query_node_embed(dst)
# (1, dim) -> (1, dim) -> (1, )
logit = th.sigmoid(th.sum(src * dst))
preds.append(logit.detach().numpy().tolist())
labels.append(edge.label)
fpr, tpr, thresholds = metrics.roc_curve(labels, preds, pos_label=1)
print("Evaluate link prediction AUC: {:.4f}".format(metrics.auc(fpr, tpr)))
if __name__ == "__main__":
args = utils.init_args()
valid_sku_raw_ids = utils.get_valid_sku_set(args.item_info_data)
g, sku_encoder, sku_decoder = utils.construct_graph(
args.action_data,
args.session_interval_sec,
valid_sku_raw_ids
)
train_g, test_g = utils.split_train_test_graph(g)
sku_info_encoder, sku_info_decoder, sku_info = \
utils.encode_sku_fields(args.item_info_data, sku_encoder, sku_decoder)
num_skus = len(sku_encoder)
num_brands = len(sku_info_encoder["brand"])
num_shops = len(sku_info_encoder["shop"])
num_cates = len(sku_info_encoder["cate"])
print(
"Num skus: {}, num brands: {}, num shops: {}, num cates: {}".\
format(num_skus, num_brands, num_shops, num_cates)
)
model = train(args, train_g, sku_info, num_skus, num_brands, num_shops, num_cates)
import torch as th
class EGES(th.nn.Module):
def __init__(self, dim, num_nodes, num_brands, num_shops, num_cates):
super(EGES, self).__init__()
self.dim = dim
# embeddings for nodes
base_embeds = th.nn.Embedding(num_nodes, dim)
brand_embeds = th.nn.Embedding(num_brands, dim)
shop_embeds = th.nn.Embedding(num_shops, dim)
cate_embeds = th.nn.Embedding(num_cates, dim)
self.embeds = [base_embeds, brand_embeds, shop_embeds, cate_embeds]
# weights for each node's side information
self.side_info_weights = th.nn.Embedding(num_nodes, 4)
def forward(self, srcs, dsts):
# srcs: sku_id, brand_id, shop_id, cate_id
srcs = self.query_node_embed(srcs)
dsts = self.query_node_embed(dsts)
return srcs, dsts
def query_node_embed(self, nodes):
"""
@nodes: tensor of shape (batch_size, num_side_info)
"""
batch_size = nodes.shape[0]
# query side info weights, (batch_size, 4)
side_info_weights = th.exp(self.side_info_weights(nodes[:, 0]))
# merge all embeddings
side_info_weighted_embeds_sum = []
side_info_weights_sum = []
for i in range(4):
# weights for i-th side info, (batch_size, ) -> (batch_size, 1)
i_th_side_info_weights = side_info_weights[:, i].view((batch_size, 1))
# batch of i-th side info embedding * its weight, (batch_size, dim)
side_info_weighted_embeds_sum.append(i_th_side_info_weights * self.embeds[i](nodes[:, i]))
side_info_weights_sum.append(i_th_side_info_weights)
# stack: (batch_size, 4, dim), sum: (batch_size, dim)
side_info_weighted_embeds_sum = th.sum(th.stack(side_info_weighted_embeds_sum, axis=1), axis=1)
# stack: (batch_size, 4), sum: (batch_size, )
side_info_weights_sum = th.sum(th.stack(side_info_weights_sum, axis=1), axis=1)
# (batch_size, dim)
H = side_info_weighted_embeds_sum / side_info_weights_sum
return H
def loss(self, srcs, dsts, labels):
dots = th.sigmoid(th.sum(srcs * dsts, axis=1))
dots = th.clamp(dots, min=1e-7, max=1 - 1e-7)
return th.mean(- (labels * th.log(dots) + (1 - labels) * th.log(1 - dots)))
import dgl
import numpy as np
import torch as th
class Sampler:
def __init__(self,
graph,
walk_length,
num_walks,
window_size,
num_negative):
self.graph = graph
self.walk_length = walk_length
self.num_walks = num_walks
self.window_size = window_size
self.num_negative = num_negative
self.node_weights = self.compute_node_sample_weight()
def sample(self, batch, sku_info):
"""
Given a batch of target nodes, sample postive
pairs and negative pairs from the graph
"""
batch = np.repeat(batch, self.num_walks)
pos_pairs = self.generate_pos_pairs(batch)
neg_pairs = self.generate_neg_pairs(pos_pairs)
# get sku info with id
srcs, dsts, labels = [], [], []
for pair in pos_pairs + neg_pairs:
src, dst, label = pair
src_info = sku_info[src]
dst_info = sku_info[dst]
srcs.append(src_info)
dsts.append(dst_info)
labels.append(label)
return th.tensor(srcs), th.tensor(dsts), th.tensor(labels)
def filter_padding(self, traces):
for i in range(len(traces)):
traces[i] = [x for x in traces[i] if x != -1]
def generate_pos_pairs(self, nodes):
"""
For seq [1, 2, 3, 4] and node NO.2,
the window_size=1 will generate:
(1, 2) and (2, 3)
"""
# random walk
traces, types = dgl.sampling.random_walk(
g=self.graph,
nodes=nodes,
length=self.walk_length,
prob="weight"
)
traces = traces.tolist()
self.filter_padding(traces)
# skip-gram
pairs = []
for trace in traces:
for i in range(len(trace)):
center = trace[i]
left = max(0, i - self.window_size)
right = min(len(trace), i + self.window_size + 1)
pairs.extend([[center, x, 1] for x in trace[left:i]])
pairs.extend([[center, x, 1] for x in trace[i+1:right]])
return pairs
def compute_node_sample_weight(self):
"""
Using node degree as sample weight
"""
return self.graph.in_degrees().float()
def generate_neg_pairs(self, pos_pairs):
"""
Sample based on node freq in traces, frequently shown
nodes will have larger chance to be sampled as
negative node.
"""
# sample `self.num_negative` neg dst node
# for each pos node pair's src node.
negs = th.multinomial(
self.node_weights,
len(pos_pairs) * self.num_negative,
replacement=True
).tolist()
tar = np.repeat([pair[0] for pair in pos_pairs], self.num_negative)
assert(len(tar) == len(negs))
neg_pairs = [[x, y, 0] for x, y in zip(tar, negs)]
return neg_pairs
import dgl
import random
import argparse
import torch as th
import numpy as np
import networkx as nx
from datetime import datetime
def init_args():
# TODO: change args
argparser = argparse.ArgumentParser()
argparser.add_argument('--session_interval_sec', type=int, default=1800)
argparser.add_argument('--action_data', type=str, default="data/action_head.csv")
argparser.add_argument('--item_info_data', type=str,
default="data/jdata_product.csv")
argparser.add_argument('--walk_length', type=int, default=10)
argparser.add_argument('--num_walks', type=int, default=5)
argparser.add_argument('--batch_size', type=int, default=64)
argparser.add_argument('--dim', type=int, default=16)
argparser.add_argument('--epochs', type=int, default=30)
argparser.add_argument('--window_size', type=int, default=2)
argparser.add_argument('--num_negative', type=int, default=5)
argparser.add_argument('--lr', type=float, default=0.001)
argparser.add_argument('--log_every', type=int, default=100)
return argparser.parse_args()
def construct_graph(datapath, session_interval_gap_sec, valid_sku_raw_ids):
user_clicks, sku_encoder, sku_decoder = parse_actions(datapath, valid_sku_raw_ids)
# {src,dst: weight}
graph = {}
for user_id, action_list in user_clicks.items():
# sort by action time
_action_list = sorted(action_list, key=lambda x: x[1])
last_action_time = datetime.strptime(_action_list[0][1], "%Y-%m-%d %H:%M:%S")
session = [_action_list[0][0]]
# cut sessions and add to graph
for sku_id, action_time in _action_list[1:]:
action_time = datetime.strptime(action_time, "%Y-%m-%d %H:%M:%S")
gap = action_time - last_action_time
if gap.seconds < session_interval_gap_sec:
session.append(sku_id)
else:
# here we have a new session
# add prev session to graph
add_session(session, graph)
# create a new session
session = [sku_id]
# add last session
add_session(session, graph)
g = convert_to_dgl_graph(graph)
return g, sku_encoder, sku_decoder
def convert_to_dgl_graph(graph):
# directed graph
g = nx.DiGraph()
for edge, weight in graph.items():
nodes = edge.split(",")
src, dst = int(nodes[0]), int(nodes[1])
g.add_edge(src, dst, weight=float(weight))
return dgl.from_networkx(g, edge_attrs=['weight'])
def add_session(session, graph):
"""
For session like:
[sku1, sku2, sku3]
add 1 weight to each of the following edges:
sku1 -> sku2
sku2 -> sku3
If sesson length < 2, no nodes/edges will be added
"""
for i in range(len(session)-1):
edge = str(session[i]) + "," + str(session[i+1])
try:
graph[edge] += 1
except KeyError:
graph[edge] = 1
def parse_actions(datapath, valid_sku_raw_ids):
user_clicks = {}
with open(datapath, "r") as f:
f.readline()
# raw_id -> new_id and new_id -> raw_id
sku_encoder, sku_decoder = {}, []
sku_id = -1
for line in f:
line = line.replace("\n", "")
fields = line.split(",")
action_type = fields[-1]
# actually, all types in the dataset is "1"
if action_type == "1":
user_id = fields[0]
sku_raw_id = fields[1]
if sku_raw_id in valid_sku_raw_ids:
action_time = fields[2]
# encode sku_id
sku_id = encode_id(sku_encoder,
sku_decoder,
sku_raw_id,
sku_id)
# add to user clicks
try:
user_clicks[user_id].append((sku_id, action_time))
except KeyError:
user_clicks[user_id] = [(sku_id, action_time)]
return user_clicks, sku_encoder, sku_decoder
def encode_id(encoder, decoder, raw_id, encoded_id):
if raw_id in encoder:
return encoded_id
else:
encoded_id += 1
encoder[raw_id] = encoded_id
decoder.append(raw_id)
return encoded_id
def get_valid_sku_set(datapath):
sku_ids = set()
with open(datapath, "r") as f:
for line in f.readlines():
line.replace("\n", "")
sku_raw_id = line.split(",")[0]
sku_ids.add(sku_raw_id)
return sku_ids
def encode_sku_fields(datapath, sku_encoder, sku_decoder):
# sku_id,brand,shop_id,cate,market_time
sku_info_encoder = {"brand": {}, "shop": {}, "cate": {}}
sku_info_decoder = {"brand": [], "shop": [], "cate": []}
sku_info = {}
brand_id, shop_id, cate_id = -1, -1, -1
with open(datapath, "r") as f:
f.readline()
for line in f:
line = line.replace("\n", "")
fields = line.split(",")
sku_raw_id = fields[0]
brand_raw_id = fields[1]
shop_raw_id = fields[2]
cate_raw_id = fields[3]
if sku_raw_id in sku_encoder:
sku_id = sku_encoder[sku_raw_id]
brand_id = encode_id(
sku_info_encoder["brand"],
sku_info_decoder["brand"],
brand_raw_id,
brand_id
)
shop_id = encode_id(
sku_info_encoder["shop"],
sku_info_decoder["shop"],
shop_raw_id,
shop_id
)
cate_id = encode_id(
sku_info_encoder["cate"],
sku_info_decoder["cate"],
cate_raw_id,
cate_id
)
sku_info[sku_id] = [sku_id, brand_id, shop_id, cate_id]
return sku_info_encoder, sku_info_decoder, sku_info
class TestEdge:
def __init__(self, src, dst, label):
self.src = src
self.dst = dst
self.label = label
def split_train_test_graph(graph):
"""
For test true edges, 1/3 of the edges are randomly chosen
and removed as ground truth in the test set,
the remaining graph is taken as the training set.
"""
test_edges = []
neg_sampler = dgl.dataloading.negative_sampler.Uniform(1)
sampled_edge_ids = random.sample(range(graph.num_edges()), int(graph.num_edges() / 3))
for edge_id in sampled_edge_ids:
src, dst = graph.find_edges(edge_id)
test_edges.append(TestEdge(src, dst, 1))
src, dst = neg_sampler(graph, th.tensor([edge_id]))
test_edges.append(TestEdge(src, dst, 0))
graph.remove_edges(sampled_edge_ids)
test_graph = test_edges
return graph, test_graph
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment