"tools/vscode:/vscode.git/clone" did not exist on "a1ebc651ab830a381e8960029145b557990342d6"
aug.py 941 Bytes
Newer Older
1
2
3
# Data augmentation on graphs via edge dropping and feature masking

import numpy as np
4
5
import torch as th

6
7
import dgl

8

9
def aug(graph, x, feat_drop_rate, edge_mask_rate):
10
    n_node = graph.num_nodes()
11

12
13
    edge_mask = mask_edge(graph, edge_mask_rate)
    feat = drop_feature(x, feat_drop_rate)
14

15
16
    src = graph.edges()[0]
    dst = graph.edges()[1]
17

18
19
    nsrc = src[edge_mask]
    ndst = dst[edge_mask]
20

21
22
    ng = dgl.graph((nsrc, ndst), num_nodes=n_node)
    ng = ng.add_self_loop()
23

24
    return ng, feat
25

26

27
def drop_feature(x, drop_prob):
28
29
30
31
    drop_mask = (
        th.empty((x.size(1),), dtype=th.float32, device=x.device).uniform_(0, 1)
        < drop_prob
    )
32
    x = x.clone()
33
34
35
    x[:, drop_mask] = 0

    return x
36

37

38
39
40
41
42
43
def mask_edge(graph, mask_prob):
    E = graph.num_edges()

    mask_rates = th.FloatTensor(np.ones(E) * mask_prob)
    masks = th.bernoulli(1 - mask_rates)
    mask_idx = masks.nonzero().squeeze(1)
44
    return mask_idx