process_nowplaying_rs.py 3.33 KB
Newer Older
1
2
3
4
5
6
"""
Script that reads from raw Nowplaying-RS data and dumps into a pickle
file a heterogeneous graph with categorical and numeric features.
"""

import argparse
7
8
9
import os
import pickle

10
11
12
import pandas as pd
import scipy.sparse as ssp
from builder import PandasGraphBuilder
13
14
15
from data_utils import *

import dgl
16

17
if __name__ == "__main__":
18
    parser = argparse.ArgumentParser()
19
20
    parser.add_argument("directory", type=str)
    parser.add_argument("out_directory", type=str)
21
22
    args = parser.parse_args()
    directory = args.directory
23
24
    out_directory = args.out_directory
    os.makedirs(out_directory, exist_ok=True)
25

26
    data = pd.read_csv(os.path.join(directory, "context_content_features.csv"))
27
    track_feature_cols = list(data.columns[1:13])
28
29
30
    data = data[
        ["user_id", "track_id", "created_at"] + track_feature_cols
    ].dropna()
31

32
33
34
35
36
37
38
39
40
41
    users = data[["user_id"]].drop_duplicates()
    tracks = data[["track_id"] + track_feature_cols].drop_duplicates()
    assert tracks["track_id"].value_counts().max() == 1
    tracks = tracks.astype(
        {"mode": "int64", "key": "int64", "artist_id": "category"}
    )
    events = data[["user_id", "track_id", "created_at"]]
    events["created_at"] = (
        events["created_at"].values.astype("datetime64[s]").astype("int64")
    )
42

43
    graph_builder = PandasGraphBuilder()
44
45
46
47
48
49
50
51
    graph_builder.add_entities(users, "user_id", "user")
    graph_builder.add_entities(tracks, "track_id", "track")
    graph_builder.add_binary_relations(
        events, "user_id", "track_id", "listened"
    )
    graph_builder.add_binary_relations(
        events, "track_id", "user_id", "listened-by"
    )
52

53
    g = graph_builder.build()
54

55
56
    float_cols = []
    for col in tracks.columns:
57
        if col == "track_id":
58
            continue
59
60
61
62
63
        elif col == "artist_id":
            g.nodes["track"].data[col] = torch.LongTensor(
                tracks[col].cat.codes.values
            )
        elif tracks.dtypes[col] == "float64":
64
65
            float_cols.append(col)
        else:
66
67
68
69
70
71
72
73
74
75
            g.nodes["track"].data[col] = torch.LongTensor(tracks[col].values)
    g.nodes["track"].data["song_features"] = torch.FloatTensor(
        linear_normalize(tracks[float_cols].values)
    )
    g.edges["listened"].data["created_at"] = torch.LongTensor(
        events["created_at"].values
    )
    g.edges["listened-by"].data["created_at"] = torch.LongTensor(
        events["created_at"].values
    )
76

77
78
79
80
81
82
83
84
    n_edges = g.num_edges("listened")
    train_indices, val_indices, test_indices = train_test_split_by_time(
        events, "created_at", "user_id"
    )
    train_g = build_train_graph(
        g, train_indices, "user", "track", "listened", "listened-by"
    )
    assert train_g.out_degrees(etype="listened").min() > 0
85
    val_matrix, test_matrix = build_val_test_matrix(
86
87
        g, val_indices, test_indices, "user", "track", "listened"
    )
88

89
    dgl.save_graphs(os.path.join(out_directory, "train_g.bin"), train_g)
90

91
    dataset = {
92
93
94
95
96
97
98
99
100
101
        "val-matrix": val_matrix,
        "test-matrix": test_matrix,
        "item-texts": {},
        "item-images": None,
        "user-type": "user",
        "item-type": "track",
        "user-to-item-type": "listened",
        "item-to-user-type": "listened-by",
        "timestamp-edge-column": "created_at",
    }
102

103
    with open(os.path.join(out_directory, "data.pkl"), "wb") as f:
104
        pickle.dump(dataset, f)