data_preprocess.py 4.79 KB
Newer Older
1
2
import json
import logging
3
4
import os
import sys
5

6
import numpy as np
7
8
import torch

9
10
11
from dgl.data import LegacyTUDataset


12
def _load_check_mark(path: str):
13
    if os.path.exists(path):
14
        with open(path, "r") as f:
15
16
17
18
            return json.load(f)
    else:
        return {}

19
20
21

def _save_check_mark(path: str, marks: dict):
    with open(path, "w") as f:
22
23
24
        json.dump(marks, f)


25
def node_label_as_feature(dataset: LegacyTUDataset, mode="concat", save=True):
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    """
    Description
    -----------
    Add node labels to graph node features dict

    Parameters
    ----------
    dataset : LegacyTUDataset
        The dataset object
    concat : str, optional
        How to add node label to the graph. Valid options are "add",
        "replace" and "concat".
        - "add": Directly add node_label to graph node feature dict.
        - "concat": Concatenate "feat" and "node_label"
        - "replace": Use "node_label" as "feat"
        Default: :obj:`"concat"`
    save : bool, optional
        Save the result dataset.
        Default: :obj:`True`
    """
    # check if node label is not available
47
48
49
50
    if (
        not os.path.exists(dataset._file_path("node_labels"))
        or len(dataset) == 0
    ):
51
52
        logging.warning("No Node Label Data")
        return dataset
53

54
55
56
    # check if has cached value
    check_mark_name = "node_label_as_feature"
    check_mark_path = os.path.join(
57
58
        dataset.save_path, "info_{}_{}.json".format(dataset.name, dataset.hash)
    )
59
    check_mark = _load_check_mark(check_mark_path)
60
61
62
63
64
    if (
        check_mark_name in check_mark
        and check_mark[check_mark_name]
        and not dataset._force_reload
    ):
65
66
        logging.warning("Using cached value in node_label_as_feature")
        return dataset
67
68
69
70
    logging.warning(
        "Adding node labels into node features..., mode={}".format(mode)
    )

71
72
73
74
75
    # check if graph has "feat"
    if "feat" not in dataset[0][0].ndata:
        logging.warning("Dataset has no node feature 'feat'")
        if mode.lower() == "concat":
            mode = "replace"
76

77
78
    # first read node labels
    DS_node_labels = dataset._idx_from_zero(
79
80
        np.loadtxt(dataset._file_path("node_labels"), dtype=int)
    )
81
    one_hot_node_labels = dataset._to_onehot(DS_node_labels)
82

83
84
    # read graph idx
    DS_indicator = dataset._idx_from_zero(
85
86
        np.genfromtxt(dataset._file_path("graph_indicator"), dtype=int)
    )
87
88
89
90
    node_idx_list = []
    for idx in range(np.max(DS_indicator) + 1):
        node_idx = np.where(DS_indicator == idx)
        node_idx_list.append(node_idx[0])
91

92
93
94
95
96
    # add to node feature dict
    for idx, g in zip(node_idx_list, dataset.graph_lists):
        node_labels_tensor = torch.tensor(one_hot_node_labels[idx, :])
        if mode.lower() == "concat":
            g.ndata["feat"] = torch.cat(
97
98
                (g.ndata["feat"], node_labels_tensor), dim=1
            )
99
100
        elif mode.lower() == "add":
            g.ndata["node_label"] = node_labels_tensor
101
        else:  # replace
102
            g.ndata["feat"] = node_labels_tensor
103

104
105
106
107
108
109
110
    if save:
        check_mark[check_mark_name] = True
        _save_check_mark(check_mark_path, check_mark)
        dataset.save()
    return dataset


111
def degree_as_feature(dataset: LegacyTUDataset, save=True):
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    """
    Description
    -----------
    Use node degree (in one-hot format) as node feature

    Parameters
    ----------
    dataset : LegacyTUDataset
        The dataset object

    save : bool, optional
        Save the result dataset.
        Default: :obj:`True`
    """
    # first check if already have such feature
    check_mark_name = "degree_as_feat"
    feat_name = "feat"
    check_mark_path = os.path.join(
130
131
        dataset.save_path, "info_{}_{}.json".format(dataset.name, dataset.hash)
    )
132
133
    check_mark = _load_check_mark(check_mark_path)

134
135
136
137
138
    if (
        check_mark_name in check_mark
        and check_mark[check_mark_name]
        and not dataset._force_reload
    ):
139
140
141
142
143
144
145
146
147
148
        logging.warning("Using cached value in 'degree_as_feature'")
        return dataset

    logging.warning("Adding node degree into node features...")
    min_degree = sys.maxsize
    max_degree = 0
    for i in range(len(dataset)):
        degrees = dataset.graph_lists[i].in_degrees()
        min_degree = min(min_degree, degrees.min().item())
        max_degree = max(max_degree, degrees.max().item())
149

150
151
152
153
154
    vec_len = max_degree - min_degree + 1
    for i in range(len(dataset)):
        num_nodes = dataset.graph_lists[i].num_nodes()
        node_feat = torch.zeros((num_nodes, vec_len))
        degrees = dataset.graph_lists[i].in_degrees()
155
        node_feat[torch.arange(num_nodes), degrees - min_degree] = 1.0
156
157
158
159
160
161
162
        dataset.graph_lists[i].ndata[feat_name] = node_feat

    if save:
        check_mark[check_mark_name] = True
        dataset.save()
        _save_check_mark(check_mark_path, check_mark)
    return dataset