data_utils.py 1.29 KB
Newer Older
1
2
3
import numpy as np
import torch

4

5
def one_hotify(labels, pad=-1):
6
    """
7
    cast label to one hot vector
8
    """
9
10
11
12
13
14
15
16
    num_instances = len(labels)
    if pad <= 0:
        dim_embedding = np.max(labels) + 1  # zero-indexed assumed
    else:
        assert pad > 0, "result_dim for padding one hot embedding not set!"
        dim_embedding = pad + 1
    embeddings = np.zeros((num_instances, dim_embedding))
    embeddings[np.arange(num_instances), labels] = 1
17

18
    return embeddings
19
20
21


def pre_process(dataset, prog_args):
22
23
24
25
26
    """
    diffpool specific data partition, pre-process and shuffling
    """
    if prog_args.data_mode != "default":
        print("overwrite node attributes with DiffPool's preprocess setting")
27
        if prog_args.data_mode == "id":
28
29
            for g, _ in dataset:
                id_list = np.arange(g.number_of_nodes())
30
                g.ndata["feat"] = one_hotify(id_list, pad=dataset.max_num_node)
31

32
        elif prog_args.data_mode == "deg-num":
33
            for g, _ in dataset:
34
                g.ndata["feat"] = np.expand_dims(g.in_degrees(), axis=1)
35

36
        elif prog_args.data_mode == "deg":
37
38
39
            for g in dataset:
                degs = list(g.in_degrees())
                degs_one_hot = one_hotify(degs, pad=dataset.max_degrees)
40
                g.ndata["feat"] = degs_one_hot