Unverified Commit 704bcaf6 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files
parent 6bc82161
......@@ -2,6 +2,8 @@ import argparse
import time
import traceback
import dgl
import networkx as nx
import numpy as np
import torch
......@@ -11,12 +13,10 @@ from dataloader import (
MultiBodyTrainDataset,
MultiBodyValidDataset,
)
from models import MLP, InteractionNet, PrepareLayer
from models import InteractionNet, MLP, PrepareLayer
from torch.utils.data import DataLoader
from utils import make_video
import dgl
def train(
optimizer, loss_fn, reg_fn, model, prep, dataloader, lambda_reg, device
......
import torch
from modules import MSA, BiLSTM, GraphTrans
from modules import BiLSTM, GraphTrans, MSA
from torch import nn
from utlis import *
......
......@@ -2,11 +2,11 @@ import json
import pickle
import random
import dgl
import numpy as np
import torch
import dgl
NODE_TYPE = {"entity": 0, "root": 1, "relation": 2}
......
from typing import Optional
import dgl
import torch
import torch.nn
from torch import Tensor
import dgl
from dgl import DGLGraph
from dgl.nn import GraphConv
from torch import Tensor
class GraphConvWithDropout(GraphConv):
......
......@@ -3,18 +3,18 @@ import os
from datetime import datetime
from time import time
import dgl
import torch
import torch.nn.functional as F
from data_preprocess import degree_as_feature, node_label_as_feature
from dgl.data import LegacyTUDataset
from dgl.dataloading import GraphDataLoader
from networks import GraphClassifier
from torch import Tensor
from torch.utils.data import random_split
from utils import get_stats, parse_args
import dgl
from dgl.data import LegacyTUDataset
from dgl.dataloading import GraphDataLoader
def compute_loss(
cls_logits: Tensor,
......
......@@ -3,18 +3,18 @@ import os
from datetime import datetime
from time import time
import dgl
import torch
import torch.nn.functional as F
from data_preprocess import degree_as_feature, node_label_as_feature
from dgl.data import LegacyTUDataset
from dgl.dataloading import GraphDataLoader
from networks import GraphClassifier
from torch import Tensor
from torch.utils.data import random_split
from utils import get_stats, parse_args
import dgl
from dgl.data import LegacyTUDataset
from dgl.dataloading import GraphDataLoader
def compute_loss(
cls_logits: Tensor,
......
from typing import List, Tuple, Union
from layers import *
import dgl.function as fn
import torch
import torch.nn
import torch.nn.functional as F
import dgl.function as fn
from dgl.nn.pytorch.glob import SortPooling
......@@ -35,9 +35,18 @@ class GraphCrossModule(torch.nn.Module):
The weight parameter used at the end of GXN for channel fusion.
Default: :obj:`1.0`
"""
def __init__(self, pool_ratios:Union[float, List[float]], in_dim:int,
out_dim:int, hidden_dim:int, cross_weight:float=1.,
fuse_weight:float=1., dist:int=1, num_cross_layers:int=2):
def __init__(
self,
pool_ratios: Union[float, List[float]],
in_dim: int,
out_dim: int,
hidden_dim: int,
cross_weight: float = 1.0,
fuse_weight: float = 1.0,
dist: int = 1,
num_cross_layers: int = 2,
):
super(GraphCrossModule, self).__init__()
if isinstance(pool_ratios, float):
pool_ratios = (pool_ratios, pool_ratios)
......@@ -50,8 +59,12 @@ class GraphCrossModule(torch.nn.Module):
self.start_gcn_scale2 = GraphConvWithDropout(hidden_dim, hidden_dim)
self.end_gcn = GraphConvWithDropout(2 * hidden_dim, out_dim)
self.index_select_scale1 = IndexSelect(pool_ratios[0], hidden_dim, act="prelu", dist=dist)
self.index_select_scale2 = IndexSelect(pool_ratios[1], hidden_dim, act="prelu", dist=dist)
self.index_select_scale1 = IndexSelect(
pool_ratios[0], hidden_dim, act="prelu", dist=dist
)
self.index_select_scale2 = IndexSelect(
pool_ratios[1], hidden_dim, act="prelu", dist=dist
)
self.start_pool_s12 = GraphPool(hidden_dim)
self.start_pool_s23 = GraphPool(hidden_dim)
self.end_unpool_s21 = GraphUnpool(hidden_dim)
......@@ -85,21 +98,45 @@ class GraphCrossModule(torch.nn.Module):
graph_scale1 = graph
feat_scale1 = self.start_gcn_scale1(graph_scale1, feat)
feat_origin = feat_scale1
feat_scale1_neg = feat_scale1[torch.randperm(feat_scale1.size(0))] # negative samples
logit_s1, scores_s1, select_idx_s1, non_select_idx_s1, feat_down_s1 = \
self.index_select_scale1(graph_scale1, feat_scale1, feat_scale1_neg)
feat_scale2, graph_scale2 = self.start_pool_s12(graph_scale1, feat_scale1,
select_idx_s1, non_select_idx_s1,
scores_s1, pool_graph=True)
feat_scale1_neg = feat_scale1[
torch.randperm(feat_scale1.size(0))
] # negative samples
(
logit_s1,
scores_s1,
select_idx_s1,
non_select_idx_s1,
feat_down_s1,
) = self.index_select_scale1(graph_scale1, feat_scale1, feat_scale1_neg)
feat_scale2, graph_scale2 = self.start_pool_s12(
graph_scale1,
feat_scale1,
select_idx_s1,
non_select_idx_s1,
scores_s1,
pool_graph=True,
)
# start of scale-2
feat_scale2 = self.start_gcn_scale2(graph_scale2, feat_scale2)
feat_scale2_neg = feat_scale2[torch.randperm(feat_scale2.size(0))] # negative samples
logit_s2, scores_s2, select_idx_s2, non_select_idx_s2, feat_down_s2 = \
self.index_select_scale2(graph_scale2, feat_scale2, feat_scale2_neg)
feat_scale3, graph_scale3 = self.start_pool_s23(graph_scale2, feat_scale2,
select_idx_s2, non_select_idx_s2,
scores_s2, pool_graph=True)
feat_scale2_neg = feat_scale2[
torch.randperm(feat_scale2.size(0))
] # negative samples
(
logit_s2,
scores_s2,
select_idx_s2,
non_select_idx_s2,
feat_down_s2,
) = self.index_select_scale2(graph_scale2, feat_scale2, feat_scale2_neg)
feat_scale3, graph_scale3 = self.start_pool_s23(
graph_scale2,
feat_scale2,
select_idx_s2,
non_select_idx_s2,
scores_s2,
pool_graph=True,
)
# layer-1
res_s1_0, res_s2_0, res_s3_0 = feat_scale1, feat_scale2, feat_scale3
......@@ -109,18 +146,38 @@ class GraphCrossModule(torch.nn.Module):
feat_scale3 = F.relu(self.s3_l1_gcn(graph_scale3, feat_scale3))
if self.num_cross_layers >= 1:
feat_s12_fu = self.pool_s12_1(graph_scale1, feat_scale1,
select_idx_s1, non_select_idx_s1,
scores_s1)
feat_s21_fu = self.unpool_s21_1(graph_scale1, feat_scale2, select_idx_s1)
feat_s23_fu = self.pool_s23_1(graph_scale2, feat_scale2,
select_idx_s2, non_select_idx_s2,
scores_s2)
feat_s32_fu = self.unpool_s32_1(graph_scale2, feat_scale3, select_idx_s2)
feat_scale1 = feat_scale1 + self.cross_weight * feat_s21_fu + res_s1_0
feat_scale2 = feat_scale2 + self.cross_weight * (feat_s12_fu + feat_s32_fu) / 2 + res_s2_0
feat_scale3 = feat_scale3 + self.cross_weight * feat_s23_fu + res_s3_0
feat_s12_fu = self.pool_s12_1(
graph_scale1,
feat_scale1,
select_idx_s1,
non_select_idx_s1,
scores_s1,
)
feat_s21_fu = self.unpool_s21_1(
graph_scale1, feat_scale2, select_idx_s1
)
feat_s23_fu = self.pool_s23_1(
graph_scale2,
feat_scale2,
select_idx_s2,
non_select_idx_s2,
scores_s2,
)
feat_s32_fu = self.unpool_s32_1(
graph_scale2, feat_scale3, select_idx_s2
)
feat_scale1 = (
feat_scale1 + self.cross_weight * feat_s21_fu + res_s1_0
)
feat_scale2 = (
feat_scale2
+ self.cross_weight * (feat_s12_fu + feat_s32_fu) / 2
+ res_s2_0
)
feat_scale3 = (
feat_scale3 + self.cross_weight * feat_s23_fu + res_s3_0
)
# layer-2
feat_scale1 = F.relu(self.s1_l2_gcn(graph_scale1, feat_scale1))
......@@ -128,18 +185,32 @@ class GraphCrossModule(torch.nn.Module):
feat_scale3 = F.relu(self.s3_l2_gcn(graph_scale3, feat_scale3))
if self.num_cross_layers >= 2:
feat_s12_fu = self.pool_s12_2(graph_scale1, feat_scale1,
select_idx_s1, non_select_idx_s1,
scores_s1)
feat_s21_fu = self.unpool_s21_2(graph_scale1, feat_scale2, select_idx_s1)
feat_s23_fu = self.pool_s23_2(graph_scale2, feat_scale2,
select_idx_s2, non_select_idx_s2,
scores_s2)
feat_s32_fu = self.unpool_s32_2(graph_scale2, feat_scale3, select_idx_s2)
feat_s12_fu = self.pool_s12_2(
graph_scale1,
feat_scale1,
select_idx_s1,
non_select_idx_s1,
scores_s1,
)
feat_s21_fu = self.unpool_s21_2(
graph_scale1, feat_scale2, select_idx_s1
)
feat_s23_fu = self.pool_s23_2(
graph_scale2,
feat_scale2,
select_idx_s2,
non_select_idx_s2,
scores_s2,
)
feat_s32_fu = self.unpool_s32_2(
graph_scale2, feat_scale3, select_idx_s2
)
cross_weight = self.cross_weight * 0.05
feat_scale1 = feat_scale1 + cross_weight * feat_s21_fu
feat_scale2 = feat_scale2 + cross_weight * (feat_s12_fu + feat_s32_fu) / 2
feat_scale2 = (
feat_scale2 + cross_weight * (feat_s12_fu + feat_s32_fu) / 2
)
feat_scale3 = feat_scale3 + cross_weight * feat_s23_fu
# layer-3
......@@ -148,9 +219,18 @@ class GraphCrossModule(torch.nn.Module):
feat_scale3 = F.relu(self.s3_l3_gcn(graph_scale3, feat_scale3))
# final layers
feat_s3_out = self.end_unpool_s32(graph_scale2, feat_scale3, select_idx_s2) + feat_down_s2
feat_s2_out = self.end_unpool_s21(graph_scale1, feat_scale2 + feat_s3_out, select_idx_s1)
feat_agg = feat_scale1 + self.fuse_weight * feat_s2_out + self.fuse_weight * feat_down_s1
feat_s3_out = (
self.end_unpool_s32(graph_scale2, feat_scale3, select_idx_s2)
+ feat_down_s2
)
feat_s2_out = self.end_unpool_s21(
graph_scale1, feat_scale2 + feat_s3_out, select_idx_s1
)
feat_agg = (
feat_scale1
+ self.fuse_weight * feat_s2_out
+ self.fuse_weight * feat_down_s1
)
feat_agg = torch.cat((feat_agg, feat_origin), dim=1)
feat_agg = self.end_gcn(graph_scale1, feat_agg)
......@@ -198,11 +278,21 @@ class GraphCrossNet(torch.nn.Module):
The weight parameter used at the end of GXN for channel fusion.
Default: :obj:`1.0`
"""
def __init__(self, in_dim:int, out_dim:int, edge_feat_dim:int=0,
hidden_dim:int=96, pool_ratios:Union[List[float], float]=[0.9, 0.7],
readout_nodes:int=30, conv1d_dims:List[int]=[16, 32],
conv1d_kws:List[int]=[5],
cross_weight:float=1., fuse_weight:float=1., dist:int=1):
def __init__(
self,
in_dim: int,
out_dim: int,
edge_feat_dim: int = 0,
hidden_dim: int = 96,
pool_ratios: Union[List[float], float] = [0.9, 0.7],
readout_nodes: int = 30,
conv1d_dims: List[int] = [16, 32],
conv1d_kws: List[int] = [5],
cross_weight: float = 1.0,
fuse_weight: float = 1.0,
dist: int = 1,
):
super(GraphCrossNet, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
......@@ -217,20 +307,29 @@ class GraphCrossNet(torch.nn.Module):
else:
self.e2l_lin = None
self.gxn = GraphCrossModule(pool_ratios, in_dim=self.in_dim, out_dim=hidden_dim,
hidden_dim=hidden_dim//2, cross_weight=cross_weight,
fuse_weight=fuse_weight, dist=dist)
self.gxn = GraphCrossModule(
pool_ratios,
in_dim=self.in_dim,
out_dim=hidden_dim,
hidden_dim=hidden_dim // 2,
cross_weight=cross_weight,
fuse_weight=fuse_weight,
dist=dist,
)
self.sortpool = SortPooling(readout_nodes)
# final updates
self.final_conv1 = torch.nn.Conv1d(1, conv1d_dims[0],
kernel_size=conv1d_kws[0],
stride=conv1d_kws[0])
self.final_conv1 = torch.nn.Conv1d(
1, conv1d_dims[0], kernel_size=conv1d_kws[0], stride=conv1d_kws[0]
)
self.final_maxpool = torch.nn.MaxPool1d(2, 2)
self.final_conv2 = torch.nn.Conv1d(conv1d_dims[0], conv1d_dims[1],
kernel_size=conv1d_kws[1], stride=1)
self.final_conv2 = torch.nn.Conv1d(
conv1d_dims[0], conv1d_dims[1], kernel_size=conv1d_kws[1], stride=1
)
self.final_dense_dim = int((readout_nodes - 2) / 2 + 1)
self.final_dense_dim = (self.final_dense_dim - conv1d_kws[1] + 1) * conv1d_dims[1]
self.final_dense_dim = (
self.final_dense_dim - conv1d_kws[1] + 1
) * conv1d_dims[1]
if self.out_dim > 0:
self.out_lin = torch.nn.Linear(self.final_dense_dim, out_dim)
......@@ -245,7 +344,12 @@ class GraphCrossNet(torch.nn.Module):
if self.out_dim > 0:
torch.nn.init.xavier_normal_(self.out_lin.weight)
def forward(self, graph:DGLGraph, node_feat:Tensor, edge_feat:Optional[Tensor]=None):
def forward(
self,
graph: DGLGraph,
node_feat: Tensor,
edge_feat: Optional[Tensor] = None,
):
num_batch = graph.batch_size
if edge_feat is not None:
edge_feat = self.e2l_lin(edge_feat)
......@@ -280,9 +384,11 @@ class GraphClassifier(torch.nn.Module):
Graph Classifier for graph classification.
GXN + MLP
"""
def __init__(self, args):
super(GraphClassifier, self).__init__()
self.gxn = GraphCrossNet(in_dim=args.in_dim,
self.gxn = GraphCrossNet(
in_dim=args.in_dim,
out_dim=args.embed_dim,
edge_feat_dim=args.edge_feat_dim,
hidden_dim=args.hidden_dim,
......@@ -291,12 +397,18 @@ class GraphClassifier(torch.nn.Module):
conv1d_dims=args.conv1d_dims,
conv1d_kws=args.conv1d_kws,
cross_weight=args.cross_weight,
fuse_weight=args.fuse_weight)
fuse_weight=args.fuse_weight,
)
self.lin1 = torch.nn.Linear(args.embed_dim, args.final_dense_hidden_dim)
self.lin2 = torch.nn.Linear(args.final_dense_hidden_dim, args.out_dim)
self.dropout = args.dropout
def forward(self, graph:DGLGraph, node_feat:Tensor, edge_feat:Optional[Tensor]=None):
def forward(
self,
graph: DGLGraph,
node_feat: Tensor,
edge_feat: Optional[Tensor] = None,
):
embed, logits1, logits2 = self.gxn(graph, node_feat, edge_feat)
logits = F.relu(self.lin1(embed))
if self.dropout > 0:
......
......@@ -7,11 +7,10 @@ constructed another dataset from ACM with a different set of papers, connections
labels.
"""
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.nn.pytorch import GATConv
......
......@@ -6,19 +6,19 @@ so we sampled twice as many neighbors during val/test than training.
"""
import argparse
import dgl
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GATConv
from dgl.sampling import RandomWalkNeighborSampler
from model_hetero import SemanticAttention
from sklearn.metrics import f1_score
from torch.utils.data import DataLoader
from utils import EarlyStopping, set_random_seed
import dgl
from dgl.nn.pytorch import GATConv
from dgl.sampling import RandomWalkNeighborSampler
class HANLayer(torch.nn.Module):
"""
......
......@@ -5,13 +5,12 @@ import pickle
import random
from pprint import pprint
import dgl
import numpy as np
import torch
from scipy import io as sio
from scipy import sparse
import dgl
from dgl.data.utils import _get_dgl_url, download, get_download_dir
from scipy import io as sio, sparse
def set_random_seed(seed=0):
......
......@@ -7,12 +7,12 @@ Paper: https://arxiv.org/abs/1907.04652
from functools import partial
import dgl
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
from dgl.base import DGLError
from dgl.nn.pytorch import edge_softmax
from dgl.nn.pytorch.utils import Identity
......
......@@ -8,19 +8,19 @@ Paper: https://arxiv.org/abs/1907.04652
import argparse
import time
import dgl
import numpy as np
import torch
import torch.nn.functional as F
from hgao import HardGAT
from utils import EarlyStopping
import dgl
from dgl.data import (
CiteseerGraphDataset,
CoraGraphDataset,
PubmedGraphDataset,
register_data_args,
)
from hgao import HardGAT
from utils import EarlyStopping
def accuracy(logits, labels):
......@@ -161,7 +161,6 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="GAT")
register_data_args(parser)
parser.add_argument(
......
......@@ -7,15 +7,14 @@ for detailed description.
Here we implement a graph-edge version of sparsemax where we perform sparsemax for all edges
with the same node as end-node in graphs.
"""
import torch
from torch import Tensor
from torch.autograd import Function
import dgl
import torch
from dgl.backend import astype
from dgl.base import ALL, is_all
from dgl.heterograph_index import HeteroGraphIndex
from dgl.sparse import _gsddmm, _gspmm
from torch import Tensor
from torch.autograd import Function
def _neighbor_sort(
......
......@@ -7,10 +7,10 @@ import torch.nn.functional as F
from dgl import DGLGraph
from dgl.nn import AvgPooling, GraphConv, MaxPooling
from dgl.ops import edge_softmax
from torch import Tensor
from torch.nn import Parameter
from functions import edge_sparsemax
from torch import Tensor
from torch.nn import Parameter
from utils import get_batch_id, topk
......@@ -30,7 +30,8 @@ class WeightedGraphConv(GraphConv):
e_feat : torch.Tensor, optional
The edge features. Default: :obj:`None`
"""
def forward(self, graph:DGLGraph, n_feat, e_feat=None):
def forward(self, graph: DGLGraph, n_feat, e_feat=None):
if e_feat is None:
return super(WeightedGraphConv, self).forward(graph, n_feat)
......@@ -44,8 +45,7 @@ class WeightedGraphConv(GraphConv):
n_feat = n_feat * src_norm
graph.ndata["h"] = n_feat
graph.edata["e"] = e_feat
graph.update_all(fn.u_mul_e("h", "e", "m"),
fn.sum("m", "h"))
graph.update_all(fn.u_mul_e("h", "e", "m"), fn.sum("m", "h"))
n_feat = graph.ndata.pop("h")
n_feat = n_feat * dst_norm
if self.bias is not None:
......@@ -83,16 +83,21 @@ class NodeInfoScoreLayer(nn.Module):
Tensor
Score for each node.
"""
def __init__(self, sym_norm:bool=True):
def __init__(self, sym_norm: bool = True):
super(NodeInfoScoreLayer, self).__init__()
self.sym_norm = sym_norm
def forward(self, graph:dgl.DGLGraph, feat:Tensor, e_feat:Tensor):
def forward(self, graph: dgl.DGLGraph, feat: Tensor, e_feat: Tensor):
with graph.local_scope():
if self.sym_norm:
src_norm = torch.pow(graph.out_degrees().float().clamp(min=1), -0.5)
src_norm = torch.pow(
graph.out_degrees().float().clamp(min=1), -0.5
)
src_norm = src_norm.view(-1, 1).to(feat.device)
dst_norm = torch.pow(graph.in_degrees().float().clamp(min=1), -0.5)
dst_norm = torch.pow(
graph.in_degrees().float().clamp(min=1), -0.5
)
dst_norm = dst_norm.view(-1, 1).to(feat.device)
src_feat = feat * src_norm
......@@ -105,7 +110,7 @@ class NodeInfoScoreLayer(nn.Module):
dst_feat = graph.ndata.pop("h") * dst_norm
feat = feat - dst_feat
else:
dst_norm = 1. / graph.in_degrees().float().clamp(min=1)
dst_norm = 1.0 / graph.in_degrees().float().clamp(min=1)
dst_norm = dst_norm.view(-1, 1)
graph.ndata["h"] = feat
......@@ -159,9 +164,19 @@ class HGPSLPool(nn.Module):
torch.Tensor
Permutation index
"""
def __init__(self, in_feat:int, ratio=0.8, sample=True,
sym_score_norm=True, sparse=True, sl=True,
lamb=1.0, negative_slop=0.2, k_hop=3):
def __init__(
self,
in_feat: int,
ratio=0.8,
sample=True,
sym_score_norm=True,
sparse=True,
sl=True,
lamb=1.0,
negative_slop=0.2,
k_hop=3,
):
super(HGPSLPool, self).__init__()
self.in_feat = in_feat
self.ratio = ratio
......@@ -180,16 +195,17 @@ class HGPSLPool(nn.Module):
def reset_parameters(self):
nn.init.xavier_normal_(self.att.data)
def forward(self, graph:DGLGraph, feat:Tensor, e_feat=None):
def forward(self, graph: DGLGraph, feat: Tensor, e_feat=None):
# top-k pool first
if e_feat is None:
e_feat = torch.ones((graph.number_of_edges(),),
dtype=feat.dtype, device=feat.device)
e_feat = torch.ones(
(graph.number_of_edges(),), dtype=feat.dtype, device=feat.device
)
batch_num_nodes = graph.batch_num_nodes()
x_score = self.calc_info_score(graph, feat, e_feat)
perm, next_batch_num_nodes = topk(x_score, self.ratio,
get_batch_id(batch_num_nodes),
batch_num_nodes)
perm, next_batch_num_nodes = topk(
x_score, self.ratio, get_batch_id(batch_num_nodes), batch_num_nodes
)
feat = feat[perm]
pool_graph = None
if not self.sample or not self.sl:
......@@ -215,31 +231,43 @@ class HGPSLPool(nn.Module):
row, col = graph.all_edges()
num_nodes = graph.num_nodes()
scipy_adj = scipy.sparse.coo_matrix((e_feat.detach().cpu(), (row.detach().cpu(), col.detach().cpu())), shape=(num_nodes, num_nodes))
scipy_adj = scipy.sparse.coo_matrix(
(
e_feat.detach().cpu(),
(row.detach().cpu(), col.detach().cpu()),
),
shape=(num_nodes, num_nodes),
)
for _ in range(self.k_hop):
two_hop = scipy_adj ** 2
two_hop = scipy_adj**2
two_hop = two_hop * (1e-5 / two_hop.max())
scipy_adj = two_hop + scipy_adj
row, col = scipy_adj.nonzero()
row = torch.tensor(row, dtype=torch.long, device=graph.device)
col = torch.tensor(col, dtype=torch.long, device=graph.device)
e_feat = torch.tensor(scipy_adj.data, dtype=torch.float, device=feat.device)
e_feat = torch.tensor(
scipy_adj.data, dtype=torch.float, device=feat.device
)
# perform pooling on multi-hop graph
mask = perm.new_full((num_nodes, ), -1)
mask = perm.new_full((num_nodes,), -1)
i = torch.arange(perm.size(0), dtype=torch.long, device=perm.device)
mask[perm] = i
row, col = mask[row], mask[col]
mask = (row >=0 ) & (col >= 0)
mask = (row >= 0) & (col >= 0)
row, col = row[mask], col[mask]
e_feat = e_feat[mask]
# add remaining self loops
mask = row != col
num_nodes = perm.size(0) # num nodes after pool
loop_index = torch.arange(0, num_nodes, dtype=row.dtype, device=row.device)
loop_index = torch.arange(
0, num_nodes, dtype=row.dtype, device=row.device
)
inv_mask = ~mask
loop_weight = torch.full((num_nodes, ), 0, dtype=e_feat.dtype, device=e_feat.device)
loop_weight = torch.full(
(num_nodes,), 0, dtype=e_feat.dtype, device=e_feat.device
)
remaining_e_feat = e_feat[inv_mask]
if remaining_e_feat.numel() > 0:
loop_weight[row[inv_mask]] = remaining_e_feat
......@@ -249,8 +277,12 @@ class HGPSLPool(nn.Module):
col = torch.cat([col, loop_index], dim=0)
# attention scores
weights = (torch.cat([feat[row], feat[col]], dim=1) * self.att).sum(dim=-1)
weights = F.leaky_relu(weights, self.negative_slop) + e_feat * self.lamb
weights = (torch.cat([feat[row], feat[col]], dim=1) * self.att).sum(
dim=-1
)
weights = (
F.leaky_relu(weights, self.negative_slop) + e_feat * self.lamb
)
# sl and normalization
sl_graph = dgl.graph((row, col))
......@@ -274,19 +306,27 @@ class HGPSLPool(nn.Module):
# use dense to build, then transform to sparse.
# maybe there's more efficient way?
batch_num_nodes = next_batch_num_nodes
block_begin_idx = torch.cat([batch_num_nodes.new_zeros(1),
batch_num_nodes.cumsum(dim=0)[:-1]], dim=0)
block_begin_idx = torch.cat(
[
batch_num_nodes.new_zeros(1),
batch_num_nodes.cumsum(dim=0)[:-1],
],
dim=0,
)
block_end_idx = batch_num_nodes.cumsum(dim=0)
dense_adj = torch.zeros((pool_graph.num_nodes(),
pool_graph.num_nodes()),
dense_adj = torch.zeros(
(pool_graph.num_nodes(), pool_graph.num_nodes()),
dtype=torch.float,
device=feat.device)
device=feat.device,
)
for idx_b, idx_e in zip(block_begin_idx, block_end_idx):
dense_adj[idx_b:idx_e, idx_b:idx_e] = 1.
dense_adj[idx_b:idx_e, idx_b:idx_e] = 1.0
row, col = torch.nonzero(dense_adj).t().contiguous()
# compute weights for node-pairs
weights = (torch.cat([feat[row], feat[col]], dim=1) * self.att).sum(dim=-1)
weights = (torch.cat([feat[row], feat[col]], dim=1) * self.att).sum(
dim=-1
)
weights = F.leaky_relu(weights, self.negative_slop)
dense_adj[row, col] = weights
......@@ -316,15 +356,30 @@ class HGPSLPool(nn.Module):
class ConvPoolReadout(torch.nn.Module):
"""A helper class. (GraphConv -> Pooling -> Readout)"""
def __init__(self, in_feat:int, out_feat:int, pool_ratio=0.8,
sample:bool=False, sparse:bool=True, sl:bool=True,
lamb:float=1., pool:bool=True):
def __init__(
self,
in_feat: int,
out_feat: int,
pool_ratio=0.8,
sample: bool = False,
sparse: bool = True,
sl: bool = True,
lamb: float = 1.0,
pool: bool = True,
):
super(ConvPoolReadout, self).__init__()
self.use_pool = pool
self.conv = WeightedGraphConv(in_feat, out_feat)
if pool:
self.pool = HGPSLPool(out_feat, ratio=pool_ratio, sparse=sparse,
sample=sample, sl=sl, lamb=lamb)
self.pool = HGPSLPool(
out_feat,
ratio=pool_ratio,
sparse=sparse,
sample=sample,
sl=sl,
lamb=lamb,
)
else:
self.pool = None
self.avgpool = AvgPooling()
......@@ -334,5 +389,7 @@ class ConvPoolReadout(torch.nn.Module):
out = F.relu(self.conv(graph, feature, e_feat))
if self.use_pool:
graph, out, e_feat, _ = self.pool(graph, out, e_feat)
readout = torch.cat([self.avgpool(graph, out), self.maxpool(graph, out)], dim=-1)
readout = torch.cat(
[self.avgpool(graph, out), self.maxpool(graph, out)], dim=-1
)
return graph, out, e_feat, readout
......@@ -4,17 +4,17 @@ import logging
import os
from time import time
import dgl
import torch
import torch.nn
import torch.nn.functional as F
from dgl.data import LegacyTUDataset
from dgl.dataloading import GraphDataLoader
from networks import HGPSLModel
from torch.utils.data import random_split
from utils import get_stats
import dgl
from dgl.data import LegacyTUDataset
from dgl.dataloading import GraphDataLoader
def parse_args():
parser = argparse.ArgumentParser(description="HGP-SL-DGL")
......
import torch
import torch.nn
import torch.nn.functional as F
from layers import ConvPoolReadout
from dgl.nn import AvgPooling, MaxPooling
from layers import ConvPoolReadout
class HGPSLModel(torch.nn.Module):
......
import math
import dgl
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
from dgl.nn.functional import edge_softmax
......
......@@ -20,6 +20,8 @@ from torch import nn
from tqdm import tqdm
"""============================================================================================================="""
################### TensorBoard Settings ###################
def args2exp_name(args):
exp_name = f"{args.dataset}_{args.loss}_{args.lr}_bs{args.bs}_spc{args.samples_per_class}_embed{args.embed_dim}_arch{args.arch}_decay{args.decay}_fclr{args.fc_lr_mul}_anneal{args.sigmoid_temperature}"
......@@ -381,6 +383,8 @@ def eval_metrics_query_and_gallery_dataset(
"""============================================================================================================="""
####### RECOVER CLOSEST EXAMPLE IMAGES #######
def recover_closest_one_dataset(
feature_matrix_all, image_paths, save_path, n_image_samples=10, n_closest=3
......@@ -489,6 +493,8 @@ def recover_closest_inshop(
"""============================================================================================================="""
################## SET NETWORK TRAINING CHECKPOINT #####################
def set_checkpoint(model, opt, progress_saver, savepath):
"""
......@@ -514,6 +520,8 @@ def set_checkpoint(model, opt, progress_saver, savepath):
"""============================================================================================================="""
################## WRITE TO CSV FILE #####################
class CSV_Writer:
"""
......
......@@ -20,6 +20,8 @@ from torch.utils.data import Dataset
from torchvision import transforms
"""============================================================================"""
################ FUNCTION TO RETURN ALL DATALOADERS NECESSARY ####################
def give_dataloaders(dataset, trainset, testset, opt, cluster_path=""):
"""
......
......@@ -8,7 +8,6 @@ import netlib as netlib
import torch
if __name__ == "__main__":
################## INPUT ARGUMENTS ###################
parser = argparse.ArgumentParser()
####### Main Parameter: Dataset to use for Training
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment