"...python/git@developer.sourcefind.cn:change/sglang.git" did not exist on "f9f0138f80a32ecba8a4da619cb51dce2bb3381c"
Unverified Commit cba5af22 authored by Yu Sun's avatar Yu Sun Committed by GitHub
Browse files

[bugfix] Fix bugs in vgae (#2727)



* [Example]Variational Graph Auto-Encoders

* change dgl dataset to single directional graph

* clean code

* refresh

* fix bug

* fix bug

* fix bug

* add gpu
Co-authored-by: default avatarTianjun Xiao <xiaotj1990327@gmail.com>
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent b2e35e6a
...@@ -3,6 +3,8 @@ import torch ...@@ -3,6 +3,8 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from train import device
class VGAEModel(nn.Module): class VGAEModel(nn.Module):
def __init__(self, in_dim, hidden1_dim, hidden2_dim): def __init__(self, in_dim, hidden1_dim, hidden2_dim):
...@@ -20,8 +22,8 @@ class VGAEModel(nn.Module): ...@@ -20,8 +22,8 @@ class VGAEModel(nn.Module):
h = self.layers[0](g, features) h = self.layers[0](g, features)
self.mean = self.layers[1](g, h) self.mean = self.layers[1](g, h)
self.log_std = self.layers[2](g, h) self.log_std = self.layers[2](g, h)
gaussian_noise = torch.randn(features.size(0), self.hidden2_dim) gaussian_noise = torch.randn(features.size(0), self.hidden2_dim).to(device)
sampled_z = self.mean + gaussian_noise * torch.exp(self.log_std) sampled_z = self.mean + gaussian_noise * torch.exp(self.log_std).to(device)
return sampled_z return sampled_z
def decoder(self, z): def decoder(self, z):
......
...@@ -88,7 +88,7 @@ def mask_test_edges(adj): ...@@ -88,7 +88,7 @@ def mask_test_edges(adj):
def mask_test_edges_dgl(graph, adj): def mask_test_edges_dgl(graph, adj):
src, dst = graph.edges() src, dst = graph.edges()
edges_all = torch.stack([src, dst], dim=0) edges_all = torch.stack([src, dst], dim=0)
edges_all = edges_all.t().numpy() edges_all = edges_all.t().cpu().numpy()
num_test = int(np.floor(edges_all.shape[0] / 10.)) num_test = int(np.floor(edges_all.shape[0] / 10.))
num_val = int(np.floor(edges_all.shape[0] / 20.)) num_val = int(np.floor(edges_all.shape[0] / 20.))
......
...@@ -23,11 +23,15 @@ parser.add_argument('--hidden1', '-h1', type=int, default=32, help='Number of un ...@@ -23,11 +23,15 @@ parser.add_argument('--hidden1', '-h1', type=int, default=32, help='Number of un
parser.add_argument('--hidden2', '-h2', type=int, default=16, help='Number of units in hidden layer 2.') parser.add_argument('--hidden2', '-h2', type=int, default=16, help='Number of units in hidden layer 2.')
parser.add_argument('--datasrc', '-s', type=str, default='dgl', parser.add_argument('--datasrc', '-s', type=str, default='dgl',
help='Dataset download from dgl Dataset or website.') help='Dataset download from dgl Dataset or website.')
parser.add_argument('--dataset', '-d', type=str, default='pubmed', help='Dataset string.') parser.add_argument('--dataset', '-d', type=str, default='cora', help='Dataset string.')
parser.add_argument('--gpu_id', type=int, default=0, help='GPU id to use.') parser.add_argument('--gpu_id', type=int, default=0, help='GPU id to use.')
args = parser.parse_args() args = parser.parse_args()
# check device
device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else "cpu")
# device = "cpu"
# roc_means = [] # roc_means = []
# ap_means = [] # ap_means = []
...@@ -35,7 +39,7 @@ def compute_loss_para(adj): ...@@ -35,7 +39,7 @@ def compute_loss_para(adj):
pos_weight = ((adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum()) pos_weight = ((adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum())
norm = adj.shape[0] * adj.shape[0] / float((adj.shape[0] * adj.shape[0] - adj.sum()) * 2) norm = adj.shape[0] * adj.shape[0] / float((adj.shape[0] * adj.shape[0] - adj.sum()) * 2)
weight_mask = adj.view(-1) == 1 weight_mask = adj.view(-1) == 1
weight_tensor = torch.ones(weight_mask.size(0)) weight_tensor = torch.ones(weight_mask.size(0)).to(device)
weight_tensor[weight_mask] = pos_weight weight_tensor[weight_mask] = pos_weight
return weight_tensor, norm return weight_tensor, norm
...@@ -51,6 +55,7 @@ def get_scores(edges_pos, edges_neg, adj_rec): ...@@ -51,6 +55,7 @@ def get_scores(edges_pos, edges_neg, adj_rec):
def sigmoid(x): def sigmoid(x):
return 1 / (1 + np.exp(-x)) return 1 / (1 + np.exp(-x))
adj_rec = adj_rec.cpu()
# Predict on test set of edges # Predict on test set of edges
preds = [] preds = []
for e in edges_pos: for e in edges_pos:
...@@ -80,21 +85,18 @@ def dgl_main(): ...@@ -80,21 +85,18 @@ def dgl_main():
raise NotImplementedError raise NotImplementedError
graph = dataset[0] graph = dataset[0]
# check device
device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else "cpu")
# Extract node features # Extract node features
feats = graph.ndata.pop('feat').to(device) feats = graph.ndata.pop('feat').to(device)
in_dim = feats.shape[-1] in_dim = feats.shape[-1]
graph = graph.to(device)
# generate input # generate input
adj_orig = graph.adjacency_matrix().to_dense().to(device) adj_orig = graph.adjacency_matrix().to_dense()
# build test set with 10% positive links # build test set with 10% positive links
train_edge_idx, val_edges, val_edges_false, test_edges, test_edges_false = mask_test_edges_dgl(graph, adj_orig) train_edge_idx, val_edges, val_edges_false, test_edges, test_edges_false = mask_test_edges_dgl(graph, adj_orig)
graph = graph.to(device)
# create train graph # create train graph
train_edge_idx = torch.tensor(train_edge_idx).to(device) train_edge_idx = torch.tensor(train_edge_idx).to(device)
train_graph = dgl.edge_subgraph(graph, train_edge_idx, preserve_nodes=True) train_graph = dgl.edge_subgraph(graph, train_edge_idx, preserve_nodes=True)
...@@ -119,7 +121,7 @@ def dgl_main(): ...@@ -119,7 +121,7 @@ def dgl_main():
# Training and validation using a full graph # Training and validation using a full graph
vgae_model.train() vgae_model.train()
logits = vgae_model.forward(train_graph, feats) logits = vgae_model.forward(graph, feats)
# compute loss # compute loss
loss = norm * F.binary_cross_entropy(logits.view(-1), adj.view(-1), weight=weight_tensor) loss = norm * F.binary_cross_entropy(logits.view(-1), adj.view(-1), weight=weight_tensor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment