Unverified Commit 704bcaf6 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files
parent 6bc82161
...@@ -17,12 +17,22 @@ from torch.utils.checkpoint import checkpoint ...@@ -17,12 +17,22 @@ from torch.utils.checkpoint import checkpoint
class MWEConv(nn.Module): class MWEConv(nn.Module):
def __init__(self, in_feats, out_feats, activation, bias=True, num_channels=8, aggr_mode="sum"): def __init__(
self,
in_feats,
out_feats,
activation,
bias=True,
num_channels=8,
aggr_mode="sum",
):
super(MWEConv, self).__init__() super(MWEConv, self).__init__()
self.num_channels = num_channels self.num_channels = num_channels
self._in_feats = in_feats self._in_feats = in_feats
self._out_feats = out_feats self._out_feats = out_feats
self.weight = nn.Parameter(torch.Tensor(in_feats, out_feats, num_channels)) self.weight = nn.Parameter(
torch.Tensor(in_feats, out_feats, num_channels)
)
if bias: if bias:
self.bias = nn.Parameter(torch.Tensor(out_feats, num_channels)) self.bias = nn.Parameter(torch.Tensor(out_feats, num_channels))
...@@ -59,11 +69,14 @@ class MWEConv(nn.Module): ...@@ -59,11 +69,14 @@ class MWEConv(nn.Module):
for c in range(self.num_channels): for c in range(self.num_channels):
node_state_c = node_state node_state_c = node_state
if self._out_feats < self._in_feats: if self._out_feats < self._in_feats:
g.ndata["feat_" + str(c)] = torch.mm(node_state_c, self.weight[:, :, c]) g.ndata["feat_" + str(c)] = torch.mm(
node_state_c, self.weight[:, :, c]
)
else: else:
g.ndata["feat_" + str(c)] = node_state_c g.ndata["feat_" + str(c)] = node_state_c
g.update_all( g.update_all(
fn.u_mul_e("feat_" + str(c), "feat_" + str(c), "m"), fn.sum("m", "feat_" + str(c) + "_new") fn.u_mul_e("feat_" + str(c), "feat_" + str(c), "m"),
fn.sum("m", "feat_" + str(c) + "_new"),
) )
node_state_c = g.ndata.pop("feat_" + str(c) + "_new") node_state_c = g.ndata.pop("feat_" + str(c) + "_new")
if self._out_feats >= self._in_feats: if self._out_feats >= self._in_feats:
...@@ -83,15 +96,36 @@ class MWEConv(nn.Module): ...@@ -83,15 +96,36 @@ class MWEConv(nn.Module):
class MWE_GCN(nn.Module): class MWE_GCN(nn.Module):
def __init__(self, n_input, n_hidden, n_output, n_layers, activation, dropout, aggr_mode="sum", device="cpu"): def __init__(
self,
n_input,
n_hidden,
n_output,
n_layers,
activation,
dropout,
aggr_mode="sum",
device="cpu",
):
super(MWE_GCN, self).__init__() super(MWE_GCN, self).__init__()
self.dropout = dropout self.dropout = dropout
self.activation = activation self.activation = activation
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
self.layers.append(MWEConv(n_input, n_hidden, activation=activation, aggr_mode=aggr_mode)) self.layers.append(
MWEConv(
n_input, n_hidden, activation=activation, aggr_mode=aggr_mode
)
)
for i in range(n_layers - 1): for i in range(n_layers - 1):
self.layers.append(MWEConv(n_hidden, n_hidden, activation=activation, aggr_mode=aggr_mode)) self.layers.append(
MWEConv(
n_hidden,
n_hidden,
activation=activation,
aggr_mode=aggr_mode,
)
)
self.pred_out = nn.Linear(n_hidden, n_output) self.pred_out = nn.Linear(n_hidden, n_output)
self.device = device self.device = device
...@@ -100,7 +134,9 @@ class MWE_GCN(nn.Module): ...@@ -100,7 +134,9 @@ class MWE_GCN(nn.Module):
node_state = torch.ones(g.number_of_nodes(), 1).float().to(self.device) node_state = torch.ones(g.number_of_nodes(), 1).float().to(self.device)
for layer in self.layers: for layer in self.layers:
node_state = F.dropout(node_state, p=self.dropout, training=self.training) node_state = F.dropout(
node_state, p=self.dropout, training=self.training
)
node_state = layer(g, node_state) node_state = layer(g, node_state)
node_state = self.activation(node_state) node_state = self.activation(node_state)
...@@ -110,7 +146,16 @@ class MWE_GCN(nn.Module): ...@@ -110,7 +146,16 @@ class MWE_GCN(nn.Module):
class MWE_DGCN(nn.Module): class MWE_DGCN(nn.Module):
def __init__( def __init__(
self, n_input, n_hidden, n_output, n_layers, activation, dropout, residual=False, aggr_mode="sum", device="cpu" self,
n_input,
n_hidden,
n_output,
n_layers,
activation,
dropout,
residual=False,
aggr_mode="sum",
device="cpu",
): ):
super(MWE_DGCN, self).__init__() super(MWE_DGCN, self).__init__()
self.n_layers = n_layers self.n_layers = n_layers
...@@ -121,13 +166,26 @@ class MWE_DGCN(nn.Module): ...@@ -121,13 +166,26 @@ class MWE_DGCN(nn.Module):
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
self.layer_norms = nn.ModuleList() self.layer_norms = nn.ModuleList()
self.layers.append(MWEConv(n_input, n_hidden, activation=activation, aggr_mode=aggr_mode)) self.layers.append(
MWEConv(
n_input, n_hidden, activation=activation, aggr_mode=aggr_mode
)
)
for i in range(n_layers - 1): for i in range(n_layers - 1):
self.layers.append(MWEConv(n_hidden, n_hidden, activation=activation, aggr_mode=aggr_mode)) self.layers.append(
MWEConv(
n_hidden,
n_hidden,
activation=activation,
aggr_mode=aggr_mode,
)
)
for i in range(n_layers): for i in range(n_layers):
self.layer_norms.append(nn.LayerNorm(n_hidden, elementwise_affine=True)) self.layer_norms.append(
nn.LayerNorm(n_hidden, elementwise_affine=True)
)
self.pred_out = nn.Linear(n_hidden, n_output) self.pred_out = nn.Linear(n_hidden, n_output)
self.device = device self.device = device
...@@ -140,7 +198,9 @@ class MWE_DGCN(nn.Module): ...@@ -140,7 +198,9 @@ class MWE_DGCN(nn.Module):
for layer in range(1, self.n_layers): for layer in range(1, self.n_layers):
node_state_new = self.layer_norms[layer - 1](node_state) node_state_new = self.layer_norms[layer - 1](node_state)
node_state_new = self.activation(node_state_new) node_state_new = self.activation(node_state_new)
node_state_new = F.dropout(node_state_new, p=self.dropout, training=self.training) node_state_new = F.dropout(
node_state_new, p=self.dropout, training=self.training
)
if self.residual == "true": if self.residual == "true":
node_state = node_state + self.layers[layer](g, node_state_new) node_state = node_state + self.layers[layer](g, node_state_new)
...@@ -149,7 +209,9 @@ class MWE_DGCN(nn.Module): ...@@ -149,7 +209,9 @@ class MWE_DGCN(nn.Module):
node_state = self.layer_norms[self.n_layers - 1](node_state) node_state = self.layer_norms[self.n_layers - 1](node_state)
node_state = self.activation(node_state) node_state = self.activation(node_state)
node_state = F.dropout(node_state, p=self.dropout, training=self.training) node_state = F.dropout(
node_state, p=self.dropout, training=self.training
)
out = self.pred_out(node_state) out = self.pred_out(node_state)
...@@ -180,7 +242,9 @@ class GATConv(nn.Module): ...@@ -180,7 +242,9 @@ class GATConv(nn.Module):
self._use_symmetric_norm = use_symmetric_norm self._use_symmetric_norm = use_symmetric_norm
# feat fc # feat fc
self.src_fc = nn.Linear(self._in_src_feats, out_feats * n_heads, bias=False) self.src_fc = nn.Linear(
self._in_src_feats, out_feats * n_heads, bias=False
)
if residual: if residual:
self.dst_fc = nn.Linear(self._in_src_feats, out_feats * n_heads) self.dst_fc = nn.Linear(self._in_src_feats, out_feats * n_heads)
self.bias = None self.bias = None
...@@ -191,7 +255,9 @@ class GATConv(nn.Module): ...@@ -191,7 +255,9 @@ class GATConv(nn.Module):
# attn fc # attn fc
self.attn_src_fc = nn.Linear(self._in_src_feats, n_heads, bias=False) self.attn_src_fc = nn.Linear(self._in_src_feats, n_heads, bias=False)
if use_attn_dst: if use_attn_dst:
self.attn_dst_fc = nn.Linear(self._in_src_feats, n_heads, bias=False) self.attn_dst_fc = nn.Linear(
self._in_src_feats, n_heads, bias=False
)
else: else:
self.attn_dst_fc = None self.attn_dst_fc = None
if edge_feats > 0: if edge_feats > 0:
...@@ -243,8 +309,12 @@ class GATConv(nn.Module): ...@@ -243,8 +309,12 @@ class GATConv(nn.Module):
norm = torch.reshape(norm, shp) norm = torch.reshape(norm, shp)
feat_src = feat_src * norm feat_src = feat_src * norm
feat_src_fc = self.src_fc(feat_src).view(-1, self._n_heads, self._out_feats) feat_src_fc = self.src_fc(feat_src).view(
feat_dst_fc = self.dst_fc(feat_dst).view(-1, self._n_heads, self._out_feats) -1, self._n_heads, self._out_feats
)
feat_dst_fc = self.dst_fc(feat_dst).view(
-1, self._n_heads, self._out_feats
)
attn_src = self.attn_src_fc(feat_src).view(-1, self._n_heads, 1) attn_src = self.attn_src_fc(feat_src).view(-1, self._n_heads, 1)
# NOTE: GAT paper uses "first concatenation then linear projection" # NOTE: GAT paper uses "first concatenation then linear projection"
...@@ -257,18 +327,24 @@ class GATConv(nn.Module): ...@@ -257,18 +327,24 @@ class GATConv(nn.Module):
# save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus, # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,
# addition could be optimized with DGL's built-in function u_add_v, # addition could be optimized with DGL's built-in function u_add_v,
# which further speeds up computation and saves memory footprint. # which further speeds up computation and saves memory footprint.
graph.srcdata.update({"feat_src_fc": feat_src_fc, "attn_src": attn_src}) graph.srcdata.update(
{"feat_src_fc": feat_src_fc, "attn_src": attn_src}
)
if self.attn_dst_fc is not None: if self.attn_dst_fc is not None:
attn_dst = self.attn_dst_fc(feat_dst).view(-1, self._n_heads, 1) attn_dst = self.attn_dst_fc(feat_dst).view(-1, self._n_heads, 1)
graph.dstdata.update({"attn_dst": attn_dst}) graph.dstdata.update({"attn_dst": attn_dst})
graph.apply_edges(fn.u_add_v("attn_src", "attn_dst", "attn_node")) graph.apply_edges(
fn.u_add_v("attn_src", "attn_dst", "attn_node")
)
else: else:
graph.apply_edges(fn.copy_u("attn_src", "attn_node")) graph.apply_edges(fn.copy_u("attn_src", "attn_node"))
e = graph.edata["attn_node"] e = graph.edata["attn_node"]
if feat_edge is not None: if feat_edge is not None:
attn_edge = self.attn_edge_fc(feat_edge).view(-1, self._n_heads, 1) attn_edge = self.attn_edge_fc(feat_edge).view(
-1, self._n_heads, 1
)
graph.edata.update({"attn_edge": attn_edge}) graph.edata.update({"attn_edge": attn_edge})
e += graph.edata["attn_edge"] e += graph.edata["attn_edge"]
e = self.leaky_relu(e) e = self.leaky_relu(e)
...@@ -278,12 +354,16 @@ class GATConv(nn.Module): ...@@ -278,12 +354,16 @@ class GATConv(nn.Module):
bound = int(graph.number_of_edges() * self.edge_drop) bound = int(graph.number_of_edges() * self.edge_drop)
eids = perm[bound:] eids = perm[bound:]
graph.edata["a"] = torch.zeros_like(e) graph.edata["a"] = torch.zeros_like(e)
graph.edata["a"][eids] = self.attn_drop(edge_softmax(graph, e[eids], eids=eids)) graph.edata["a"][eids] = self.attn_drop(
edge_softmax(graph, e[eids], eids=eids)
)
else: else:
graph.edata["a"] = self.attn_drop(edge_softmax(graph, e)) graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))
# message passing # message passing
graph.update_all(fn.u_mul_e("feat_src_fc", "a", "m"), fn.sum("m", "feat_src_fc")) graph.update_all(
fn.u_mul_e("feat_src_fc", "a", "m"), fn.sum("m", "feat_src_fc")
)
rst = graph.dstdata["feat_src_fc"] rst = graph.dstdata["feat_src_fc"]
......
...@@ -5,9 +5,14 @@ import random ...@@ -5,9 +5,14 @@ import random
import sys import sys
import time import time
import dgl
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from dgl.dataloading import DataLoader, Sampler
from dgl.nn import GraphConv, SortPooling
from dgl.sampling import global_uniform_negative_sampling
from ogb.linkproppred import DglLinkPropPredDataset, Evaluator from ogb.linkproppred import DglLinkPropPredDataset, Evaluator
from scipy.sparse.csgraph import shortest_path from scipy.sparse.csgraph import shortest_path
from torch.nn import ( from torch.nn import (
...@@ -20,11 +25,6 @@ from torch.nn import ( ...@@ -20,11 +25,6 @@ from torch.nn import (
) )
from tqdm import tqdm from tqdm import tqdm
import dgl
from dgl.dataloading import DataLoader, Sampler
from dgl.nn import GraphConv, SortPooling
from dgl.sampling import global_uniform_negative_sampling
class Logger(object): class Logger(object):
def __init__(self, runs, info=None): def __init__(self, runs, info=None):
......
import dgl
import dgl.function as fn
import numpy as np import numpy as np
import torch import torch
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
import dgl
import dgl.function as fn
def get_ogb_evaluator(dataset): def get_ogb_evaluator(dataset):
""" """
......
import argparse import argparse
import time import time
import dgl
import dgl.function as fn
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from dataset import load_dataset from dataset import load_dataset
import dgl
import dgl.function as fn
class FeedForwardNet(nn.Module): class FeedForwardNet(nn.Module):
def __init__(self, in_feats, hidden, out_feats, n_layers, dropout): def __init__(self, in_feats, hidden, out_feats, n_layers, dropout):
......
import argparse import argparse
import os import os
import dgl
import dgl.function as fn
import numpy as np import numpy as np
import ogb import ogb
import torch import torch
import tqdm import tqdm
from ogb.lsc import MAG240MDataset from ogb.lsc import MAG240MDataset
import dgl
import dgl.function as fn
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--rootdir", "--rootdir",
......
...@@ -4,6 +4,10 @@ ...@@ -4,6 +4,10 @@
import argparse import argparse
import time import time
import dgl
import dgl.function as fn
import dgl.nn as dglnn
import numpy as np import numpy as np
import ogb import ogb
import torch import torch
...@@ -12,10 +16,6 @@ import torch.nn.functional as F ...@@ -12,10 +16,6 @@ import torch.nn.functional as F
import tqdm import tqdm
from ogb.lsc import MAG240MDataset, MAG240MEvaluator from ogb.lsc import MAG240MDataset, MAG240MEvaluator
import dgl
import dgl.function as fn
import dgl.nn as dglnn
class RGAT(nn.Module): class RGAT(nn.Module):
def __init__( def __init__(
......
...@@ -5,6 +5,9 @@ import math ...@@ -5,6 +5,9 @@ import math
import sys import sys
from collections import OrderedDict from collections import OrderedDict
import dgl
import dgl.nn as dglnn
import numpy as np import numpy as np
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
...@@ -14,9 +17,6 @@ import tqdm ...@@ -14,9 +17,6 @@ import tqdm
from ogb.lsc import MAG240MDataset, MAG240MEvaluator from ogb.lsc import MAG240MDataset, MAG240MEvaluator
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
import dgl
import dgl.nn as dglnn
class RGAT(nn.Module): class RGAT(nn.Module):
def __init__( def __init__(
......
import dgl
import dgl.function as fn
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
import dgl
import dgl.function as fn
from dgl.nn.pytorch import SumPooling from dgl.nn.pytorch import SumPooling
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
### GIN convolution along the graph structure ### GIN convolution along the graph structure
...@@ -128,7 +127,6 @@ class GNN_node(nn.Module): ...@@ -128,7 +127,6 @@ class GNN_node(nn.Module):
### computing input node embedding ### computing input node embedding
h_list = [self.atom_encoder(x)] h_list = [self.atom_encoder(x)]
for layer in range(self.num_layers): for layer in range(self.num_layers):
h = self.convs[layer](g, h_list[layer], edge_attr) h = self.convs[layer](g, h_list[layer], edge_attr)
h = self.batch_norms[layer](h) h = self.batch_norms[layer](h)
......
...@@ -2,6 +2,8 @@ import argparse ...@@ -2,6 +2,8 @@ import argparse
import os import os
import random import random
import dgl
import numpy as np import numpy as np
import torch import torch
import torch.optim as optim import torch.optim as optim
...@@ -12,8 +14,6 @@ from torch.utils.data import DataLoader ...@@ -12,8 +14,6 @@ from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm from tqdm import tqdm
import dgl
reg_criterion = torch.nn.L1Loss() reg_criterion = torch.nn.L1Loss()
......
...@@ -2,6 +2,8 @@ import argparse ...@@ -2,6 +2,8 @@ import argparse
import os import os
import random import random
import dgl
import numpy as np import numpy as np
import torch import torch
from gnn import GNN from gnn import GNN
...@@ -10,8 +12,6 @@ from ogb.utils import smiles2graph ...@@ -10,8 +12,6 @@ from ogb.utils import smiles2graph
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
import dgl
def collate_dgl(graphs): def collate_dgl(graphs):
batched_graph = dgl.batch(graphs) batched_graph = dgl.batch(graphs)
......
import networkx as nx
import torch
import dgl import dgl
import dgl.function as fn import dgl.function as fn
import networkx as nx
import torch
N = 100 N = 100
g = nx.erdos_renyi_graph(N, 0.05) g = nx.erdos_renyi_graph(N, 0.05)
...@@ -10,15 +10,19 @@ g = dgl.DGLGraph(g) ...@@ -10,15 +10,19 @@ g = dgl.DGLGraph(g)
DAMP = 0.85 DAMP = 0.85
K = 10 K = 10
def compute_pagerank(g): def compute_pagerank(g):
g.ndata['pv'] = torch.ones(N) / N g.ndata["pv"] = torch.ones(N) / N
degrees = g.out_degrees(g.nodes()).type(torch.float32) degrees = g.out_degrees(g.nodes()).type(torch.float32)
for k in range(K): for k in range(K):
g.ndata['pv'] = g.ndata['pv'] / degrees g.ndata["pv"] = g.ndata["pv"] / degrees
g.update_all(message_func=fn.copy_u(u='pv', out='m'), g.update_all(
reduce_func=fn.sum(msg='m', out='pv')) message_func=fn.copy_u(u="pv", out="m"),
g.ndata['pv'] = (1 - DAMP) / N + DAMP * g.ndata['pv'] reduce_func=fn.sum(msg="m", out="pv"),
return g.ndata['pv'] )
g.ndata["pv"] = (1 - DAMP) / N + DAMP * g.ndata["pv"]
return g.ndata["pv"]
pv = compute_pagerank(g) pv = compute_pagerank(g)
print(pv) print(pv)
"""Graph builder from pandas dataframes""" """Graph builder from pandas dataframes"""
from collections import namedtuple from collections import namedtuple
import dgl
from pandas.api.types import ( from pandas.api.types import (
is_categorical, is_categorical,
is_categorical_dtype, is_categorical_dtype,
is_numeric_dtype, is_numeric_dtype,
) )
import dgl
__all__ = ["PandasGraphBuilder"] __all__ = ["PandasGraphBuilder"]
......
import dask.dataframe as dd import dask.dataframe as dd
import dgl
import numpy as np import numpy as np
import scipy.sparse as ssp import scipy.sparse as ssp
import torch import torch
import tqdm import tqdm
import dgl
# This is the train-test split method most of the recommender system papers running on MovieLens # This is the train-test split method most of the recommender system papers running on MovieLens
# takes. It essentially follows the intuition of "training on the past and predict the future". # takes. It essentially follows the intuition of "training on the past and predict the future".
......
import argparse import argparse
import pickle import pickle
import dgl
import numpy as np import numpy as np
import torch import torch
import dgl
def prec(recommendations, ground_truth): def prec(recommendations, ground_truth):
n_users, n_items = ground_truth.shape n_users, n_items = ground_truth.shape
......
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl import dgl
import dgl.function as fn import dgl.function as fn
import dgl.nn.pytorch as dglnn import dgl.nn.pytorch as dglnn
import torch
import torch.nn as nn
import torch.nn.functional as F
def disable_grad(module): def disable_grad(module):
......
...@@ -2,6 +2,8 @@ import argparse ...@@ -2,6 +2,8 @@ import argparse
import os import os
import pickle import pickle
import dgl
import evaluation import evaluation
import layers import layers
import numpy as np import numpy as np
...@@ -14,8 +16,6 @@ from torch.utils.data import DataLoader ...@@ -14,8 +16,6 @@ from torch.utils.data import DataLoader
from torchtext.data.utils import get_tokenizer from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator from torchtext.vocab import build_vocab_from_iterator
import dgl
class PinSAGEModel(nn.Module): class PinSAGEModel(nn.Module):
def __init__(self, full_graph, ntype, textsets, hidden_dims, n_layers): def __init__(self, full_graph, ntype, textsets, hidden_dims, n_layers):
......
...@@ -2,6 +2,8 @@ import argparse ...@@ -2,6 +2,8 @@ import argparse
import os import os
import pickle import pickle
import dgl
import evaluation import evaluation
import layers import layers
import numpy as np import numpy as np
...@@ -14,8 +16,6 @@ from torch.utils.data import DataLoader ...@@ -14,8 +16,6 @@ from torch.utils.data import DataLoader
from torchtext.data.utils import get_tokenizer from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator from torchtext.vocab import build_vocab_from_iterator
import dgl
class PinSAGEModel(nn.Module): class PinSAGEModel(nn.Module):
def __init__(self, full_graph, ntype, textsets, hidden_dims, n_layers): def __init__(self, full_graph, ntype, textsets, hidden_dims, n_layers):
......
import dgl
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import DataLoader, IterableDataset from torch.utils.data import DataLoader, IterableDataset
from torchtext.data.functional import numericalize_tokens_from_iterator from torchtext.data.functional import numericalize_tokens_from_iterator
import dgl
def padding(array, yy, val): def padding(array, yy, val):
""" """
......
import numpy as np
import warnings
import os import os
import warnings
import numpy as np
from torch.utils.data import Dataset from torch.utils.data import Dataset
warnings.filterwarnings('ignore')
warnings.filterwarnings("ignore")
def pc_normalize(pc): def pc_normalize(pc):
centroid = np.mean(pc, axis=0) centroid = np.mean(pc, axis=0)
...@@ -11,6 +14,7 @@ def pc_normalize(pc): ...@@ -11,6 +14,7 @@ def pc_normalize(pc):
pc = pc / m pc = pc / m
return pc return pc
def farthest_point_sample(point, npoint): def farthest_point_sample(point, npoint):
""" """
Farthest point sampler works as follows: Farthest point sampler works as follows:
...@@ -25,7 +29,7 @@ def farthest_point_sample(point, npoint): ...@@ -25,7 +29,7 @@ def farthest_point_sample(point, npoint):
centroids: sampled pointcloud index, [npoint, D] centroids: sampled pointcloud index, [npoint, D]
""" """
N, D = point.shape N, D = point.shape
xyz = point[:,:3] xyz = point[:, :3]
centroids = np.zeros((npoint,)) centroids = np.zeros((npoint,))
distance = np.ones((N,)) * 1e10 distance = np.ones((N,)) * 1e10
farthest = np.random.randint(0, N) farthest = np.random.randint(0, N)
...@@ -41,11 +45,18 @@ def farthest_point_sample(point, npoint): ...@@ -41,11 +45,18 @@ def farthest_point_sample(point, npoint):
class ModelNetDataLoader(Dataset): class ModelNetDataLoader(Dataset):
def __init__(self, root, npoint=1024, split='train', fps=False, def __init__(
normal_channel=True, cache_size=15000): self,
root,
npoint=1024,
split="train",
fps=False,
normal_channel=True,
cache_size=15000,
):
""" """
Input: Input:
root: the root path to the local data files root: the root path to the local data files
npoint: number of points from each cloud npoint: number of points from each cloud
split: which split of the data, 'train' or 'test' split: which split of the data, 'train' or 'test'
fps: whether to sample points with farthest point sampler fps: whether to sample points with farthest point sampler
...@@ -55,22 +66,34 @@ class ModelNetDataLoader(Dataset): ...@@ -55,22 +66,34 @@ class ModelNetDataLoader(Dataset):
self.root = root self.root = root
self.npoints = npoint self.npoints = npoint
self.fps = fps self.fps = fps
self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt') self.catfile = os.path.join(self.root, "modelnet40_shape_names.txt")
self.cat = [line.rstrip() for line in open(self.catfile)] self.cat = [line.rstrip() for line in open(self.catfile)]
self.classes = dict(zip(self.cat, range(len(self.cat)))) self.classes = dict(zip(self.cat, range(len(self.cat))))
self.normal_channel = normal_channel self.normal_channel = normal_channel
shape_ids = {} shape_ids = {}
shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))] shape_ids["train"] = [
shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))] line.rstrip()
for line in open(os.path.join(self.root, "modelnet40_train.txt"))
]
shape_ids["test"] = [
line.rstrip()
for line in open(os.path.join(self.root, "modelnet40_test.txt"))
]
assert (split == 'train' or split == 'test') assert split == "train" or split == "test"
shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]] shape_names = ["_".join(x.split("_")[0:-1]) for x in shape_ids[split]]
# list of (shape_name, shape_txt_file_path) tuple # list of (shape_name, shape_txt_file_path) tuple
self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i self.datapath = [
in range(len(shape_ids[split]))] (
print('The size of %s data is %d'%(split,len(self.datapath))) shape_names[i],
os.path.join(self.root, shape_names[i], shape_ids[split][i])
+ ".txt",
)
for i in range(len(shape_ids[split]))
]
print("The size of %s data is %d" % (split, len(self.datapath)))
self.cache_size = cache_size self.cache_size = cache_size
self.cache = {} self.cache = {}
...@@ -85,11 +108,11 @@ class ModelNetDataLoader(Dataset): ...@@ -85,11 +108,11 @@ class ModelNetDataLoader(Dataset):
fn = self.datapath[index] fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]] cls = self.classes[self.datapath[index][0]]
cls = np.array([cls]).astype(np.int32) cls = np.array([cls]).astype(np.int32)
point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32) point_set = np.loadtxt(fn[1], delimiter=",").astype(np.float32)
if self.fps: if self.fps:
point_set = farthest_point_sample(point_set, self.npoints) point_set = farthest_point_sample(point_set, self.npoints)
else: else:
point_set = point_set[0:self.npoints,:] point_set = point_set[0 : self.npoints, :]
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3]) point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
......
import dgl
import torch import torch
import torch.nn.functional as F
import torch.nn as nn import torch.nn as nn
from torch.nn.modules.utils import _single import torch.nn.functional as F
from torch.autograd import Function from torch.autograd import Function
from torch.nn import Parameter from torch.nn import Parameter
import dgl from torch.nn.modules.utils import _single
class BinaryQuantize(Function): class BinaryQuantize(Function):
...@@ -34,16 +34,27 @@ class BiLinearLSR(torch.nn.Linear): ...@@ -34,16 +34,27 @@ class BiLinearLSR(torch.nn.Linear):
# hence, init scale with Parameter # hence, init scale with Parameter
# however, Parameter(None) actually has size [0], not [] as a scalar # however, Parameter(None) actually has size [0], not [] as a scalar
# hence, init it using the following trick # hence, init it using the following trick
self.register_parameter('scale', Parameter(torch.Tensor([0.0]).squeeze())) self.register_parameter(
"scale", Parameter(torch.Tensor([0.0]).squeeze())
)
def reset_scale(self, input): def reset_scale(self, input):
bw = self.weight bw = self.weight
ba = input ba = input
bw = bw - bw.mean() bw = bw - bw.mean()
self.scale = Parameter((F.linear(ba, bw).std() / F.linear(torch.sign(ba), torch.sign(bw)).std()).float().to(ba.device)) self.scale = Parameter(
(
F.linear(ba, bw).std()
/ F.linear(torch.sign(ba), torch.sign(bw)).std()
)
.float()
.to(ba.device)
)
# corner case when ba is all 0.0 # corner case when ba is all 0.0
if torch.isnan(self.scale): if torch.isnan(self.scale):
self.scale = Parameter((bw.std() / torch.sign(bw).std()).float().to(ba.device)) self.scale = Parameter(
(bw.std() / torch.sign(bw).std()).float().to(ba.device)
)
def forward(self, input): def forward(self, input):
bw = self.weight bw = self.weight
...@@ -79,87 +90,134 @@ class BiLinear(torch.nn.Linear): ...@@ -79,87 +90,134 @@ class BiLinear(torch.nn.Linear):
class BiConv2d(torch.nn.Conv2d): class BiConv2d(torch.nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, def __init__(
padding=0, dilation=1, groups=1, self,
bias=True, padding_mode='zeros'): in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
):
super(BiConv2d, self).__init__( super(BiConv2d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation, in_channels,
groups, bias, padding_mode) out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
padding_mode,
)
def forward(self, input): def forward(self, input):
bw = self.weight bw = self.weight
ba = input ba = input
bw = bw - bw.mean() bw = bw - bw.mean()
bw = BinaryQuantize().apply(bw) bw = BinaryQuantize().apply(bw)
ba = BinaryQuantize().apply(ba) ba = BinaryQuantize().apply(ba)
if self.padding_mode == 'circular': if self.padding_mode == "circular":
expanded_padding = ((self.padding[0] + 1) // 2, self.padding[0] // 2) expanded_padding = (
return F.conv2d(F.pad(ba, expanded_padding, mode='circular'), (self.padding[0] + 1) // 2,
bw, self.bias, self.stride, self.padding[0] // 2,
_single(0), self.dilation, self.groups) )
return F.conv2d(ba, bw, self.bias, self.stride, return F.conv2d(
self.padding, self.dilation, self.groups) F.pad(ba, expanded_padding, mode="circular"),
bw,
self.bias,
self.stride,
_single(0),
self.dilation,
self.groups,
)
return F.conv2d(
ba,
bw,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
)
def square_distance(src, dst): def square_distance(src, dst):
''' """
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
''' """
B, N, _ = src.shape B, N, _ = src.shape
_, M, _ = dst.shape _, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1) dist += torch.sum(src**2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M) dist += torch.sum(dst**2, -1).view(B, 1, M)
return dist return dist
def index_points(points, idx): def index_points(points, idx):
''' """
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
''' """
device = points.device device = points.device
B = points.shape[0] B = points.shape[0]
view_shape = list(idx.shape) view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1) view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape) repeat_shape = list(idx.shape)
repeat_shape[0] = 1 repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) batch_indices = (
torch.arange(B, dtype=torch.long)
.to(device)
.view(view_shape)
.repeat(repeat_shape)
)
new_points = points[batch_indices, idx, :] new_points = points[batch_indices, idx, :]
return new_points return new_points
class FixedRadiusNearNeighbors(nn.Module): class FixedRadiusNearNeighbors(nn.Module):
''' """
Ball Query - Find the neighbors with-in a fixed radius Ball Query - Find the neighbors with-in a fixed radius
''' """
def __init__(self, radius, n_neighbor): def __init__(self, radius, n_neighbor):
super(FixedRadiusNearNeighbors, self).__init__() super(FixedRadiusNearNeighbors, self).__init__()
self.radius = radius self.radius = radius
self.n_neighbor = n_neighbor self.n_neighbor = n_neighbor
def forward(self, pos, centroids): def forward(self, pos, centroids):
''' """
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
''' """
device = pos.device device = pos.device
B, N, _ = pos.shape B, N, _ = pos.shape
center_pos = index_points(pos, centroids) center_pos = index_points(pos, centroids)
_, S, _ = center_pos.shape _, S, _ = center_pos.shape
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) group_idx = (
torch.arange(N, dtype=torch.long)
.to(device)
.view(1, 1, N)
.repeat([B, S, 1])
)
sqrdists = square_distance(center_pos, pos) sqrdists = square_distance(center_pos, pos)
group_idx[sqrdists > self.radius ** 2] = N group_idx[sqrdists > self.radius**2] = N
group_idx = group_idx.sort(dim=-1)[0][:, :, :self.n_neighbor] group_idx = group_idx.sort(dim=-1)[0][:, :, : self.n_neighbor]
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, self.n_neighbor]) group_first = (
group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, self.n_neighbor])
)
mask = group_idx == N mask = group_idx == N
group_idx[mask] = group_first[mask] group_idx[mask] = group_first[mask]
return group_idx return group_idx
class FixedRadiusNNGraph(nn.Module): class FixedRadiusNNGraph(nn.Module):
''' """
Build NN graph Build NN graph
''' """
def __init__(self, radius, n_neighbor): def __init__(self, radius, n_neighbor):
super(FixedRadiusNNGraph, self).__init__() super(FixedRadiusNNGraph, self).__init__()
self.radius = radius self.radius = radius
...@@ -179,31 +237,32 @@ class FixedRadiusNNGraph(nn.Module): ...@@ -179,31 +237,32 @@ class FixedRadiusNNGraph(nn.Module):
unified = torch.cat([src, dst]) unified = torch.cat([src, dst])
uniq, inv_idx = torch.unique(unified, return_inverse=True) uniq, inv_idx = torch.unique(unified, return_inverse=True)
src_idx = inv_idx[:src.shape[0]] src_idx = inv_idx[: src.shape[0]]
dst_idx = inv_idx[src.shape[0]:] dst_idx = inv_idx[src.shape[0] :]
g = dgl.graph((src_idx, dst_idx)) g = dgl.graph((src_idx, dst_idx))
g.ndata['pos'] = pos[i][uniq] g.ndata["pos"] = pos[i][uniq]
g.ndata['center'] = center[uniq] g.ndata["center"] = center[uniq]
if feat is not None: if feat is not None:
g.ndata['feat'] = feat[i][uniq] g.ndata["feat"] = feat[i][uniq]
glist.append(g) glist.append(g)
bg = dgl.batch(glist) bg = dgl.batch(glist)
return bg return bg
class RelativePositionMessage(nn.Module): class RelativePositionMessage(nn.Module):
''' """
Compute the input feature from neighbors Compute the input feature from neighbors
''' """
def __init__(self, n_neighbor): def __init__(self, n_neighbor):
super(RelativePositionMessage, self).__init__() super(RelativePositionMessage, self).__init__()
self.n_neighbor = n_neighbor self.n_neighbor = n_neighbor
def forward(self, edges): def forward(self, edges):
pos = edges.src['pos'] - edges.dst['pos'] pos = edges.src["pos"] - edges.dst["pos"]
if 'feat' in edges.src: if "feat" in edges.src:
res = torch.cat([pos, edges.src['feat']], 1) res = torch.cat([pos, edges.src["feat"]], 1)
else: else:
res = pos res = pos
return {'agg_feat': res} return {"agg_feat": res}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment