"tests/nn/git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "b6dc98cfe3365b9fbc7e386c2c246765a5b5710f"
Unverified Commit 23d09057 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4642)



* [Misc] Black auto fix.

* sort
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent a9f2acf3
from collections import defaultdict
import math import math
import os import os
import sys import sys
import time import time
from collections import defaultdict
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from tqdm.auto import tqdm
from numpy import random from numpy import random
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from tqdm.auto import tqdm
from utils import *
import dgl import dgl
import dgl.function as fn import dgl.function as fn
from utils import *
def get_graph(network_data, vocab): def get_graph(network_data, vocab):
""" Build graph, treat all nodes as the same type """Build graph, treat all nodes as the same type
Parameters Parameters
---------- ----------
...@@ -57,7 +57,9 @@ class NeighborSampler(object): ...@@ -57,7 +57,9 @@ class NeighborSampler(object):
def sample(self, pairs): def sample(self, pairs):
heads, tails, types = zip(*pairs) heads, tails, types = zip(*pairs)
seeds, head_invmap = torch.unique(torch.LongTensor(heads), return_inverse=True) seeds, head_invmap = torch.unique(
torch.LongTensor(heads), return_inverse=True
)
blocks = [] blocks = []
for fanout in reversed(self.num_fanouts): for fanout in reversed(self.num_fanouts):
sampled_graph = dgl.sampling.sample_neighbors(self.g, seeds, fanout) sampled_graph = dgl.sampling.sample_neighbors(self.g, seeds, fanout)
...@@ -90,7 +92,9 @@ class DGLGATNE(nn.Module): ...@@ -90,7 +92,9 @@ class DGLGATNE(nn.Module):
self.edge_type_count = edge_type_count self.edge_type_count = edge_type_count
self.dim_a = dim_a self.dim_a = dim_a
self.node_embeddings = Parameter(torch.FloatTensor(num_nodes, embedding_size)) self.node_embeddings = Parameter(
torch.FloatTensor(num_nodes, embedding_size)
)
self.node_type_embeddings = Parameter( self.node_type_embeddings = Parameter(
torch.FloatTensor(num_nodes, edge_type_count, embedding_u_size) torch.FloatTensor(num_nodes, edge_type_count, embedding_u_size)
) )
...@@ -100,16 +104,24 @@ class DGLGATNE(nn.Module): ...@@ -100,16 +104,24 @@ class DGLGATNE(nn.Module):
self.trans_weights_s1 = Parameter( self.trans_weights_s1 = Parameter(
torch.FloatTensor(edge_type_count, embedding_u_size, dim_a) torch.FloatTensor(edge_type_count, embedding_u_size, dim_a)
) )
self.trans_weights_s2 = Parameter(torch.FloatTensor(edge_type_count, dim_a, 1)) self.trans_weights_s2 = Parameter(
torch.FloatTensor(edge_type_count, dim_a, 1)
)
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
self.node_embeddings.data.uniform_(-1.0, 1.0) self.node_embeddings.data.uniform_(-1.0, 1.0)
self.node_type_embeddings.data.uniform_(-1.0, 1.0) self.node_type_embeddings.data.uniform_(-1.0, 1.0)
self.trans_weights.data.normal_(std=1.0 / math.sqrt(self.embedding_size)) self.trans_weights.data.normal_(
self.trans_weights_s1.data.normal_(std=1.0 / math.sqrt(self.embedding_size)) std=1.0 / math.sqrt(self.embedding_size)
self.trans_weights_s2.data.normal_(std=1.0 / math.sqrt(self.embedding_size)) )
self.trans_weights_s1.data.normal_(
std=1.0 / math.sqrt(self.embedding_size)
)
self.trans_weights_s2.data.normal_(
std=1.0 / math.sqrt(self.embedding_size)
)
# embs: [batch_size, embedding_size] # embs: [batch_size, embedding_size]
def forward(self, block): def forward(self, block):
...@@ -122,10 +134,16 @@ class DGLGATNE(nn.Module): ...@@ -122,10 +134,16 @@ class DGLGATNE(nn.Module):
with block.local_scope(): with block.local_scope():
for i in range(self.edge_type_count): for i in range(self.edge_type_count):
edge_type = self.edge_types[i] edge_type = self.edge_types[i]
block.srcdata[edge_type] = self.node_type_embeddings[input_nodes, i] block.srcdata[edge_type] = self.node_type_embeddings[
block.dstdata[edge_type] = self.node_type_embeddings[output_nodes, i] input_nodes, i
]
block.dstdata[edge_type] = self.node_type_embeddings[
output_nodes, i
]
block.update_all( block.update_all(
fn.copy_u(edge_type, "m"), fn.sum("m", edge_type), etype=edge_type fn.copy_u(edge_type, "m"),
fn.sum("m", edge_type),
etype=edge_type,
) )
node_type_embed.append(block.dstdata[edge_type]) node_type_embed.append(block.dstdata[edge_type])
...@@ -152,7 +170,9 @@ class DGLGATNE(nn.Module): ...@@ -152,7 +170,9 @@ class DGLGATNE(nn.Module):
attention = ( attention = (
F.softmax( F.softmax(
torch.matmul( torch.matmul(
torch.tanh(torch.matmul(tmp_node_type_embed, trans_w_s1)), torch.tanh(
torch.matmul(tmp_node_type_embed, trans_w_s1)
),
trans_w_s2, trans_w_s2,
) )
.squeeze(2) .squeeze(2)
...@@ -173,7 +193,9 @@ class DGLGATNE(nn.Module): ...@@ -173,7 +193,9 @@ class DGLGATNE(nn.Module):
) )
last_node_embed = F.normalize(node_embed, dim=2) last_node_embed = F.normalize(node_embed, dim=2)
return last_node_embed # [batch_size, edge_type_count, embedding_size] return (
last_node_embed # [batch_size, edge_type_count, embedding_size]
)
class NSLoss(nn.Module): class NSLoss(nn.Module):
...@@ -187,7 +209,8 @@ class NSLoss(nn.Module): ...@@ -187,7 +209,8 @@ class NSLoss(nn.Module):
self.sample_weights = F.normalize( self.sample_weights = F.normalize(
torch.Tensor( torch.Tensor(
[ [
(math.log(k + 2) - math.log(k + 1)) / math.log(num_nodes + 1) (math.log(k + 2) - math.log(k + 1))
/ math.log(num_nodes + 1)
for k in range(num_nodes) for k in range(num_nodes)
] ]
), ),
...@@ -257,14 +280,20 @@ def train_model(network_data): ...@@ -257,14 +280,20 @@ def train_model(network_data):
pin_memory=True, pin_memory=True,
) )
model = DGLGATNE( model = DGLGATNE(
num_nodes, embedding_size, embedding_u_size, edge_types, edge_type_count, dim_a num_nodes,
embedding_size,
embedding_u_size,
edge_types,
edge_type_count,
dim_a,
) )
nsloss = NSLoss(num_nodes, num_sampled, embedding_size) nsloss = NSLoss(num_nodes, num_sampled, embedding_size)
model.to(device) model.to(device)
nsloss.to(device) nsloss.to(device)
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
[{"params": model.parameters()}, {"params": nsloss.parameters()}], lr=1e-3 [{"params": model.parameters()}, {"params": nsloss.parameters()}],
lr=1e-3,
) )
best_score = 0 best_score = 0
...@@ -286,7 +315,10 @@ def train_model(network_data): ...@@ -286,7 +315,10 @@ def train_model(network_data):
block_types = block_types.to(device) block_types = block_types.to(device)
embs = model(block[0].to(device))[head_invmap] embs = model(block[0].to(device))[head_invmap]
embs = embs.gather( embs = embs.gather(
1, block_types.view(-1, 1, 1).expand(embs.shape[0], 1, embs.shape[2]) 1,
block_types.view(-1, 1, 1).expand(
embs.shape[0], 1, embs.shape[2]
),
)[:, 0] )[:, 0]
loss = nsloss( loss = nsloss(
block[0].dstdata[dgl.NID][head_invmap].to(device), block[0].dstdata[dgl.NID][head_invmap].to(device),
...@@ -307,7 +339,9 @@ def train_model(network_data): ...@@ -307,7 +339,9 @@ def train_model(network_data):
model.eval() model.eval()
# {'1': {}, '2': {}} # {'1': {}, '2': {}}
final_model = dict(zip(edge_types, [dict() for _ in range(edge_type_count)])) final_model = dict(
zip(edge_types, [dict() for _ in range(edge_type_count)])
)
for i in range(num_nodes): for i in range(num_nodes):
train_inputs = ( train_inputs = (
torch.tensor([i for _ in range(edge_type_count)]) torch.tensor([i for _ in range(edge_type_count)])
...@@ -315,7 +349,9 @@ def train_model(network_data): ...@@ -315,7 +349,9 @@ def train_model(network_data):
.to(device) .to(device)
) # [i, i] ) # [i, i]
train_types = ( train_types = (
torch.tensor(list(range(edge_type_count))).unsqueeze(1).to(device) torch.tensor(list(range(edge_type_count)))
.unsqueeze(1)
.to(device)
) # [0, 1] ) # [0, 1]
pairs = torch.cat( pairs = torch.cat(
(train_inputs, train_inputs, train_types), dim=1 (train_inputs, train_inputs, train_types), dim=1
...@@ -343,7 +379,9 @@ def train_model(network_data): ...@@ -343,7 +379,9 @@ def train_model(network_data):
valid_aucs, valid_f1s, valid_prs = [], [], [] valid_aucs, valid_f1s, valid_prs = [], [], []
test_aucs, test_f1s, test_prs = [], [], [] test_aucs, test_f1s, test_prs = [], [], []
for i in range(edge_type_count): for i in range(edge_type_count):
if args.eval_type == "all" or edge_types[i] in args.eval_type.split(","): if args.eval_type == "all" or edge_types[i] in args.eval_type.split(
","
):
tmp_auc, tmp_f1, tmp_pr = evaluate( tmp_auc, tmp_f1, tmp_pr = evaluate(
final_model[edge_types[i]], final_model[edge_types[i]],
valid_true_data_by_edge[edge_types[i]], valid_true_data_by_edge[edge_types[i]],
......
from collections import defaultdict
import math import math
import os import os
import sys import sys
import time import time
from collections import defaultdict
import numpy as np import numpy as np
import torch import torch
...@@ -11,14 +11,14 @@ import torch.nn.functional as F ...@@ -11,14 +11,14 @@ import torch.nn.functional as F
import tqdm import tqdm
from numpy import random from numpy import random
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from utils import *
import dgl import dgl
import dgl.function as fn import dgl.function as fn
from utils import *
def get_graph(network_data, vocab): def get_graph(network_data, vocab):
""" Build graph, treat all nodes as the same type """Build graph, treat all nodes as the same type
Parameters Parameters
---------- ----------
...@@ -58,7 +58,9 @@ class NeighborSampler(object): ...@@ -58,7 +58,9 @@ class NeighborSampler(object):
def sample(self, pairs): def sample(self, pairs):
pairs = np.stack(pairs) pairs = np.stack(pairs)
heads, tails, types = pairs[:, 0], pairs[:, 1], pairs[:, 2] heads, tails, types = pairs[:, 0], pairs[:, 1], pairs[:, 2]
seeds, head_invmap = torch.unique(torch.LongTensor(heads), return_inverse=True) seeds, head_invmap = torch.unique(
torch.LongTensor(heads), return_inverse=True
)
blocks = [] blocks = []
for fanout in reversed(self.num_fanouts): for fanout in reversed(self.num_fanouts):
sampled_graph = dgl.sampling.sample_neighbors(self.g, seeds, fanout) sampled_graph = dgl.sampling.sample_neighbors(self.g, seeds, fanout)
...@@ -91,7 +93,9 @@ class DGLGATNE(nn.Module): ...@@ -91,7 +93,9 @@ class DGLGATNE(nn.Module):
self.edge_type_count = edge_type_count self.edge_type_count = edge_type_count
self.dim_a = dim_a self.dim_a = dim_a
self.node_embeddings = nn.Embedding(num_nodes, embedding_size, sparse=True) self.node_embeddings = nn.Embedding(
num_nodes, embedding_size, sparse=True
)
self.node_type_embeddings = nn.Embedding( self.node_type_embeddings = nn.Embedding(
num_nodes * edge_type_count, embedding_u_size, sparse=True num_nodes * edge_type_count, embedding_u_size, sparse=True
) )
...@@ -101,16 +105,24 @@ class DGLGATNE(nn.Module): ...@@ -101,16 +105,24 @@ class DGLGATNE(nn.Module):
self.trans_weights_s1 = Parameter( self.trans_weights_s1 = Parameter(
torch.FloatTensor(edge_type_count, embedding_u_size, dim_a) torch.FloatTensor(edge_type_count, embedding_u_size, dim_a)
) )
self.trans_weights_s2 = Parameter(torch.FloatTensor(edge_type_count, dim_a, 1)) self.trans_weights_s2 = Parameter(
torch.FloatTensor(edge_type_count, dim_a, 1)
)
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
self.node_embeddings.weight.data.uniform_(-1.0, 1.0) self.node_embeddings.weight.data.uniform_(-1.0, 1.0)
self.node_type_embeddings.weight.data.uniform_(-1.0, 1.0) self.node_type_embeddings.weight.data.uniform_(-1.0, 1.0)
self.trans_weights.data.normal_(std=1.0 / math.sqrt(self.embedding_size)) self.trans_weights.data.normal_(
self.trans_weights_s1.data.normal_(std=1.0 / math.sqrt(self.embedding_size)) std=1.0 / math.sqrt(self.embedding_size)
self.trans_weights_s2.data.normal_(std=1.0 / math.sqrt(self.embedding_size)) )
self.trans_weights_s1.data.normal_(
std=1.0 / math.sqrt(self.embedding_size)
)
self.trans_weights_s2.data.normal_(
std=1.0 / math.sqrt(self.embedding_size)
)
# embs: [batch_size, embedding_size] # embs: [batch_size, embedding_size]
def forward(self, block): def forward(self, block):
...@@ -129,7 +141,9 @@ class DGLGATNE(nn.Module): ...@@ -129,7 +141,9 @@ class DGLGATNE(nn.Module):
output_nodes * self.edge_type_count + i output_nodes * self.edge_type_count + i
) )
block.update_all( block.update_all(
fn.copy_u(edge_type, "m"), fn.sum("m", edge_type), etype=edge_type fn.copy_u(edge_type, "m"),
fn.sum("m", edge_type),
etype=edge_type,
) )
node_type_embed.append(block.dstdata[edge_type]) node_type_embed.append(block.dstdata[edge_type])
...@@ -156,7 +170,9 @@ class DGLGATNE(nn.Module): ...@@ -156,7 +170,9 @@ class DGLGATNE(nn.Module):
attention = ( attention = (
F.softmax( F.softmax(
torch.matmul( torch.matmul(
torch.tanh(torch.matmul(tmp_node_type_embed, trans_w_s1)), torch.tanh(
torch.matmul(tmp_node_type_embed, trans_w_s1)
),
trans_w_s2, trans_w_s2,
) )
.squeeze(2) .squeeze(2)
...@@ -177,7 +193,9 @@ class DGLGATNE(nn.Module): ...@@ -177,7 +193,9 @@ class DGLGATNE(nn.Module):
) )
last_node_embed = F.normalize(node_embed, dim=2) last_node_embed = F.normalize(node_embed, dim=2)
return last_node_embed # [batch_size, edge_type_count, embedding_size] return (
last_node_embed # [batch_size, edge_type_count, embedding_size]
)
class NSLoss(nn.Module): class NSLoss(nn.Module):
...@@ -191,7 +209,8 @@ class NSLoss(nn.Module): ...@@ -191,7 +209,8 @@ class NSLoss(nn.Module):
self.sample_weights = F.normalize( self.sample_weights = F.normalize(
torch.Tensor( torch.Tensor(
[ [
(math.log(k + 2) - math.log(k + 1)) / math.log(num_nodes + 1) (math.log(k + 2) - math.log(k + 1))
/ math.log(num_nodes + 1)
for k in range(num_nodes) for k in range(num_nodes)
] ]
), ),
...@@ -201,7 +220,9 @@ class NSLoss(nn.Module): ...@@ -201,7 +220,9 @@ class NSLoss(nn.Module):
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
self.weights.weight.data.normal_(std=1.0 / math.sqrt(self.embedding_size)) self.weights.weight.data.normal_(
std=1.0 / math.sqrt(self.embedding_size)
)
def forward(self, input, embs, label): def forward(self, input, embs, label):
n = input.shape[0] n = input.shape[0]
...@@ -266,7 +287,12 @@ def train_model(network_data): ...@@ -266,7 +287,12 @@ def train_model(network_data):
) )
model = DGLGATNE( model = DGLGATNE(
num_nodes, embedding_size, embedding_u_size, edge_types, edge_type_count, dim_a, num_nodes,
embedding_size,
embedding_u_size,
edge_types,
edge_type_count,
dim_a,
) )
nsloss = NSLoss(num_nodes, num_sampled, embedding_size) nsloss = NSLoss(num_nodes, num_sampled, embedding_size)
...@@ -274,21 +300,23 @@ def train_model(network_data): ...@@ -274,21 +300,23 @@ def train_model(network_data):
model.to(device) model.to(device)
nsloss.to(device) nsloss.to(device)
embeddings_params = list(map(id, model.node_embeddings.parameters())) + list( embeddings_params = list(
map(id, model.node_type_embeddings.parameters()) map(id, model.node_embeddings.parameters())
) ) + list(map(id, model.node_type_embeddings.parameters()))
weights_params = list(map(id, nsloss.weights.parameters())) weights_params = list(map(id, nsloss.weights.parameters()))
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
[ [
{ {
"params": filter( "params": filter(
lambda p: id(p) not in embeddings_params, model.parameters(), lambda p: id(p) not in embeddings_params,
model.parameters(),
) )
}, },
{ {
"params": filter( "params": filter(
lambda p: id(p) not in weights_params, nsloss.parameters(), lambda p: id(p) not in weights_params,
nsloss.parameters(),
) )
}, },
], ],
...@@ -325,7 +353,10 @@ def train_model(network_data): ...@@ -325,7 +353,10 @@ def train_model(network_data):
block_types = block_types.to(device) block_types = block_types.to(device)
embs = model(block[0].to(device))[head_invmap] embs = model(block[0].to(device))[head_invmap]
embs = embs.gather( embs = embs.gather(
1, block_types.view(-1, 1, 1).expand(embs.shape[0], 1, embs.shape[2]), 1,
block_types.view(-1, 1, 1).expand(
embs.shape[0], 1, embs.shape[2]
),
)[:, 0] )[:, 0]
loss = nsloss( loss = nsloss(
block[0].dstdata[dgl.NID][head_invmap].to(device), block[0].dstdata[dgl.NID][head_invmap].to(device),
...@@ -347,7 +378,9 @@ def train_model(network_data): ...@@ -347,7 +378,9 @@ def train_model(network_data):
model.eval() model.eval()
# {'1': {}, '2': {}} # {'1': {}, '2': {}}
final_model = dict(zip(edge_types, [dict() for _ in range(edge_type_count)])) final_model = dict(
zip(edge_types, [dict() for _ in range(edge_type_count)])
)
for i in range(num_nodes): for i in range(num_nodes):
train_inputs = ( train_inputs = (
torch.tensor([i for _ in range(edge_type_count)]) torch.tensor([i for _ in range(edge_type_count)])
...@@ -355,7 +388,9 @@ def train_model(network_data): ...@@ -355,7 +388,9 @@ def train_model(network_data):
.to(device) .to(device)
) # [i, i] ) # [i, i]
train_types = ( train_types = (
torch.tensor(list(range(edge_type_count))).unsqueeze(1).to(device) torch.tensor(list(range(edge_type_count)))
.unsqueeze(1)
.to(device)
) # [0, 1] ) # [0, 1]
pairs = torch.cat( pairs = torch.cat(
(train_inputs, train_inputs, train_types), dim=1 (train_inputs, train_inputs, train_types), dim=1
...@@ -383,7 +418,9 @@ def train_model(network_data): ...@@ -383,7 +418,9 @@ def train_model(network_data):
valid_aucs, valid_f1s, valid_prs = [], [], [] valid_aucs, valid_f1s, valid_prs = [], [], []
test_aucs, test_f1s, test_prs = [], [], [] test_aucs, test_f1s, test_prs = [], [], []
for i in range(edge_type_count): for i in range(edge_type_count):
if args.eval_type == "all" or edge_types[i] in args.eval_type.split(","): if args.eval_type == "all" or edge_types[i] in args.eval_type.split(
","
):
tmp_auc, tmp_f1, tmp_pr = evaluate( tmp_auc, tmp_f1, tmp_pr = evaluate(
final_model[edge_types[i]], final_model[edge_types[i]],
valid_true_data_by_edge[edge_types[i]], valid_true_data_by_edge[edge_types[i]],
......
from collections import defaultdict import datetime
import math import math
import os import os
import sys import sys
import time import time
import datetime from collections import defaultdict
import numpy as np import numpy as np
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel
from tqdm.auto import tqdm
from numpy import random from numpy import random
from torch.nn.parallel import DistributedDataParallel
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from tqdm.auto import tqdm
from utils import *
import dgl import dgl
import dgl.function as fn import dgl.function as fn
import torch.multiprocessing as mp
from utils import *
def setup_seed(seed): def setup_seed(seed):
...@@ -29,7 +29,7 @@ def setup_seed(seed): ...@@ -29,7 +29,7 @@ def setup_seed(seed):
def get_graph(network_data, vocab): def get_graph(network_data, vocab):
""" Build graph, treat all nodes as the same type """Build graph, treat all nodes as the same type
Parameters Parameters
---------- ----------
...@@ -69,7 +69,9 @@ class NeighborSampler(object): ...@@ -69,7 +69,9 @@ class NeighborSampler(object):
def sample(self, pairs): def sample(self, pairs):
pairs = np.stack(pairs) pairs = np.stack(pairs)
heads, tails, types = pairs[:, 0], pairs[:, 1], pairs[:, 2] heads, tails, types = pairs[:, 0], pairs[:, 1], pairs[:, 2]
seeds, head_invmap = torch.unique(torch.LongTensor(heads), return_inverse=True) seeds, head_invmap = torch.unique(
torch.LongTensor(heads), return_inverse=True
)
blocks = [] blocks = []
for fanout in reversed(self.num_fanouts): for fanout in reversed(self.num_fanouts):
sampled_graph = dgl.sampling.sample_neighbors(self.g, seeds, fanout) sampled_graph = dgl.sampling.sample_neighbors(self.g, seeds, fanout)
...@@ -102,7 +104,9 @@ class DGLGATNE(nn.Module): ...@@ -102,7 +104,9 @@ class DGLGATNE(nn.Module):
self.edge_type_count = edge_type_count self.edge_type_count = edge_type_count
self.dim_a = dim_a self.dim_a = dim_a
self.node_embeddings = nn.Embedding(num_nodes, embedding_size, sparse=True) self.node_embeddings = nn.Embedding(
num_nodes, embedding_size, sparse=True
)
self.node_type_embeddings = nn.Embedding( self.node_type_embeddings = nn.Embedding(
num_nodes * edge_type_count, embedding_u_size, sparse=True num_nodes * edge_type_count, embedding_u_size, sparse=True
) )
...@@ -112,16 +116,24 @@ class DGLGATNE(nn.Module): ...@@ -112,16 +116,24 @@ class DGLGATNE(nn.Module):
self.trans_weights_s1 = Parameter( self.trans_weights_s1 = Parameter(
torch.FloatTensor(edge_type_count, embedding_u_size, dim_a) torch.FloatTensor(edge_type_count, embedding_u_size, dim_a)
) )
self.trans_weights_s2 = Parameter(torch.FloatTensor(edge_type_count, dim_a, 1)) self.trans_weights_s2 = Parameter(
torch.FloatTensor(edge_type_count, dim_a, 1)
)
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
self.node_embeddings.weight.data.uniform_(-1.0, 1.0) self.node_embeddings.weight.data.uniform_(-1.0, 1.0)
self.node_type_embeddings.weight.data.uniform_(-1.0, 1.0) self.node_type_embeddings.weight.data.uniform_(-1.0, 1.0)
self.trans_weights.data.normal_(std=1.0 / math.sqrt(self.embedding_size)) self.trans_weights.data.normal_(
self.trans_weights_s1.data.normal_(std=1.0 / math.sqrt(self.embedding_size)) std=1.0 / math.sqrt(self.embedding_size)
self.trans_weights_s2.data.normal_(std=1.0 / math.sqrt(self.embedding_size)) )
self.trans_weights_s1.data.normal_(
std=1.0 / math.sqrt(self.embedding_size)
)
self.trans_weights_s2.data.normal_(
std=1.0 / math.sqrt(self.embedding_size)
)
# embs: [batch_size, embedding_size] # embs: [batch_size, embedding_size]
def forward(self, block): def forward(self, block):
...@@ -140,7 +152,9 @@ class DGLGATNE(nn.Module): ...@@ -140,7 +152,9 @@ class DGLGATNE(nn.Module):
output_nodes * self.edge_type_count + i output_nodes * self.edge_type_count + i
) )
block.update_all( block.update_all(
fn.copy_u(edge_type, "m"), fn.sum("m", edge_type), etype=edge_type fn.copy_u(edge_type, "m"),
fn.sum("m", edge_type),
etype=edge_type,
) )
node_type_embed.append(block.dstdata[edge_type]) node_type_embed.append(block.dstdata[edge_type])
...@@ -167,7 +181,9 @@ class DGLGATNE(nn.Module): ...@@ -167,7 +181,9 @@ class DGLGATNE(nn.Module):
attention = ( attention = (
F.softmax( F.softmax(
torch.matmul( torch.matmul(
torch.tanh(torch.matmul(tmp_node_type_embed, trans_w_s1)), torch.tanh(
torch.matmul(tmp_node_type_embed, trans_w_s1)
),
trans_w_s2, trans_w_s2,
) )
.squeeze(2) .squeeze(2)
...@@ -188,7 +204,9 @@ class DGLGATNE(nn.Module): ...@@ -188,7 +204,9 @@ class DGLGATNE(nn.Module):
) )
last_node_embed = F.normalize(node_embed, dim=2) last_node_embed = F.normalize(node_embed, dim=2)
return last_node_embed # [batch_size, edge_type_count, embedding_size] return (
last_node_embed # [batch_size, edge_type_count, embedding_size]
)
class NSLoss(nn.Module): class NSLoss(nn.Module):
...@@ -202,7 +220,8 @@ class NSLoss(nn.Module): ...@@ -202,7 +220,8 @@ class NSLoss(nn.Module):
self.sample_weights = F.normalize( self.sample_weights = F.normalize(
torch.Tensor( torch.Tensor(
[ [
(math.log(k + 2) - math.log(k + 1)) / math.log(num_nodes + 1) (math.log(k + 2) - math.log(k + 1))
/ math.log(num_nodes + 1)
for k in range(num_nodes) for k in range(num_nodes)
] ]
), ),
...@@ -212,7 +231,9 @@ class NSLoss(nn.Module): ...@@ -212,7 +231,9 @@ class NSLoss(nn.Module):
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
self.weights.weight.data.normal_(std=1.0 / math.sqrt(self.embedding_size)) self.weights.weight.data.normal_(
std=1.0 / math.sqrt(self.embedding_size)
)
def forward(self, input, embs, label): def forward(self, input, embs, label):
n = input.shape[0] n = input.shape[0]
...@@ -267,7 +288,12 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -267,7 +288,12 @@ def run(proc_id, n_gpus, args, devices, data):
neighbor_sampler = NeighborSampler(g, [neighbor_samples]) neighbor_sampler = NeighborSampler(g, [neighbor_samples])
if n_gpus > 1: if n_gpus > 1:
train_sampler = torch.utils.data.distributed.DistributedSampler( train_sampler = torch.utils.data.distributed.DistributedSampler(
train_pairs, num_replicas=world_size, rank=proc_id, shuffle=True, drop_last=False) train_pairs,
num_replicas=world_size,
rank=proc_id,
shuffle=True,
drop_last=False,
)
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
train_pairs, train_pairs,
batch_size=batch_size, batch_size=batch_size,
...@@ -288,7 +314,12 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -288,7 +314,12 @@ def run(proc_id, n_gpus, args, devices, data):
) )
model = DGLGATNE( model = DGLGATNE(
num_nodes, embedding_size, embedding_u_size, edge_types, edge_type_count, dim_a, num_nodes,
embedding_size,
embedding_u_size,
edge_types,
edge_type_count,
dim_a,
) )
nsloss = NSLoss(num_nodes, num_sampled, embedding_size) nsloss = NSLoss(num_nodes, num_sampled, embedding_size)
...@@ -306,21 +337,23 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -306,21 +337,23 @@ def run(proc_id, n_gpus, args, devices, data):
else: else:
mmodel = model mmodel = model
embeddings_params = list(map(id, mmodel.node_embeddings.parameters())) + list( embeddings_params = list(
map(id, mmodel.node_type_embeddings.parameters()) map(id, mmodel.node_embeddings.parameters())
) ) + list(map(id, mmodel.node_type_embeddings.parameters()))
weights_params = list(map(id, nsloss.weights.parameters())) weights_params = list(map(id, nsloss.weights.parameters()))
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
[ [
{ {
"params": filter( "params": filter(
lambda p: id(p) not in embeddings_params, model.parameters(), lambda p: id(p) not in embeddings_params,
model.parameters(),
) )
}, },
{ {
"params": filter( "params": filter(
lambda p: id(p) not in weights_params, nsloss.parameters(), lambda p: id(p) not in weights_params,
nsloss.parameters(),
) )
}, },
], ],
...@@ -363,7 +396,10 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -363,7 +396,10 @@ def run(proc_id, n_gpus, args, devices, data):
block_types = block_types.to(dev_id) block_types = block_types.to(dev_id)
embs = model(block[0].to(dev_id))[head_invmap] embs = model(block[0].to(dev_id))[head_invmap]
embs = embs.gather( embs = embs.gather(
1, block_types.view(-1, 1, 1).expand(embs.shape[0], 1, embs.shape[2]), 1,
block_types.view(-1, 1, 1).expand(
embs.shape[0], 1, embs.shape[2]
),
)[:, 0] )[:, 0]
loss = nsloss( loss = nsloss(
block[0].dstdata[dgl.NID][head_invmap].to(dev_id), block[0].dstdata[dgl.NID][head_invmap].to(dev_id),
...@@ -399,7 +435,9 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -399,7 +435,9 @@ def run(proc_id, n_gpus, args, devices, data):
.to(dev_id) .to(dev_id)
) # [i, i] ) # [i, i]
train_types = ( train_types = (
torch.tensor(list(range(edge_type_count))).unsqueeze(1).to(dev_id) torch.tensor(list(range(edge_type_count)))
.unsqueeze(1)
.to(dev_id)
) # [0, 1] ) # [0, 1]
pairs = torch.cat( pairs = torch.cat(
(train_inputs, train_inputs, train_types), dim=1 (train_inputs, train_inputs, train_types), dim=1
...@@ -427,9 +465,9 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -427,9 +465,9 @@ def run(proc_id, n_gpus, args, devices, data):
valid_aucs, valid_f1s, valid_prs = [], [], [] valid_aucs, valid_f1s, valid_prs = [], [], []
test_aucs, test_f1s, test_prs = [], [], [] test_aucs, test_f1s, test_prs = [], [], []
for i in range(edge_type_count): for i in range(edge_type_count):
if args.eval_type == "all" or edge_types[i] in args.eval_type.split( if args.eval_type == "all" or edge_types[
"," i
): ] in args.eval_type.split(","):
tmp_auc, tmp_f1, tmp_pr = evaluate( tmp_auc, tmp_f1, tmp_pr = evaluate(
final_model[edge_types[i]], final_model[edge_types[i]],
valid_true_data_by_edge[edge_types[i]], valid_true_data_by_edge[edge_types[i]],
......
import argparse import argparse
import multiprocessing
import time
from collections import defaultdict from collections import defaultdict
from functools import partial, reduce, wraps
import networkx as nx import networkx as nx
import numpy as np import numpy as np
import torch
from gensim.models.keyedvectors import Vocab from gensim.models.keyedvectors import Vocab
from six import iteritems from six import iteritems
from sklearn.metrics import auc, f1_score, precision_recall_curve, roc_auc_score from sklearn.metrics import (auc, f1_score, precision_recall_curve,
import torch roc_auc_score)
import time
import multiprocessing
from functools import partial, reduce, wraps
def parse_args(): def parse_args():
...@@ -25,7 +25,10 @@ def parse_args(): ...@@ -25,7 +25,10 @@ def parse_args():
) )
parser.add_argument( parser.add_argument(
"--epoch", type=int, default=100, help="Number of epoch. Default is 100." "--epoch",
type=int,
default=100,
help="Number of epoch. Default is 100.",
) )
parser.add_argument( parser.add_argument(
...@@ -36,7 +39,10 @@ def parse_args(): ...@@ -36,7 +39,10 @@ def parse_args():
) )
parser.add_argument( parser.add_argument(
"--eval-type", type=str, default="all", help="The edge type(s) for evaluation." "--eval-type",
type=str,
default="all",
help="The edge type(s) for evaluation.",
) )
parser.add_argument( parser.add_argument(
...@@ -103,15 +109,24 @@ def parse_args(): ...@@ -103,15 +109,24 @@ def parse_args():
) )
parser.add_argument( parser.add_argument(
"--patience", type=int, default=5, help="Early stopping patience. Default is 5." "--patience",
type=int,
default=5,
help="Early stopping patience. Default is 5.",
) )
parser.add_argument( parser.add_argument(
"--gpu", type=str, default=None, help="Comma separated list of GPU device IDs." "--gpu",
type=str,
default=None,
help="Comma separated list of GPU device IDs.",
) )
parser.add_argument( parser.add_argument(
"--workers", type=int, default=4, help="Number of workers.", "--workers",
type=int,
default=4,
help="Number of workers.",
) )
return parser.parse_args() return parser.parse_args()
...@@ -205,7 +220,9 @@ def generate_pairs(all_walks, window_size, num_workers): ...@@ -205,7 +220,9 @@ def generate_pairs(all_walks, window_size, num_workers):
walks_list = [walks] walks_list = [walks]
tmp_result = pool.map( tmp_result = pool.map(
partial( partial(
generate_pairs_parallel, skip_window=skip_window, layer_id=layer_id generate_pairs_parallel,
skip_window=skip_window,
layer_id=layer_id,
), ),
walks_list, walks_list,
) )
...@@ -285,4 +302,8 @@ def evaluate(model, true_edges, false_edges, num_workers): ...@@ -285,4 +302,8 @@ def evaluate(model, true_edges, false_edges, num_workers):
y_true = np.array(true_list) y_true = np.array(true_list)
y_scores = np.array(prediction_list) y_scores = np.array(prediction_list)
ps, rs, _ = precision_recall_curve(y_true, y_scores) ps, rs, _ = precision_recall_curve(y_true, y_scores)
return roc_auc_score(y_true, y_scores), f1_score(y_true, y_pred), auc(rs, ps) return (
roc_auc_score(y_true, y_scores),
f1_score(y_true, y_pred),
auc(rs, ps),
)
import itertools import itertools
import time import time
import numpy as np
import torch
from catboost import Pool, CatBoostClassifier, CatBoostRegressor, sum_models
from tqdm import tqdm
from collections import defaultdict as ddict from collections import defaultdict as ddict
import numpy as np
import pandas as pd import pandas as pd
from sklearn import preprocessing import torch
import torch.nn.functional as F import torch.nn.functional as F
from catboost import CatBoostClassifier, CatBoostRegressor, Pool, sum_models
from sklearn import preprocessing
from sklearn.metrics import r2_score from sklearn.metrics import r2_score
from tqdm import tqdm
class BGNNPredictor: class BGNNPredictor:
''' """
Description Description
----------- -----------
Boost GNN predictor for semi-supervised node classification or regression problems. Boost GNN predictor for semi-supervised node classification or regression problems.
...@@ -50,22 +51,26 @@ class BGNNPredictor: ...@@ -50,22 +51,26 @@ class BGNNPredictor:
gnn_model = GAT(10, 20, num_heads=5), gnn_model = GAT(10, 20, num_heads=5),
bgnn = BGNNPredictor(gnn_model) bgnn = BGNNPredictor(gnn_model)
metrics = bgnn.fit(graph, X, y, train_mask, val_mask, test_mask, cat_features) metrics = bgnn.fit(graph, X, y, train_mask, val_mask, test_mask, cat_features)
''' """
def __init__(self,
gnn_model, def __init__(
task = 'regression', self,
loss_fn = None, gnn_model,
trees_per_epoch = 10, task="regression",
backprop_per_epoch = 10, loss_fn=None,
lr=0.01, trees_per_epoch=10,
append_gbdt_pred = True, backprop_per_epoch=10,
train_input_features = False, lr=0.01,
gbdt_depth=6, append_gbdt_pred=True,
gbdt_lr=0.1, train_input_features=False,
gbdt_alpha = 1, gbdt_depth=6,
random_seed = 0 gbdt_lr=0.1,
): gbdt_alpha=1,
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') random_seed=0,
):
self.device = torch.device(
"cuda:0" if torch.cuda.is_available() else "cpu"
)
self.model = gnn_model.to(self.device) self.model = gnn_model.to(self.device)
self.task = task self.task = task
...@@ -83,23 +88,25 @@ class BGNNPredictor: ...@@ -83,23 +88,25 @@ class BGNNPredictor:
np.random.seed(random_seed) np.random.seed(random_seed)
def init_gbdt_model(self, num_epochs, epoch): def init_gbdt_model(self, num_epochs, epoch):
if self.task == 'regression': if self.task == "regression":
catboost_model_obj = CatBoostRegressor catboost_model_obj = CatBoostRegressor
catboost_loss_fn = 'RMSE' catboost_loss_fn = "RMSE"
else: else:
if epoch == 0: # we predict multiclass probs at first epoch if epoch == 0: # we predict multiclass probs at first epoch
catboost_model_obj = CatBoostClassifier catboost_model_obj = CatBoostClassifier
catboost_loss_fn = 'MultiClass' catboost_loss_fn = "MultiClass"
else: # we predict the gradients for each class at epochs > 0 else: # we predict the gradients for each class at epochs > 0
catboost_model_obj = CatBoostRegressor catboost_model_obj = CatBoostRegressor
catboost_loss_fn = 'MultiRMSE' catboost_loss_fn = "MultiRMSE"
return catboost_model_obj(iterations=num_epochs, return catboost_model_obj(
depth=self.gbdt_depth, iterations=num_epochs,
learning_rate=self.gbdt_lr, depth=self.gbdt_depth,
loss_function=catboost_loss_fn, learning_rate=self.gbdt_lr,
random_seed=self.random_seed, loss_function=catboost_loss_fn,
nan_mode='Min') random_seed=self.random_seed,
nan_mode="Min",
)
def fit_gbdt(self, pool, trees_per_epoch, epoch): def fit_gbdt(self, pool, trees_per_epoch, epoch):
gbdt_model = self.init_gbdt_model(trees_per_epoch, epoch) gbdt_model = self.init_gbdt_model(trees_per_epoch, epoch)
...@@ -111,19 +118,30 @@ class BGNNPredictor: ...@@ -111,19 +118,30 @@ class BGNNPredictor:
return new_gbdt_model return new_gbdt_model
return sum_models([self.gbdt_model, new_gbdt_model], weights=weights) return sum_models([self.gbdt_model, new_gbdt_model], weights=weights)
def train_gbdt(self, gbdt_X_train, gbdt_y_train, cat_features, epoch, def train_gbdt(
gbdt_trees_per_epoch, gbdt_alpha): self,
gbdt_X_train,
gbdt_y_train,
cat_features,
epoch,
gbdt_trees_per_epoch,
gbdt_alpha,
):
pool = Pool(gbdt_X_train, gbdt_y_train, cat_features=cat_features) pool = Pool(gbdt_X_train, gbdt_y_train, cat_features=cat_features)
epoch_gbdt_model = self.fit_gbdt(pool, gbdt_trees_per_epoch, epoch) epoch_gbdt_model = self.fit_gbdt(pool, gbdt_trees_per_epoch, epoch)
if epoch == 0 and self.task=='classification': if epoch == 0 and self.task == "classification":
self.base_gbdt = epoch_gbdt_model self.base_gbdt = epoch_gbdt_model
else: else:
self.gbdt_model = self.append_gbdt_model(epoch_gbdt_model, weights=[1, gbdt_alpha]) self.gbdt_model = self.append_gbdt_model(
epoch_gbdt_model, weights=[1, gbdt_alpha]
)
def update_node_features(self, node_features, X, original_X): def update_node_features(self, node_features, X, original_X):
# get predictions from gbdt model # get predictions from gbdt model
if self.task == 'regression': if self.task == "regression":
predictions = np.expand_dims(self.gbdt_model.predict(original_X), axis=1) predictions = np.expand_dims(
self.gbdt_model.predict(original_X), axis=1
)
else: else:
predictions = self.base_gbdt.predict_proba(original_X) predictions = self.base_gbdt.predict_proba(original_X)
if self.gbdt_model is not None: if self.gbdt_model is not None:
...@@ -133,26 +151,43 @@ class BGNNPredictor: ...@@ -133,26 +151,43 @@ class BGNNPredictor:
# update node features with predictions # update node features with predictions
if self.append_gbdt_pred: if self.append_gbdt_pred:
if self.train_input_features: if self.train_input_features:
predictions = np.append(node_features.detach().cpu().data[:, :-self.out_dim], predictions = np.append(
predictions, node_features.detach().cpu().data[:, : -self.out_dim],
axis=1) # replace old predictions with new predictions predictions,
axis=1,
) # replace old predictions with new predictions
else: else:
predictions = np.append(X, predictions, axis=1) # append original features with new predictions predictions = np.append(
X, predictions, axis=1
) # append original features with new predictions
predictions = torch.from_numpy(predictions).to(self.device) predictions = torch.from_numpy(predictions).to(self.device)
node_features.data = predictions.float().data node_features.data = predictions.float().data
def update_gbdt_targets(self, node_features, node_features_before, train_mask): def update_gbdt_targets(
return (node_features - node_features_before).detach().cpu().numpy()[train_mask, -self.out_dim:] self, node_features, node_features_before, train_mask
):
return (
(node_features - node_features_before)
.detach()
.cpu()
.numpy()[train_mask, -self.out_dim :]
)
def init_node_features(self, X): def init_node_features(self, X):
node_features = torch.empty(X.shape[0], self.in_dim, requires_grad=True, device=self.device) node_features = torch.empty(
X.shape[0], self.in_dim, requires_grad=True, device=self.device
)
if self.append_gbdt_pred: if self.append_gbdt_pred:
node_features.data[:, :-self.out_dim] = torch.from_numpy(X.to_numpy(copy=True)) node_features.data[:, : -self.out_dim] = torch.from_numpy(
X.to_numpy(copy=True)
)
return node_features return node_features
def init_optimizer(self, node_features, optimize_node_features, learning_rate): def init_optimizer(
self, node_features, optimize_node_features, learning_rate
):
params = [self.model.parameters()] params = [self.model.parameters()]
if optimize_node_features: if optimize_node_features:
...@@ -170,12 +205,14 @@ class BGNNPredictor: ...@@ -170,12 +205,14 @@ class BGNNPredictor:
if self.loss_fn is not None: if self.loss_fn is not None:
loss = self.loss_fn(pred, y) loss = self.loss_fn(pred, y)
else: else:
if self.task == 'regression': if self.task == "regression":
loss = torch.sqrt(F.mse_loss(pred, y)) loss = torch.sqrt(F.mse_loss(pred, y))
elif self.task == 'classification': elif self.task == "classification":
loss = F.cross_entropy(pred, y.long()) loss = F.cross_entropy(pred, y.long())
else: else:
raise NotImplemented("Unknown task. Supported tasks: classification, regression.") raise NotImplemented(
"Unknown task. Supported tasks: classification, regression."
)
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
...@@ -187,24 +224,43 @@ class BGNNPredictor: ...@@ -187,24 +224,43 @@ class BGNNPredictor:
y = target_labels[mask] y = target_labels[mask]
with torch.no_grad(): with torch.no_grad():
pred = logits[mask] pred = logits[mask]
if self.task == 'regression': if self.task == "regression":
metrics['loss'] = torch.sqrt(F.mse_loss(pred, y).squeeze() + 1e-8) metrics["loss"] = torch.sqrt(
metrics['rmsle'] = torch.sqrt(F.mse_loss(torch.log(pred + 1), torch.log(y + 1)).squeeze() + 1e-8) F.mse_loss(pred, y).squeeze() + 1e-8
metrics['mae'] = F.l1_loss(pred, y) )
metrics['r2'] = torch.Tensor([r2_score(y.cpu().numpy(), pred.cpu().numpy())]) metrics["rmsle"] = torch.sqrt(
elif self.task == 'classification': F.mse_loss(torch.log(pred + 1), torch.log(y + 1)).squeeze()
metrics['loss'] = F.cross_entropy(pred, y.long()) + 1e-8
metrics['accuracy'] = torch.Tensor([(y == pred.max(1)[1]).sum().item()/y.shape[0]]) )
metrics["mae"] = F.l1_loss(pred, y)
metrics["r2"] = torch.Tensor(
[r2_score(y.cpu().numpy(), pred.cpu().numpy())]
)
elif self.task == "classification":
metrics["loss"] = F.cross_entropy(pred, y.long())
metrics["accuracy"] = torch.Tensor(
[(y == pred.max(1)[1]).sum().item() / y.shape[0]]
)
return metrics return metrics
def train_and_evaluate(
def train_and_evaluate(self, model_in, target_labels, train_mask, val_mask, test_mask, self,
optimizer, metrics, gnn_passes_per_epoch): model_in,
target_labels,
train_mask,
val_mask,
test_mask,
optimizer,
metrics,
gnn_passes_per_epoch,
):
loss = None loss = None
for _ in range(gnn_passes_per_epoch): for _ in range(gnn_passes_per_epoch):
loss = self.train_model(model_in, target_labels, train_mask, optimizer) loss = self.train_model(
model_in, target_labels, train_mask, optimizer
)
self.model.eval() self.model.eval()
logits = self.model(*model_in).squeeze() logits = self.model(*model_in).squeeze()
...@@ -212,16 +268,29 @@ class BGNNPredictor: ...@@ -212,16 +268,29 @@ class BGNNPredictor:
val_results = self.evaluate_model(logits, target_labels, val_mask) val_results = self.evaluate_model(logits, target_labels, val_mask)
test_results = self.evaluate_model(logits, target_labels, test_mask) test_results = self.evaluate_model(logits, target_labels, test_mask)
for metric_name in train_results: for metric_name in train_results:
metrics[metric_name].append((train_results[metric_name].detach().item(), metrics[metric_name].append(
val_results[metric_name].detach().item(), (
test_results[metric_name].detach().item() train_results[metric_name].detach().item(),
)) val_results[metric_name].detach().item(),
test_results[metric_name].detach().item(),
)
)
return loss return loss
def update_early_stopping(self, metrics, epoch, best_metric, best_val_epoch, epochs_since_last_best_metric, metric_name, def update_early_stopping(
lower_better=False): self,
metrics,
epoch,
best_metric,
best_val_epoch,
epochs_since_last_best_metric,
metric_name,
lower_better=False,
):
train_metric, val_metric, test_metric = metrics[metric_name][-1] train_metric, val_metric, test_metric = metrics[metric_name][-1]
if (lower_better and val_metric < best_metric[1]) or (not lower_better and val_metric > best_metric[1]): if (lower_better and val_metric < best_metric[1]) or (
not lower_better and val_metric > best_metric[1]
):
best_metric = metrics[metric_name][-1] best_metric = metrics[metric_name][-1]
best_val_epoch = epoch best_val_epoch = epoch
epochs_since_last_best_metric = 0 epochs_since_last_best_metric = 0
...@@ -229,26 +298,45 @@ class BGNNPredictor: ...@@ -229,26 +298,45 @@ class BGNNPredictor:
epochs_since_last_best_metric += 1 epochs_since_last_best_metric += 1
return best_metric, best_val_epoch, epochs_since_last_best_metric return best_metric, best_val_epoch, epochs_since_last_best_metric
def log_epoch(self, pbar, metrics, epoch, loss, epoch_time, logging_epochs, metric_name='loss'): def log_epoch(
self,
pbar,
metrics,
epoch,
loss,
epoch_time,
logging_epochs,
metric_name="loss",
):
train_metric, val_metric, test_metric = metrics[metric_name][-1] train_metric, val_metric, test_metric = metrics[metric_name][-1]
if epoch and epoch % logging_epochs == 0: if epoch and epoch % logging_epochs == 0:
pbar.set_description( pbar.set_description(
"Epoch {:05d} | Loss {:.3f} | Loss {:.3f}/{:.3f}/{:.3f} | Time {:.4f}".format(epoch, loss, "Epoch {:05d} | Loss {:.3f} | Loss {:.3f}/{:.3f}/{:.3f} | Time {:.4f}".format(
train_metric, epoch,
val_metric, loss,
test_metric, train_metric,
epoch_time)) val_metric,
test_metric,
def fit(self, graph, X, y, epoch_time,
train_mask, val_mask, test_mask, )
original_X = None, )
cat_features = None,
num_epochs=100, def fit(
patience=10, self,
logging_epochs=1, graph,
metric_name='loss', X,
): y,
''' train_mask,
val_mask,
test_mask,
original_X=None,
cat_features=None,
num_epochs=100,
patience=10,
logging_epochs=1,
metric_name="loss",
):
"""
:param graph : dgl.DGLGraph :param graph : dgl.DGLGraph
Input graph Input graph
...@@ -283,13 +371,13 @@ class BGNNPredictor: ...@@ -283,13 +371,13 @@ class BGNNPredictor:
:param replace_na: bool :param replace_na: bool
If to replace missing values (None) in X. If to replace missing values (None) in X.
:return: metrics evaluated during training :return: metrics evaluated during training
''' """
# initialize for early stopping and metrics # initialize for early stopping and metrics
if metric_name in ['r2', 'accuracy']: if metric_name in ["r2", "accuracy"]:
best_metric = [np.float('-inf')] * 3 # for train/val/test best_metric = [np.float("-inf")] * 3 # for train/val/test
else: else:
best_metric = [np.float('inf')] * 3 # for train/val/test best_metric = [np.float("inf")] * 3 # for train/val/test
best_val_epoch = 0 best_val_epoch = 0
epochs_since_last_best_metric = 0 epochs_since_last_best_metric = 0
...@@ -297,11 +385,13 @@ class BGNNPredictor: ...@@ -297,11 +385,13 @@ class BGNNPredictor:
if cat_features is None: if cat_features is None:
cat_features = [] cat_features = []
if self.task == 'regression': if self.task == "regression":
self.out_dim = y.shape[1] self.out_dim = y.shape[1]
elif self.task == 'classification': elif self.task == "classification":
self.out_dim = len(set(y.iloc[test_mask, 0])) self.out_dim = len(set(y.iloc[test_mask, 0]))
self.in_dim = self.out_dim + X.shape[1] if self.append_gbdt_pred else self.out_dim self.in_dim = (
self.out_dim + X.shape[1] if self.append_gbdt_pred else self.out_dim
)
if original_X is None: if original_X is None:
original_X = X.copy() original_X = X.copy()
...@@ -313,9 +403,16 @@ class BGNNPredictor: ...@@ -313,9 +403,16 @@ class BGNNPredictor:
self.gbdt_model = None self.gbdt_model = None
node_features = self.init_node_features(X) node_features = self.init_node_features(X)
optimizer = self.init_optimizer(node_features, optimize_node_features=True, learning_rate=self.lr) optimizer = self.init_optimizer(
node_features, optimize_node_features=True, learning_rate=self.lr
)
y = torch.from_numpy(y.to_numpy(copy=True)).float().squeeze().to(self.device) y = (
torch.from_numpy(y.to_numpy(copy=True))
.float()
.squeeze()
.to(self.device)
)
graph = graph.to(self.device) graph = graph.to(self.device)
pbar = tqdm(range(num_epochs)) pbar = tqdm(range(num_epochs))
...@@ -323,31 +420,68 @@ class BGNNPredictor: ...@@ -323,31 +420,68 @@ class BGNNPredictor:
start2epoch = time.time() start2epoch = time.time()
# gbdt part # gbdt part
self.train_gbdt(gbdt_X_train, gbdt_y_train, cat_features, epoch, self.train_gbdt(
self.trees_per_epoch, gbdt_alpha) gbdt_X_train,
gbdt_y_train,
cat_features,
epoch,
self.trees_per_epoch,
gbdt_alpha,
)
self.update_node_features(node_features, X, original_X) self.update_node_features(node_features, X, original_X)
node_features_before = node_features.clone() node_features_before = node_features.clone()
model_in=(graph, node_features) model_in = (graph, node_features)
loss = self.train_and_evaluate(model_in, y, train_mask, val_mask, test_mask, loss = self.train_and_evaluate(
optimizer, metrics, self.backprop_per_epoch) model_in,
gbdt_y_train = self.update_gbdt_targets(node_features, node_features_before, train_mask) y,
train_mask,
self.log_epoch(pbar, metrics, epoch, loss, time.time() - start2epoch, logging_epochs, val_mask,
metric_name=metric_name) test_mask,
optimizer,
metrics,
self.backprop_per_epoch,
)
gbdt_y_train = self.update_gbdt_targets(
node_features, node_features_before, train_mask
)
self.log_epoch(
pbar,
metrics,
epoch,
loss,
time.time() - start2epoch,
logging_epochs,
metric_name=metric_name,
)
# check early stopping # check early stopping
best_metric, best_val_epoch, epochs_since_last_best_metric = \ (
self.update_early_stopping(metrics, epoch, best_metric, best_val_epoch, epochs_since_last_best_metric, best_metric,
metric_name, lower_better=(metric_name not in ['r2', 'accuracy'])) best_val_epoch,
epochs_since_last_best_metric,
) = self.update_early_stopping(
metrics,
epoch,
best_metric,
best_val_epoch,
epochs_since_last_best_metric,
metric_name,
lower_better=(metric_name not in ["r2", "accuracy"]),
)
if patience and epochs_since_last_best_metric > patience: if patience and epochs_since_last_best_metric > patience:
break break
if np.isclose(gbdt_y_train.sum(), 0.): if np.isclose(gbdt_y_train.sum(), 0.0):
print('Node embeddings do not change anymore. Stopping...') print("Node embeddings do not change anymore. Stopping...")
break break
print('Best {} at iteration {}: {:.3f}/{:.3f}/{:.3f}'.format(metric_name, best_val_epoch, *best_metric)) print(
"Best {} at iteration {}: {:.3f}/{:.3f}/{:.3f}".format(
metric_name, best_val_epoch, *best_metric
)
)
return metrics return metrics
def predict(self, graph, X, test_mask): def predict(self, graph, X, test_mask):
...@@ -355,27 +489,42 @@ class BGNNPredictor: ...@@ -355,27 +489,42 @@ class BGNNPredictor:
node_features = torch.empty(X.shape[0], self.in_dim).to(self.device) node_features = torch.empty(X.shape[0], self.in_dim).to(self.device)
self.update_node_features(node_features, X, X) self.update_node_features(node_features, X, X)
logits = self.model(graph, node_features).squeeze() logits = self.model(graph, node_features).squeeze()
if self.task == 'regression': if self.task == "regression":
return logits[test_mask] return logits[test_mask]
else: else:
return logits[test_mask].max(1)[1] return logits[test_mask].max(1)[1]
def plot_interactive(self, metrics, legend, title, logx=False, logy=False, metric_name='loss', start_from=0): def plot_interactive(
self,
metrics,
legend,
title,
logx=False,
logy=False,
metric_name="loss",
start_from=0,
):
import plotly.graph_objects as go import plotly.graph_objects as go
metric_results = metrics[metric_name] metric_results = metrics[metric_name]
xs = [list(range(len(metric_results)))] * len(metric_results[0]) xs = [list(range(len(metric_results)))] * len(metric_results[0])
ys = list(zip(*metric_results)) ys = list(zip(*metric_results))
fig = go.Figure() fig = go.Figure()
for i in range(len(ys)): for i in range(len(ys)):
fig.add_trace(go.Scatter(x=xs[i][start_from:], y=ys[i][start_from:], fig.add_trace(
mode='lines+markers', go.Scatter(
name=legend[i])) x=xs[i][start_from:],
y=ys[i][start_from:],
mode="lines+markers",
name=legend[i],
)
)
fig.update_layout( fig.update_layout(
title=title, title=title,
title_x=0.5, title_x=0.5,
xaxis_title='Epoch', xaxis_title="Epoch",
yaxis_title=metric_name, yaxis_title=metric_name,
font=dict( font=dict(
size=40, size=40,
...@@ -388,4 +537,4 @@ class BGNNPredictor: ...@@ -388,4 +537,4 @@ class BGNNPredictor:
if logy: if logy:
fig.update_layout(yaxis_type="log") fig.update_layout(yaxis_type="log")
fig.show() fig.show()
\ No newline at end of file
import torch
from BGNN import BGNNPredictor
import pandas as pd
import numpy as np
import json import json
import os import os
from dgl.data.utils import load_graphs
from dgl.nn.pytorch import GATConv as GATConvDGL, GraphConv, ChebConv as ChebConvDGL, \ import numpy as np
AGNNConv as AGNNConvDGL, APPNPConv import pandas as pd
from torch.nn import Dropout, ELU, Sequential, Linear, ReLU import torch
import torch.nn.functional as F import torch.nn.functional as F
from BGNN import BGNNPredictor
from category_encoders import CatBoostEncoder from category_encoders import CatBoostEncoder
from sklearn import preprocessing from sklearn import preprocessing
from torch.nn import ELU, Dropout, Linear, ReLU, Sequential
from dgl.data.utils import load_graphs
from dgl.nn.pytorch import AGNNConv as AGNNConvDGL
from dgl.nn.pytorch import APPNPConv
from dgl.nn.pytorch import ChebConv as ChebConvDGL
from dgl.nn.pytorch import GATConv as GATConvDGL
from dgl.nn.pytorch import GraphConv
class GNNModelDGL(torch.nn.Module): class GNNModelDGL(torch.nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim, def __init__(
dropout=0., name='gat', residual=True, use_mlp=False, join_with_mlp=False): self,
in_dim,
hidden_dim,
out_dim,
dropout=0.0,
name="gat",
residual=True,
use_mlp=False,
join_with_mlp=False,
):
super(GNNModelDGL, self).__init__() super(GNNModelDGL, self).__init__()
self.name = name self.name = name
self.use_mlp = use_mlp self.use_mlp = use_mlp
self.join_with_mlp = join_with_mlp self.join_with_mlp = join_with_mlp
self.normalize_input_columns = True self.normalize_input_columns = True
if name == 'gat': if name == "gat":
self.l1 = GATConvDGL(in_dim, hidden_dim//8, 8, feat_drop=dropout, attn_drop=dropout, residual=False, self.l1 = GATConvDGL(
activation=F.elu) in_dim,
self.l2 = GATConvDGL(hidden_dim, out_dim, 1, feat_drop=dropout, attn_drop=dropout, residual=residual, activation=None) hidden_dim // 8,
elif name == 'gcn': 8,
feat_drop=dropout,
attn_drop=dropout,
residual=False,
activation=F.elu,
)
self.l2 = GATConvDGL(
hidden_dim,
out_dim,
1,
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=None,
)
elif name == "gcn":
self.l1 = GraphConv(in_dim, hidden_dim, activation=F.elu) self.l1 = GraphConv(in_dim, hidden_dim, activation=F.elu)
self.l2 = GraphConv(hidden_dim, out_dim, activation=F.elu) self.l2 = GraphConv(hidden_dim, out_dim, activation=F.elu)
self.drop = Dropout(p=dropout) self.drop = Dropout(p=dropout)
elif name == 'cheb': elif name == "cheb":
self.l1 = ChebConvDGL(in_dim, hidden_dim, k = 3) self.l1 = ChebConvDGL(in_dim, hidden_dim, k=3)
self.l2 = ChebConvDGL(hidden_dim, out_dim, k = 3) self.l2 = ChebConvDGL(hidden_dim, out_dim, k=3)
self.drop = Dropout(p=dropout) self.drop = Dropout(p=dropout)
elif name == 'agnn': elif name == "agnn":
self.lin1 = Sequential(Dropout(p=dropout), Linear(in_dim, hidden_dim), ELU()) self.lin1 = Sequential(
Dropout(p=dropout), Linear(in_dim, hidden_dim), ELU()
)
self.l1 = AGNNConvDGL(learn_beta=False) self.l1 = AGNNConvDGL(learn_beta=False)
self.l2 = AGNNConvDGL(learn_beta=True) self.l2 = AGNNConvDGL(learn_beta=True)
self.lin2 = Sequential(Dropout(p=dropout), Linear(hidden_dim, out_dim), ELU()) self.lin2 = Sequential(
elif name == 'appnp': Dropout(p=dropout), Linear(hidden_dim, out_dim), ELU()
self.lin1 = Sequential(Dropout(p=dropout), Linear(in_dim, hidden_dim), )
ReLU(), Dropout(p=dropout), Linear(hidden_dim, out_dim)) elif name == "appnp":
self.l1 = APPNPConv(k=10, alpha=0.1, edge_drop=0.) self.lin1 = Sequential(
Dropout(p=dropout),
Linear(in_dim, hidden_dim),
ReLU(),
Dropout(p=dropout),
Linear(hidden_dim, out_dim),
)
self.l1 = APPNPConv(k=10, alpha=0.1, edge_drop=0.0)
def forward(self, graph, features): def forward(self, graph, features):
h = features h = features
...@@ -50,36 +88,37 @@ class GNNModelDGL(torch.nn.Module): ...@@ -50,36 +88,37 @@ class GNNModelDGL(torch.nn.Module):
h = torch.cat((h, self.mlp(features)), 1) h = torch.cat((h, self.mlp(features)), 1)
else: else:
h = self.mlp(features) h = self.mlp(features)
if self.name == 'gat': if self.name == "gat":
h = self.l1(graph, h).flatten(1) h = self.l1(graph, h).flatten(1)
logits = self.l2(graph, h).mean(1) logits = self.l2(graph, h).mean(1)
elif self.name in ['appnp']: elif self.name in ["appnp"]:
h = self.lin1(h) h = self.lin1(h)
logits = self.l1(graph, h) logits = self.l1(graph, h)
elif self.name == 'agnn': elif self.name == "agnn":
h = self.lin1(h) h = self.lin1(h)
h = self.l1(graph, h) h = self.l1(graph, h)
h = self.l2(graph, h) h = self.l2(graph, h)
logits = self.lin2(h) logits = self.lin2(h)
elif self.name == 'che3b': elif self.name == "che3b":
lambda_max = dgl.laplacian_lambda_max(graph) lambda_max = dgl.laplacian_lambda_max(graph)
h = self.drop(h) h = self.drop(h)
h = self.l1(graph, h, lambda_max) h = self.l1(graph, h, lambda_max)
logits = self.l2(graph, h, lambda_max) logits = self.l2(graph, h, lambda_max)
elif self.name == 'gcn': elif self.name == "gcn":
h = self.drop(h) h = self.drop(h)
h = self.l1(graph, h) h = self.l1(graph, h)
logits = self.l2(graph, h) logits = self.l2(graph, h)
return logits return logits
def read_input(input_folder): def read_input(input_folder):
X = pd.read_csv(f'{input_folder}/X.csv') X = pd.read_csv(f"{input_folder}/X.csv")
y = pd.read_csv(f'{input_folder}/y.csv') y = pd.read_csv(f"{input_folder}/y.csv")
categorical_columns = [] categorical_columns = []
if os.path.exists(f'{input_folder}/cat_features.txt'): if os.path.exists(f"{input_folder}/cat_features.txt"):
with open(f'{input_folder}/cat_features.txt') as f: with open(f"{input_folder}/cat_features.txt") as f:
for line in f: for line in f:
if line.strip(): if line.strip():
categorical_columns.append(line.strip()) categorical_columns.append(line.strip())
...@@ -92,14 +131,15 @@ def read_input(input_folder): ...@@ -92,14 +131,15 @@ def read_input(input_folder):
for col in list(columns[cat_features]): for col in list(columns[cat_features]):
X[col] = X[col].astype(str) X[col] = X[col].astype(str)
gs, _ = load_graphs(f'{input_folder}/graph.dgl') gs, _ = load_graphs(f"{input_folder}/graph.dgl")
graph = gs[0] graph = gs[0]
with open(f'{input_folder}/masks.json') as f: with open(f"{input_folder}/masks.json") as f:
masks = json.load(f) masks = json.load(f)
return graph, X, y, cat_features, masks return graph, X, y, cat_features, masks
def normalize_features(X, train_mask, val_mask, test_mask): def normalize_features(X, train_mask, val_mask, test_mask):
min_max_scaler = preprocessing.MinMaxScaler() min_max_scaler = preprocessing.MinMaxScaler()
A = X.to_numpy(copy=True) A = X.to_numpy(copy=True)
...@@ -107,72 +147,106 @@ def normalize_features(X, train_mask, val_mask, test_mask): ...@@ -107,72 +147,106 @@ def normalize_features(X, train_mask, val_mask, test_mask):
A[val_mask + test_mask] = min_max_scaler.transform(A[val_mask + test_mask]) A[val_mask + test_mask] = min_max_scaler.transform(A[val_mask + test_mask])
return pd.DataFrame(A, columns=X.columns).astype(float) return pd.DataFrame(A, columns=X.columns).astype(float)
def replace_na(X, train_mask): def replace_na(X, train_mask):
if X.isna().any().any(): if X.isna().any().any():
return X.fillna(X.iloc[train_mask].min() - 1) return X.fillna(X.iloc[train_mask].min() - 1)
return X return X
def encode_cat_features(X, y, cat_features, train_mask, val_mask, test_mask): def encode_cat_features(X, y, cat_features, train_mask, val_mask, test_mask):
enc = CatBoostEncoder() enc = CatBoostEncoder()
A = X.to_numpy(copy=True) A = X.to_numpy(copy=True)
b = y.to_numpy(copy=True) b = y.to_numpy(copy=True)
A[np.ix_(train_mask, cat_features)] = enc.fit_transform(A[np.ix_(train_mask, cat_features)], b[train_mask]) A[np.ix_(train_mask, cat_features)] = enc.fit_transform(
A[np.ix_(val_mask + test_mask, cat_features)] = enc.transform(A[np.ix_(val_mask + test_mask, cat_features)]) A[np.ix_(train_mask, cat_features)], b[train_mask]
)
A[np.ix_(val_mask + test_mask, cat_features)] = enc.transform(
A[np.ix_(val_mask + test_mask, cat_features)]
)
A = A.astype(float) A = A.astype(float)
return pd.DataFrame(A, columns=X.columns) return pd.DataFrame(A, columns=X.columns)
if __name__ == '__main__':
if __name__ == "__main__":
# datasets can be found here: https://www.dropbox.com/s/verx1evkykzli88/datasets.zip # datasets can be found here: https://www.dropbox.com/s/verx1evkykzli88/datasets.zip
# Read dataset # Read dataset
input_folder = 'datasets/avazu' input_folder = "datasets/avazu"
graph, X, y, cat_features, masks = read_input(input_folder) graph, X, y, cat_features, masks = read_input(input_folder)
train_mask, val_mask, test_mask = masks['0']['train'], masks['0']['val'], masks['0']['test'] train_mask, val_mask, test_mask = (
masks["0"]["train"],
masks["0"]["val"],
masks["0"]["test"],
)
encoded_X = X.copy() encoded_X = X.copy()
normalizeFeatures = False normalizeFeatures = False
replaceNa = True replaceNa = True
if len(cat_features): if len(cat_features):
encoded_X = encode_cat_features(encoded_X, y, cat_features, train_mask, val_mask, test_mask) encoded_X = encode_cat_features(
encoded_X, y, cat_features, train_mask, val_mask, test_mask
)
if normalizeFeatures: if normalizeFeatures:
encoded_X = normalize_features(encoded_X, train_mask, val_mask, test_mask) encoded_X = normalize_features(
encoded_X, train_mask, val_mask, test_mask
)
if replaceNa: if replaceNa:
encoded_X = replace_na(encoded_X, train_mask) encoded_X = replace_na(encoded_X, train_mask)
# specify parameters # specify parameters
task = 'regression' task = "regression"
hidden_dim = 128 hidden_dim = 128
trees_per_epoch = 5 # 5-10 are good values to try trees_per_epoch = 5 # 5-10 are good values to try
backprop_per_epoch = 5 # 5-10 are good values to try backprop_per_epoch = 5 # 5-10 are good values to try
lr = 0.1 # 0.01-0.1 are good values to try lr = 0.1 # 0.01-0.1 are good values to try
append_gbdt_pred = False # this can be important for performance (try True and False) append_gbdt_pred = (
False # this can be important for performance (try True and False)
)
train_input_features = False train_input_features = False
gbdt_depth = 6 gbdt_depth = 6
gbdt_lr = 0.1 gbdt_lr = 0.1
out_dim = y.shape[1] if task == 'regression' else len(set(y.iloc[test_mask, 0])) out_dim = (
y.shape[1] if task == "regression" else len(set(y.iloc[test_mask, 0]))
)
in_dim = out_dim + X.shape[1] if append_gbdt_pred else out_dim in_dim = out_dim + X.shape[1] if append_gbdt_pred else out_dim
# specify GNN model # specify GNN model
gnn_model = GNNModelDGL(in_dim, hidden_dim, out_dim) gnn_model = GNNModelDGL(in_dim, hidden_dim, out_dim)
# initialize BGNN model # initialize BGNN model
bgnn = BGNNPredictor(gnn_model, task=task, bgnn = BGNNPredictor(
loss_fn=None, gnn_model,
trees_per_epoch=trees_per_epoch, task=task,
backprop_per_epoch=backprop_per_epoch, loss_fn=None,
lr=lr, trees_per_epoch=trees_per_epoch,
append_gbdt_pred=append_gbdt_pred, backprop_per_epoch=backprop_per_epoch,
train_input_features=train_input_features, lr=lr,
gbdt_depth=gbdt_depth, append_gbdt_pred=append_gbdt_pred,
gbdt_lr=gbdt_lr) train_input_features=train_input_features,
gbdt_depth=gbdt_depth,
gbdt_lr=gbdt_lr,
)
# train # train
metrics = bgnn.fit(graph, encoded_X, y, train_mask, val_mask, test_mask, metrics = bgnn.fit(
original_X = X, cat_features=cat_features, graph,
num_epochs=100, patience=10, metric_name='loss') encoded_X,
y,
bgnn.plot_interactive(metrics, legend=['train', 'valid', 'test'], title='Avazu', metric_name='loss') train_mask,
val_mask,
test_mask,
original_X=X,
cat_features=cat_features,
num_epochs=100,
patience=10,
metric_name="loss",
)
bgnn.plot_interactive(
metrics,
legend=["train", "valid", "test"],
title="Avazu",
metric_name="loss",
)
import torch
import numpy as np import numpy as np
import torch
from sklearn import metrics from sklearn import metrics
from sklearn.multiclass import OneVsRestClassifier
from sklearn.linear_model import LogisticRegression from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV, ShuffleSplit, train_test_split from sklearn.model_selection import (GridSearchCV, ShuffleSplit,
train_test_split)
from sklearn.multiclass import OneVsRestClassifier
from sklearn.preprocessing import OneHotEncoder, normalize from sklearn.preprocessing import OneHotEncoder, normalize
def fit_logistic_regression(X, y, data_random_seed=1, repeat=1): def fit_logistic_regression(X, y, data_random_seed=1, repeat=1):
# transform targets to one-hot vector # transform targets to one-hot vector
one_hot_encoder = OneHotEncoder(categories='auto', sparse=False) one_hot_encoder = OneHotEncoder(categories="auto", sparse=False)
y = one_hot_encoder.fit_transform(y.reshape(-1, 1)).astype(np.bool) y = one_hot_encoder.fit_transform(y.reshape(-1, 1)).astype(np.bool)
# normalize x # normalize x
X = normalize(X, norm='l2') X = normalize(X, norm="l2")
# set random state, this will ensure the dataset will be split exactly the same throughout training # set random state, this will ensure the dataset will be split exactly the same throughout training
rng = np.random.RandomState(data_random_seed) rng = np.random.RandomState(data_random_seed)
...@@ -22,37 +23,51 @@ def fit_logistic_regression(X, y, data_random_seed=1, repeat=1): ...@@ -22,37 +23,51 @@ def fit_logistic_regression(X, y, data_random_seed=1, repeat=1):
accuracies = [] accuracies = []
for _ in range(repeat): for _ in range(repeat):
# different random split after each repeat # different random split after each repeat
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.8, random_state=rng) X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.8, random_state=rng
)
# grid search with one-vs-rest classifiers # grid search with one-vs-rest classifiers
logreg = LogisticRegression(solver='liblinear') logreg = LogisticRegression(solver="liblinear")
c = 2.0 ** np.arange(-10, 11) c = 2.0 ** np.arange(-10, 11)
cv = ShuffleSplit(n_splits=5, test_size=0.5) cv = ShuffleSplit(n_splits=5, test_size=0.5)
clf = GridSearchCV(estimator=OneVsRestClassifier(logreg), param_grid=dict(estimator__C=c), clf = GridSearchCV(
n_jobs=5, cv=cv, verbose=0) estimator=OneVsRestClassifier(logreg),
param_grid=dict(estimator__C=c),
n_jobs=5,
cv=cv,
verbose=0,
)
clf.fit(X_train, y_train) clf.fit(X_train, y_train)
y_pred = clf.predict_proba(X_test) y_pred = clf.predict_proba(X_test)
y_pred = np.argmax(y_pred, axis=1) y_pred = np.argmax(y_pred, axis=1)
y_pred = one_hot_encoder.transform(y_pred.reshape(-1, 1)).astype(np.bool) y_pred = one_hot_encoder.transform(y_pred.reshape(-1, 1)).astype(
np.bool
)
test_acc = metrics.accuracy_score(y_test, y_pred) test_acc = metrics.accuracy_score(y_test, y_pred)
accuracies.append(test_acc) accuracies.append(test_acc)
return accuracies return accuracies
def fit_logistic_regression_preset_splits(X, y, train_mask, val_mask, test_mask): def fit_logistic_regression_preset_splits(
X, y, train_mask, val_mask, test_mask
):
# transform targets to one-hot vector # transform targets to one-hot vector
one_hot_encoder = OneHotEncoder(categories='auto', sparse=False) one_hot_encoder = OneHotEncoder(categories="auto", sparse=False)
y = one_hot_encoder.fit_transform(y.reshape(-1, 1)).astype(np.bool) y = one_hot_encoder.fit_transform(y.reshape(-1, 1)).astype(np.bool)
# normalize x # normalize x
X = normalize(X, norm='l2') X = normalize(X, norm="l2")
accuracies = [] accuracies = []
for split_id in range(train_mask.shape[1]): for split_id in range(train_mask.shape[1]):
# get train/val/test masks # get train/val/test masks
tmp_train_mask, tmp_val_mask = train_mask[:, split_id], val_mask[:, split_id] tmp_train_mask, tmp_val_mask = (
train_mask[:, split_id],
val_mask[:, split_id],
)
# make custom cv # make custom cv
X_train, y_train = X[tmp_train_mask], y[tmp_train_mask] X_train, y_train = X[tmp_train_mask], y[tmp_train_mask]
...@@ -62,29 +77,37 @@ def fit_logistic_regression_preset_splits(X, y, train_mask, val_mask, test_mask) ...@@ -62,29 +77,37 @@ def fit_logistic_regression_preset_splits(X, y, train_mask, val_mask, test_mask)
# grid search with one-vs-rest classifiers # grid search with one-vs-rest classifiers
best_test_acc, best_acc = 0, 0 best_test_acc, best_acc = 0, 0
for c in 2.0 ** np.arange(-10, 11): for c in 2.0 ** np.arange(-10, 11):
clf = OneVsRestClassifier(LogisticRegression(solver='liblinear', C=c)) clf = OneVsRestClassifier(
LogisticRegression(solver="liblinear", C=c)
)
clf.fit(X_train, y_train) clf.fit(X_train, y_train)
y_pred = clf.predict_proba(X_val) y_pred = clf.predict_proba(X_val)
y_pred = np.argmax(y_pred, axis=1) y_pred = np.argmax(y_pred, axis=1)
y_pred = one_hot_encoder.transform(y_pred.reshape(-1, 1)).astype(np.bool) y_pred = one_hot_encoder.transform(y_pred.reshape(-1, 1)).astype(
np.bool
)
val_acc = metrics.accuracy_score(y_val, y_pred) val_acc = metrics.accuracy_score(y_val, y_pred)
if val_acc > best_acc: if val_acc > best_acc:
best_acc = val_acc best_acc = val_acc
y_pred = clf.predict_proba(X_test) y_pred = clf.predict_proba(X_test)
y_pred = np.argmax(y_pred, axis=1) y_pred = np.argmax(y_pred, axis=1)
y_pred = one_hot_encoder.transform(y_pred.reshape(-1, 1)).astype(np.bool) y_pred = one_hot_encoder.transform(
y_pred.reshape(-1, 1)
).astype(np.bool)
best_test_acc = metrics.accuracy_score(y_test, y_pred) best_test_acc = metrics.accuracy_score(y_test, y_pred)
accuracies.append(best_test_acc) accuracies.append(best_test_acc)
return accuracies return accuracies
def fit_ppi_linear(num_classes, train_data, val_data, test_data, device, repeat=1): def fit_ppi_linear(
num_classes, train_data, val_data, test_data, device, repeat=1
):
r""" r"""
Trains a linear layer on top of the representations. This function is specific to the PPI dataset, Trains a linear layer on top of the representations. This function is specific to the PPI dataset,
which has multiple labels. which has multiple labels.
""" """
def train(classifier, train_data, optimizer): def train(classifier, train_data, optimizer):
classifier.train() classifier.train()
...@@ -111,13 +134,19 @@ def fit_ppi_linear(num_classes, train_data, val_data, test_data, device, repeat= ...@@ -111,13 +134,19 @@ def fit_ppi_linear(num_classes, train_data, val_data, test_data, device, repeat=
pred_logits = classifier(x.to(device)) pred_logits = classifier(x.to(device))
pred_class = (pred_logits > 0).float().cpu().numpy() pred_class = (pred_logits > 0).float().cpu().numpy()
return metrics.f1_score(label, pred_class, average='micro') if pred_class.sum() > 0 else 0 return (
metrics.f1_score(label, pred_class, average="micro")
if pred_class.sum() > 0
else 0
)
num_feats = train_data[0].size(1) num_feats = train_data[0].size(1)
criterion = torch.nn.BCEWithLogitsLoss() criterion = torch.nn.BCEWithLogitsLoss()
# normalization # normalization
mean, std = train_data[0].mean(0, keepdim=True), train_data[0].std(0, unbiased=False, keepdim=True) mean, std = train_data[0].mean(0, keepdim=True), train_data[0].std(
0, unbiased=False, keepdim=True
)
train_data[0] = (train_data[0] - mean) / std train_data[0] = (train_data[0] - mean) / std
val_data[0] = (val_data[0] - mean) / std val_data[0] = (val_data[0] - mean) / std
test_data[0] = (test_data[0] - mean) / std test_data[0] = (test_data[0] - mean) / std
...@@ -129,7 +158,11 @@ def fit_ppi_linear(num_classes, train_data, val_data, test_data, device, repeat= ...@@ -129,7 +158,11 @@ def fit_ppi_linear(num_classes, train_data, val_data, test_data, device, repeat=
tmp_test_f1 = 0 tmp_test_f1 = 0
for weight_decay in 2.0 ** np.arange(-10, 11, 2): for weight_decay in 2.0 ** np.arange(-10, 11, 2):
classifier = torch.nn.Linear(num_feats, num_classes).to(device) classifier = torch.nn.Linear(num_feats, num_classes).to(device)
optimizer = torch.optim.AdamW(params=classifier.parameters(), lr=0.01, weight_decay=weight_decay) optimizer = torch.optim.AdamW(
params=classifier.parameters(),
lr=0.01,
weight_decay=weight_decay,
)
train(classifier, train_data, optimizer) train(classifier, train_data, optimizer)
val_f1 = test(classifier, val_data) val_f1 = test(classifier, val_data)
......
import os
import dgl
import copy import copy
import torch import os
import warnings
import numpy as np import numpy as np
from tqdm import tqdm import torch
from torch.optim import AdamW from eval_function import (fit_logistic_regression,
fit_logistic_regression_preset_splits,
fit_ppi_linear)
from model import (BGRL, GCN, GraphSAGE_GCN, MLP_Predictor,
compute_representations)
from torch.nn.functional import cosine_similarity from torch.nn.functional import cosine_similarity
from utils import get_graph_drop_transform, CosineDecayScheduler, get_dataset from torch.optim import AdamW
from model import GCN, GraphSAGE_GCN, MLP_Predictor, BGRL, compute_representations from tqdm import tqdm
from eval_function import fit_logistic_regression, fit_logistic_regression_preset_splits, fit_ppi_linear from utils import CosineDecayScheduler, get_dataset, get_graph_drop_transform
import dgl
import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
def train(step, model, optimizer, lr_scheduler, mm_scheduler, transform_1, transform_2, data, args): def train(
step,
model,
optimizer,
lr_scheduler,
mm_scheduler,
transform_1,
transform_2,
data,
args,
):
model.train() model.train()
# update learning rate # update learning rate
lr = lr_scheduler.get(step) lr = lr_scheduler.get(step)
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
param_group['lr'] = lr param_group["lr"] = lr
# update momentum # update momentum
mm = 1 - mm_scheduler.get(step) mm = 1 - mm_scheduler.get(step)
...@@ -31,13 +45,17 @@ def train(step, model, optimizer, lr_scheduler, mm_scheduler, transform_1, trans ...@@ -31,13 +45,17 @@ def train(step, model, optimizer, lr_scheduler, mm_scheduler, transform_1, trans
x1, x2 = transform_1(data), transform_2(data) x1, x2 = transform_1(data), transform_2(data)
if args.dataset != 'ppi': if args.dataset != "ppi":
x1, x2 = dgl.add_self_loop(x1), dgl.add_self_loop(x2) x1, x2 = dgl.add_self_loop(x1), dgl.add_self_loop(x2)
q1, y2 = model(x1, x2) q1, y2 = model(x1, x2)
q2, y1 = model(x2, x1) q2, y1 = model(x2, x1)
loss = 2 - cosine_similarity(q1, y2.detach(), dim=-1).mean() - cosine_similarity(q2, y1.detach(), dim=-1).mean() loss = (
2
- cosine_similarity(q1, y2.detach(), dim=-1).mean()
- cosine_similarity(q2, y1.detach(), dim=-1).mean()
)
loss.backward() loss.backward()
# update online network # update online network
...@@ -53,113 +71,187 @@ def eval(model, dataset, device, args, train_data, val_data, test_data): ...@@ -53,113 +71,187 @@ def eval(model, dataset, device, args, train_data, val_data, test_data):
tmp_encoder = copy.deepcopy(model.online_encoder).eval() tmp_encoder = copy.deepcopy(model.online_encoder).eval()
val_scores = None val_scores = None
if args.dataset == 'ppi': if args.dataset == "ppi":
train_data = compute_representations(tmp_encoder, train_data, device) train_data = compute_representations(tmp_encoder, train_data, device)
val_data = compute_representations(tmp_encoder, val_data, device) val_data = compute_representations(tmp_encoder, val_data, device)
test_data = compute_representations(tmp_encoder, test_data, device) test_data = compute_representations(tmp_encoder, test_data, device)
num_classes = train_data[1].shape[1] num_classes = train_data[1].shape[1]
val_scores, test_scores = fit_ppi_linear(num_classes, train_data, val_data, test_data, device, val_scores, test_scores = fit_ppi_linear(
args.num_eval_splits) num_classes,
elif args.dataset != 'wiki_cs': train_data,
representations, labels = compute_representations(tmp_encoder, dataset, device) val_data,
test_scores = fit_logistic_regression(representations.cpu().numpy(), labels.cpu().numpy(), test_data,
data_random_seed=args.data_seed, repeat=args.num_eval_splits) device,
args.num_eval_splits,
)
elif args.dataset != "wiki_cs":
representations, labels = compute_representations(
tmp_encoder, dataset, device
)
test_scores = fit_logistic_regression(
representations.cpu().numpy(),
labels.cpu().numpy(),
data_random_seed=args.data_seed,
repeat=args.num_eval_splits,
)
else: else:
g = dataset[0] g = dataset[0]
train_mask = g.ndata['train_mask'] train_mask = g.ndata["train_mask"]
val_mask = g.ndata['val_mask'] val_mask = g.ndata["val_mask"]
test_mask = g.ndata['test_mask'] test_mask = g.ndata["test_mask"]
representations, labels = compute_representations(tmp_encoder, dataset, device) representations, labels = compute_representations(
test_scores = fit_logistic_regression_preset_splits(representations.cpu().numpy(), labels.cpu().numpy(), tmp_encoder, dataset, device
train_mask, val_mask, test_mask) )
test_scores = fit_logistic_regression_preset_splits(
representations.cpu().numpy(),
labels.cpu().numpy(),
train_mask,
val_mask,
test_mask,
)
return val_scores, test_scores return val_scores, test_scores
def main(args): def main(args):
# use CUDA_VISIBLE_DEVICES to select gpu # use CUDA_VISIBLE_DEVICES to select gpu
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') device = (
print('Using device:', device) torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
)
print("Using device:", device)
dataset, train_data, val_data, test_data = get_dataset(args.dataset) dataset, train_data, val_data, test_data = get_dataset(args.dataset)
g = dataset[0] g = dataset[0]
g = g.to(device) g = g.to(device)
input_size, representation_size = g.ndata['feat'].size(1), args.graph_encoder_layer[-1] input_size, representation_size = (
g.ndata["feat"].size(1),
args.graph_encoder_layer[-1],
)
# prepare transforms # prepare transforms
transform_1 = get_graph_drop_transform(drop_edge_p=args.drop_edge_p[0], feat_mask_p=args.feat_mask_p[0]) transform_1 = get_graph_drop_transform(
transform_2 = get_graph_drop_transform(drop_edge_p=args.drop_edge_p[1], feat_mask_p=args.feat_mask_p[1]) drop_edge_p=args.drop_edge_p[0], feat_mask_p=args.feat_mask_p[0]
)
transform_2 = get_graph_drop_transform(
drop_edge_p=args.drop_edge_p[1], feat_mask_p=args.feat_mask_p[1]
)
# scheduler # scheduler
lr_scheduler = CosineDecayScheduler(args.lr, args.lr_warmup_epochs, args.epochs) lr_scheduler = CosineDecayScheduler(
args.lr, args.lr_warmup_epochs, args.epochs
)
mm_scheduler = CosineDecayScheduler(1 - args.mm, 0, args.epochs) mm_scheduler = CosineDecayScheduler(1 - args.mm, 0, args.epochs)
# build networks # build networks
if args.dataset == 'ppi': if args.dataset == "ppi":
encoder = GraphSAGE_GCN([input_size] + args.graph_encoder_layer) encoder = GraphSAGE_GCN([input_size] + args.graph_encoder_layer)
else: else:
encoder = GCN([input_size] + args.graph_encoder_layer) encoder = GCN([input_size] + args.graph_encoder_layer)
predictor = MLP_Predictor(representation_size, representation_size, hidden_size=args.predictor_hidden_size) predictor = MLP_Predictor(
representation_size,
representation_size,
hidden_size=args.predictor_hidden_size,
)
model = BGRL(encoder, predictor).to(device) model = BGRL(encoder, predictor).to(device)
# optimizer # optimizer
optimizer = AdamW(model.trainable_parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer = AdamW(
model.trainable_parameters(), lr=args.lr, weight_decay=args.weight_decay
)
# train # train
for epoch in tqdm(range(1, args.epochs + 1), desc=' - (Training) '): for epoch in tqdm(range(1, args.epochs + 1), desc=" - (Training) "):
train(epoch - 1, model, optimizer, lr_scheduler, mm_scheduler, transform_1, transform_2, g, args) train(
epoch - 1,
model,
optimizer,
lr_scheduler,
mm_scheduler,
transform_1,
transform_2,
g,
args,
)
if epoch % args.eval_epochs == 0: if epoch % args.eval_epochs == 0:
val_scores, test_scores = eval(model, dataset, device, args, train_data, val_data, test_data) val_scores, test_scores = eval(
if args.dataset == 'ppi': model, dataset, device, args, train_data, val_data, test_data
print('Epoch: {:04d} | Best Val F1: {:.4f} | Test F1: {:.4f}'.format(epoch, np.mean(val_scores), )
np.mean(test_scores))) if args.dataset == "ppi":
print(
"Epoch: {:04d} | Best Val F1: {:.4f} | Test F1: {:.4f}".format(
epoch, np.mean(val_scores), np.mean(test_scores)
)
)
else: else:
print('Epoch: {:04d} | Test Accuracy: {:.4f}'.format(epoch, np.mean(test_scores))) print(
"Epoch: {:04d} | Test Accuracy: {:.4f}".format(
epoch, np.mean(test_scores)
)
)
# save encoder weights # save encoder weights
if not os.path.isdir(args.weights_dir): if not os.path.isdir(args.weights_dir):
os.mkdir(args.weights_dir) os.mkdir(args.weights_dir)
torch.save({'model': model.online_encoder.state_dict()}, torch.save(
os.path.join(args.weights_dir, 'bgrl-{}.pt'.format(args.dataset))) {"model": model.online_encoder.state_dict()},
os.path.join(args.weights_dir, "bgrl-{}.pt".format(args.dataset)),
)
if __name__ == '__main__': if __name__ == "__main__":
from argparse import ArgumentParser from argparse import ArgumentParser
parser = ArgumentParser() parser = ArgumentParser()
# Dataset options. # Dataset options.
parser.add_argument('--dataset', type=str, default='amazon_photos', choices=['coauthor_cs', 'coauthor_physics', parser.add_argument(
'amazon_photos', 'amazon_computers', "--dataset",
'wiki_cs', 'ppi']) type=str,
default="amazon_photos",
choices=[
"coauthor_cs",
"coauthor_physics",
"amazon_photos",
"amazon_computers",
"wiki_cs",
"ppi",
],
)
# Model options. # Model options.
parser.add_argument('--graph_encoder_layer', type=int, nargs='+', default=[256, 128]) parser.add_argument(
parser.add_argument('--predictor_hidden_size', type=int, default=512) "--graph_encoder_layer", type=int, nargs="+", default=[256, 128]
)
parser.add_argument("--predictor_hidden_size", type=int, default=512)
# Training options. # Training options.
parser.add_argument('--epochs', type=int, default=10000) parser.add_argument("--epochs", type=int, default=10000)
parser.add_argument('--lr', type=float, default=1e-5) parser.add_argument("--lr", type=float, default=1e-5)
parser.add_argument('--weight_decay', type=float, default=1e-5) parser.add_argument("--weight_decay", type=float, default=1e-5)
parser.add_argument('--mm', type=float, default=0.99) parser.add_argument("--mm", type=float, default=0.99)
parser.add_argument('--lr_warmup_epochs', type=int, default=1000) parser.add_argument("--lr_warmup_epochs", type=int, default=1000)
parser.add_argument('--weights_dir', type=str, default='../weights') parser.add_argument("--weights_dir", type=str, default="../weights")
# Augmentations options. # Augmentations options.
parser.add_argument('--drop_edge_p', type=float, nargs='+', default=[0., 0.]) parser.add_argument(
parser.add_argument('--feat_mask_p', type=float, nargs='+', default=[0., 0.]) "--drop_edge_p", type=float, nargs="+", default=[0.0, 0.0]
)
parser.add_argument(
"--feat_mask_p", type=float, nargs="+", default=[0.0, 0.0]
)
# Evaluation options. # Evaluation options.
parser.add_argument('--eval_epochs', type=int, default=250) parser.add_argument("--eval_epochs", type=int, default=250)
parser.add_argument('--num_eval_splits', type=int, default=20) parser.add_argument("--num_eval_splits", type=int, default=20)
parser.add_argument('--data_seed', type=int, default=1) parser.add_argument("--data_seed", type=int, default=1)
# Experiment options. # Experiment options.
parser.add_argument('--num_experiments', type=int, default=20) parser.add_argument("--num_experiments", type=int, default=20)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
import dgl
import copy import copy
import torch import torch
from torch import nn from torch import nn
from torch.nn.init import ones_, zeros_
from torch.nn import BatchNorm1d, Parameter from torch.nn import BatchNorm1d, Parameter
from torch.nn.init import ones_, zeros_
import dgl
from dgl.nn.pytorch.conv import GraphConv, SAGEConv from dgl.nn.pytorch.conv import GraphConv, SAGEConv
...@@ -17,8 +19,8 @@ class LayerNorm(nn.Module): ...@@ -17,8 +19,8 @@ class LayerNorm(nn.Module):
self.weight = Parameter(torch.Tensor(in_channels)) self.weight = Parameter(torch.Tensor(in_channels))
self.bias = Parameter(torch.Tensor(in_channels)) self.bias = Parameter(torch.Tensor(in_channels))
else: else:
self.register_parameter('weight', None) self.register_parameter("weight", None)
self.register_parameter('bias', None) self.register_parameter("bias", None)
self.reset_parameters() self.reset_parameters()
...@@ -34,13 +36,27 @@ class LayerNorm(nn.Module): ...@@ -34,13 +36,27 @@ class LayerNorm(nn.Module):
else: else:
batch_size = int(batch.max()) + 1 batch_size = int(batch.max()) + 1
batch_idx = [batch == i for i in range(batch_size)] batch_idx = [batch == i for i in range(batch_size)]
norm = torch.tensor([i.sum() for i in batch_idx], dtype=x.dtype).clamp_(min=1).to(device) norm = (
torch.tensor([i.sum() for i in batch_idx], dtype=x.dtype)
.clamp_(min=1)
.to(device)
)
norm = norm.mul_(x.size(-1)).view(-1, 1) norm = norm.mul_(x.size(-1)).view(-1, 1)
tmp_list = [x[i] for i in batch_idx] tmp_list = [x[i] for i in batch_idx]
mean = torch.concat([i.sum(0).unsqueeze(0) for i in tmp_list], dim=0).sum(dim=-1, keepdim=True).to(device) mean = (
torch.concat([i.sum(0).unsqueeze(0) for i in tmp_list], dim=0)
.sum(dim=-1, keepdim=True)
.to(device)
)
mean = mean / norm mean = mean / norm
x = x - mean.index_select(0, batch.long()) x = x - mean.index_select(0, batch.long())
var = torch.concat([(i * i).sum(0).unsqueeze(0) for i in tmp_list], dim=0).sum(dim=-1, keepdim=True).to(device) var = (
torch.concat(
[(i * i).sum(0).unsqueeze(0) for i in tmp_list], dim=0
)
.sum(dim=-1, keepdim=True)
.to(device)
)
var = var / norm var = var / norm
out = x / (var + self.eps).sqrt().index_select(0, batch.long()) out = x / (var + self.eps).sqrt().index_select(0, batch.long())
...@@ -50,7 +66,7 @@ class LayerNorm(nn.Module): ...@@ -50,7 +66,7 @@ class LayerNorm(nn.Module):
return out return out
def __repr__(self): def __repr__(self):
return f'{self.__class__.__name__}({self.in_channels})' return f"{self.__class__.__name__}({self.in_channels})"
class MLP_Predictor(nn.Module): class MLP_Predictor(nn.Module):
...@@ -60,13 +76,14 @@ class MLP_Predictor(nn.Module): ...@@ -60,13 +76,14 @@ class MLP_Predictor(nn.Module):
output_size (int): Size of output features. output_size (int): Size of output features.
hidden_size (int, optional): Size of hidden layer. (default: :obj:`4096`). hidden_size (int, optional): Size of hidden layer. (default: :obj:`4096`).
""" """
def __init__(self, input_size, output_size, hidden_size=512): def __init__(self, input_size, output_size, hidden_size=512):
super().__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
nn.Linear(input_size, hidden_size, bias=True), nn.Linear(input_size, hidden_size, bias=True),
nn.PReLU(1), nn.PReLU(1),
nn.Linear(hidden_size, output_size, bias=True) nn.Linear(hidden_size, output_size, bias=True),
) )
self.reset_parameters() self.reset_parameters()
...@@ -91,7 +108,7 @@ class GCN(nn.Module): ...@@ -91,7 +108,7 @@ class GCN(nn.Module):
self.layers.append(nn.PReLU()) self.layers.append(nn.PReLU())
def forward(self, g): def forward(self, g):
x = g.ndata['feat'] x = g.ndata["feat"]
for layer in self.layers: for layer in self.layers:
if isinstance(layer, GraphConv): if isinstance(layer, GraphConv):
x = layer(g, x) x = layer(g, x)
...@@ -101,7 +118,7 @@ class GCN(nn.Module): ...@@ -101,7 +118,7 @@ class GCN(nn.Module):
def reset_parameters(self): def reset_parameters(self):
for layer in self.layers: for layer in self.layers:
if hasattr(layer, 'reset_parameters'): if hasattr(layer, "reset_parameters"):
layer.reset_parameters() layer.reset_parameters()
...@@ -111,33 +128,41 @@ class GraphSAGE_GCN(nn.Module): ...@@ -111,33 +128,41 @@ class GraphSAGE_GCN(nn.Module):
input_size, hidden_size, embedding_size = layer_sizes input_size, hidden_size, embedding_size = layer_sizes
self.convs = nn.ModuleList([ self.convs = nn.ModuleList(
SAGEConv(input_size, hidden_size, 'mean'), [
SAGEConv(hidden_size, hidden_size, 'mean'), SAGEConv(input_size, hidden_size, "mean"),
SAGEConv(hidden_size, embedding_size, 'mean') SAGEConv(hidden_size, hidden_size, "mean"),
]) SAGEConv(hidden_size, embedding_size, "mean"),
]
self.skip_lins = nn.ModuleList([ )
nn.Linear(input_size, hidden_size, bias=False),
nn.Linear(input_size, hidden_size, bias=False), self.skip_lins = nn.ModuleList(
]) [
nn.Linear(input_size, hidden_size, bias=False),
self.layer_norms = nn.ModuleList([ nn.Linear(input_size, hidden_size, bias=False),
LayerNorm(hidden_size), ]
LayerNorm(hidden_size), )
LayerNorm(embedding_size),
]) self.layer_norms = nn.ModuleList(
[
self.activations = nn.ModuleList([ LayerNorm(hidden_size),
nn.PReLU(), LayerNorm(hidden_size),
nn.PReLU(), LayerNorm(embedding_size),
nn.PReLU(), ]
]) )
self.activations = nn.ModuleList(
[
nn.PReLU(),
nn.PReLU(),
nn.PReLU(),
]
)
def forward(self, g): def forward(self, g):
x = g.ndata['feat'] x = g.ndata["feat"]
if 'batch' in g.ndata.keys(): if "batch" in g.ndata.keys():
batch = g.ndata['batch'] batch = g.ndata["batch"]
else: else:
batch = None batch = None
...@@ -176,6 +201,7 @@ class BGRL(nn.Module): ...@@ -176,6 +201,7 @@ class BGRL(nn.Module):
`encoder` must have a `reset_parameters` method, as the weights of the target network will be initialized `encoder` must have a `reset_parameters` method, as the weights of the target network will be initialized
differently from the online network. differently from the online network.
""" """
def __init__(self, encoder, predictor): def __init__(self, encoder, predictor):
super(BGRL, self).__init__() super(BGRL, self).__init__()
# online network # online network
...@@ -194,7 +220,9 @@ class BGRL(nn.Module): ...@@ -194,7 +220,9 @@ class BGRL(nn.Module):
def trainable_parameters(self): def trainable_parameters(self):
r"""Returns the parameters that will be updated via an optimizer.""" r"""Returns the parameters that will be updated via an optimizer."""
return list(self.online_encoder.parameters()) + list(self.predictor.parameters()) return list(self.online_encoder.parameters()) + list(
self.predictor.parameters()
)
@torch.no_grad() @torch.no_grad()
def update_target_network(self, mm): def update_target_network(self, mm):
...@@ -202,8 +230,10 @@ class BGRL(nn.Module): ...@@ -202,8 +230,10 @@ class BGRL(nn.Module):
Args: Args:
mm (float): Momentum used in moving average update. mm (float): Momentum used in moving average update.
""" """
for param_q, param_k in zip(self.online_encoder.parameters(), self.target_encoder.parameters()): for param_q, param_k in zip(
param_k.data.mul_(mm).add_(param_q.data, alpha=1. - mm) self.online_encoder.parameters(), self.target_encoder.parameters()
):
param_k.data.mul_(mm).add_(param_q.data, alpha=1.0 - mm)
def forward(self, online_x, target_x): def forward(self, online_x, target_x):
# forward online network # forward online network
...@@ -233,16 +263,15 @@ def compute_representations(net, dataset, device): ...@@ -233,16 +263,15 @@ def compute_representations(net, dataset, device):
g = g.to(device) g = g.to(device)
with torch.no_grad(): with torch.no_grad():
reps.append(net(g)) reps.append(net(g))
labels.append(g.ndata['label']) labels.append(g.ndata["label"])
else: else:
for g in dataset: for g in dataset:
# forward # forward
g = g.to(device) g = g.to(device)
with torch.no_grad(): with torch.no_grad():
reps.append(net(g)) reps.append(net(g))
labels.append(g.ndata['label']) labels.append(g.ndata["label"])
reps = torch.cat(reps, dim=0) reps = torch.cat(reps, dim=0)
labels = torch.cat(labels, dim=0) labels = torch.cat(labels, dim=0)
return [reps, labels] return [reps, labels]
import copy import copy
import torch
import numpy as np import numpy as np
import torch
from dgl.data import (AmazonCoBuyComputerDataset, AmazonCoBuyPhotoDataset,
CoauthorCSDataset, CoauthorPhysicsDataset, PPIDataset,
WikiCSDataset)
from dgl.dataloading import GraphDataLoader from dgl.dataloading import GraphDataLoader
from dgl.transforms import Compose, DropEdge, FeatMask, RowFeatNormalizer from dgl.transforms import Compose, DropEdge, FeatMask, RowFeatNormalizer
from dgl.data import CoauthorCSDataset, CoauthorPhysicsDataset, AmazonCoBuyPhotoDataset, AmazonCoBuyComputerDataset, PPIDataset, WikiCSDataset
class CosineDecayScheduler: class CosineDecayScheduler:
...@@ -16,10 +20,24 @@ class CosineDecayScheduler: ...@@ -16,10 +20,24 @@ class CosineDecayScheduler:
if step < self.warmup_steps: if step < self.warmup_steps:
return self.max_val * step / self.warmup_steps return self.max_val * step / self.warmup_steps
elif self.warmup_steps <= step <= self.total_steps: elif self.warmup_steps <= step <= self.total_steps:
return self.max_val * (1 + np.cos((step - self.warmup_steps) * np.pi / return (
(self.total_steps - self.warmup_steps))) / 2 self.max_val
* (
1
+ np.cos(
(step - self.warmup_steps)
* np.pi
/ (self.total_steps - self.warmup_steps)
)
)
/ 2
)
else: else:
raise ValueError('Step ({}) > total number of steps ({}).'.format(step, self.total_steps)) raise ValueError(
"Step ({}) > total number of steps ({}).".format(
step, self.total_steps
)
)
def get_graph_drop_transform(drop_edge_p, feat_mask_p): def get_graph_drop_transform(drop_edge_p, feat_mask_p):
...@@ -29,12 +47,12 @@ def get_graph_drop_transform(drop_edge_p, feat_mask_p): ...@@ -29,12 +47,12 @@ def get_graph_drop_transform(drop_edge_p, feat_mask_p):
transforms.append(copy.deepcopy) transforms.append(copy.deepcopy)
# drop edges # drop edges
if drop_edge_p > 0.: if drop_edge_p > 0.0:
transforms.append(DropEdge(drop_edge_p)) transforms.append(DropEdge(drop_edge_p))
# drop features # drop features
if feat_mask_p > 0.: if feat_mask_p > 0.0:
transforms.append(FeatMask(feat_mask_p, node_feat_names=['feat'])) transforms.append(FeatMask(feat_mask_p, node_feat_names=["feat"]))
return Compose(transforms) return Compose(transforms)
...@@ -42,41 +60,41 @@ def get_graph_drop_transform(drop_edge_p, feat_mask_p): ...@@ -42,41 +60,41 @@ def get_graph_drop_transform(drop_edge_p, feat_mask_p):
def get_wiki_cs(transform=RowFeatNormalizer(subtract_min=True)): def get_wiki_cs(transform=RowFeatNormalizer(subtract_min=True)):
dataset = WikiCSDataset(transform=transform) dataset = WikiCSDataset(transform=transform)
g = dataset[0] g = dataset[0]
std, mean = torch.std_mean(g.ndata['feat'], dim=0, unbiased=False) std, mean = torch.std_mean(g.ndata["feat"], dim=0, unbiased=False)
g.ndata['feat'] = (g.ndata['feat'] - mean) / std g.ndata["feat"] = (g.ndata["feat"] - mean) / std
return [g] return [g]
def get_ppi(): def get_ppi():
train_dataset = PPIDataset(mode='train') train_dataset = PPIDataset(mode="train")
val_dataset = PPIDataset(mode='valid') val_dataset = PPIDataset(mode="valid")
test_dataset = PPIDataset(mode='test') test_dataset = PPIDataset(mode="test")
train_val_dataset = [i for i in train_dataset] + [i for i in val_dataset] train_val_dataset = [i for i in train_dataset] + [i for i in val_dataset]
for idx, data in enumerate(train_val_dataset): for idx, data in enumerate(train_val_dataset):
data.ndata['batch'] = torch.zeros(data.number_of_nodes()) + idx data.ndata["batch"] = torch.zeros(data.number_of_nodes()) + idx
data.ndata['batch'] = data.ndata['batch'].long() data.ndata["batch"] = data.ndata["batch"].long()
g = list(GraphDataLoader(train_val_dataset, batch_size=22, shuffle=True)) g = list(GraphDataLoader(train_val_dataset, batch_size=22, shuffle=True))
return g, PPIDataset(mode='train'), PPIDataset(mode='valid'), test_dataset return g, PPIDataset(mode="train"), PPIDataset(mode="valid"), test_dataset
def get_dataset(name, transform=RowFeatNormalizer(subtract_min=True)): def get_dataset(name, transform=RowFeatNormalizer(subtract_min=True)):
dgl_dataset_dict = { dgl_dataset_dict = {
'coauthor_cs': CoauthorCSDataset, "coauthor_cs": CoauthorCSDataset,
'coauthor_physics': CoauthorPhysicsDataset, "coauthor_physics": CoauthorPhysicsDataset,
'amazon_computers': AmazonCoBuyComputerDataset, "amazon_computers": AmazonCoBuyComputerDataset,
'amazon_photos': AmazonCoBuyPhotoDataset, "amazon_photos": AmazonCoBuyPhotoDataset,
'wiki_cs': get_wiki_cs, "wiki_cs": get_wiki_cs,
'ppi': get_ppi "ppi": get_ppi,
} }
dataset_class = dgl_dataset_dict[name] dataset_class = dgl_dataset_dict[name]
train_data, val_data, test_data = None, None, None train_data, val_data, test_data = None, None, None
if name != 'ppi': if name != "ppi":
dataset = dataset_class(transform=transform) dataset = dataset_class(transform=transform)
else: else:
dataset, train_data, val_data, test_data = dataset_class() dataset, train_data, val_data, test_data = dataset_class()
return dataset, train_data, val_data, test_data return dataset, train_data, val_data, test_data
\ No newline at end of file
import dgl
import torch import torch
from DGLRoutingLayer import DGLRoutingLayer
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
import dgl.function as fn
from DGLRoutingLayer import DGLRoutingLayer import dgl
import dgl.function as fn
class DGLDigitCapsuleLayer(nn.Module): class DGLDigitCapsuleLayer(nn.Module):
def __init__(self, in_nodes_dim=8, in_nodes=1152, out_nodes=10, out_nodes_dim=16, device='cpu'): def __init__(
self,
in_nodes_dim=8,
in_nodes=1152,
out_nodes=10,
out_nodes_dim=16,
device="cpu",
):
super(DGLDigitCapsuleLayer, self).__init__() super(DGLDigitCapsuleLayer, self).__init__()
self.device = device self.device = device
self.in_nodes_dim, self.out_nodes_dim = in_nodes_dim, out_nodes_dim self.in_nodes_dim, self.out_nodes_dim = in_nodes_dim, out_nodes_dim
self.in_nodes, self.out_nodes = in_nodes, out_nodes self.in_nodes, self.out_nodes = in_nodes, out_nodes
self.weight = nn.Parameter(torch.randn(in_nodes, out_nodes, out_nodes_dim, in_nodes_dim)) self.weight = nn.Parameter(
torch.randn(in_nodes, out_nodes, out_nodes_dim, in_nodes_dim)
)
def forward(self, x): def forward(self, x):
self.batch_size = x.size(0) self.batch_size = x.size(0)
u_hat = self.compute_uhat(x) u_hat = self.compute_uhat(x)
routing = DGLRoutingLayer(self.in_nodes, self.out_nodes, self.out_nodes_dim, batch_size=self.batch_size, routing = DGLRoutingLayer(
device=self.device) self.in_nodes,
self.out_nodes,
self.out_nodes_dim,
batch_size=self.batch_size,
device=self.device,
)
routing(u_hat, routing_num=3) routing(u_hat, routing_num=3)
out_nodes_feature = routing.g.nodes[routing.out_indx].data['v'] out_nodes_feature = routing.g.nodes[routing.out_indx].data["v"]
# shape transformation is for further classification # shape transformation is for further classification
return out_nodes_feature.transpose(0, 1).unsqueeze(1).unsqueeze(4).squeeze(1) return (
out_nodes_feature.transpose(0, 1)
.unsqueeze(1)
.unsqueeze(4)
.squeeze(1)
)
def compute_uhat(self, x): def compute_uhat(self, x):
# x is the input vextor with shape [batch_size, in_nodes_dim, in_nodes] # x is the input vextor with shape [batch_size, in_nodes_dim, in_nodes]
......
import torch.nn as nn
import torch as th import torch as th
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl import dgl
class DGLRoutingLayer(nn.Module): class DGLRoutingLayer(nn.Module):
def __init__(self, in_nodes, out_nodes, f_size, batch_size=0, device='cpu'): def __init__(self, in_nodes, out_nodes, f_size, batch_size=0, device="cpu"):
super(DGLRoutingLayer, self).__init__() super(DGLRoutingLayer, self).__init__()
self.batch_size = batch_size self.batch_size = batch_size
self.g = init_graph(in_nodes, out_nodes, f_size, device=device) self.g = init_graph(in_nodes, out_nodes, f_size, device=device)
...@@ -16,49 +17,59 @@ class DGLRoutingLayer(nn.Module): ...@@ -16,49 +17,59 @@ class DGLRoutingLayer(nn.Module):
self.device = device self.device = device
def forward(self, u_hat, routing_num=1): def forward(self, u_hat, routing_num=1):
self.g.edata['u_hat'] = u_hat self.g.edata["u_hat"] = u_hat
batch_size = self.batch_size batch_size = self.batch_size
# step 2 (line 5) # step 2 (line 5)
def cap_message(edges): def cap_message(edges):
if batch_size: if batch_size:
return {'m': edges.data['c'].unsqueeze(1) * edges.data['u_hat']} return {"m": edges.data["c"].unsqueeze(1) * edges.data["u_hat"]}
else: else:
return {'m': edges.data['c'] * edges.data['u_hat']} return {"m": edges.data["c"] * edges.data["u_hat"]}
def cap_reduce(nodes): def cap_reduce(nodes):
return {'s': th.sum(nodes.mailbox['m'], dim=1)} return {"s": th.sum(nodes.mailbox["m"], dim=1)}
for r in range(routing_num): for r in range(routing_num):
# step 1 (line 4): normalize over out edges # step 1 (line 4): normalize over out edges
edges_b = self.g.edata['b'].view(self.in_nodes, self.out_nodes) edges_b = self.g.edata["b"].view(self.in_nodes, self.out_nodes)
self.g.edata['c'] = F.softmax(edges_b, dim=1).view(-1, 1) self.g.edata["c"] = F.softmax(edges_b, dim=1).view(-1, 1)
# Execute step 1 & 2 # Execute step 1 & 2
self.g.update_all(message_func=cap_message, reduce_func=cap_reduce) self.g.update_all(message_func=cap_message, reduce_func=cap_reduce)
# step 3 (line 6) # step 3 (line 6)
if self.batch_size: if self.batch_size:
self.g.nodes[self.out_indx].data['v'] = squash(self.g.nodes[self.out_indx].data['s'], dim=2) self.g.nodes[self.out_indx].data["v"] = squash(
self.g.nodes[self.out_indx].data["s"], dim=2
)
else: else:
self.g.nodes[self.out_indx].data['v'] = squash(self.g.nodes[self.out_indx].data['s'], dim=1) self.g.nodes[self.out_indx].data["v"] = squash(
self.g.nodes[self.out_indx].data["s"], dim=1
)
# step 4 (line 7) # step 4 (line 7)
v = th.cat([self.g.nodes[self.out_indx].data['v']] * self.in_nodes, dim=0) v = th.cat(
[self.g.nodes[self.out_indx].data["v"]] * self.in_nodes, dim=0
)
if self.batch_size: if self.batch_size:
self.g.edata['b'] = self.g.edata['b'] + (self.g.edata['u_hat'] * v).mean(dim=1).sum(dim=1, keepdim=True) self.g.edata["b"] = self.g.edata["b"] + (
self.g.edata["u_hat"] * v
).mean(dim=1).sum(dim=1, keepdim=True)
else: else:
self.g.edata['b'] = self.g.edata['b'] + (self.g.edata['u_hat'] * v).sum(dim=1, keepdim=True) self.g.edata["b"] = self.g.edata["b"] + (
self.g.edata["u_hat"] * v
).sum(dim=1, keepdim=True)
def squash(s, dim=1): def squash(s, dim=1):
sq = th.sum(s ** 2, dim=dim, keepdim=True) sq = th.sum(s**2, dim=dim, keepdim=True)
s_norm = th.sqrt(sq) s_norm = th.sqrt(sq)
s = (sq / (1.0 + sq)) * (s / s_norm) s = (sq / (1.0 + sq)) * (s / s_norm)
return s return s
def init_graph(in_nodes, out_nodes, f_size, device='cpu'): def init_graph(in_nodes, out_nodes, f_size, device="cpu"):
g = dgl.DGLGraph() g = dgl.DGLGraph()
g.set_n_initializer(dgl.frame.zero_initializer) g.set_n_initializer(dgl.frame.zero_initializer)
all_nodes = in_nodes + out_nodes all_nodes = in_nodes + out_nodes
...@@ -70,5 +81,5 @@ def init_graph(in_nodes, out_nodes, f_size, device='cpu'): ...@@ -70,5 +81,5 @@ def init_graph(in_nodes, out_nodes, f_size, device='cpu'):
g.add_edges(u, out_indx) g.add_edges(u, out_indx)
g = g.to(device) g = g.to(device)
g.edata['b'] = th.zeros(in_nodes * out_nodes, 1).to(device) g.edata["b"] = th.zeros(in_nodes * out_nodes, 1).to(device)
return g return g
import argparse import argparse
import torch import torch
import torch.optim as optim import torch.optim as optim
from torchvision import datasets, transforms
from model import Net from model import Net
from torchvision import datasets, transforms
def train(args, model, device, train_loader, optimizer, epoch): def train(args, model, device, train_loader, optimizer, epoch):
...@@ -16,9 +16,15 @@ def train(args, model, device, train_loader, optimizer, epoch): ...@@ -16,9 +16,15 @@ def train(args, model, device, train_loader, optimizer, epoch):
loss.backward() loss.backward()
optimizer.step() optimizer.step()
if batch_idx % args.log_interval == 0: if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( print(
epoch, batch_idx * len(data), len(train_loader.dataset), "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
100. * batch_idx / len(train_loader), loss.item())) epoch,
batch_idx * len(data),
len(train_loader.dataset),
100.0 * batch_idx / len(train_loader),
loss.item(),
)
)
def test(args, model, device, test_loader): def test(args, model, device, test_loader):
...@@ -29,33 +35,76 @@ def test(args, model, device, test_loader): ...@@ -29,33 +35,76 @@ def test(args, model, device, test_loader):
for data, target in test_loader: for data, target in test_loader:
data, target = data.to(device), target.to(device) data, target = data.to(device), target.to(device)
output = model(data) output = model(data)
test_loss += model.margin_loss(output, target).item() # sum up batch loss test_loss += model.margin_loss(
pred = output.norm(dim=2).squeeze().max(1, keepdim=True)[1] # get the index of the max log-probability output, target
).item() # sum up batch loss
pred = (
output.norm(dim=2).squeeze().max(1, keepdim=True)[1]
) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item() correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset) test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( print(
test_loss, correct, len(test_loader.dataset), "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
100. * correct / len(test_loader.dataset))) test_loss,
correct,
len(test_loader.dataset),
100.0 * correct / len(test_loader.dataset),
)
)
def main(): def main():
# Training settings # Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
parser.add_argument('--batch-size', type=int, default=512, metavar='N', parser.add_argument(
help='input batch size for training (default: 64)') "--batch-size",
parser.add_argument('--test-batch-size', type=int, default=512, metavar='N', type=int,
help='input batch size for testing (default: 1000)') default=512,
parser.add_argument('--epochs', type=int, default=10, metavar='N', metavar="N",
help='number of epochs to train (default: 10)') help="input batch size for training (default: 64)",
parser.add_argument('--lr', type=float, default=0.01, metavar='LR', )
help='learning rate (default: 0.01)') parser.add_argument(
parser.add_argument('--no-cuda', action='store_true', default=False, "--test-batch-size",
help='disables CUDA training') type=int,
parser.add_argument('--seed', type=int, default=1, metavar='S', default=512,
help='random seed (default: 1)') metavar="N",
parser.add_argument('--log-interval', type=int, default=10, metavar='N', help="input batch size for testing (default: 1000)",
help='how many batches to wait before logging training status') )
parser.add_argument(
"--epochs",
type=int,
default=10,
metavar="N",
help="number of epochs to train (default: 10)",
)
parser.add_argument(
"--lr",
type=float,
default=0.01,
metavar="LR",
help="learning rate (default: 0.01)",
)
parser.add_argument(
"--no-cuda",
action="store_true",
default=False,
help="disables CUDA training",
)
parser.add_argument(
"--seed",
type=int,
default=1,
metavar="S",
help="random seed (default: 1)",
)
parser.add_argument(
"--log-interval",
type=int,
default=10,
metavar="N",
help="how many batches to wait before logging training status",
)
args = parser.parse_args() args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available() use_cuda = not args.no_cuda and torch.cuda.is_available()
...@@ -63,20 +112,38 @@ def main(): ...@@ -63,20 +112,38 @@ def main():
device = torch.device("cuda" if use_cuda else "cpu") device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True, datasets.MNIST(
transform=transforms.Compose([ "../data",
transforms.ToTensor(), train=True,
transforms.Normalize((0.1307,), (0.3081,)) download=True,
])), transform=transforms.Compose(
batch_size=args.batch_size, shuffle=True, **kwargs) [
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
),
),
batch_size=args.batch_size,
shuffle=True,
**kwargs
)
test_loader = torch.utils.data.DataLoader( test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([ datasets.MNIST(
transforms.ToTensor(), "../data",
transforms.Normalize((0.1307,), (0.3081,)) train=False,
])), transform=transforms.Compose(
batch_size=args.test_batch_size, shuffle=True, **kwargs) [
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
),
),
batch_size=args.test_batch_size,
shuffle=True,
**kwargs
)
model = Net(device=device).to(device) model = Net(device=device).to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr) optimizer = optim.Adam(model.parameters(), lr=args.lr)
...@@ -86,5 +153,5 @@ def main(): ...@@ -86,5 +153,5 @@ def main():
test(args, model, device, test_loader) test(args, model, device, test_loader)
if __name__ == '__main__': if __name__ == "__main__":
main() main()
import torch import torch
from torch import nn
from DGLDigitCapsule import DGLDigitCapsuleLayer from DGLDigitCapsule import DGLDigitCapsuleLayer
from DGLRoutingLayer import squash from DGLRoutingLayer import squash
from torch import nn
class Net(nn.Module): class Net(nn.Module):
def __init__(self, device='cpu'): def __init__(self, device="cpu"):
super(Net, self).__init__() super(Net, self).__init__()
self.device = device self.device = device
self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1, self.conv1 = nn.Sequential(
out_channels=256, nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1),
kernel_size=9, nn.ReLU(inplace=True),
stride=1), nn.ReLU(inplace=True)) )
self.primary = PrimaryCapsuleLayer(device=device) self.primary = PrimaryCapsuleLayer(device=device)
self.digits = DGLDigitCapsuleLayer(device=device) self.digits = DGLDigitCapsuleLayer(device=device)
...@@ -29,7 +28,7 @@ class Net(nn.Module): ...@@ -29,7 +28,7 @@ class Net(nn.Module):
for i in range(batch_s): for i in range(batch_s):
one_hot_vec[i, target[i]] = 1.0 one_hot_vec[i, target[i]] = 1.0
batch_size = input.size(0) batch_size = input.size(0)
v_c = torch.sqrt((input ** 2).sum(dim=2, keepdim=True)) v_c = torch.sqrt((input**2).sum(dim=2, keepdim=True))
zero = torch.zeros(1).to(self.device) zero = torch.zeros(1).to(self.device)
m_plus = 0.9 m_plus = 0.9
m_minus = 0.1 m_minus = 0.1
...@@ -43,15 +42,14 @@ class Net(nn.Module): ...@@ -43,15 +42,14 @@ class Net(nn.Module):
class PrimaryCapsuleLayer(nn.Module): class PrimaryCapsuleLayer(nn.Module):
def __init__(self, in_channel=256, num_unit=8, device="cpu"):
def __init__(self, in_channel=256, num_unit=8, device='cpu'):
super(PrimaryCapsuleLayer, self).__init__() super(PrimaryCapsuleLayer, self).__init__()
self.in_channel = in_channel self.in_channel = in_channel
self.num_unit = num_unit self.num_unit = num_unit
self.deivce = device self.deivce = device
self.conv_units = nn.ModuleList([ self.conv_units = nn.ModuleList(
nn.Conv2d(self.in_channel, 32, 9, 2) for _ in range(self.num_unit) [nn.Conv2d(self.in_channel, 32, 9, 2) for _ in range(self.num_unit)]
]) )
def forward(self, x): def forward(self, x):
unit = [self.conv_units[i](x) for i, l in enumerate(self.conv_units)] unit = [self.conv_units[i](x) for i, l in enumerate(self.conv_units)]
......
import dgl
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F
from DGLRoutingLayer import DGLRoutingLayer from DGLRoutingLayer import DGLRoutingLayer
from torch.nn import functional as F
import dgl
g = dgl.DGLGraph() g = dgl.DGLGraph()
g.graph_data = {} g.graph_data = {}
in_nodes = 20 in_nodes = 20
out_nodes = 10 out_nodes = 10
g.graph_data['in_nodes']=in_nodes g.graph_data["in_nodes"] = in_nodes
g.graph_data['out_nodes']=out_nodes g.graph_data["out_nodes"] = out_nodes
all_nodes = in_nodes + out_nodes all_nodes = in_nodes + out_nodes
g.add_nodes(all_nodes) g.add_nodes(all_nodes)
in_indx = list(range(in_nodes)) in_indx = list(range(in_nodes))
out_indx = list(range(in_nodes, in_nodes + out_nodes)) out_indx = list(range(in_nodes, in_nodes + out_nodes))
g.graph_data['in_indx']=in_indx g.graph_data["in_indx"] = in_indx
g.graph_data['out_indx']=out_indx g.graph_data["out_indx"] = out_indx
# add edges use edge broadcasting # add edges use edge broadcasting
for u in out_indx: for u in out_indx:
...@@ -28,17 +27,16 @@ for u in out_indx: ...@@ -28,17 +27,16 @@ for u in out_indx:
# init states # init states
f_size = 4 f_size = 4
g.ndata['v'] = th.zeros(all_nodes, f_size) g.ndata["v"] = th.zeros(all_nodes, f_size)
g.edata['u_hat'] = th.randn(in_nodes * out_nodes, f_size) g.edata["u_hat"] = th.randn(in_nodes * out_nodes, f_size)
g.edata['b'] = th.randn(in_nodes * out_nodes, 1) g.edata["b"] = th.randn(in_nodes * out_nodes, 1)
routing_layer = DGLRoutingLayer(g) routing_layer = DGLRoutingLayer(g)
entropy_list=[] entropy_list = []
for i in range(15): for i in range(15):
routing_layer() routing_layer()
dist_matrix = g.edata['c'].view(in_nodes, out_nodes) dist_matrix = g.edata["c"].view(in_nodes, out_nodes)
entropy = (-dist_matrix * th.log(dist_matrix)).sum(dim=0) entropy = (-dist_matrix * th.log(dist_matrix)).sum(dim=0)
entropy_list.append(entropy.data.numpy()) entropy_list.append(entropy.data.numpy())
std = dist_matrix.std(dim=0) std = dist_matrix.std(dim=0)
import dgl
import argparse import argparse
import torch as th import torch as th
from model import CAREGNN
import torch.optim as optim import torch.optim as optim
from torch.nn.functional import softmax from model import CAREGNN
from sklearn.metrics import recall_score, roc_auc_score from sklearn.metrics import recall_score, roc_auc_score
from torch.nn.functional import softmax
from utils import EarlyStopping from utils import EarlyStopping
import dgl
def main(args): def main(args):
# Step 1: Prepare graph data and retrieve train/validation/test index ============================= # # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
...@@ -18,45 +19,51 @@ def main(args): ...@@ -18,45 +19,51 @@ def main(args):
# check cuda # check cuda
if args.gpu >= 0 and th.cuda.is_available(): if args.gpu >= 0 and th.cuda.is_available():
device = 'cuda:{}'.format(args.gpu) device = "cuda:{}".format(args.gpu)
else: else:
device = 'cpu' device = "cpu"
# retrieve labels of ground truth # retrieve labels of ground truth
labels = graph.ndata['label'].to(device) labels = graph.ndata["label"].to(device)
# Extract node features # Extract node features
feat = graph.ndata['feature'].to(device) feat = graph.ndata["feature"].to(device)
# retrieve masks for train/validation/test # retrieve masks for train/validation/test
train_mask = graph.ndata['train_mask'] train_mask = graph.ndata["train_mask"]
val_mask = graph.ndata['val_mask'] val_mask = graph.ndata["val_mask"]
test_mask = graph.ndata['test_mask'] test_mask = graph.ndata["test_mask"]
train_idx = th.nonzero(train_mask, as_tuple=False).squeeze(1).to(device) train_idx = th.nonzero(train_mask, as_tuple=False).squeeze(1).to(device)
val_idx = th.nonzero(val_mask, as_tuple=False).squeeze(1).to(device) val_idx = th.nonzero(val_mask, as_tuple=False).squeeze(1).to(device)
test_idx = th.nonzero(test_mask, as_tuple=False).squeeze(1).to(device) test_idx = th.nonzero(test_mask, as_tuple=False).squeeze(1).to(device)
# Reinforcement learning module only for positive training nodes # Reinforcement learning module only for positive training nodes
rl_idx = th.nonzero(train_mask.to(device) & labels.bool(), as_tuple=False).squeeze(1) rl_idx = th.nonzero(
train_mask.to(device) & labels.bool(), as_tuple=False
).squeeze(1)
graph = graph.to(device) graph = graph.to(device)
# Step 2: Create model =================================================================== # # Step 2: Create model =================================================================== #
model = CAREGNN(in_dim=feat.shape[-1], model = CAREGNN(
num_classes=num_classes, in_dim=feat.shape[-1],
hid_dim=args.hid_dim, num_classes=num_classes,
num_layers=args.num_layers, hid_dim=args.hid_dim,
activation=th.tanh, num_layers=args.num_layers,
step_size=args.step_size, activation=th.tanh,
edges=graph.canonical_etypes) step_size=args.step_size,
edges=graph.canonical_etypes,
)
model = model.to(device) model = model.to(device)
# Step 3: Create training components ===================================================== # # Step 3: Create training components ===================================================== #
_, cnt = th.unique(labels, return_counts=True) _, cnt = th.unique(labels, return_counts=True)
loss_fn = th.nn.CrossEntropyLoss(weight=1 / cnt) loss_fn = th.nn.CrossEntropyLoss(weight=1 / cnt)
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer = optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay
)
if args.early_stop: if args.early_stop:
stopper = EarlyStopping(patience=100) stopper = EarlyStopping(patience=100)
...@@ -67,17 +74,30 @@ def main(args): ...@@ -67,17 +74,30 @@ def main(args):
logits_gnn, logits_sim = model(graph, feat) logits_gnn, logits_sim = model(graph, feat)
# compute loss # compute loss
tr_loss = loss_fn(logits_gnn[train_idx], labels[train_idx]) + \ tr_loss = loss_fn(
args.sim_weight * loss_fn(logits_sim[train_idx], labels[train_idx]) logits_gnn[train_idx], labels[train_idx]
) + args.sim_weight * loss_fn(logits_sim[train_idx], labels[train_idx])
tr_recall = recall_score(labels[train_idx].cpu(), logits_gnn.data[train_idx].argmax(dim=1).cpu())
tr_auc = roc_auc_score(labels[train_idx].cpu(), softmax(logits_gnn, dim=1).data[train_idx][:, 1].cpu()) tr_recall = recall_score(
labels[train_idx].cpu(),
logits_gnn.data[train_idx].argmax(dim=1).cpu(),
)
tr_auc = roc_auc_score(
labels[train_idx].cpu(),
softmax(logits_gnn, dim=1).data[train_idx][:, 1].cpu(),
)
# validation # validation
val_loss = loss_fn(logits_gnn[val_idx], labels[val_idx]) + \ val_loss = loss_fn(
args.sim_weight * loss_fn(logits_sim[val_idx], labels[val_idx]) logits_gnn[val_idx], labels[val_idx]
val_recall = recall_score(labels[val_idx].cpu(), logits_gnn.data[val_idx].argmax(dim=1).cpu()) ) + args.sim_weight * loss_fn(logits_sim[val_idx], labels[val_idx])
val_auc = roc_auc_score(labels[val_idx].cpu(), softmax(logits_gnn, dim=1).data[val_idx][:, 1].cpu()) val_recall = recall_score(
labels[val_idx].cpu(), logits_gnn.data[val_idx].argmax(dim=1).cpu()
)
val_auc = roc_auc_score(
labels[val_idx].cpu(),
softmax(logits_gnn, dim=1).data[val_idx][:, 1].cpu(),
)
# backward # backward
optimizer.zero_grad() optimizer.zero_grad()
...@@ -85,8 +105,17 @@ def main(args): ...@@ -85,8 +105,17 @@ def main(args):
optimizer.step() optimizer.step()
# Print out performance # Print out performance
print("Epoch {}, Train: Recall: {:.4f} AUC: {:.4f} Loss: {:.4f} | Val: Recall: {:.4f} AUC: {:.4f} Loss: {:.4f}" print(
.format(epoch, tr_recall, tr_auc, tr_loss.item(), val_recall, val_auc, val_loss.item())) "Epoch {}, Train: Recall: {:.4f} AUC: {:.4f} Loss: {:.4f} | Val: Recall: {:.4f} AUC: {:.4f} Loss: {:.4f}".format(
epoch,
tr_recall,
tr_auc,
tr_loss.item(),
val_recall,
val_auc,
val_loss.item(),
)
)
# Adjust p value with reinforcement learning module # Adjust p value with reinforcement learning module
model.RLModule(graph, epoch, rl_idx) model.RLModule(graph, epoch, rl_idx)
...@@ -98,32 +127,80 @@ def main(args): ...@@ -98,32 +127,80 @@ def main(args):
# Test after all epoch # Test after all epoch
model.eval() model.eval()
if args.early_stop: if args.early_stop:
model.load_state_dict(th.load('es_checkpoint.pt')) model.load_state_dict(th.load("es_checkpoint.pt"))
# forward # forward
logits_gnn, logits_sim = model.forward(graph, feat) logits_gnn, logits_sim = model.forward(graph, feat)
# compute loss # compute loss
test_loss = loss_fn(logits_gnn[test_idx], labels[test_idx]) + \ test_loss = loss_fn(
args.sim_weight * loss_fn(logits_sim[test_idx], labels[test_idx]) logits_gnn[test_idx], labels[test_idx]
test_recall = recall_score(labels[test_idx].cpu(), logits_gnn[test_idx].argmax(dim=1).cpu()) ) + args.sim_weight * loss_fn(logits_sim[test_idx], labels[test_idx])
test_auc = roc_auc_score(labels[test_idx].cpu(), softmax(logits_gnn, dim=1).data[test_idx][:, 1].cpu()) test_recall = recall_score(
labels[test_idx].cpu(), logits_gnn[test_idx].argmax(dim=1).cpu()
print("Test Recall: {:.4f} AUC: {:.4f} Loss: {:.4f}".format(test_recall, test_auc, test_loss.item())) )
test_auc = roc_auc_score(
labels[test_idx].cpu(),
if __name__ == '__main__': softmax(logits_gnn, dim=1).data[test_idx][:, 1].cpu(),
parser = argparse.ArgumentParser(description='GCN-based Anti-Spam Model') )
parser.add_argument("--dataset", type=str, default="amazon", help="DGL dataset for this model (yelp, or amazon)")
parser.add_argument("--gpu", type=int, default=-1, help="GPU index. Default: -1, using CPU.") print(
parser.add_argument("--hid_dim", type=int, default=64, help="Hidden layer dimension") "Test Recall: {:.4f} AUC: {:.4f} Loss: {:.4f}".format(
parser.add_argument("--num_layers", type=int, default=1, help="Number of layers") test_recall, test_auc, test_loss.item()
parser.add_argument("--max_epoch", type=int, default=30, help="The max number of epochs. Default: 30") )
parser.add_argument("--lr", type=float, default=0.01, help="Learning rate. Default: 0.01") )
parser.add_argument("--weight_decay", type=float, default=0.001, help="Weight decay. Default: 0.001")
parser.add_argument("--step_size", type=float, default=0.02, help="RL action step size (lambda 2). Default: 0.02")
parser.add_argument("--sim_weight", type=float, default=2, help="Similarity loss weight (lambda 1). Default: 2") if __name__ == "__main__":
parser.add_argument('--early-stop', action='store_true', default=False, help="indicates whether to use early stop") parser = argparse.ArgumentParser(description="GCN-based Anti-Spam Model")
parser.add_argument(
"--dataset",
type=str,
default="amazon",
help="DGL dataset for this model (yelp, or amazon)",
)
parser.add_argument(
"--gpu", type=int, default=-1, help="GPU index. Default: -1, using CPU."
)
parser.add_argument(
"--hid_dim", type=int, default=64, help="Hidden layer dimension"
)
parser.add_argument(
"--num_layers", type=int, default=1, help="Number of layers"
)
parser.add_argument(
"--max_epoch",
type=int,
default=30,
help="The max number of epochs. Default: 30",
)
parser.add_argument(
"--lr", type=float, default=0.01, help="Learning rate. Default: 0.01"
)
parser.add_argument(
"--weight_decay",
type=float,
default=0.001,
help="Weight decay. Default: 0.001",
)
parser.add_argument(
"--step_size",
type=float,
default=0.02,
help="RL action step size (lambda 2). Default: 0.02",
)
parser.add_argument(
"--sim_weight",
type=float,
default=2,
help="Similarity loss weight (lambda 1). Default: 2",
)
parser.add_argument(
"--early-stop",
action="store_true",
default=False,
help="indicates whether to use early stop",
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
import dgl
import argparse import argparse
import torch as th import torch as th
import torch.optim as optim import torch.optim as optim
from model_sampling import CAREGNN, CARESampler, _l1_dist
from sklearn.metrics import recall_score, roc_auc_score
from torch.nn.functional import softmax from torch.nn.functional import softmax
from sklearn.metrics import roc_auc_score, recall_score
from utils import EarlyStopping from utils import EarlyStopping
from model_sampling import CAREGNN, CARESampler, _l1_dist
import dgl
def evaluate(model, loss_fn, dataloader, device='cpu'): def evaluate(model, loss_fn, dataloader, device="cpu"):
loss = 0 loss = 0
auc = 0 auc = 0
recall = 0 recall = 0
num_blocks = 0 num_blocks = 0
for input_nodes, output_nodes, blocks in dataloader: for input_nodes, output_nodes, blocks in dataloader:
blocks = [b.to(device) for b in blocks] blocks = [b.to(device) for b in blocks]
feature = blocks[0].srcdata['feature'] feature = blocks[0].srcdata["feature"]
label = blocks[-1].dstdata['label'] label = blocks[-1].dstdata["label"]
logits_gnn, logits_sim = model(blocks, feature) logits_gnn, logits_sim = model(blocks, feature)
# compute loss # compute loss
loss += loss_fn(logits_gnn, label).item() + args.sim_weight * loss_fn(logits_sim, label).item() loss += (
recall += recall_score(label.cpu(), logits_gnn.argmax(dim=1).detach().cpu()) loss_fn(logits_gnn, label).item()
auc += roc_auc_score(label.cpu(), softmax(logits_gnn, dim=1)[:, 1].detach().cpu()) + args.sim_weight * loss_fn(logits_sim, label).item()
)
recall += recall_score(
label.cpu(), logits_gnn.argmax(dim=1).detach().cpu()
)
auc += roc_auc_score(
label.cpu(), softmax(logits_gnn, dim=1)[:, 1].detach().cpu()
)
num_blocks += 1 num_blocks += 1
return recall / num_blocks, auc / num_blocks, loss / num_blocks return recall / num_blocks, auc / num_blocks, loss / num_blocks
...@@ -38,47 +46,53 @@ def main(args): ...@@ -38,47 +46,53 @@ def main(args):
# check cuda # check cuda
if args.gpu >= 0 and th.cuda.is_available(): if args.gpu >= 0 and th.cuda.is_available():
device = 'cuda:{}'.format(args.gpu) device = "cuda:{}".format(args.gpu)
args.num_workers = 0 args.num_workers = 0
else: else:
device = 'cpu' device = "cpu"
# retrieve labels of ground truth # retrieve labels of ground truth
labels = graph.ndata['label'].to(device) labels = graph.ndata["label"].to(device)
# Extract node features # Extract node features
feat = graph.ndata['feature'].to(device) feat = graph.ndata["feature"].to(device)
layers_feat = feat.expand(args.num_layers, -1, -1) layers_feat = feat.expand(args.num_layers, -1, -1)
# retrieve masks for train/validation/test # retrieve masks for train/validation/test
train_mask = graph.ndata['train_mask'] train_mask = graph.ndata["train_mask"]
val_mask = graph.ndata['val_mask'] val_mask = graph.ndata["val_mask"]
test_mask = graph.ndata['test_mask'] test_mask = graph.ndata["test_mask"]
train_idx = th.nonzero(train_mask, as_tuple=False).squeeze(1).to(device) train_idx = th.nonzero(train_mask, as_tuple=False).squeeze(1).to(device)
val_idx = th.nonzero(val_mask, as_tuple=False).squeeze(1).to(device) val_idx = th.nonzero(val_mask, as_tuple=False).squeeze(1).to(device)
test_idx = th.nonzero(test_mask, as_tuple=False).squeeze(1).to(device) test_idx = th.nonzero(test_mask, as_tuple=False).squeeze(1).to(device)
# Reinforcement learning module only for positive training nodes # Reinforcement learning module only for positive training nodes
rl_idx = th.nonzero(train_mask.to(device) & labels.bool(), as_tuple=False).squeeze(1) rl_idx = th.nonzero(
train_mask.to(device) & labels.bool(), as_tuple=False
).squeeze(1)
graph = graph.to(device) graph = graph.to(device)
# Step 2: Create model =================================================================== # # Step 2: Create model =================================================================== #
model = CAREGNN(in_dim=feat.shape[-1], model = CAREGNN(
num_classes=num_classes, in_dim=feat.shape[-1],
hid_dim=args.hid_dim, num_classes=num_classes,
num_layers=args.num_layers, hid_dim=args.hid_dim,
activation=th.tanh, num_layers=args.num_layers,
step_size=args.step_size, activation=th.tanh,
edges=graph.canonical_etypes) step_size=args.step_size,
edges=graph.canonical_etypes,
)
model = model.to(device) model = model.to(device)
# Step 3: Create training components ===================================================== # # Step 3: Create training components ===================================================== #
_, cnt = th.unique(labels, return_counts=True) _, cnt = th.unique(labels, return_counts=True)
loss_fn = th.nn.CrossEntropyLoss(weight=1 / cnt) loss_fn = th.nn.CrossEntropyLoss(weight=1 / cnt)
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer = optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay
)
if args.early_stop: if args.early_stop:
stopper = EarlyStopping(patience=100) stopper = EarlyStopping(patience=100)
...@@ -89,13 +103,13 @@ def main(args): ...@@ -89,13 +103,13 @@ def main(args):
p = [] p = []
for i in range(args.num_layers): for i in range(args.num_layers):
dist = {} dist = {}
graph.ndata['nd'] = th.tanh(model.layers[i].MLP(layers_feat[i])) graph.ndata["nd"] = th.tanh(model.layers[i].MLP(layers_feat[i]))
for etype in graph.canonical_etypes: for etype in graph.canonical_etypes:
graph.apply_edges(_l1_dist, etype=etype) graph.apply_edges(_l1_dist, etype=etype)
dist[etype] = graph.edges[etype].data.pop('ed').detach().cpu() dist[etype] = graph.edges[etype].data.pop("ed").detach().cpu()
dists.append(dist) dists.append(dist)
p.append(model.layers[i].p) p.append(model.layers[i].p)
graph.ndata.pop('nd') graph.ndata.pop("nd")
sampler = CARESampler(p, dists, args.num_layers) sampler = CARESampler(p, dists, args.num_layers)
# train # train
...@@ -105,20 +119,33 @@ def main(args): ...@@ -105,20 +119,33 @@ def main(args):
tr_auc = 0 tr_auc = 0
tr_blk = 0 tr_blk = 0
train_dataloader = dgl.dataloading.DataLoader( train_dataloader = dgl.dataloading.DataLoader(
graph, train_idx, sampler, batch_size=args.batch_size, graph,
shuffle=True, drop_last=False, num_workers=args.num_workers) train_idx,
sampler,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
num_workers=args.num_workers,
)
for input_nodes, output_nodes, blocks in train_dataloader: for input_nodes, output_nodes, blocks in train_dataloader:
blocks = [b.to(device) for b in blocks] blocks = [b.to(device) for b in blocks]
train_feature = blocks[0].srcdata['feature'] train_feature = blocks[0].srcdata["feature"]
train_label = blocks[-1].dstdata['label'] train_label = blocks[-1].dstdata["label"]
logits_gnn, logits_sim = model(blocks, train_feature) logits_gnn, logits_sim = model(blocks, train_feature)
# compute loss # compute loss
blk_loss = loss_fn(logits_gnn, train_label) + args.sim_weight * loss_fn(logits_sim, train_label) blk_loss = loss_fn(
logits_gnn, train_label
) + args.sim_weight * loss_fn(logits_sim, train_label)
tr_loss += blk_loss.item() tr_loss += blk_loss.item()
tr_recall += recall_score(train_label.cpu(), logits_gnn.argmax(dim=1).detach().cpu()) tr_recall += recall_score(
tr_auc += roc_auc_score(train_label.cpu(), softmax(logits_gnn, dim=1)[:, 1].detach().cpu()) train_label.cpu(), logits_gnn.argmax(dim=1).detach().cpu()
)
tr_auc += roc_auc_score(
train_label.cpu(),
softmax(logits_gnn, dim=1)[:, 1].detach().cpu(),
)
tr_blk += 1 tr_blk += 1
# backward # backward
...@@ -132,15 +159,32 @@ def main(args): ...@@ -132,15 +159,32 @@ def main(args):
# validation # validation
model.eval() model.eval()
val_dataloader = dgl.dataloading.DataLoader( val_dataloader = dgl.dataloading.DataLoader(
graph, val_idx, sampler, batch_size=args.batch_size, graph,
shuffle=True, drop_last=False, num_workers=args.num_workers) val_idx,
sampler,
val_recall, val_auc, val_loss = evaluate(model, loss_fn, val_dataloader, device) batch_size=args.batch_size,
shuffle=True,
drop_last=False,
num_workers=args.num_workers,
)
val_recall, val_auc, val_loss = evaluate(
model, loss_fn, val_dataloader, device
)
# Print out performance # Print out performance
print("In epoch {}, Train Recall: {:.4f} | Train AUC: {:.4f} | Train Loss: {:.4f}; " print(
"Valid Recall: {:.4f} | Valid AUC: {:.4f} | Valid loss: {:.4f}". "In epoch {}, Train Recall: {:.4f} | Train AUC: {:.4f} | Train Loss: {:.4f}; "
format(epoch, tr_recall / tr_blk, tr_auc / tr_blk, tr_loss / tr_blk, val_recall, val_auc, val_loss)) "Valid Recall: {:.4f} | Valid AUC: {:.4f} | Valid loss: {:.4f}".format(
epoch,
tr_recall / tr_blk,
tr_auc / tr_blk,
tr_loss / tr_blk,
val_recall,
val_auc,
val_loss,
)
)
if args.early_stop: if args.early_stop:
if stopper.step(val_auc, model): if stopper.step(val_auc, model):
...@@ -149,30 +193,84 @@ def main(args): ...@@ -149,30 +193,84 @@ def main(args):
# Test with mini batch after all epoch # Test with mini batch after all epoch
model.eval() model.eval()
if args.early_stop: if args.early_stop:
model.load_state_dict(th.load('es_checkpoint.pt')) model.load_state_dict(th.load("es_checkpoint.pt"))
test_dataloader = dgl.dataloading.DataLoader( test_dataloader = dgl.dataloading.DataLoader(
graph, test_idx, sampler, batch_size=args.batch_size, graph,
shuffle=True, drop_last=False, num_workers=args.num_workers) test_idx,
sampler,
test_recall, test_auc, test_loss = evaluate(model, loss_fn, test_dataloader, device) batch_size=args.batch_size,
shuffle=True,
print("Test Recall: {:.4f} | Test AUC: {:.4f} | Test loss: {:.4f}".format(test_recall, test_auc, test_loss)) drop_last=False,
num_workers=args.num_workers,
)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN-based Anti-Spam Model') test_recall, test_auc, test_loss = evaluate(
parser.add_argument("--dataset", type=str, default="amazon", help="DGL dataset for this model (yelp, or amazon)") model, loss_fn, test_dataloader, device
parser.add_argument("--gpu", type=int, default=-1, help="GPU index. Default: -1, using CPU.") )
parser.add_argument("--hid_dim", type=int, default=64, help="Hidden layer dimension")
parser.add_argument("--num_layers", type=int, default=1, help="Number of layers") print(
parser.add_argument("--batch_size", type=int, default=256, help="Size of mini-batch") "Test Recall: {:.4f} | Test AUC: {:.4f} | Test loss: {:.4f}".format(
parser.add_argument("--max_epoch", type=int, default=30, help="The max number of epochs. Default: 30") test_recall, test_auc, test_loss
parser.add_argument("--lr", type=float, default=0.01, help="Learning rate. Default: 0.01") )
parser.add_argument("--weight_decay", type=float, default=0.001, help="Weight decay. Default: 0.001") )
parser.add_argument("--step_size", type=float, default=0.02, help="RL action step size (lambda 2). Default: 0.02")
parser.add_argument("--sim_weight", type=float, default=2, help="Similarity loss weight (lambda 1). Default: 0.001")
parser.add_argument("--num_workers", type=int, default=4, help="Number of node dataloader") if __name__ == "__main__":
parser.add_argument('--early-stop', action='store_true', default=False, help="indicates whether to use early stop") parser = argparse.ArgumentParser(description="GCN-based Anti-Spam Model")
parser.add_argument(
"--dataset",
type=str,
default="amazon",
help="DGL dataset for this model (yelp, or amazon)",
)
parser.add_argument(
"--gpu", type=int, default=-1, help="GPU index. Default: -1, using CPU."
)
parser.add_argument(
"--hid_dim", type=int, default=64, help="Hidden layer dimension"
)
parser.add_argument(
"--num_layers", type=int, default=1, help="Number of layers"
)
parser.add_argument(
"--batch_size", type=int, default=256, help="Size of mini-batch"
)
parser.add_argument(
"--max_epoch",
type=int,
default=30,
help="The max number of epochs. Default: 30",
)
parser.add_argument(
"--lr", type=float, default=0.01, help="Learning rate. Default: 0.01"
)
parser.add_argument(
"--weight_decay",
type=float,
default=0.001,
help="Weight decay. Default: 0.001",
)
parser.add_argument(
"--step_size",
type=float,
default=0.02,
help="RL action step size (lambda 2). Default: 0.02",
)
parser.add_argument(
"--sim_weight",
type=float,
default=2,
help="Similarity loss weight (lambda 1). Default: 0.001",
)
parser.add_argument(
"--num_workers", type=int, default=4, help="Number of node dataloader"
)
parser.add_argument(
"--early-stop",
action="store_true",
default=False,
help="indicates whether to use early stop",
)
args = parser.parse_args() args = parser.parse_args()
th.manual_seed(717) th.manual_seed(717)
......
import torch as th
import numpy as np import numpy as np
import torch as th
import torch.nn as nn import torch.nn as nn
import dgl.function as fn import dgl.function as fn
class CAREConv(nn.Module): class CAREConv(nn.Module):
"""One layer of CARE-GNN.""" """One layer of CARE-GNN."""
def __init__(self, in_dim, out_dim, num_classes, edges, activation=None, step_size=0.02): def __init__(
self,
in_dim,
out_dim,
num_classes,
edges,
activation=None,
step_size=0.02,
):
super(CAREConv, self).__init__() super(CAREConv, self).__init__()
self.activation = activation self.activation = activation
...@@ -33,20 +42,27 @@ class CAREConv(nn.Module): ...@@ -33,20 +42,27 @@ class CAREConv(nn.Module):
def _calc_distance(self, edges): def _calc_distance(self, edges):
# formula 2 # formula 2
d = th.norm(th.tanh(self.MLP(edges.src['h'])) - th.tanh(self.MLP(edges.dst['h'])), 1, 1) d = th.norm(
return {'d': d} th.tanh(self.MLP(edges.src["h"]))
- th.tanh(self.MLP(edges.dst["h"])),
1,
1,
)
return {"d": d}
def _top_p_sampling(self, g, p): def _top_p_sampling(self, g, p):
# this implementation is low efficient # this implementation is low efficient
# optimization requires dgl.sampling.select_top_p requested in issue #3100 # optimization requires dgl.sampling.select_top_p requested in issue #3100
dist = g.edata['d'] dist = g.edata["d"]
neigh_list = [] neigh_list = []
for node in g.nodes(): for node in g.nodes():
edges = g.in_edges(node, form='eid') edges = g.in_edges(node, form="eid")
num_neigh = th.ceil(g.in_degrees(node) * p).int().item() num_neigh = th.ceil(g.in_degrees(node) * p).int().item()
neigh_dist = dist[edges] neigh_dist = dist[edges]
if neigh_dist.shape[0] > num_neigh: if neigh_dist.shape[0] > num_neigh:
neigh_index = np.argpartition(neigh_dist.cpu().detach(), num_neigh)[:num_neigh] neigh_index = np.argpartition(
neigh_dist.cpu().detach(), num_neigh
)[:num_neigh]
else: else:
neigh_index = np.arange(num_neigh) neigh_index = np.arange(num_neigh)
neigh_list.append(edges[neigh_index]) neigh_list.append(edges[neigh_index])
...@@ -54,22 +70,29 @@ class CAREConv(nn.Module): ...@@ -54,22 +70,29 @@ class CAREConv(nn.Module):
def forward(self, g, feat): def forward(self, g, feat):
with g.local_scope(): with g.local_scope():
g.ndata['h'] = feat g.ndata["h"] = feat
hr = {} hr = {}
for i, etype in enumerate(g.canonical_etypes): for i, etype in enumerate(g.canonical_etypes):
g.apply_edges(self._calc_distance, etype=etype) g.apply_edges(self._calc_distance, etype=etype)
self.dist[etype] = g.edges[etype].data['d'] self.dist[etype] = g.edges[etype].data["d"]
sampled_edges = self._top_p_sampling(g[etype], self.p[etype]) sampled_edges = self._top_p_sampling(g[etype], self.p[etype])
# formula 8 # formula 8
g.send_and_recv(sampled_edges, fn.copy_u('h', 'm'), fn.mean('m', 'h_%s' % etype[1]), etype=etype) g.send_and_recv(
hr[etype] = g.ndata['h_%s' % etype[1]] sampled_edges,
fn.copy_u("h", "m"),
fn.mean("m", "h_%s" % etype[1]),
etype=etype,
)
hr[etype] = g.ndata["h_%s" % etype[1]]
if self.activation is not None: if self.activation is not None:
hr[etype] = self.activation(hr[etype]) hr[etype] = self.activation(hr[etype])
# formula 9 using mean as inter-relation aggregator # formula 9 using mean as inter-relation aggregator
p_tensor = th.Tensor(list(self.p.values())).view(-1, 1, 1).to(g.device) p_tensor = (
th.Tensor(list(self.p.values())).view(-1, 1, 1).to(g.device)
)
h_homo = th.sum(th.stack(list(hr.values())) * p_tensor, dim=0) h_homo = th.sum(th.stack(list(hr.values())) * p_tensor, dim=0)
h_homo += feat h_homo += feat
if self.activation is not None: if self.activation is not None:
...@@ -79,14 +102,16 @@ class CAREConv(nn.Module): ...@@ -79,14 +102,16 @@ class CAREConv(nn.Module):
class CAREGNN(nn.Module): class CAREGNN(nn.Module):
def __init__(self, def __init__(
in_dim, self,
num_classes, in_dim,
hid_dim=64, num_classes,
edges=None, hid_dim=64,
num_layers=2, edges=None,
activation=None, num_layers=2,
step_size=0.02): activation=None,
step_size=0.02,
):
super(CAREGNN, self).__init__() super(CAREGNN, self).__init__()
self.in_dim = in_dim self.in_dim = in_dim
self.hid_dim = hid_dim self.hid_dim = hid_dim
...@@ -100,38 +125,54 @@ class CAREGNN(nn.Module): ...@@ -100,38 +125,54 @@ class CAREGNN(nn.Module):
if self.num_layers == 1: if self.num_layers == 1:
# Single layer # Single layer
self.layers.append(CAREConv(self.in_dim, self.layers.append(
self.num_classes, CAREConv(
self.num_classes, self.in_dim,
self.edges, self.num_classes,
activation=self.activation, self.num_classes,
step_size=self.step_size)) self.edges,
activation=self.activation,
step_size=self.step_size,
)
)
else: else:
# Input layer # Input layer
self.layers.append(CAREConv(self.in_dim, self.layers.append(
self.hid_dim, CAREConv(
self.num_classes, self.in_dim,
self.edges, self.hid_dim,
activation=self.activation, self.num_classes,
step_size=self.step_size)) self.edges,
activation=self.activation,
step_size=self.step_size,
)
)
# Hidden layers with n - 2 layers # Hidden layers with n - 2 layers
for i in range(self.num_layers - 2): for i in range(self.num_layers - 2):
self.layers.append(CAREConv(self.hid_dim, self.layers.append(
self.hid_dim, CAREConv(
self.num_classes, self.hid_dim,
self.edges, self.hid_dim,
activation=self.activation, self.num_classes,
step_size=self.step_size)) self.edges,
activation=self.activation,
step_size=self.step_size,
)
)
# Output layer # Output layer
self.layers.append(CAREConv(self.hid_dim, self.layers.append(
self.num_classes, CAREConv(
self.num_classes, self.hid_dim,
self.edges, self.num_classes,
activation=self.activation, self.num_classes,
step_size=self.step_size)) self.edges,
activation=self.activation,
step_size=self.step_size,
)
)
def forward(self, graph, feat): def forward(self, graph, feat):
# For full graph training, directly use the graph # For full graph training, directly use the graph
...@@ -149,7 +190,7 @@ class CAREGNN(nn.Module): ...@@ -149,7 +190,7 @@ class CAREGNN(nn.Module):
for etype in self.edges: for etype in self.edges:
if not layer.cvg[etype]: if not layer.cvg[etype]:
# formula 5 # formula 5
eid = graph.in_edges(idx, form='eid', etype=etype) eid = graph.in_edges(idx, form="eid", etype=etype)
avg_dist = th.mean(layer.dist[etype][eid]) avg_dist = th.mean(layer.dist[etype][eid])
# formula 6 # formula 6
......
import dgl
import torch as th
import numpy as np import numpy as np
import torch as th
import torch.nn as nn import torch.nn as nn
import dgl
import dgl.function as fn import dgl.function as fn
def _l1_dist(edges): def _l1_dist(edges):
# formula 2 # formula 2
ed = th.norm(edges.src['nd'] - edges.dst['nd'], 1, 1) ed = th.norm(edges.src["nd"] - edges.dst["nd"], 1, 1)
return {'ed': ed} return {"ed": ed}
class CARESampler(dgl.dataloading.BlockSampler): class CARESampler(dgl.dataloading.BlockSampler):
...@@ -25,11 +26,20 @@ class CARESampler(dgl.dataloading.BlockSampler): ...@@ -25,11 +26,20 @@ class CARESampler(dgl.dataloading.BlockSampler):
edge_mask = th.zeros(g.number_of_edges(etype)) edge_mask = th.zeros(g.number_of_edges(etype))
# extract each node from dict because of single node type # extract each node from dict because of single node type
for node in seed_nodes: for node in seed_nodes:
edges = g.in_edges(node, form='eid', etype=etype) edges = g.in_edges(node, form="eid", etype=etype)
num_neigh = th.ceil(g.in_degrees(node, etype=etype) * self.p[block_id][etype]).int().item() num_neigh = (
th.ceil(
g.in_degrees(node, etype=etype)
* self.p[block_id][etype]
)
.int()
.item()
)
neigh_dist = self.dists[block_id][etype][edges] neigh_dist = self.dists[block_id][etype][edges]
if neigh_dist.shape[0] > num_neigh: if neigh_dist.shape[0] > num_neigh:
neigh_index = np.argpartition(neigh_dist, num_neigh)[:num_neigh] neigh_index = np.argpartition(neigh_dist, num_neigh)[
:num_neigh
]
else: else:
neigh_index = np.arange(num_neigh) neigh_index = np.arange(num_neigh)
edge_mask[edges[neigh_index]] = 1 edge_mask[edges[neigh_index]] = 1
...@@ -57,7 +67,15 @@ class CARESampler(dgl.dataloading.BlockSampler): ...@@ -57,7 +67,15 @@ class CARESampler(dgl.dataloading.BlockSampler):
class CAREConv(nn.Module): class CAREConv(nn.Module):
"""One layer of CARE-GNN.""" """One layer of CARE-GNN."""
def __init__(self, in_dim, out_dim, num_classes, edges, activation=None, step_size=0.02): def __init__(
self,
in_dim,
out_dim,
num_classes,
edges,
activation=None,
step_size=0.02,
):
super(CAREConv, self).__init__() super(CAREConv, self).__init__()
self.activation = activation self.activation = activation
...@@ -82,20 +100,22 @@ class CAREConv(nn.Module): ...@@ -82,20 +100,22 @@ class CAREConv(nn.Module):
self.cvg[etype] = False self.cvg[etype] = False
def forward(self, g, feat): def forward(self, g, feat):
g.srcdata['h'] = feat g.srcdata["h"] = feat
# formula 8 # formula 8
hr = {} hr = {}
for etype in g.canonical_etypes: for etype in g.canonical_etypes:
g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'hr'), etype=etype) g.update_all(fn.copy_u("h", "m"), fn.mean("m", "hr"), etype=etype)
hr[etype] = g.dstdata['hr'] hr[etype] = g.dstdata["hr"]
if self.activation is not None: if self.activation is not None:
hr[etype] = self.activation(hr[etype]) hr[etype] = self.activation(hr[etype])
# formula 9 using mean as inter-relation aggregator # formula 9 using mean as inter-relation aggregator
p_tensor = th.Tensor(list(self.p.values())).view(-1, 1, 1).to(feat.device) p_tensor = (
th.Tensor(list(self.p.values())).view(-1, 1, 1).to(feat.device)
)
h_homo = th.sum(th.stack(list(hr.values())) * p_tensor, dim=0) h_homo = th.sum(th.stack(list(hr.values())) * p_tensor, dim=0)
h_homo += feat[:g.number_of_dst_nodes()] h_homo += feat[: g.number_of_dst_nodes()]
if self.activation is not None: if self.activation is not None:
h_homo = self.activation(h_homo) h_homo = self.activation(h_homo)
...@@ -103,14 +123,16 @@ class CAREConv(nn.Module): ...@@ -103,14 +123,16 @@ class CAREConv(nn.Module):
class CAREGNN(nn.Module): class CAREGNN(nn.Module):
def __init__(self, def __init__(
in_dim, self,
num_classes, in_dim,
hid_dim=64, num_classes,
edges=None, hid_dim=64,
num_layers=2, edges=None,
activation=None, num_layers=2,
step_size=0.02): activation=None,
step_size=0.02,
):
super(CAREGNN, self).__init__() super(CAREGNN, self).__init__()
self.in_dim = in_dim self.in_dim = in_dim
self.hid_dim = hid_dim self.hid_dim = hid_dim
...@@ -124,42 +146,58 @@ class CAREGNN(nn.Module): ...@@ -124,42 +146,58 @@ class CAREGNN(nn.Module):
if self.num_layers == 1: if self.num_layers == 1:
# Single layer # Single layer
self.layers.append(CAREConv(self.in_dim, self.layers.append(
self.num_classes, CAREConv(
self.num_classes, self.in_dim,
self.edges, self.num_classes,
activation=self.activation, self.num_classes,
step_size=self.step_size)) self.edges,
activation=self.activation,
step_size=self.step_size,
)
)
else: else:
# Input layer # Input layer
self.layers.append(CAREConv(self.in_dim, self.layers.append(
self.hid_dim, CAREConv(
self.num_classes, self.in_dim,
self.edges, self.hid_dim,
activation=self.activation, self.num_classes,
step_size=self.step_size)) self.edges,
activation=self.activation,
step_size=self.step_size,
)
)
# Hidden layers with n - 2 layers # Hidden layers with n - 2 layers
for i in range(self.num_layers - 2): for i in range(self.num_layers - 2):
self.layers.append(CAREConv(self.hid_dim, self.layers.append(
self.hid_dim, CAREConv(
self.num_classes, self.hid_dim,
self.edges, self.hid_dim,
activation=self.activation, self.num_classes,
step_size=self.step_size)) self.edges,
activation=self.activation,
step_size=self.step_size,
)
)
# Output layer # Output layer
self.layers.append(CAREConv(self.hid_dim, self.layers.append(
self.num_classes, CAREConv(
self.num_classes, self.hid_dim,
self.edges, self.num_classes,
activation=self.activation, self.num_classes,
step_size=self.step_size)) self.edges,
activation=self.activation,
step_size=self.step_size,
)
)
def forward(self, blocks, feat): def forward(self, blocks, feat):
# formula 4 # formula 4
sim = th.tanh(self.layers[0].MLP(blocks[-1].dstdata['feature'].float())) sim = th.tanh(self.layers[0].MLP(blocks[-1].dstdata["feature"].float()))
# Forward of n layers of CARE-GNN # Forward of n layers of CARE-GNN
for block, layer in zip(blocks, self.layers): for block, layer in zip(blocks, self.layers):
...@@ -171,7 +209,7 @@ class CAREGNN(nn.Module): ...@@ -171,7 +209,7 @@ class CAREGNN(nn.Module):
for etype in self.edges: for etype in self.edges:
if not layer.cvg[etype]: if not layer.cvg[etype]:
# formula 5 # formula 5
eid = graph.in_edges(idx, form='eid', etype=etype) eid = graph.in_edges(idx, form="eid", etype=etype)
avg_dist = th.mean(dists[i][etype][eid]) avg_dist = th.mean(dists[i][etype][eid])
# formula 6 # formula 6
......
...@@ -18,7 +18,9 @@ class EarlyStopping: ...@@ -18,7 +18,9 @@ class EarlyStopping:
self.save_checkpoint(model) self.save_checkpoint(model)
elif score < self.best_score: elif score < self.best_score:
self.counter += 1 self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}') print(
f"EarlyStopping counter: {self.counter} out of {self.patience}"
)
if self.counter >= self.patience: if self.counter >= self.patience:
self.early_stop = True self.early_stop = True
else: else:
...@@ -28,5 +30,5 @@ class EarlyStopping: ...@@ -28,5 +30,5 @@ class EarlyStopping:
return self.early_stop return self.early_stop
def save_checkpoint(self, model): def save_checkpoint(self, model):
'''Saves model when validation loss decrease.''' """Saves model when validation loss decrease."""
torch.save(model.state_dict(), 'es_checkpoint.pt') torch.save(model.state_dict(), "es_checkpoint.pt")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment