"vscode:/vscode.git/clone" did not exist on "52b460feb98740d68b44aaef4d68470170b3c4a6"
aug.py 960 Bytes
Newer Older
1
2
3
4
5
6
7
# Data augmentation on graphs via edge dropping and feature masking

import torch as th
import numpy as np
import dgl

def aug(graph, x, feat_drop_rate, edge_mask_rate):
8
    n_node = graph.num_nodes()
9

10
11
    edge_mask = mask_edge(graph, edge_mask_rate)
    feat = drop_feature(x, feat_drop_rate)
12

13
14
    src = graph.edges()[0]
    dst = graph.edges()[1]
15

16
17
    nsrc = src[edge_mask]
    ndst = dst[edge_mask]
18

19
20
    ng = dgl.graph((nsrc, ndst), num_nodes=n_node)
    ng = ng.add_self_loop()
21

22
    return ng, feat
23

24
25
26
27
def drop_feature(x, drop_prob):
    drop_mask = th.empty((x.size(1),),
                        dtype=th.float32,
                        device=x.device).uniform_(0, 1) < drop_prob
28
    x = x.clone()
29
30
31
    x[:, drop_mask] = 0

    return x
32

33
34
35
36
37
38
39
def mask_edge(graph, mask_prob):
    E = graph.num_edges()

    mask_rates = th.FloatTensor(np.ones(E) * mask_prob)
    masks = th.bernoulli(1 - mask_rates)
    mask_idx = masks.nonzero().squeeze(1)
    return mask_idx