"docs/source/api/logging.mdx" did not exist on "2a69c0b7b8c4ab2ced90a1784a3eeef6ddf5ae8c"
utils.py 2.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import copy
import torch
import numpy as np
from dgl.dataloading import GraphDataLoader
from dgl.transforms import Compose, DropEdge, FeatMask, RowFeatNormalizer
from dgl.data import CoauthorCSDataset, CoauthorPhysicsDataset, AmazonCoBuyPhotoDataset, AmazonCoBuyComputerDataset, PPIDataset, WikiCSDataset


class CosineDecayScheduler:
    def __init__(self, max_val, warmup_steps, total_steps):
        self.max_val = max_val
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps

    def get(self, step):
        if step < self.warmup_steps:
            return self.max_val * step / self.warmup_steps
        elif self.warmup_steps <= step <= self.total_steps:
            return self.max_val * (1 + np.cos((step - self.warmup_steps) * np.pi /
                                              (self.total_steps - self.warmup_steps))) / 2
        else:
            raise ValueError('Step ({}) > total number of steps ({}).'.format(step, self.total_steps))


def get_graph_drop_transform(drop_edge_p, feat_mask_p):
    transforms = list()

    # make copy of graph
    transforms.append(copy.deepcopy)

    # drop edges
    if drop_edge_p > 0.:
        transforms.append(DropEdge(drop_edge_p))

    # drop features
    if feat_mask_p > 0.:
        transforms.append(FeatMask(feat_mask_p, node_feat_names=['feat']))

    return Compose(transforms)


def get_wiki_cs(transform=RowFeatNormalizer(subtract_min=True)):
    dataset = WikiCSDataset(transform=transform)
    g = dataset[0]
    std, mean = torch.std_mean(g.ndata['feat'], dim=0, unbiased=False)
    g.ndata['feat'] = (g.ndata['feat'] - mean) / std

    return [g]


def get_ppi():
    train_dataset = PPIDataset(mode='train')
    val_dataset = PPIDataset(mode='valid')
    test_dataset = PPIDataset(mode='test')
    train_val_dataset = [i for i in train_dataset] + [i for i in val_dataset]
    for idx, data in enumerate(train_val_dataset):
        data.ndata['batch'] = torch.zeros(data.number_of_nodes()) + idx
        data.ndata['batch'] = data.ndata['batch'].long()

    g = list(GraphDataLoader(train_val_dataset, batch_size=22, shuffle=True))

    return g, PPIDataset(mode='train'), PPIDataset(mode='valid'), test_dataset


def get_dataset(name, transform=RowFeatNormalizer(subtract_min=True)):
    dgl_dataset_dict = {
        'coauthor_cs': CoauthorCSDataset,
        'coauthor_physics': CoauthorPhysicsDataset,
        'amazon_computers': AmazonCoBuyComputerDataset,
        'amazon_photos': AmazonCoBuyPhotoDataset,
        'wiki_cs': get_wiki_cs,
        'ppi': get_ppi
    }

    dataset_class = dgl_dataset_dict[name]
    train_data, val_data, test_data = None, None, None
    if name != 'ppi':
        dataset = dataset_class(transform=transform)
    else:
        dataset, train_data, val_data, test_data = dataset_class()

    return dataset, train_data, val_data, test_data