import argparse import random from datetime import datetime import networkx as nx import numpy as np import torch as th import dgl def init_args(): # TODO: change args argparser = argparse.ArgumentParser() argparser.add_argument("--session_interval_sec", type=int, default=1800) argparser.add_argument( "--action_data", type=str, default="data/action_head.csv" ) argparser.add_argument( "--item_info_data", type=str, default="data/jdata_product.csv" ) argparser.add_argument("--walk_length", type=int, default=10) argparser.add_argument("--num_walks", type=int, default=5) argparser.add_argument("--batch_size", type=int, default=64) argparser.add_argument("--dim", type=int, default=16) argparser.add_argument("--epochs", type=int, default=30) argparser.add_argument("--window_size", type=int, default=2) argparser.add_argument("--num_negative", type=int, default=5) argparser.add_argument("--lr", type=float, default=0.001) argparser.add_argument("--log_every", type=int, default=100) return argparser.parse_args() def construct_graph(datapath, session_interval_gap_sec, valid_sku_raw_ids): user_clicks, sku_encoder, sku_decoder = parse_actions( datapath, valid_sku_raw_ids ) # {src,dst: weight} graph = {} for user_id, action_list in user_clicks.items(): # sort by action time _action_list = sorted(action_list, key=lambda x: x[1]) last_action_time = datetime.strptime( _action_list[0][1], "%Y-%m-%d %H:%M:%S" ) session = [_action_list[0][0]] # cut sessions and add to graph for sku_id, action_time in _action_list[1:]: action_time = datetime.strptime(action_time, "%Y-%m-%d %H:%M:%S") gap = action_time - last_action_time if gap.seconds < session_interval_gap_sec: session.append(sku_id) else: # here we have a new session # add prev session to graph add_session(session, graph) # create a new session session = [sku_id] # add last session add_session(session, graph) g = convert_to_dgl_graph(graph) return g, sku_encoder, sku_decoder def convert_to_dgl_graph(graph): # directed graph g = nx.DiGraph() for edge, weight in graph.items(): nodes = edge.split(",") src, dst = int(nodes[0]), int(nodes[1]) g.add_edge(src, dst, weight=float(weight)) return dgl.from_networkx(g, edge_attrs=["weight"]) def add_session(session, graph): """ For session like: [sku1, sku2, sku3] add 1 weight to each of the following edges: sku1 -> sku2 sku2 -> sku3 If sesson length < 2, no nodes/edges will be added """ for i in range(len(session) - 1): edge = str(session[i]) + "," + str(session[i + 1]) try: graph[edge] += 1 except KeyError: graph[edge] = 1 def parse_actions(datapath, valid_sku_raw_ids): user_clicks = {} with open(datapath, "r") as f: f.readline() # raw_id -> new_id and new_id -> raw_id sku_encoder, sku_decoder = {}, [] sku_id = -1 for line in f: line = line.replace("\n", "") fields = line.split(",") action_type = fields[-1] # actually, all types in the dataset is "1" if action_type == "1": user_id = fields[0] sku_raw_id = fields[1] if sku_raw_id in valid_sku_raw_ids: action_time = fields[2] # encode sku_id sku_id = encode_id( sku_encoder, sku_decoder, sku_raw_id, sku_id ) # add to user clicks try: user_clicks[user_id].append((sku_id, action_time)) except KeyError: user_clicks[user_id] = [(sku_id, action_time)] return user_clicks, sku_encoder, sku_decoder def encode_id(encoder, decoder, raw_id, encoded_id): if raw_id in encoder: return encoded_id else: encoded_id += 1 encoder[raw_id] = encoded_id decoder.append(raw_id) return encoded_id def get_valid_sku_set(datapath): sku_ids = set() with open(datapath, "r") as f: for line in f.readlines(): line.replace("\n", "") sku_raw_id = line.split(",")[0] sku_ids.add(sku_raw_id) return sku_ids def encode_sku_fields(datapath, sku_encoder, sku_decoder): # sku_id,brand,shop_id,cate,market_time sku_info_encoder = {"brand": {}, "shop": {}, "cate": {}} sku_info_decoder = {"brand": [], "shop": [], "cate": []} sku_info = {} brand_id, shop_id, cate_id = -1, -1, -1 with open(datapath, "r") as f: f.readline() for line in f: line = line.replace("\n", "") fields = line.split(",") sku_raw_id = fields[0] brand_raw_id = fields[1] shop_raw_id = fields[2] cate_raw_id = fields[3] if sku_raw_id in sku_encoder: sku_id = sku_encoder[sku_raw_id] brand_id = encode_id( sku_info_encoder["brand"], sku_info_decoder["brand"], brand_raw_id, brand_id, ) shop_id = encode_id( sku_info_encoder["shop"], sku_info_decoder["shop"], shop_raw_id, shop_id, ) cate_id = encode_id( sku_info_encoder["cate"], sku_info_decoder["cate"], cate_raw_id, cate_id, ) sku_info[sku_id] = [sku_id, brand_id, shop_id, cate_id] return sku_info_encoder, sku_info_decoder, sku_info class TestEdge: def __init__(self, src, dst, label): self.src = src self.dst = dst self.label = label def split_train_test_graph(graph): """ For test true edges, 1/3 of the edges are randomly chosen and removed as ground truth in the test set, the remaining graph is taken as the training set. """ test_edges = [] neg_sampler = dgl.dataloading.negative_sampler.Uniform(1) sampled_edge_ids = random.sample( range(graph.num_edges()), int(graph.num_edges() / 3) ) for edge_id in sampled_edge_ids: src, dst = graph.find_edges(edge_id) test_edges.append(TestEdge(src, dst, 1)) src, dst = neg_sampler(graph, th.tensor([edge_id])) test_edges.append(TestEdge(src, dst, 0)) graph.remove_edges(sampled_edge_ids) test_graph = test_edges return graph, test_graph