Unverified Commit 704bcaf6 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files
parent 6bc82161
......@@ -17,12 +17,22 @@ from torch.utils.checkpoint import checkpoint
class MWEConv(nn.Module):
def __init__(self, in_feats, out_feats, activation, bias=True, num_channels=8, aggr_mode="sum"):
def __init__(
self,
in_feats,
out_feats,
activation,
bias=True,
num_channels=8,
aggr_mode="sum",
):
super(MWEConv, self).__init__()
self.num_channels = num_channels
self._in_feats = in_feats
self._out_feats = out_feats
self.weight = nn.Parameter(torch.Tensor(in_feats, out_feats, num_channels))
self.weight = nn.Parameter(
torch.Tensor(in_feats, out_feats, num_channels)
)
if bias:
self.bias = nn.Parameter(torch.Tensor(out_feats, num_channels))
......@@ -59,11 +69,14 @@ class MWEConv(nn.Module):
for c in range(self.num_channels):
node_state_c = node_state
if self._out_feats < self._in_feats:
g.ndata["feat_" + str(c)] = torch.mm(node_state_c, self.weight[:, :, c])
g.ndata["feat_" + str(c)] = torch.mm(
node_state_c, self.weight[:, :, c]
)
else:
g.ndata["feat_" + str(c)] = node_state_c
g.update_all(
fn.u_mul_e("feat_" + str(c), "feat_" + str(c), "m"), fn.sum("m", "feat_" + str(c) + "_new")
fn.u_mul_e("feat_" + str(c), "feat_" + str(c), "m"),
fn.sum("m", "feat_" + str(c) + "_new"),
)
node_state_c = g.ndata.pop("feat_" + str(c) + "_new")
if self._out_feats >= self._in_feats:
......@@ -83,15 +96,36 @@ class MWEConv(nn.Module):
class MWE_GCN(nn.Module):
def __init__(self, n_input, n_hidden, n_output, n_layers, activation, dropout, aggr_mode="sum", device="cpu"):
def __init__(
self,
n_input,
n_hidden,
n_output,
n_layers,
activation,
dropout,
aggr_mode="sum",
device="cpu",
):
super(MWE_GCN, self).__init__()
self.dropout = dropout
self.activation = activation
self.layers = nn.ModuleList()
self.layers.append(MWEConv(n_input, n_hidden, activation=activation, aggr_mode=aggr_mode))
self.layers.append(
MWEConv(
n_input, n_hidden, activation=activation, aggr_mode=aggr_mode
)
)
for i in range(n_layers - 1):
self.layers.append(MWEConv(n_hidden, n_hidden, activation=activation, aggr_mode=aggr_mode))
self.layers.append(
MWEConv(
n_hidden,
n_hidden,
activation=activation,
aggr_mode=aggr_mode,
)
)
self.pred_out = nn.Linear(n_hidden, n_output)
self.device = device
......@@ -100,7 +134,9 @@ class MWE_GCN(nn.Module):
node_state = torch.ones(g.number_of_nodes(), 1).float().to(self.device)
for layer in self.layers:
node_state = F.dropout(node_state, p=self.dropout, training=self.training)
node_state = F.dropout(
node_state, p=self.dropout, training=self.training
)
node_state = layer(g, node_state)
node_state = self.activation(node_state)
......@@ -110,7 +146,16 @@ class MWE_GCN(nn.Module):
class MWE_DGCN(nn.Module):
def __init__(
self, n_input, n_hidden, n_output, n_layers, activation, dropout, residual=False, aggr_mode="sum", device="cpu"
self,
n_input,
n_hidden,
n_output,
n_layers,
activation,
dropout,
residual=False,
aggr_mode="sum",
device="cpu",
):
super(MWE_DGCN, self).__init__()
self.n_layers = n_layers
......@@ -121,13 +166,26 @@ class MWE_DGCN(nn.Module):
self.layers = nn.ModuleList()
self.layer_norms = nn.ModuleList()
self.layers.append(MWEConv(n_input, n_hidden, activation=activation, aggr_mode=aggr_mode))
self.layers.append(
MWEConv(
n_input, n_hidden, activation=activation, aggr_mode=aggr_mode
)
)
for i in range(n_layers - 1):
self.layers.append(MWEConv(n_hidden, n_hidden, activation=activation, aggr_mode=aggr_mode))
self.layers.append(
MWEConv(
n_hidden,
n_hidden,
activation=activation,
aggr_mode=aggr_mode,
)
)
for i in range(n_layers):
self.layer_norms.append(nn.LayerNorm(n_hidden, elementwise_affine=True))
self.layer_norms.append(
nn.LayerNorm(n_hidden, elementwise_affine=True)
)
self.pred_out = nn.Linear(n_hidden, n_output)
self.device = device
......@@ -140,7 +198,9 @@ class MWE_DGCN(nn.Module):
for layer in range(1, self.n_layers):
node_state_new = self.layer_norms[layer - 1](node_state)
node_state_new = self.activation(node_state_new)
node_state_new = F.dropout(node_state_new, p=self.dropout, training=self.training)
node_state_new = F.dropout(
node_state_new, p=self.dropout, training=self.training
)
if self.residual == "true":
node_state = node_state + self.layers[layer](g, node_state_new)
......@@ -149,7 +209,9 @@ class MWE_DGCN(nn.Module):
node_state = self.layer_norms[self.n_layers - 1](node_state)
node_state = self.activation(node_state)
node_state = F.dropout(node_state, p=self.dropout, training=self.training)
node_state = F.dropout(
node_state, p=self.dropout, training=self.training
)
out = self.pred_out(node_state)
......@@ -180,7 +242,9 @@ class GATConv(nn.Module):
self._use_symmetric_norm = use_symmetric_norm
# feat fc
self.src_fc = nn.Linear(self._in_src_feats, out_feats * n_heads, bias=False)
self.src_fc = nn.Linear(
self._in_src_feats, out_feats * n_heads, bias=False
)
if residual:
self.dst_fc = nn.Linear(self._in_src_feats, out_feats * n_heads)
self.bias = None
......@@ -191,7 +255,9 @@ class GATConv(nn.Module):
# attn fc
self.attn_src_fc = nn.Linear(self._in_src_feats, n_heads, bias=False)
if use_attn_dst:
self.attn_dst_fc = nn.Linear(self._in_src_feats, n_heads, bias=False)
self.attn_dst_fc = nn.Linear(
self._in_src_feats, n_heads, bias=False
)
else:
self.attn_dst_fc = None
if edge_feats > 0:
......@@ -243,8 +309,12 @@ class GATConv(nn.Module):
norm = torch.reshape(norm, shp)
feat_src = feat_src * norm
feat_src_fc = self.src_fc(feat_src).view(-1, self._n_heads, self._out_feats)
feat_dst_fc = self.dst_fc(feat_dst).view(-1, self._n_heads, self._out_feats)
feat_src_fc = self.src_fc(feat_src).view(
-1, self._n_heads, self._out_feats
)
feat_dst_fc = self.dst_fc(feat_dst).view(
-1, self._n_heads, self._out_feats
)
attn_src = self.attn_src_fc(feat_src).view(-1, self._n_heads, 1)
# NOTE: GAT paper uses "first concatenation then linear projection"
......@@ -257,18 +327,24 @@ class GATConv(nn.Module):
# save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,
# addition could be optimized with DGL's built-in function u_add_v,
# which further speeds up computation and saves memory footprint.
graph.srcdata.update({"feat_src_fc": feat_src_fc, "attn_src": attn_src})
graph.srcdata.update(
{"feat_src_fc": feat_src_fc, "attn_src": attn_src}
)
if self.attn_dst_fc is not None:
attn_dst = self.attn_dst_fc(feat_dst).view(-1, self._n_heads, 1)
graph.dstdata.update({"attn_dst": attn_dst})
graph.apply_edges(fn.u_add_v("attn_src", "attn_dst", "attn_node"))
graph.apply_edges(
fn.u_add_v("attn_src", "attn_dst", "attn_node")
)
else:
graph.apply_edges(fn.copy_u("attn_src", "attn_node"))
e = graph.edata["attn_node"]
if feat_edge is not None:
attn_edge = self.attn_edge_fc(feat_edge).view(-1, self._n_heads, 1)
attn_edge = self.attn_edge_fc(feat_edge).view(
-1, self._n_heads, 1
)
graph.edata.update({"attn_edge": attn_edge})
e += graph.edata["attn_edge"]
e = self.leaky_relu(e)
......@@ -278,12 +354,16 @@ class GATConv(nn.Module):
bound = int(graph.number_of_edges() * self.edge_drop)
eids = perm[bound:]
graph.edata["a"] = torch.zeros_like(e)
graph.edata["a"][eids] = self.attn_drop(edge_softmax(graph, e[eids], eids=eids))
graph.edata["a"][eids] = self.attn_drop(
edge_softmax(graph, e[eids], eids=eids)
)
else:
graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))
# message passing
graph.update_all(fn.u_mul_e("feat_src_fc", "a", "m"), fn.sum("m", "feat_src_fc"))
graph.update_all(
fn.u_mul_e("feat_src_fc", "a", "m"), fn.sum("m", "feat_src_fc")
)
rst = graph.dstdata["feat_src_fc"]
......
......@@ -5,9 +5,14 @@ import random
import sys
import time
import dgl
import numpy as np
import torch
import torch.nn.functional as F
from dgl.dataloading import DataLoader, Sampler
from dgl.nn import GraphConv, SortPooling
from dgl.sampling import global_uniform_negative_sampling
from ogb.linkproppred import DglLinkPropPredDataset, Evaluator
from scipy.sparse.csgraph import shortest_path
from torch.nn import (
......@@ -20,11 +25,6 @@ from torch.nn import (
)
from tqdm import tqdm
import dgl
from dgl.dataloading import DataLoader, Sampler
from dgl.nn import GraphConv, SortPooling
from dgl.sampling import global_uniform_negative_sampling
class Logger(object):
def __init__(self, runs, info=None):
......
import dgl
import dgl.function as fn
import numpy as np
import torch
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
import dgl
import dgl.function as fn
def get_ogb_evaluator(dataset):
"""
......
import argparse
import time
import dgl
import dgl.function as fn
import numpy as np
import torch
import torch.nn as nn
from dataset import load_dataset
import dgl
import dgl.function as fn
class FeedForwardNet(nn.Module):
def __init__(self, in_feats, hidden, out_feats, n_layers, dropout):
......
import argparse
import os
import dgl
import dgl.function as fn
import numpy as np
import ogb
import torch
import tqdm
from ogb.lsc import MAG240MDataset
import dgl
import dgl.function as fn
parser = argparse.ArgumentParser()
parser.add_argument(
"--rootdir",
......
......@@ -4,6 +4,10 @@
import argparse
import time
import dgl
import dgl.function as fn
import dgl.nn as dglnn
import numpy as np
import ogb
import torch
......@@ -12,10 +16,6 @@ import torch.nn.functional as F
import tqdm
from ogb.lsc import MAG240MDataset, MAG240MEvaluator
import dgl
import dgl.function as fn
import dgl.nn as dglnn
class RGAT(nn.Module):
def __init__(
......
......@@ -5,6 +5,9 @@ import math
import sys
from collections import OrderedDict
import dgl
import dgl.nn as dglnn
import numpy as np
import torch
import torch.multiprocessing as mp
......@@ -14,9 +17,6 @@ import tqdm
from ogb.lsc import MAG240MDataset, MAG240MEvaluator
from torch.nn.parallel import DistributedDataParallel
import dgl
import dgl.nn as dglnn
class RGAT(nn.Module):
def __init__(
......
import dgl
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
import dgl
import dgl.function as fn
from dgl.nn.pytorch import SumPooling
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
### GIN convolution along the graph structure
......@@ -128,7 +127,6 @@ class GNN_node(nn.Module):
### computing input node embedding
h_list = [self.atom_encoder(x)]
for layer in range(self.num_layers):
h = self.convs[layer](g, h_list[layer], edge_attr)
h = self.batch_norms[layer](h)
......
......@@ -2,6 +2,8 @@ import argparse
import os
import random
import dgl
import numpy as np
import torch
import torch.optim as optim
......@@ -12,8 +14,6 @@ from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import dgl
reg_criterion = torch.nn.L1Loss()
......
......@@ -2,6 +2,8 @@ import argparse
import os
import random
import dgl
import numpy as np
import torch
from gnn import GNN
......@@ -10,8 +12,6 @@ from ogb.utils import smiles2graph
from torch.utils.data import DataLoader
from tqdm import tqdm
import dgl
def collate_dgl(graphs):
batched_graph = dgl.batch(graphs)
......
import networkx as nx
import torch
import dgl
import dgl.function as fn
import networkx as nx
import torch
N = 100
g = nx.erdos_renyi_graph(N, 0.05)
......@@ -10,15 +10,19 @@ g = dgl.DGLGraph(g)
DAMP = 0.85
K = 10
def compute_pagerank(g):
g.ndata['pv'] = torch.ones(N) / N
g.ndata["pv"] = torch.ones(N) / N
degrees = g.out_degrees(g.nodes()).type(torch.float32)
for k in range(K):
g.ndata['pv'] = g.ndata['pv'] / degrees
g.update_all(message_func=fn.copy_u(u='pv', out='m'),
reduce_func=fn.sum(msg='m', out='pv'))
g.ndata['pv'] = (1 - DAMP) / N + DAMP * g.ndata['pv']
return g.ndata['pv']
g.ndata["pv"] = g.ndata["pv"] / degrees
g.update_all(
message_func=fn.copy_u(u="pv", out="m"),
reduce_func=fn.sum(msg="m", out="pv"),
)
g.ndata["pv"] = (1 - DAMP) / N + DAMP * g.ndata["pv"]
return g.ndata["pv"]
pv = compute_pagerank(g)
print(pv)
"""Graph builder from pandas dataframes"""
from collections import namedtuple
import dgl
from pandas.api.types import (
is_categorical,
is_categorical_dtype,
is_numeric_dtype,
)
import dgl
__all__ = ["PandasGraphBuilder"]
......
import dask.dataframe as dd
import dgl
import numpy as np
import scipy.sparse as ssp
import torch
import tqdm
import dgl
# This is the train-test split method most of the recommender system papers running on MovieLens
# takes. It essentially follows the intuition of "training on the past and predict the future".
......
import argparse
import pickle
import dgl
import numpy as np
import torch
import dgl
def prec(recommendations, ground_truth):
n_users, n_items = ground_truth.shape
......
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import torch
import torch.nn as nn
import torch.nn.functional as F
def disable_grad(module):
......
......@@ -2,6 +2,8 @@ import argparse
import os
import pickle
import dgl
import evaluation
import layers
import numpy as np
......@@ -14,8 +16,6 @@ from torch.utils.data import DataLoader
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import dgl
class PinSAGEModel(nn.Module):
def __init__(self, full_graph, ntype, textsets, hidden_dims, n_layers):
......
......@@ -2,6 +2,8 @@ import argparse
import os
import pickle
import dgl
import evaluation
import layers
import numpy as np
......@@ -14,8 +16,6 @@ from torch.utils.data import DataLoader
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import dgl
class PinSAGEModel(nn.Module):
def __init__(self, full_graph, ntype, textsets, hidden_dims, n_layers):
......
import dgl
import numpy as np
import torch
from torch.utils.data import DataLoader, IterableDataset
from torchtext.data.functional import numericalize_tokens_from_iterator
import dgl
def padding(array, yy, val):
"""
......
import numpy as np
import warnings
import os
import warnings
import numpy as np
from torch.utils.data import Dataset
warnings.filterwarnings('ignore')
warnings.filterwarnings("ignore")
def pc_normalize(pc):
centroid = np.mean(pc, axis=0)
......@@ -11,6 +14,7 @@ def pc_normalize(pc):
pc = pc / m
return pc
def farthest_point_sample(point, npoint):
"""
Farthest point sampler works as follows:
......@@ -25,7 +29,7 @@ def farthest_point_sample(point, npoint):
centroids: sampled pointcloud index, [npoint, D]
"""
N, D = point.shape
xyz = point[:,:3]
xyz = point[:, :3]
centroids = np.zeros((npoint,))
distance = np.ones((N,)) * 1e10
farthest = np.random.randint(0, N)
......@@ -41,11 +45,18 @@ def farthest_point_sample(point, npoint):
class ModelNetDataLoader(Dataset):
def __init__(self, root, npoint=1024, split='train', fps=False,
normal_channel=True, cache_size=15000):
def __init__(
self,
root,
npoint=1024,
split="train",
fps=False,
normal_channel=True,
cache_size=15000,
):
"""
Input:
root: the root path to the local data files
root: the root path to the local data files
npoint: number of points from each cloud
split: which split of the data, 'train' or 'test'
fps: whether to sample points with farthest point sampler
......@@ -55,22 +66,34 @@ class ModelNetDataLoader(Dataset):
self.root = root
self.npoints = npoint
self.fps = fps
self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt')
self.catfile = os.path.join(self.root, "modelnet40_shape_names.txt")
self.cat = [line.rstrip() for line in open(self.catfile)]
self.classes = dict(zip(self.cat, range(len(self.cat))))
self.normal_channel = normal_channel
shape_ids = {}
shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))]
shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))]
shape_ids["train"] = [
line.rstrip()
for line in open(os.path.join(self.root, "modelnet40_train.txt"))
]
shape_ids["test"] = [
line.rstrip()
for line in open(os.path.join(self.root, "modelnet40_test.txt"))
]
assert (split == 'train' or split == 'test')
shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]]
assert split == "train" or split == "test"
shape_names = ["_".join(x.split("_")[0:-1]) for x in shape_ids[split]]
# list of (shape_name, shape_txt_file_path) tuple
self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i
in range(len(shape_ids[split]))]
print('The size of %s data is %d'%(split,len(self.datapath)))
self.datapath = [
(
shape_names[i],
os.path.join(self.root, shape_names[i], shape_ids[split][i])
+ ".txt",
)
for i in range(len(shape_ids[split]))
]
print("The size of %s data is %d" % (split, len(self.datapath)))
self.cache_size = cache_size
self.cache = {}
......@@ -85,11 +108,11 @@ class ModelNetDataLoader(Dataset):
fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]]
cls = np.array([cls]).astype(np.int32)
point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)
point_set = np.loadtxt(fn[1], delimiter=",").astype(np.float32)
if self.fps:
point_set = farthest_point_sample(point_set, self.npoints)
else:
point_set = point_set[0:self.npoints,:]
point_set = point_set[0 : self.npoints, :]
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
......
import dgl
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn.modules.utils import _single
import torch.nn.functional as F
from torch.autograd import Function
from torch.nn import Parameter
import dgl
from torch.nn.modules.utils import _single
class BinaryQuantize(Function):
......@@ -34,16 +34,27 @@ class BiLinearLSR(torch.nn.Linear):
# hence, init scale with Parameter
# however, Parameter(None) actually has size [0], not [] as a scalar
# hence, init it using the following trick
self.register_parameter('scale', Parameter(torch.Tensor([0.0]).squeeze()))
self.register_parameter(
"scale", Parameter(torch.Tensor([0.0]).squeeze())
)
def reset_scale(self, input):
bw = self.weight
ba = input
bw = bw - bw.mean()
self.scale = Parameter((F.linear(ba, bw).std() / F.linear(torch.sign(ba), torch.sign(bw)).std()).float().to(ba.device))
self.scale = Parameter(
(
F.linear(ba, bw).std()
/ F.linear(torch.sign(ba), torch.sign(bw)).std()
)
.float()
.to(ba.device)
)
# corner case when ba is all 0.0
if torch.isnan(self.scale):
self.scale = Parameter((bw.std() / torch.sign(bw).std()).float().to(ba.device))
self.scale = Parameter(
(bw.std() / torch.sign(bw).std()).float().to(ba.device)
)
def forward(self, input):
bw = self.weight
......@@ -79,87 +90,134 @@ class BiLinear(torch.nn.Linear):
class BiConv2d(torch.nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1,
bias=True, padding_mode='zeros'):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
):
super(BiConv2d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, bias, padding_mode)
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
padding_mode,
)
def forward(self, input):
bw = self.weight
ba = input
bw = bw - bw.mean()
bw = BinaryQuantize().apply(bw)
ba = BinaryQuantize().apply(ba)
if self.padding_mode == 'circular':
expanded_padding = ((self.padding[0] + 1) // 2, self.padding[0] // 2)
return F.conv2d(F.pad(ba, expanded_padding, mode='circular'),
bw, self.bias, self.stride,
_single(0), self.dilation, self.groups)
return F.conv2d(ba, bw, self.bias, self.stride,
self.padding, self.dilation, self.groups)
if self.padding_mode == "circular":
expanded_padding = (
(self.padding[0] + 1) // 2,
self.padding[0] // 2,
)
return F.conv2d(
F.pad(ba, expanded_padding, mode="circular"),
bw,
self.bias,
self.stride,
_single(0),
self.dilation,
self.groups,
)
return F.conv2d(
ba,
bw,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
)
def square_distance(src, dst):
'''
"""
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
'''
"""
B, N, _ = src.shape
_, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
dist += torch.sum(src**2, -1).view(B, N, 1)
dist += torch.sum(dst**2, -1).view(B, 1, M)
return dist
def index_points(points, idx):
'''
"""
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
'''
"""
device = points.device
B = points.shape[0]
view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape)
repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
batch_indices = (
torch.arange(B, dtype=torch.long)
.to(device)
.view(view_shape)
.repeat(repeat_shape)
)
new_points = points[batch_indices, idx, :]
return new_points
class FixedRadiusNearNeighbors(nn.Module):
'''
"""
Ball Query - Find the neighbors with-in a fixed radius
'''
"""
def __init__(self, radius, n_neighbor):
super(FixedRadiusNearNeighbors, self).__init__()
self.radius = radius
self.n_neighbor = n_neighbor
def forward(self, pos, centroids):
'''
"""
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
'''
"""
device = pos.device
B, N, _ = pos.shape
center_pos = index_points(pos, centroids)
_, S, _ = center_pos.shape
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
group_idx = (
torch.arange(N, dtype=torch.long)
.to(device)
.view(1, 1, N)
.repeat([B, S, 1])
)
sqrdists = square_distance(center_pos, pos)
group_idx[sqrdists > self.radius ** 2] = N
group_idx = group_idx.sort(dim=-1)[0][:, :, :self.n_neighbor]
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, self.n_neighbor])
group_idx[sqrdists > self.radius**2] = N
group_idx = group_idx.sort(dim=-1)[0][:, :, : self.n_neighbor]
group_first = (
group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, self.n_neighbor])
)
mask = group_idx == N
group_idx[mask] = group_first[mask]
return group_idx
class FixedRadiusNNGraph(nn.Module):
'''
"""
Build NN graph
'''
"""
def __init__(self, radius, n_neighbor):
super(FixedRadiusNNGraph, self).__init__()
self.radius = radius
......@@ -179,31 +237,32 @@ class FixedRadiusNNGraph(nn.Module):
unified = torch.cat([src, dst])
uniq, inv_idx = torch.unique(unified, return_inverse=True)
src_idx = inv_idx[:src.shape[0]]
dst_idx = inv_idx[src.shape[0]:]
src_idx = inv_idx[: src.shape[0]]
dst_idx = inv_idx[src.shape[0] :]
g = dgl.graph((src_idx, dst_idx))
g.ndata['pos'] = pos[i][uniq]
g.ndata['center'] = center[uniq]
g.ndata["pos"] = pos[i][uniq]
g.ndata["center"] = center[uniq]
if feat is not None:
g.ndata['feat'] = feat[i][uniq]
g.ndata["feat"] = feat[i][uniq]
glist.append(g)
bg = dgl.batch(glist)
return bg
class RelativePositionMessage(nn.Module):
'''
"""
Compute the input feature from neighbors
'''
"""
def __init__(self, n_neighbor):
super(RelativePositionMessage, self).__init__()
self.n_neighbor = n_neighbor
def forward(self, edges):
pos = edges.src['pos'] - edges.dst['pos']
if 'feat' in edges.src:
res = torch.cat([pos, edges.src['feat']], 1)
pos = edges.src["pos"] - edges.dst["pos"]
if "feat" in edges.src:
res = torch.cat([pos, edges.src["feat"]], 1)
else:
res = pos
return {'agg_feat': res}
return {"agg_feat": res}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment