"git@developer.sourcefind.cn:change/sglang.git" did not exist on "d774acad5cef7a538da33d39207f9e2bc51474eb"
Commit 07804abc authored by rusty1s's avatar rusty1s
Browse files

python impl

parent 5c2b664b
import torch
def node_degree(index, out=None):
one = torch.ones(index.size(1), out)
zero = torch.zeros(index.size(1), out)
return zero.scatter_add_(0, index[0], one)
import torch
# from torch.autograd import Variable as Var
from .degree import node_degree
from .utils import spline_bases, spline_weighting
def spline_conv(x,
index,
pseudo,
weight,
kernel_size,
is_open_spline,
root_weight=None,
degree=1,
bias=None):
x = x.unsqueeze(-1) if x.dim() == 1 else x
# Get features for every target node => |E| x M_in
output = x[index[1]]
# Get B-spline basis products and weight indices for each edge.
basis, weight_index = spline_bases(pseudo, kernel_size, is_open_spline,
degree)
# Weight gathered features based on B-spline basis and trainable weights.
output = spline_weighting(output, weight, basis, weight_index)
# Perform the real convolution => Convert |E| x M_out to N x M_out output.
row = index[0].unsqueeze(-1).expand(-1, output.size(1))
# zero = x if torch.is_tensor(x) else x.data
zero = x.new(row.size()).fill_(0)
# row, zero = row, zero if torch.is_tensor(x) else Var(row), Var(zero)
output = zero.scatter_add_(0, row, output)
# Normalize output by node degree.
output /= node_degree(index, out=x.new()).unsqueeze(-1).clamp_(min=1)
# Weight root node separately (if wished).
if root_weight is not None:
output += torch.mm(x, root_weight)
# Add bias (if wished).
if bias is not None:
output += bias
return output
import torch
from torch.autograd import Function
from .._ext import ffi
def get_func(name, tensor):
typename = type(tensor).__name__.replace('Tensor', '')
cuda = 'cuda_' if tensor.is_cuda else ''
func = getattr(ffi, 'spline_{}_{}{}'.format(name, cuda, typename))
return func
def spline_bases(pseudo, kernel_size, is_open_spline, degree):
# raise NotImplementedError for degree > 3
pass
def spline_weighting_forward(x, weight, basis, weight_index):
pass
def spline_weighting_backward(x, weight, basis, weight_index):
pass
class SplineWeighting(Function):
def __init__(self, basis, weight_index):
super(SplineWeighting, self).__init__()
self.basis = basis
self.weight_index = weight_index
def forward(self, x, weight):
pass
def backward(self, grad_output):
pass
def spline_weighting(x, weight, basis, weight_index):
if torch.is_tensor(x):
return spline_weighting_forward(x, weight, basis, weight_index)
else:
return SplineWeighting(basis, weight_index)(x, weight)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment