utils.py 8.51 KB
Newer Older
Mufei Li's avatar
Mufei Li committed
1
2
3
4
5
6
import datetime
import errno
import os
import pickle
import random
from pprint import pprint
7

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
8
9
import dgl

10
11
12
import numpy as np
import torch
from dgl.data.utils import _get_dgl_url, download, get_download_dir
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
13
from scipy import io as sio, sparse
14

Mufei Li's avatar
Mufei Li committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28

def set_random_seed(seed=0):
    """Set random seed.
    Parameters
    ----------
    seed : int
        Random seed to use
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

29

Mufei Li's avatar
Mufei Li committed
30
31
32
33
34
35
36
37
38
39
40
41
def mkdir_p(path, log=True):
    """Create a directory for the specified path.
    Parameters
    ----------
    path : str
        Path name
    log : bool
        Whether to print result for directory creation
    """
    try:
        os.makedirs(path)
        if log:
42
            print("Created directory {}".format(path))
Mufei Li's avatar
Mufei Li committed
43
44
    except OSError as exc:
        if exc.errno == errno.EEXIST and os.path.isdir(path) and log:
45
            print("Directory {} already exists.".format(path))
Mufei Li's avatar
Mufei Li committed
46
47
48
        else:
            raise

49

Mufei Li's avatar
Mufei Li committed
50
51
52
53
54
55
56
def get_date_postfix():
    """Get a date based postfix for directory name.
    Returns
    -------
    post_fix : str
    """
    dt = datetime.datetime.now()
57
58
59
    post_fix = "{}_{:02d}-{:02d}-{:02d}".format(
        dt.date(), dt.hour, dt.minute, dt.second
    )
Mufei Li's avatar
Mufei Li committed
60
61
62

    return post_fix

63

Mufei Li's avatar
Mufei Li committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def setup_log_dir(args, sampling=False):
    """Name and create directory for logging.
    Parameters
    ----------
    args : dict
        Configuration
    Returns
    -------
    log_dir : str
        Path for logging directory
    sampling : bool
        Whether we are using sampling based training
    """
    date_postfix = get_date_postfix()
    log_dir = os.path.join(
79
80
        args["log_dir"], "{}_{}".format(args["dataset"], date_postfix)
    )
Mufei Li's avatar
Mufei Li committed
81
82

    if sampling:
83
        log_dir = log_dir + "_sampling"
Mufei Li's avatar
Mufei Li committed
84
85
86
87

    mkdir_p(log_dir)
    return log_dir

88

Mufei Li's avatar
Mufei Li committed
89
90
# The configuration below is from the paper.
default_configure = {
91
92
93
94
95
96
97
    "lr": 0.005,  # Learning rate
    "num_heads": [8],  # Number of attention heads for node-level attention
    "hidden_units": 8,
    "dropout": 0.6,
    "weight_decay": 0.001,
    "num_epochs": 200,
    "patience": 100,
Mufei Li's avatar
Mufei Li committed
98
99
}

100
101
sampling_configure = {"batch_size": 20}

Mufei Li's avatar
Mufei Li committed
102
103
104

def setup(args):
    args.update(default_configure)
105
106
107
108
    set_random_seed(args["seed"])
    args["dataset"] = "ACMRaw" if args["hetero"] else "ACM"
    args["device"] = "cuda:0" if torch.cuda.is_available() else "cpu"
    args["log_dir"] = setup_log_dir(args)
Mufei Li's avatar
Mufei Li committed
109
110
    return args

111

Mufei Li's avatar
Mufei Li committed
112
113
114
115
def setup_for_sampling(args):
    args.update(default_configure)
    args.update(sampling_configure)
    set_random_seed()
116
117
    args["device"] = "cuda:0" if torch.cuda.is_available() else "cpu"
    args["log_dir"] = setup_log_dir(args, sampling=True)
Mufei Li's avatar
Mufei Li committed
118
119
    return args

120

Mufei Li's avatar
Mufei Li committed
121
122
123
124
125
def get_binary_mask(total_size, indices):
    mask = torch.zeros(total_size)
    mask[indices] = 1
    return mask.byte()

126

Mufei Li's avatar
Mufei Li committed
127
def load_acm(remove_self_loop):
128
129
    url = "dataset/ACM3025.pkl"
    data_path = get_download_dir() + "/ACM3025.pkl"
Mufei Li's avatar
Mufei Li committed
130
131
    download(_get_dgl_url(url), path=data_path)

132
    with open(data_path, "rb") as f:
Mufei Li's avatar
Mufei Li committed
133
134
        data = pickle.load(f)

135
136
137
138
    labels, features = (
        torch.from_numpy(data["label"].todense()).long(),
        torch.from_numpy(data["feature"].todense()).float(),
    )
Mufei Li's avatar
Mufei Li committed
139
140
141
142
    num_classes = labels.shape[1]
    labels = labels.nonzero()[:, 1]

    if remove_self_loop:
143
144
145
        num_nodes = data["label"].shape[0]
        data["PAP"] = sparse.csr_matrix(data["PAP"] - np.eye(num_nodes))
        data["PLP"] = sparse.csr_matrix(data["PLP"] - np.eye(num_nodes))
Mufei Li's avatar
Mufei Li committed
146
147
148

    # Adjacency matrices for meta path based neighbors
    # (Mufei): I verified both of them are binary adjacency matrices with self loops
149
150
    author_g = dgl.from_scipy(data["PAP"])
    subject_g = dgl.from_scipy(data["PLP"])
Mufei Li's avatar
Mufei Li committed
151
152
    gs = [author_g, subject_g]

153
154
155
    train_idx = torch.from_numpy(data["train_idx"]).long().squeeze(0)
    val_idx = torch.from_numpy(data["val_idx"]).long().squeeze(0)
    test_idx = torch.from_numpy(data["test_idx"]).long().squeeze(0)
Mufei Li's avatar
Mufei Li committed
156

157
    num_nodes = author_g.num_nodes()
Mufei Li's avatar
Mufei Li committed
158
159
160
161
    train_mask = get_binary_mask(num_nodes, train_idx)
    val_mask = get_binary_mask(num_nodes, val_idx)
    test_mask = get_binary_mask(num_nodes, test_idx)

162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    print("dataset loaded")
    pprint(
        {
            "dataset": "ACM",
            "train": train_mask.sum().item() / num_nodes,
            "val": val_mask.sum().item() / num_nodes,
            "test": test_mask.sum().item() / num_nodes,
        }
    )

    return (
        gs,
        features,
        labels,
        num_classes,
        train_idx,
        val_idx,
        test_idx,
        train_mask,
        val_mask,
        test_mask,
    )
Mufei Li's avatar
Mufei Li committed
184
185
186
187


def load_acm_raw(remove_self_loop):
    assert not remove_self_loop
188
189
    url = "dataset/ACM.mat"
    data_path = get_download_dir() + "/ACM.mat"
Mufei Li's avatar
Mufei Li committed
190
191
192
    download(_get_dgl_url(url), path=data_path)

    data = sio.loadmat(data_path)
193
194
195
196
    p_vs_l = data["PvsL"]  # paper-field?
    p_vs_a = data["PvsA"]  # paper-author
    p_vs_t = data["PvsT"]  # paper-term, bag of words
    p_vs_c = data["PvsC"]  # paper-conference, labels come from that
Mufei Li's avatar
Mufei Li committed
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211

    # We assign
    # (1) KDD papers as class 0 (data mining),
    # (2) SIGMOD and VLDB papers as class 1 (database),
    # (3) SIGCOMM and MOBICOMM papers as class 2 (communication)
    conf_ids = [0, 1, 9, 10, 13]
    label_ids = [0, 1, 2, 2, 1]

    p_vs_c_filter = p_vs_c[:, conf_ids]
    p_selected = (p_vs_c_filter.sum(1) != 0).A1.nonzero()[0]
    p_vs_l = p_vs_l[p_selected]
    p_vs_a = p_vs_a[p_selected]
    p_vs_t = p_vs_t[p_selected]
    p_vs_c = p_vs_c[p_selected]

212
213
214
215
216
217
218
219
    hg = dgl.heterograph(
        {
            ("paper", "pa", "author"): p_vs_a.nonzero(),
            ("author", "ap", "paper"): p_vs_a.transpose().nonzero(),
            ("paper", "pf", "field"): p_vs_l.nonzero(),
            ("field", "fp", "paper"): p_vs_l.transpose().nonzero(),
        }
    )
Mufei Li's avatar
Mufei Li committed
220
221
222
223
224
225
226
227
228
229
230
231
232

    features = torch.FloatTensor(p_vs_t.toarray())

    pc_p, pc_c = p_vs_c.nonzero()
    labels = np.zeros(len(p_selected), dtype=np.int64)
    for conf_id, label_id in zip(conf_ids, label_ids):
        labels[pc_p[pc_c == conf_id]] = label_id
    labels = torch.LongTensor(labels)

    num_classes = 3

    float_mask = np.zeros(len(pc_p))
    for conf_id in conf_ids:
233
234
235
236
        pc_c_mask = pc_c == conf_id
        float_mask[pc_c_mask] = np.random.permutation(
            np.linspace(0, 1, pc_c_mask.sum())
        )
Mufei Li's avatar
Mufei Li committed
237
238
239
240
    train_idx = np.where(float_mask <= 0.2)[0]
    val_idx = np.where((float_mask > 0.2) & (float_mask <= 0.3))[0]
    test_idx = np.where(float_mask > 0.3)[0]

241
    num_nodes = hg.num_nodes("paper")
Mufei Li's avatar
Mufei Li committed
242
243
244
245
    train_mask = get_binary_mask(num_nodes, train_idx)
    val_mask = get_binary_mask(num_nodes, val_idx)
    test_mask = get_binary_mask(num_nodes, test_idx)

246
247
248
249
250
251
252
253
254
255
256
257
258
    return (
        hg,
        features,
        labels,
        num_classes,
        train_idx,
        val_idx,
        test_idx,
        train_mask,
        val_mask,
        test_mask,
    )

Mufei Li's avatar
Mufei Li committed
259
260

def load_data(dataset, remove_self_loop=False):
261
    if dataset == "ACM":
Mufei Li's avatar
Mufei Li committed
262
        return load_acm(remove_self_loop)
263
    elif dataset == "ACMRaw":
Mufei Li's avatar
Mufei Li committed
264
265
        return load_acm_raw(remove_self_loop)
    else:
266
267
        return NotImplementedError("Unsupported dataset {}".format(dataset))

Mufei Li's avatar
Mufei Li committed
268
269
270
271

class EarlyStopping(object):
    def __init__(self, patience=10):
        dt = datetime.datetime.now()
272
273
274
        self.filename = "early_stop_{}_{:02d}-{:02d}-{:02d}.pth".format(
            dt.date(), dt.hour, dt.minute, dt.second
        )
Mufei Li's avatar
Mufei Li committed
275
276
277
278
279
280
281
282
283
284
285
286
287
        self.patience = patience
        self.counter = 0
        self.best_acc = None
        self.best_loss = None
        self.early_stop = False

    def step(self, loss, acc, model):
        if self.best_loss is None:
            self.best_acc = acc
            self.best_loss = loss
            self.save_checkpoint(model)
        elif (loss > self.best_loss) and (acc < self.best_acc):
            self.counter += 1
288
289
290
            print(
                f"EarlyStopping counter: {self.counter} out of {self.patience}"
            )
Mufei Li's avatar
Mufei Li committed
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            if (loss <= self.best_loss) and (acc >= self.best_acc):
                self.save_checkpoint(model)
            self.best_loss = np.min((loss, self.best_loss))
            self.best_acc = np.max((acc, self.best_acc))
            self.counter = 0
        return self.early_stop

    def save_checkpoint(self, model):
        """Saves model when validation loss decreases."""
        torch.save(model.state_dict(), self.filename)

    def load_checkpoint(self, model):
        """Load the latest checkpoint."""
        model.load_state_dict(torch.load(self.filename))