data_preprocess.py 3.92 KB
Newer Older
1
2
3
4
import os
import ssl

import numpy as np
5
import pandas as pd
6
import torch
7
8
from six.moves import urllib

9
10
11
12
13
14
15
import dgl

# === Below data preprocessing code are based on
# https://github.com/twitter-research/tgn

# Preprocess the raw data split each features

16

17
18
19
20
21
22
23
24
def preprocess(data_name):
    u_list, i_list, ts_list, label_list = [], [], [], []
    feat_l = []
    idx_list = []

    with open(data_name) as f:
        s = next(f)
        for idx, line in enumerate(f):
25
            e = line.strip().split(",")
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
            u = int(e[0])
            i = int(e[1])

            ts = float(e[2])
            label = float(e[3])  # int(e[3])

            feat = np.array([float(x) for x in e[4:]])

            u_list.append(u)
            i_list.append(i)
            ts_list.append(ts)
            label_list.append(label)
            idx_list.append(idx)

            feat_l.append(feat)
41
42
43
44
45
46
47
48
49
50
    return pd.DataFrame(
        {
            "u": u_list,
            "i": i_list,
            "ts": ts_list,
            "label": label_list,
            "idx": idx_list,
        }
    ), np.array(feat_l)

51
52
53
54
55

# Re index nodes for DGL convience
def reindex(df, bipartite=True):
    new_df = df.copy()
    if bipartite:
56
57
        assert df.u.max() - df.u.min() + 1 == len(df.u.unique())
        assert df.i.max() - df.i.min() + 1 == len(df.i.unique())
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72

        upper_u = df.u.max() + 1
        new_i = df.i + upper_u

        new_df.i = new_i
        new_df.u += 1
        new_df.i += 1
        new_df.idx += 1
    else:
        new_df.u += 1
        new_df.i += 1
        new_df.idx += 1

    return new_df

73

74
75
# Save edge list, features in different file for data easy process data
def run(data_name, bipartite=True):
76
77
78
79
    PATH = "./data/{}.csv".format(data_name)
    OUT_DF = "./data/ml_{}.csv".format(data_name)
    OUT_FEAT = "./data/ml_{}.npy".format(data_name)
    OUT_NODE_FEAT = "./data/ml_{}_node.npy".format(data_name)
80
81
82
83
84
85
86
87
88
89
90
91
92
93

    df, feat = preprocess(PATH)
    new_df = reindex(df, bipartite)

    empty = np.zeros(feat.shape[1])[np.newaxis, :]
    feat = np.vstack([empty, feat])

    max_idx = max(new_df.u.max(), new_df.i.max())
    rand_feat = np.zeros((max_idx + 1, 172))

    new_df.to_csv(OUT_DF)
    np.save(OUT_FEAT, feat)
    np.save(OUT_NODE_FEAT, rand_feat)

94

95
96
97
98
99
# === code from twitter-research-tgn end ===

# If you have new dataset follow by same format in Jodie,
# you can directly use name to retrieve dataset

100

101
def TemporalDataset(dataset):
102
103
104
105
    if not os.path.exists("./data/{}.bin".format(dataset)):
        if not os.path.exists("./data/{}.csv".format(dataset)):
            if not os.path.exists("./data"):
                os.mkdir("./data")
106

107
            url = "https://snap.stanford.edu/jodie/{}.csv".format(dataset)
108
109
110
111
112
113
114
115
            print("Start Downloading File....")
            context = ssl._create_unverified_context()
            data = urllib.request.urlopen(url, context=context)
            with open("./data/{}.csv".format(dataset), "wb") as handle:
                handle.write(data.read())

        print("Start Process Data ...")
        run(dataset)
116
117
        raw_connection = pd.read_csv("./data/ml_{}.csv".format(dataset))
        raw_feature = np.load("./data/ml_{}.npy".format(dataset))
118
        # -1 for re-index the node
119
120
        src = raw_connection["u"].to_numpy() - 1
        dst = raw_connection["i"].to_numpy() - 1
121
122
        # Create directed graph
        g = dgl.graph((src, dst))
123
124
125
126
        g.edata["timestamp"] = torch.from_numpy(raw_connection["ts"].to_numpy())
        g.edata["label"] = torch.from_numpy(raw_connection["label"].to_numpy())
        g.edata["feats"] = torch.from_numpy(raw_feature[1:, :]).float()
        dgl.save_graphs("./data/{}.bin".format(dataset), [g])
127
128
    else:
        print("Data is exist directly loaded.")
129
        gs, _ = dgl.load_graphs("./data/{}.bin".format(dataset))
130
131
132
        g = gs[0]
    return g

133

134
135
def TemporalWikipediaDataset():
    # Download the dataset
136
137
    return TemporalDataset("wikipedia")

138
139

def TemporalRedditDataset():
140
    return TemporalDataset("reddit")