utils.py 3.06 KB
Newer Older
1
import copy
2

3
import numpy as np
4
5
import torch

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
6
7
8
9
10
11
12
13
from dgl.data import (
    AmazonCoBuyComputerDataset,
    AmazonCoBuyPhotoDataset,
    CoauthorCSDataset,
    CoauthorPhysicsDataset,
    PPIDataset,
    WikiCSDataset,
)
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from dgl.dataloading import GraphDataLoader
from dgl.transforms import Compose, DropEdge, FeatMask, RowFeatNormalizer


class CosineDecayScheduler:
    def __init__(self, max_val, warmup_steps, total_steps):
        self.max_val = max_val
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps

    def get(self, step):
        if step < self.warmup_steps:
            return self.max_val * step / self.warmup_steps
        elif self.warmup_steps <= step <= self.total_steps:
28
29
30
31
32
33
34
35
36
37
38
39
            return (
                self.max_val
                * (
                    1
                    + np.cos(
                        (step - self.warmup_steps)
                        * np.pi
                        / (self.total_steps - self.warmup_steps)
                    )
                )
                / 2
            )
40
        else:
41
42
43
44
45
            raise ValueError(
                "Step ({}) > total number of steps ({}).".format(
                    step, self.total_steps
                )
            )
46
47
48
49
50
51
52
53
54


def get_graph_drop_transform(drop_edge_p, feat_mask_p):
    transforms = list()

    # make copy of graph
    transforms.append(copy.deepcopy)

    # drop edges
55
    if drop_edge_p > 0.0:
56
57
58
        transforms.append(DropEdge(drop_edge_p))

    # drop features
59
60
    if feat_mask_p > 0.0:
        transforms.append(FeatMask(feat_mask_p, node_feat_names=["feat"]))
61
62
63
64
65
66
67

    return Compose(transforms)


def get_wiki_cs(transform=RowFeatNormalizer(subtract_min=True)):
    dataset = WikiCSDataset(transform=transform)
    g = dataset[0]
68
69
    std, mean = torch.std_mean(g.ndata["feat"], dim=0, unbiased=False)
    g.ndata["feat"] = (g.ndata["feat"] - mean) / std
70
71
72
73
74

    return [g]


def get_ppi():
75
76
77
    train_dataset = PPIDataset(mode="train")
    val_dataset = PPIDataset(mode="valid")
    test_dataset = PPIDataset(mode="test")
78
79
    train_val_dataset = [i for i in train_dataset] + [i for i in val_dataset]
    for idx, data in enumerate(train_val_dataset):
80
        data.ndata["batch"] = torch.zeros(data.num_nodes()) + idx
81
        data.ndata["batch"] = data.ndata["batch"].long()
82
83
84

    g = list(GraphDataLoader(train_val_dataset, batch_size=22, shuffle=True))

85
    return g, PPIDataset(mode="train"), PPIDataset(mode="valid"), test_dataset
86
87
88
89


def get_dataset(name, transform=RowFeatNormalizer(subtract_min=True)):
    dgl_dataset_dict = {
90
91
92
93
94
95
        "coauthor_cs": CoauthorCSDataset,
        "coauthor_physics": CoauthorPhysicsDataset,
        "amazon_computers": AmazonCoBuyComputerDataset,
        "amazon_photos": AmazonCoBuyPhotoDataset,
        "wiki_cs": get_wiki_cs,
        "ppi": get_ppi,
96
97
98
99
    }

    dataset_class = dgl_dataset_dict[name]
    train_data, val_data, test_data = None, None, None
100
    if name != "ppi":
101
102
103
104
        dataset = dataset_class(transform=transform)
    else:
        dataset, train_data, val_data, test_data = dataset_class()

105
    return dataset, train_data, val_data, test_data