Unverified Commit 704bcaf6 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files
parent 6bc82161
...@@ -2,6 +2,8 @@ import argparse ...@@ -2,6 +2,8 @@ import argparse
import time import time
import traceback import traceback
import dgl
import networkx as nx import networkx as nx
import numpy as np import numpy as np
import torch import torch
...@@ -11,12 +13,10 @@ from dataloader import ( ...@@ -11,12 +13,10 @@ from dataloader import (
MultiBodyTrainDataset, MultiBodyTrainDataset,
MultiBodyValidDataset, MultiBodyValidDataset,
) )
from models import MLP, InteractionNet, PrepareLayer from models import InteractionNet, MLP, PrepareLayer
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from utils import make_video from utils import make_video
import dgl
def train( def train(
optimizer, loss_fn, reg_fn, model, prep, dataloader, lambda_reg, device optimizer, loss_fn, reg_fn, model, prep, dataloader, lambda_reg, device
......
import torch import torch
from modules import MSA, BiLSTM, GraphTrans from modules import BiLSTM, GraphTrans, MSA
from torch import nn from torch import nn
from utlis import * from utlis import *
......
...@@ -2,11 +2,11 @@ import json ...@@ -2,11 +2,11 @@ import json
import pickle import pickle
import random import random
import dgl
import numpy as np import numpy as np
import torch import torch
import dgl
NODE_TYPE = {"entity": 0, "root": 1, "relation": 2} NODE_TYPE = {"entity": 0, "root": 1, "relation": 2}
......
from typing import Optional from typing import Optional
import dgl
import torch import torch
import torch.nn import torch.nn
from torch import Tensor
import dgl
from dgl import DGLGraph from dgl import DGLGraph
from dgl.nn import GraphConv from dgl.nn import GraphConv
from torch import Tensor
class GraphConvWithDropout(GraphConv): class GraphConvWithDropout(GraphConv):
......
...@@ -3,18 +3,18 @@ import os ...@@ -3,18 +3,18 @@ import os
from datetime import datetime from datetime import datetime
from time import time from time import time
import dgl
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from data_preprocess import degree_as_feature, node_label_as_feature from data_preprocess import degree_as_feature, node_label_as_feature
from dgl.data import LegacyTUDataset
from dgl.dataloading import GraphDataLoader
from networks import GraphClassifier from networks import GraphClassifier
from torch import Tensor from torch import Tensor
from torch.utils.data import random_split from torch.utils.data import random_split
from utils import get_stats, parse_args from utils import get_stats, parse_args
import dgl
from dgl.data import LegacyTUDataset
from dgl.dataloading import GraphDataLoader
def compute_loss( def compute_loss(
cls_logits: Tensor, cls_logits: Tensor,
......
...@@ -3,18 +3,18 @@ import os ...@@ -3,18 +3,18 @@ import os
from datetime import datetime from datetime import datetime
from time import time from time import time
import dgl
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from data_preprocess import degree_as_feature, node_label_as_feature from data_preprocess import degree_as_feature, node_label_as_feature
from dgl.data import LegacyTUDataset
from dgl.dataloading import GraphDataLoader
from networks import GraphClassifier from networks import GraphClassifier
from torch import Tensor from torch import Tensor
from torch.utils.data import random_split from torch.utils.data import random_split
from utils import get_stats, parse_args from utils import get_stats, parse_args
import dgl
from dgl.data import LegacyTUDataset
from dgl.dataloading import GraphDataLoader
def compute_loss( def compute_loss(
cls_logits: Tensor, cls_logits: Tensor,
......
from typing import List, Tuple, Union from typing import List, Tuple, Union
from layers import * from layers import *
import dgl.function as fn
import torch import torch
import torch.nn import torch.nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl.function as fn
from dgl.nn.pytorch.glob import SortPooling from dgl.nn.pytorch.glob import SortPooling
...@@ -35,9 +35,18 @@ class GraphCrossModule(torch.nn.Module): ...@@ -35,9 +35,18 @@ class GraphCrossModule(torch.nn.Module):
The weight parameter used at the end of GXN for channel fusion. The weight parameter used at the end of GXN for channel fusion.
Default: :obj:`1.0` Default: :obj:`1.0`
""" """
def __init__(self, pool_ratios:Union[float, List[float]], in_dim:int,
out_dim:int, hidden_dim:int, cross_weight:float=1., def __init__(
fuse_weight:float=1., dist:int=1, num_cross_layers:int=2): self,
pool_ratios: Union[float, List[float]],
in_dim: int,
out_dim: int,
hidden_dim: int,
cross_weight: float = 1.0,
fuse_weight: float = 1.0,
dist: int = 1,
num_cross_layers: int = 2,
):
super(GraphCrossModule, self).__init__() super(GraphCrossModule, self).__init__()
if isinstance(pool_ratios, float): if isinstance(pool_ratios, float):
pool_ratios = (pool_ratios, pool_ratios) pool_ratios = (pool_ratios, pool_ratios)
...@@ -50,8 +59,12 @@ class GraphCrossModule(torch.nn.Module): ...@@ -50,8 +59,12 @@ class GraphCrossModule(torch.nn.Module):
self.start_gcn_scale2 = GraphConvWithDropout(hidden_dim, hidden_dim) self.start_gcn_scale2 = GraphConvWithDropout(hidden_dim, hidden_dim)
self.end_gcn = GraphConvWithDropout(2 * hidden_dim, out_dim) self.end_gcn = GraphConvWithDropout(2 * hidden_dim, out_dim)
self.index_select_scale1 = IndexSelect(pool_ratios[0], hidden_dim, act="prelu", dist=dist) self.index_select_scale1 = IndexSelect(
self.index_select_scale2 = IndexSelect(pool_ratios[1], hidden_dim, act="prelu", dist=dist) pool_ratios[0], hidden_dim, act="prelu", dist=dist
)
self.index_select_scale2 = IndexSelect(
pool_ratios[1], hidden_dim, act="prelu", dist=dist
)
self.start_pool_s12 = GraphPool(hidden_dim) self.start_pool_s12 = GraphPool(hidden_dim)
self.start_pool_s23 = GraphPool(hidden_dim) self.start_pool_s23 = GraphPool(hidden_dim)
self.end_unpool_s21 = GraphUnpool(hidden_dim) self.end_unpool_s21 = GraphUnpool(hidden_dim)
...@@ -85,72 +98,139 @@ class GraphCrossModule(torch.nn.Module): ...@@ -85,72 +98,139 @@ class GraphCrossModule(torch.nn.Module):
graph_scale1 = graph graph_scale1 = graph
feat_scale1 = self.start_gcn_scale1(graph_scale1, feat) feat_scale1 = self.start_gcn_scale1(graph_scale1, feat)
feat_origin = feat_scale1 feat_origin = feat_scale1
feat_scale1_neg = feat_scale1[torch.randperm(feat_scale1.size(0))] # negative samples feat_scale1_neg = feat_scale1[
logit_s1, scores_s1, select_idx_s1, non_select_idx_s1, feat_down_s1 = \ torch.randperm(feat_scale1.size(0))
self.index_select_scale1(graph_scale1, feat_scale1, feat_scale1_neg) ] # negative samples
feat_scale2, graph_scale2 = self.start_pool_s12(graph_scale1, feat_scale1, (
select_idx_s1, non_select_idx_s1, logit_s1,
scores_s1, pool_graph=True) scores_s1,
select_idx_s1,
non_select_idx_s1,
feat_down_s1,
) = self.index_select_scale1(graph_scale1, feat_scale1, feat_scale1_neg)
feat_scale2, graph_scale2 = self.start_pool_s12(
graph_scale1,
feat_scale1,
select_idx_s1,
non_select_idx_s1,
scores_s1,
pool_graph=True,
)
# start of scale-2 # start of scale-2
feat_scale2 = self.start_gcn_scale2(graph_scale2, feat_scale2) feat_scale2 = self.start_gcn_scale2(graph_scale2, feat_scale2)
feat_scale2_neg = feat_scale2[torch.randperm(feat_scale2.size(0))] # negative samples feat_scale2_neg = feat_scale2[
logit_s2, scores_s2, select_idx_s2, non_select_idx_s2, feat_down_s2 = \ torch.randperm(feat_scale2.size(0))
self.index_select_scale2(graph_scale2, feat_scale2, feat_scale2_neg) ] # negative samples
feat_scale3, graph_scale3 = self.start_pool_s23(graph_scale2, feat_scale2, (
select_idx_s2, non_select_idx_s2, logit_s2,
scores_s2, pool_graph=True) scores_s2,
select_idx_s2,
non_select_idx_s2,
feat_down_s2,
) = self.index_select_scale2(graph_scale2, feat_scale2, feat_scale2_neg)
feat_scale3, graph_scale3 = self.start_pool_s23(
graph_scale2,
feat_scale2,
select_idx_s2,
non_select_idx_s2,
scores_s2,
pool_graph=True,
)
# layer-1 # layer-1
res_s1_0, res_s2_0, res_s3_0 = feat_scale1, feat_scale2, feat_scale3 res_s1_0, res_s2_0, res_s3_0 = feat_scale1, feat_scale2, feat_scale3
feat_scale1 = F.relu(self.s1_l1_gcn(graph_scale1, feat_scale1)) feat_scale1 = F.relu(self.s1_l1_gcn(graph_scale1, feat_scale1))
feat_scale2 = F.relu(self.s2_l1_gcn(graph_scale2, feat_scale2)) feat_scale2 = F.relu(self.s2_l1_gcn(graph_scale2, feat_scale2))
feat_scale3 = F.relu(self.s3_l1_gcn(graph_scale3, feat_scale3)) feat_scale3 = F.relu(self.s3_l1_gcn(graph_scale3, feat_scale3))
if self.num_cross_layers >= 1: if self.num_cross_layers >= 1:
feat_s12_fu = self.pool_s12_1(graph_scale1, feat_scale1, feat_s12_fu = self.pool_s12_1(
select_idx_s1, non_select_idx_s1, graph_scale1,
scores_s1) feat_scale1,
feat_s21_fu = self.unpool_s21_1(graph_scale1, feat_scale2, select_idx_s1) select_idx_s1,
feat_s23_fu = self.pool_s23_1(graph_scale2, feat_scale2, non_select_idx_s1,
select_idx_s2, non_select_idx_s2, scores_s1,
scores_s2) )
feat_s32_fu = self.unpool_s32_1(graph_scale2, feat_scale3, select_idx_s2) feat_s21_fu = self.unpool_s21_1(
graph_scale1, feat_scale2, select_idx_s1
feat_scale1 = feat_scale1 + self.cross_weight * feat_s21_fu + res_s1_0 )
feat_scale2 = feat_scale2 + self.cross_weight * (feat_s12_fu + feat_s32_fu) / 2 + res_s2_0 feat_s23_fu = self.pool_s23_1(
feat_scale3 = feat_scale3 + self.cross_weight * feat_s23_fu + res_s3_0 graph_scale2,
feat_scale2,
select_idx_s2,
non_select_idx_s2,
scores_s2,
)
feat_s32_fu = self.unpool_s32_1(
graph_scale2, feat_scale3, select_idx_s2
)
feat_scale1 = (
feat_scale1 + self.cross_weight * feat_s21_fu + res_s1_0
)
feat_scale2 = (
feat_scale2
+ self.cross_weight * (feat_s12_fu + feat_s32_fu) / 2
+ res_s2_0
)
feat_scale3 = (
feat_scale3 + self.cross_weight * feat_s23_fu + res_s3_0
)
# layer-2 # layer-2
feat_scale1 = F.relu(self.s1_l2_gcn(graph_scale1, feat_scale1)) feat_scale1 = F.relu(self.s1_l2_gcn(graph_scale1, feat_scale1))
feat_scale2 = F.relu(self.s2_l2_gcn(graph_scale2, feat_scale2)) feat_scale2 = F.relu(self.s2_l2_gcn(graph_scale2, feat_scale2))
feat_scale3 = F.relu(self.s3_l2_gcn(graph_scale3, feat_scale3)) feat_scale3 = F.relu(self.s3_l2_gcn(graph_scale3, feat_scale3))
if self.num_cross_layers >= 2: if self.num_cross_layers >= 2:
feat_s12_fu = self.pool_s12_2(graph_scale1, feat_scale1, feat_s12_fu = self.pool_s12_2(
select_idx_s1, non_select_idx_s1, graph_scale1,
scores_s1) feat_scale1,
feat_s21_fu = self.unpool_s21_2(graph_scale1, feat_scale2, select_idx_s1) select_idx_s1,
feat_s23_fu = self.pool_s23_2(graph_scale2, feat_scale2, non_select_idx_s1,
select_idx_s2, non_select_idx_s2, scores_s1,
scores_s2) )
feat_s32_fu = self.unpool_s32_2(graph_scale2, feat_scale3, select_idx_s2) feat_s21_fu = self.unpool_s21_2(
graph_scale1, feat_scale2, select_idx_s1
)
feat_s23_fu = self.pool_s23_2(
graph_scale2,
feat_scale2,
select_idx_s2,
non_select_idx_s2,
scores_s2,
)
feat_s32_fu = self.unpool_s32_2(
graph_scale2, feat_scale3, select_idx_s2
)
cross_weight = self.cross_weight * 0.05 cross_weight = self.cross_weight * 0.05
feat_scale1 = feat_scale1 + cross_weight * feat_s21_fu feat_scale1 = feat_scale1 + cross_weight * feat_s21_fu
feat_scale2 = feat_scale2 + cross_weight * (feat_s12_fu + feat_s32_fu) / 2 feat_scale2 = (
feat_scale2 + cross_weight * (feat_s12_fu + feat_s32_fu) / 2
)
feat_scale3 = feat_scale3 + cross_weight * feat_s23_fu feat_scale3 = feat_scale3 + cross_weight * feat_s23_fu
# layer-3 # layer-3
feat_scale1 = F.relu(self.s1_l3_gcn(graph_scale1, feat_scale1)) feat_scale1 = F.relu(self.s1_l3_gcn(graph_scale1, feat_scale1))
feat_scale2 = F.relu(self.s2_l3_gcn(graph_scale2, feat_scale2)) feat_scale2 = F.relu(self.s2_l3_gcn(graph_scale2, feat_scale2))
feat_scale3 = F.relu(self.s3_l3_gcn(graph_scale3, feat_scale3)) feat_scale3 = F.relu(self.s3_l3_gcn(graph_scale3, feat_scale3))
# final layers # final layers
feat_s3_out = self.end_unpool_s32(graph_scale2, feat_scale3, select_idx_s2) + feat_down_s2 feat_s3_out = (
feat_s2_out = self.end_unpool_s21(graph_scale1, feat_scale2 + feat_s3_out, select_idx_s1) self.end_unpool_s32(graph_scale2, feat_scale3, select_idx_s2)
feat_agg = feat_scale1 + self.fuse_weight * feat_s2_out + self.fuse_weight * feat_down_s1 + feat_down_s2
)
feat_s2_out = self.end_unpool_s21(
graph_scale1, feat_scale2 + feat_s3_out, select_idx_s1
)
feat_agg = (
feat_scale1
+ self.fuse_weight * feat_s2_out
+ self.fuse_weight * feat_down_s1
)
feat_agg = torch.cat((feat_agg, feat_origin), dim=1) feat_agg = torch.cat((feat_agg, feat_origin), dim=1)
feat_agg = self.end_gcn(graph_scale1, feat_agg) feat_agg = self.end_gcn(graph_scale1, feat_agg)
...@@ -198,11 +278,21 @@ class GraphCrossNet(torch.nn.Module): ...@@ -198,11 +278,21 @@ class GraphCrossNet(torch.nn.Module):
The weight parameter used at the end of GXN for channel fusion. The weight parameter used at the end of GXN for channel fusion.
Default: :obj:`1.0` Default: :obj:`1.0`
""" """
def __init__(self, in_dim:int, out_dim:int, edge_feat_dim:int=0,
hidden_dim:int=96, pool_ratios:Union[List[float], float]=[0.9, 0.7], def __init__(
readout_nodes:int=30, conv1d_dims:List[int]=[16, 32], self,
conv1d_kws:List[int]=[5], in_dim: int,
cross_weight:float=1., fuse_weight:float=1., dist:int=1): out_dim: int,
edge_feat_dim: int = 0,
hidden_dim: int = 96,
pool_ratios: Union[List[float], float] = [0.9, 0.7],
readout_nodes: int = 30,
conv1d_dims: List[int] = [16, 32],
conv1d_kws: List[int] = [5],
cross_weight: float = 1.0,
fuse_weight: float = 1.0,
dist: int = 1,
):
super(GraphCrossNet, self).__init__() super(GraphCrossNet, self).__init__()
self.in_dim = in_dim self.in_dim = in_dim
self.out_dim = out_dim self.out_dim = out_dim
...@@ -217,26 +307,35 @@ class GraphCrossNet(torch.nn.Module): ...@@ -217,26 +307,35 @@ class GraphCrossNet(torch.nn.Module):
else: else:
self.e2l_lin = None self.e2l_lin = None
self.gxn = GraphCrossModule(pool_ratios, in_dim=self.in_dim, out_dim=hidden_dim, self.gxn = GraphCrossModule(
hidden_dim=hidden_dim//2, cross_weight=cross_weight, pool_ratios,
fuse_weight=fuse_weight, dist=dist) in_dim=self.in_dim,
out_dim=hidden_dim,
hidden_dim=hidden_dim // 2,
cross_weight=cross_weight,
fuse_weight=fuse_weight,
dist=dist,
)
self.sortpool = SortPooling(readout_nodes) self.sortpool = SortPooling(readout_nodes)
# final updates # final updates
self.final_conv1 = torch.nn.Conv1d(1, conv1d_dims[0], self.final_conv1 = torch.nn.Conv1d(
kernel_size=conv1d_kws[0], 1, conv1d_dims[0], kernel_size=conv1d_kws[0], stride=conv1d_kws[0]
stride=conv1d_kws[0]) )
self.final_maxpool = torch.nn.MaxPool1d(2, 2) self.final_maxpool = torch.nn.MaxPool1d(2, 2)
self.final_conv2 = torch.nn.Conv1d(conv1d_dims[0], conv1d_dims[1], self.final_conv2 = torch.nn.Conv1d(
kernel_size=conv1d_kws[1], stride=1) conv1d_dims[0], conv1d_dims[1], kernel_size=conv1d_kws[1], stride=1
)
self.final_dense_dim = int((readout_nodes - 2) / 2 + 1) self.final_dense_dim = int((readout_nodes - 2) / 2 + 1)
self.final_dense_dim = (self.final_dense_dim - conv1d_kws[1] + 1) * conv1d_dims[1] self.final_dense_dim = (
self.final_dense_dim - conv1d_kws[1] + 1
) * conv1d_dims[1]
if self.out_dim > 0: if self.out_dim > 0:
self.out_lin = torch.nn.Linear(self.final_dense_dim, out_dim) self.out_lin = torch.nn.Linear(self.final_dense_dim, out_dim)
self.init_weights() self.init_weights()
def init_weights(self): def init_weights(self):
if self.e2l_lin is not None: if self.e2l_lin is not None:
torch.nn.init.xavier_normal_(self.e2l_lin.weight) torch.nn.init.xavier_normal_(self.e2l_lin.weight)
...@@ -245,7 +344,12 @@ class GraphCrossNet(torch.nn.Module): ...@@ -245,7 +344,12 @@ class GraphCrossNet(torch.nn.Module):
if self.out_dim > 0: if self.out_dim > 0:
torch.nn.init.xavier_normal_(self.out_lin.weight) torch.nn.init.xavier_normal_(self.out_lin.weight)
def forward(self, graph:DGLGraph, node_feat:Tensor, edge_feat:Optional[Tensor]=None): def forward(
self,
graph: DGLGraph,
node_feat: Tensor,
edge_feat: Optional[Tensor] = None,
):
num_batch = graph.batch_size num_batch = graph.batch_size
if edge_feat is not None: if edge_feat is not None:
edge_feat = self.e2l_lin(edge_feat) edge_feat = self.e2l_lin(edge_feat)
...@@ -263,13 +367,13 @@ class GraphCrossNet(torch.nn.Module): ...@@ -263,13 +367,13 @@ class GraphCrossNet(torch.nn.Module):
conv1d_result = F.relu(self.final_conv1(to_conv1d)) conv1d_result = F.relu(self.final_conv1(to_conv1d))
conv1d_result = self.final_maxpool(conv1d_result) conv1d_result = self.final_maxpool(conv1d_result)
conv1d_result = F.relu(self.final_conv2(conv1d_result)) conv1d_result = F.relu(self.final_conv2(conv1d_result))
to_dense = conv1d_result.view(num_batch, -1) to_dense = conv1d_result.view(num_batch, -1)
if self.out_dim > 0: if self.out_dim > 0:
out = F.relu(self.out_lin(to_dense)) out = F.relu(self.out_lin(to_dense))
else: else:
out = to_dense out = to_dense
return out, logits1, logits2 return out, logits1, logits2
...@@ -280,23 +384,31 @@ class GraphClassifier(torch.nn.Module): ...@@ -280,23 +384,31 @@ class GraphClassifier(torch.nn.Module):
Graph Classifier for graph classification. Graph Classifier for graph classification.
GXN + MLP GXN + MLP
""" """
def __init__(self, args): def __init__(self, args):
super(GraphClassifier, self).__init__() super(GraphClassifier, self).__init__()
self.gxn = GraphCrossNet(in_dim=args.in_dim, self.gxn = GraphCrossNet(
out_dim=args.embed_dim, in_dim=args.in_dim,
edge_feat_dim=args.edge_feat_dim, out_dim=args.embed_dim,
hidden_dim=args.hidden_dim, edge_feat_dim=args.edge_feat_dim,
pool_ratios=args.pool_ratios, hidden_dim=args.hidden_dim,
readout_nodes=args.readout_nodes, pool_ratios=args.pool_ratios,
conv1d_dims=args.conv1d_dims, readout_nodes=args.readout_nodes,
conv1d_kws=args.conv1d_kws, conv1d_dims=args.conv1d_dims,
cross_weight=args.cross_weight, conv1d_kws=args.conv1d_kws,
fuse_weight=args.fuse_weight) cross_weight=args.cross_weight,
fuse_weight=args.fuse_weight,
)
self.lin1 = torch.nn.Linear(args.embed_dim, args.final_dense_hidden_dim) self.lin1 = torch.nn.Linear(args.embed_dim, args.final_dense_hidden_dim)
self.lin2 = torch.nn.Linear(args.final_dense_hidden_dim, args.out_dim) self.lin2 = torch.nn.Linear(args.final_dense_hidden_dim, args.out_dim)
self.dropout = args.dropout self.dropout = args.dropout
def forward(self, graph:DGLGraph, node_feat:Tensor, edge_feat:Optional[Tensor]=None): def forward(
self,
graph: DGLGraph,
node_feat: Tensor,
edge_feat: Optional[Tensor] = None,
):
embed, logits1, logits2 = self.gxn(graph, node_feat, edge_feat) embed, logits1, logits2 = self.gxn(graph, node_feat, edge_feat)
logits = F.relu(self.lin1(embed)) logits = F.relu(self.lin1(embed))
if self.dropout > 0: if self.dropout > 0:
......
...@@ -7,11 +7,10 @@ constructed another dataset from ACM with a different set of papers, connections ...@@ -7,11 +7,10 @@ constructed another dataset from ACM with a different set of papers, connections
labels. labels.
""" """
import dgl
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
from dgl.nn.pytorch import GATConv from dgl.nn.pytorch import GATConv
......
...@@ -6,19 +6,19 @@ so we sampled twice as many neighbors during val/test than training. ...@@ -6,19 +6,19 @@ so we sampled twice as many neighbors during val/test than training.
""" """
import argparse import argparse
import dgl
import numpy import numpy
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from dgl.nn.pytorch import GATConv
from dgl.sampling import RandomWalkNeighborSampler
from model_hetero import SemanticAttention from model_hetero import SemanticAttention
from sklearn.metrics import f1_score from sklearn.metrics import f1_score
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from utils import EarlyStopping, set_random_seed from utils import EarlyStopping, set_random_seed
import dgl
from dgl.nn.pytorch import GATConv
from dgl.sampling import RandomWalkNeighborSampler
class HANLayer(torch.nn.Module): class HANLayer(torch.nn.Module):
""" """
......
...@@ -5,13 +5,12 @@ import pickle ...@@ -5,13 +5,12 @@ import pickle
import random import random
from pprint import pprint from pprint import pprint
import dgl
import numpy as np import numpy as np
import torch import torch
from scipy import io as sio
from scipy import sparse
import dgl
from dgl.data.utils import _get_dgl_url, download, get_download_dir from dgl.data.utils import _get_dgl_url, download, get_download_dir
from scipy import io as sio, sparse
def set_random_seed(seed=0): def set_random_seed(seed=0):
......
...@@ -7,12 +7,12 @@ Paper: https://arxiv.org/abs/1907.04652 ...@@ -7,12 +7,12 @@ Paper: https://arxiv.org/abs/1907.04652
from functools import partial from functools import partial
import dgl
import dgl.function as fn
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
import dgl.function as fn
from dgl.base import DGLError from dgl.base import DGLError
from dgl.nn.pytorch import edge_softmax from dgl.nn.pytorch import edge_softmax
from dgl.nn.pytorch.utils import Identity from dgl.nn.pytorch.utils import Identity
......
...@@ -8,19 +8,19 @@ Paper: https://arxiv.org/abs/1907.04652 ...@@ -8,19 +8,19 @@ Paper: https://arxiv.org/abs/1907.04652
import argparse import argparse
import time import time
import dgl
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from hgao import HardGAT
from utils import EarlyStopping
import dgl
from dgl.data import ( from dgl.data import (
CiteseerGraphDataset, CiteseerGraphDataset,
CoraGraphDataset, CoraGraphDataset,
PubmedGraphDataset, PubmedGraphDataset,
register_data_args, register_data_args,
) )
from hgao import HardGAT
from utils import EarlyStopping
def accuracy(logits, labels): def accuracy(logits, labels):
...@@ -161,7 +161,6 @@ def main(args): ...@@ -161,7 +161,6 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="GAT") parser = argparse.ArgumentParser(description="GAT")
register_data_args(parser) register_data_args(parser)
parser.add_argument( parser.add_argument(
......
...@@ -7,15 +7,14 @@ for detailed description. ...@@ -7,15 +7,14 @@ for detailed description.
Here we implement a graph-edge version of sparsemax where we perform sparsemax for all edges Here we implement a graph-edge version of sparsemax where we perform sparsemax for all edges
with the same node as end-node in graphs. with the same node as end-node in graphs.
""" """
import torch
from torch import Tensor
from torch.autograd import Function
import dgl import dgl
import torch
from dgl.backend import astype from dgl.backend import astype
from dgl.base import ALL, is_all from dgl.base import ALL, is_all
from dgl.heterograph_index import HeteroGraphIndex from dgl.heterograph_index import HeteroGraphIndex
from dgl.sparse import _gsddmm, _gspmm from dgl.sparse import _gsddmm, _gspmm
from torch import Tensor
from torch.autograd import Function
def _neighbor_sort( def _neighbor_sort(
......
...@@ -7,10 +7,10 @@ import torch.nn.functional as F ...@@ -7,10 +7,10 @@ import torch.nn.functional as F
from dgl import DGLGraph from dgl import DGLGraph
from dgl.nn import AvgPooling, GraphConv, MaxPooling from dgl.nn import AvgPooling, GraphConv, MaxPooling
from dgl.ops import edge_softmax from dgl.ops import edge_softmax
from torch import Tensor
from torch.nn import Parameter
from functions import edge_sparsemax from functions import edge_sparsemax
from torch import Tensor
from torch.nn import Parameter
from utils import get_batch_id, topk from utils import get_batch_id, topk
...@@ -30,10 +30,11 @@ class WeightedGraphConv(GraphConv): ...@@ -30,10 +30,11 @@ class WeightedGraphConv(GraphConv):
e_feat : torch.Tensor, optional e_feat : torch.Tensor, optional
The edge features. Default: :obj:`None` The edge features. Default: :obj:`None`
""" """
def forward(self, graph:DGLGraph, n_feat, e_feat=None):
def forward(self, graph: DGLGraph, n_feat, e_feat=None):
if e_feat is None: if e_feat is None:
return super(WeightedGraphConv, self).forward(graph, n_feat) return super(WeightedGraphConv, self).forward(graph, n_feat)
with graph.local_scope(): with graph.local_scope():
if self.weight is not None: if self.weight is not None:
n_feat = torch.matmul(n_feat, self.weight) n_feat = torch.matmul(n_feat, self.weight)
...@@ -44,8 +45,7 @@ class WeightedGraphConv(GraphConv): ...@@ -44,8 +45,7 @@ class WeightedGraphConv(GraphConv):
n_feat = n_feat * src_norm n_feat = n_feat * src_norm
graph.ndata["h"] = n_feat graph.ndata["h"] = n_feat
graph.edata["e"] = e_feat graph.edata["e"] = e_feat
graph.update_all(fn.u_mul_e("h", "e", "m"), graph.update_all(fn.u_mul_e("h", "e", "m"), fn.sum("m", "h"))
fn.sum("m", "h"))
n_feat = graph.ndata.pop("h") n_feat = graph.ndata.pop("h")
n_feat = n_feat * dst_norm n_feat = n_feat * dst_norm
if self.bias is not None: if self.bias is not None:
...@@ -77,35 +77,40 @@ class NodeInfoScoreLayer(nn.Module): ...@@ -77,35 +77,40 @@ class NodeInfoScoreLayer(nn.Module):
The node features The node features
e_feat : torch.Tensor, optional e_feat : torch.Tensor, optional
The edge features. Default: :obj:`None` The edge features. Default: :obj:`None`
Returns Returns
------- -------
Tensor Tensor
Score for each node. Score for each node.
""" """
def __init__(self, sym_norm:bool=True):
def __init__(self, sym_norm: bool = True):
super(NodeInfoScoreLayer, self).__init__() super(NodeInfoScoreLayer, self).__init__()
self.sym_norm = sym_norm self.sym_norm = sym_norm
def forward(self, graph:dgl.DGLGraph, feat:Tensor, e_feat:Tensor): def forward(self, graph: dgl.DGLGraph, feat: Tensor, e_feat: Tensor):
with graph.local_scope(): with graph.local_scope():
if self.sym_norm: if self.sym_norm:
src_norm = torch.pow(graph.out_degrees().float().clamp(min=1), -0.5) src_norm = torch.pow(
graph.out_degrees().float().clamp(min=1), -0.5
)
src_norm = src_norm.view(-1, 1).to(feat.device) src_norm = src_norm.view(-1, 1).to(feat.device)
dst_norm = torch.pow(graph.in_degrees().float().clamp(min=1), -0.5) dst_norm = torch.pow(
graph.in_degrees().float().clamp(min=1), -0.5
)
dst_norm = dst_norm.view(-1, 1).to(feat.device) dst_norm = dst_norm.view(-1, 1).to(feat.device)
src_feat = feat * src_norm src_feat = feat * src_norm
graph.ndata["h"] = src_feat graph.ndata["h"] = src_feat
graph.edata["e"] = e_feat graph.edata["e"] = e_feat
graph = dgl.remove_self_loop(graph) graph = dgl.remove_self_loop(graph)
graph.update_all(fn.u_mul_e("h", "e", "m"), fn.sum("m", "h")) graph.update_all(fn.u_mul_e("h", "e", "m"), fn.sum("m", "h"))
dst_feat = graph.ndata.pop("h") * dst_norm dst_feat = graph.ndata.pop("h") * dst_norm
feat = feat - dst_feat feat = feat - dst_feat
else: else:
dst_norm = 1. / graph.in_degrees().float().clamp(min=1) dst_norm = 1.0 / graph.in_degrees().float().clamp(min=1)
dst_norm = dst_norm.view(-1, 1) dst_norm = dst_norm.view(-1, 1)
graph.ndata["h"] = feat graph.ndata["h"] = feat
...@@ -124,7 +129,7 @@ class HGPSLPool(nn.Module): ...@@ -124,7 +129,7 @@ class HGPSLPool(nn.Module):
Description Description
----------- -----------
The HGP-SL pooling layer from The HGP-SL pooling layer from
`Hierarchical Graph Pooling with Structure Learning <https://arxiv.org/pdf/1911.05954.pdf>` `Hierarchical Graph Pooling with Structure Learning <https://arxiv.org/pdf/1911.05954.pdf>`
Parameters Parameters
...@@ -134,7 +139,7 @@ class HGPSLPool(nn.Module): ...@@ -134,7 +139,7 @@ class HGPSLPool(nn.Module):
ratio : float, optional ratio : float, optional
Pooling ratio. Default: 0.8 Pooling ratio. Default: 0.8
sample : bool, optional sample : bool, optional
Whether use k-hop union graph to increase efficiency. Whether use k-hop union graph to increase efficiency.
Currently we only support full graph. Default: :obj:`False` Currently we only support full graph. Default: :obj:`False`
sym_score_norm : bool, optional sym_score_norm : bool, optional
Use symmetric norm for adjacency or not. Default: :obj:`True` Use symmetric norm for adjacency or not. Default: :obj:`True`
...@@ -147,7 +152,7 @@ class HGPSLPool(nn.Module): ...@@ -147,7 +152,7 @@ class HGPSLPool(nn.Module):
HGP-SL paper. Default: 1.0 HGP-SL paper. Default: 1.0
negative_slop : float, optional negative_slop : float, optional
Negative slop for leaky_relu. Default: 0.2 Negative slop for leaky_relu. Default: 0.2
Returns Returns
------- -------
DGLGraph DGLGraph
...@@ -159,9 +164,19 @@ class HGPSLPool(nn.Module): ...@@ -159,9 +164,19 @@ class HGPSLPool(nn.Module):
torch.Tensor torch.Tensor
Permutation index Permutation index
""" """
def __init__(self, in_feat:int, ratio=0.8, sample=True,
sym_score_norm=True, sparse=True, sl=True, def __init__(
lamb=1.0, negative_slop=0.2, k_hop=3): self,
in_feat: int,
ratio=0.8,
sample=True,
sym_score_norm=True,
sparse=True,
sl=True,
lamb=1.0,
negative_slop=0.2,
k_hop=3,
):
super(HGPSLPool, self).__init__() super(HGPSLPool, self).__init__()
self.in_feat = in_feat self.in_feat = in_feat
self.ratio = ratio self.ratio = ratio
...@@ -180,16 +195,17 @@ class HGPSLPool(nn.Module): ...@@ -180,16 +195,17 @@ class HGPSLPool(nn.Module):
def reset_parameters(self): def reset_parameters(self):
nn.init.xavier_normal_(self.att.data) nn.init.xavier_normal_(self.att.data)
def forward(self, graph:DGLGraph, feat:Tensor, e_feat=None): def forward(self, graph: DGLGraph, feat: Tensor, e_feat=None):
# top-k pool first # top-k pool first
if e_feat is None: if e_feat is None:
e_feat = torch.ones((graph.number_of_edges(),), e_feat = torch.ones(
dtype=feat.dtype, device=feat.device) (graph.number_of_edges(),), dtype=feat.dtype, device=feat.device
)
batch_num_nodes = graph.batch_num_nodes() batch_num_nodes = graph.batch_num_nodes()
x_score = self.calc_info_score(graph, feat, e_feat) x_score = self.calc_info_score(graph, feat, e_feat)
perm, next_batch_num_nodes = topk(x_score, self.ratio, perm, next_batch_num_nodes = topk(
get_batch_id(batch_num_nodes), x_score, self.ratio, get_batch_id(batch_num_nodes), batch_num_nodes
batch_num_nodes) )
feat = feat[perm] feat = feat[perm]
pool_graph = None pool_graph = None
if not self.sample or not self.sl: if not self.sample or not self.sl:
...@@ -210,36 +226,48 @@ class HGPSLPool(nn.Module): ...@@ -210,36 +226,48 @@ class HGPSLPool(nn.Module):
# pair of nodes is time consuming. To accelerate this process, # pair of nodes is time consuming. To accelerate this process,
# we sample it's K-Hop neighbors for each node and then learn the # we sample it's K-Hop neighbors for each node and then learn the
# edge weights between them. # edge weights between them.
# first build multi-hop graph # first build multi-hop graph
row, col = graph.all_edges() row, col = graph.all_edges()
num_nodes = graph.num_nodes() num_nodes = graph.num_nodes()
scipy_adj = scipy.sparse.coo_matrix((e_feat.detach().cpu(), (row.detach().cpu(), col.detach().cpu())), shape=(num_nodes, num_nodes)) scipy_adj = scipy.sparse.coo_matrix(
(
e_feat.detach().cpu(),
(row.detach().cpu(), col.detach().cpu()),
),
shape=(num_nodes, num_nodes),
)
for _ in range(self.k_hop): for _ in range(self.k_hop):
two_hop = scipy_adj ** 2 two_hop = scipy_adj**2
two_hop = two_hop * (1e-5 / two_hop.max()) two_hop = two_hop * (1e-5 / two_hop.max())
scipy_adj = two_hop + scipy_adj scipy_adj = two_hop + scipy_adj
row, col = scipy_adj.nonzero() row, col = scipy_adj.nonzero()
row = torch.tensor(row, dtype=torch.long, device=graph.device) row = torch.tensor(row, dtype=torch.long, device=graph.device)
col = torch.tensor(col, dtype=torch.long, device=graph.device) col = torch.tensor(col, dtype=torch.long, device=graph.device)
e_feat = torch.tensor(scipy_adj.data, dtype=torch.float, device=feat.device) e_feat = torch.tensor(
scipy_adj.data, dtype=torch.float, device=feat.device
)
# perform pooling on multi-hop graph # perform pooling on multi-hop graph
mask = perm.new_full((num_nodes, ), -1) mask = perm.new_full((num_nodes,), -1)
i = torch.arange(perm.size(0), dtype=torch.long, device=perm.device) i = torch.arange(perm.size(0), dtype=torch.long, device=perm.device)
mask[perm] = i mask[perm] = i
row, col = mask[row], mask[col] row, col = mask[row], mask[col]
mask = (row >=0 ) & (col >= 0) mask = (row >= 0) & (col >= 0)
row, col = row[mask], col[mask] row, col = row[mask], col[mask]
e_feat = e_feat[mask] e_feat = e_feat[mask]
# add remaining self loops # add remaining self loops
mask = row != col mask = row != col
num_nodes = perm.size(0) # num nodes after pool num_nodes = perm.size(0) # num nodes after pool
loop_index = torch.arange(0, num_nodes, dtype=row.dtype, device=row.device) loop_index = torch.arange(
0, num_nodes, dtype=row.dtype, device=row.device
)
inv_mask = ~mask inv_mask = ~mask
loop_weight = torch.full((num_nodes, ), 0, dtype=e_feat.dtype, device=e_feat.device) loop_weight = torch.full(
(num_nodes,), 0, dtype=e_feat.dtype, device=e_feat.device
)
remaining_e_feat = e_feat[inv_mask] remaining_e_feat = e_feat[inv_mask]
if remaining_e_feat.numel() > 0: if remaining_e_feat.numel() > 0:
loop_weight[row[inv_mask]] = remaining_e_feat loop_weight[row[inv_mask]] = remaining_e_feat
...@@ -249,16 +277,20 @@ class HGPSLPool(nn.Module): ...@@ -249,16 +277,20 @@ class HGPSLPool(nn.Module):
col = torch.cat([col, loop_index], dim=0) col = torch.cat([col, loop_index], dim=0)
# attention scores # attention scores
weights = (torch.cat([feat[row], feat[col]], dim=1) * self.att).sum(dim=-1) weights = (torch.cat([feat[row], feat[col]], dim=1) * self.att).sum(
weights = F.leaky_relu(weights, self.negative_slop) + e_feat * self.lamb dim=-1
)
weights = (
F.leaky_relu(weights, self.negative_slop) + e_feat * self.lamb
)
# sl and normalization # sl and normalization
sl_graph = dgl.graph((row, col)) sl_graph = dgl.graph((row, col))
if self.sparse: if self.sparse:
weights = edge_sparsemax(sl_graph, weights) weights = edge_sparsemax(sl_graph, weights)
else: else:
weights = edge_softmax(sl_graph, weights) weights = edge_softmax(sl_graph, weights)
# get final graph # get final graph
mask = torch.abs(weights) > 0 mask = torch.abs(weights) > 0
row, col, weights = row[mask], col[mask], weights[mask] row, col, weights = row[mask], col[mask], weights[mask]
...@@ -266,7 +298,7 @@ class HGPSLPool(nn.Module): ...@@ -266,7 +298,7 @@ class HGPSLPool(nn.Module):
pool_graph.set_batch_num_nodes(next_batch_num_nodes) pool_graph.set_batch_num_nodes(next_batch_num_nodes)
e_feat = weights e_feat = weights
else: else:
# Learning the possible edge weights between each pair of # Learning the possible edge weights between each pair of
# nodes in the pooled subgraph, relative slower. # nodes in the pooled subgraph, relative slower.
...@@ -274,19 +306,27 @@ class HGPSLPool(nn.Module): ...@@ -274,19 +306,27 @@ class HGPSLPool(nn.Module):
# use dense to build, then transform to sparse. # use dense to build, then transform to sparse.
# maybe there's more efficient way? # maybe there's more efficient way?
batch_num_nodes = next_batch_num_nodes batch_num_nodes = next_batch_num_nodes
block_begin_idx = torch.cat([batch_num_nodes.new_zeros(1), block_begin_idx = torch.cat(
batch_num_nodes.cumsum(dim=0)[:-1]], dim=0) [
batch_num_nodes.new_zeros(1),
batch_num_nodes.cumsum(dim=0)[:-1],
],
dim=0,
)
block_end_idx = batch_num_nodes.cumsum(dim=0) block_end_idx = batch_num_nodes.cumsum(dim=0)
dense_adj = torch.zeros((pool_graph.num_nodes(), dense_adj = torch.zeros(
pool_graph.num_nodes()), (pool_graph.num_nodes(), pool_graph.num_nodes()),
dtype=torch.float, dtype=torch.float,
device=feat.device) device=feat.device,
)
for idx_b, idx_e in zip(block_begin_idx, block_end_idx): for idx_b, idx_e in zip(block_begin_idx, block_end_idx):
dense_adj[idx_b:idx_e, idx_b:idx_e] = 1. dense_adj[idx_b:idx_e, idx_b:idx_e] = 1.0
row, col = torch.nonzero(dense_adj).t().contiguous() row, col = torch.nonzero(dense_adj).t().contiguous()
# compute weights for node-pairs # compute weights for node-pairs
weights = (torch.cat([feat[row], feat[col]], dim=1) * self.att).sum(dim=-1) weights = (torch.cat([feat[row], feat[col]], dim=1) * self.att).sum(
dim=-1
)
weights = F.leaky_relu(weights, self.negative_slop) weights = F.leaky_relu(weights, self.negative_slop)
dense_adj[row, col] = weights dense_adj[row, col] = weights
...@@ -316,15 +356,30 @@ class HGPSLPool(nn.Module): ...@@ -316,15 +356,30 @@ class HGPSLPool(nn.Module):
class ConvPoolReadout(torch.nn.Module): class ConvPoolReadout(torch.nn.Module):
"""A helper class. (GraphConv -> Pooling -> Readout)""" """A helper class. (GraphConv -> Pooling -> Readout)"""
def __init__(self, in_feat:int, out_feat:int, pool_ratio=0.8,
sample:bool=False, sparse:bool=True, sl:bool=True, def __init__(
lamb:float=1., pool:bool=True): self,
in_feat: int,
out_feat: int,
pool_ratio=0.8,
sample: bool = False,
sparse: bool = True,
sl: bool = True,
lamb: float = 1.0,
pool: bool = True,
):
super(ConvPoolReadout, self).__init__() super(ConvPoolReadout, self).__init__()
self.use_pool = pool self.use_pool = pool
self.conv = WeightedGraphConv(in_feat, out_feat) self.conv = WeightedGraphConv(in_feat, out_feat)
if pool: if pool:
self.pool = HGPSLPool(out_feat, ratio=pool_ratio, sparse=sparse, self.pool = HGPSLPool(
sample=sample, sl=sl, lamb=lamb) out_feat,
ratio=pool_ratio,
sparse=sparse,
sample=sample,
sl=sl,
lamb=lamb,
)
else: else:
self.pool = None self.pool = None
self.avgpool = AvgPooling() self.avgpool = AvgPooling()
...@@ -334,5 +389,7 @@ class ConvPoolReadout(torch.nn.Module): ...@@ -334,5 +389,7 @@ class ConvPoolReadout(torch.nn.Module):
out = F.relu(self.conv(graph, feature, e_feat)) out = F.relu(self.conv(graph, feature, e_feat))
if self.use_pool: if self.use_pool:
graph, out, e_feat, _ = self.pool(graph, out, e_feat) graph, out, e_feat, _ = self.pool(graph, out, e_feat)
readout = torch.cat([self.avgpool(graph, out), self.maxpool(graph, out)], dim=-1) readout = torch.cat(
[self.avgpool(graph, out), self.maxpool(graph, out)], dim=-1
)
return graph, out, e_feat, readout return graph, out, e_feat, readout
...@@ -4,17 +4,17 @@ import logging ...@@ -4,17 +4,17 @@ import logging
import os import os
from time import time from time import time
import dgl
import torch import torch
import torch.nn import torch.nn
import torch.nn.functional as F import torch.nn.functional as F
from dgl.data import LegacyTUDataset
from dgl.dataloading import GraphDataLoader
from networks import HGPSLModel from networks import HGPSLModel
from torch.utils.data import random_split from torch.utils.data import random_split
from utils import get_stats from utils import get_stats
import dgl
from dgl.data import LegacyTUDataset
from dgl.dataloading import GraphDataLoader
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="HGP-SL-DGL") parser = argparse.ArgumentParser(description="HGP-SL-DGL")
......
import torch import torch
import torch.nn import torch.nn
import torch.nn.functional as F import torch.nn.functional as F
from layers import ConvPoolReadout
from dgl.nn import AvgPooling, MaxPooling from dgl.nn import AvgPooling, MaxPooling
from layers import ConvPoolReadout
class HGPSLModel(torch.nn.Module): class HGPSLModel(torch.nn.Module):
......
import math import math
import dgl
import dgl.function as fn
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
import dgl.function as fn
from dgl.nn.functional import edge_softmax from dgl.nn.functional import edge_softmax
......
...@@ -20,6 +20,8 @@ from torch import nn ...@@ -20,6 +20,8 @@ from torch import nn
from tqdm import tqdm from tqdm import tqdm
"""=============================================================================================================""" """============================================================================================================="""
################### TensorBoard Settings ################### ################### TensorBoard Settings ###################
def args2exp_name(args): def args2exp_name(args):
exp_name = f"{args.dataset}_{args.loss}_{args.lr}_bs{args.bs}_spc{args.samples_per_class}_embed{args.embed_dim}_arch{args.arch}_decay{args.decay}_fclr{args.fc_lr_mul}_anneal{args.sigmoid_temperature}" exp_name = f"{args.dataset}_{args.loss}_{args.lr}_bs{args.bs}_spc{args.samples_per_class}_embed{args.embed_dim}_arch{args.arch}_decay{args.decay}_fclr{args.fc_lr_mul}_anneal{args.sigmoid_temperature}"
...@@ -381,6 +383,8 @@ def eval_metrics_query_and_gallery_dataset( ...@@ -381,6 +383,8 @@ def eval_metrics_query_and_gallery_dataset(
"""=============================================================================================================""" """============================================================================================================="""
####### RECOVER CLOSEST EXAMPLE IMAGES ####### ####### RECOVER CLOSEST EXAMPLE IMAGES #######
def recover_closest_one_dataset( def recover_closest_one_dataset(
feature_matrix_all, image_paths, save_path, n_image_samples=10, n_closest=3 feature_matrix_all, image_paths, save_path, n_image_samples=10, n_closest=3
...@@ -489,6 +493,8 @@ def recover_closest_inshop( ...@@ -489,6 +493,8 @@ def recover_closest_inshop(
"""=============================================================================================================""" """============================================================================================================="""
################## SET NETWORK TRAINING CHECKPOINT ##################### ################## SET NETWORK TRAINING CHECKPOINT #####################
def set_checkpoint(model, opt, progress_saver, savepath): def set_checkpoint(model, opt, progress_saver, savepath):
""" """
...@@ -514,6 +520,8 @@ def set_checkpoint(model, opt, progress_saver, savepath): ...@@ -514,6 +520,8 @@ def set_checkpoint(model, opt, progress_saver, savepath):
"""=============================================================================================================""" """============================================================================================================="""
################## WRITE TO CSV FILE ##################### ################## WRITE TO CSV FILE #####################
class CSV_Writer: class CSV_Writer:
""" """
......
...@@ -20,6 +20,8 @@ from torch.utils.data import Dataset ...@@ -20,6 +20,8 @@ from torch.utils.data import Dataset
from torchvision import transforms from torchvision import transforms
"""============================================================================""" """============================================================================"""
################ FUNCTION TO RETURN ALL DATALOADERS NECESSARY #################### ################ FUNCTION TO RETURN ALL DATALOADERS NECESSARY ####################
def give_dataloaders(dataset, trainset, testset, opt, cluster_path=""): def give_dataloaders(dataset, trainset, testset, opt, cluster_path=""):
""" """
......
...@@ -8,7 +8,6 @@ import netlib as netlib ...@@ -8,7 +8,6 @@ import netlib as netlib
import torch import torch
if __name__ == "__main__": if __name__ == "__main__":
################## INPUT ARGUMENTS ################### ################## INPUT ARGUMENTS ###################
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
####### Main Parameter: Dataset to use for Training ####### Main Parameter: Dataset to use for Training
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment