"docs/source/api/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "f673fc25539ea8f38590ea62ee6fbb06478536b1"
Unverified Commit 0d1dcdcd authored by Hengrui Zhang's avatar Hengrui Zhang Committed by GitHub
Browse files

[bugfix] fix a bug found in v0.7 bug bash in the 'grace' model (#3109)



* Update aug.py

* Update aug.py
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent 32d1f3ac
...@@ -5,29 +5,35 @@ import numpy as np ...@@ -5,29 +5,35 @@ import numpy as np
import dgl import dgl
def aug(graph, x, feat_drop_rate, edge_mask_rate): def aug(graph, x, feat_drop_rate, edge_mask_rate):
ng = drop_edge(graph, edge_mask_rate) n_node = graph.num_nodes()
feat = drop_feat(x, feat_drop_rate)
ng = ng.add_self_loop()
return ng, feat edge_mask = mask_edge(graph, edge_mask_rate)
feat = drop_feature(x, feat_drop_rate)
def drop_edge(graph, drop_prob): src = graph.edges()[0]
E = graph.num_edges() dst = graph.edges()[1]
mask_rates = th.FloatTensor(np.ones(E) * drop_prob) nsrc = src[edge_mask]
masks = th.bernoulli(1 - mask_rates) ndst = dst[edge_mask]
edge_idx = masks.nonzero().squeeze(1)
sg = dgl.edge_subgraph(graph, edge_idx, relabel_nodes=False) ng = dgl.graph((nsrc, ndst), num_nodes=n_node)
ng = ng.add_self_loop()
return sg
def drop_feat(x, drop_prob): return ng, feat
D = x.shape[1]
mask_rates = th.FloatTensor(np.ones(D) * drop_prob)
masks = th.bernoulli(1 - mask_rates)
def drop_feature(x, drop_prob):
drop_mask = th.empty((x.size(1),),
dtype=th.float32,
device=x.device).uniform_(0, 1) < drop_prob
x = x.clone() x = x.clone()
x[:, masks] = 0 x[:, drop_mask] = 0
return x
return x def mask_edge(graph, mask_prob):
\ No newline at end of file E = graph.num_edges()
mask_rates = th.FloatTensor(np.ones(E) * mask_prob)
masks = th.bernoulli(1 - mask_rates)
mask_idx = masks.nonzero().squeeze(1)
return mask_idx
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment