"megatron/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "ed90a1b21de6f0682e1cfb370434b5df4f6fbeaf"
Commit 2f8f1501 authored by rusty1s's avatar rusty1s
Browse files

ignore long edges in spline convolution

parent 2cde9023
...@@ -30,8 +30,9 @@ const long* kernel_size, const long* is_open_spline, int num_threads) { ...@@ -30,8 +30,9 @@ const long* kernel_size, const long* is_open_spline, int num_threads) {
k_idx_mod = k_idx % 2; k_idx_mod = k_idx % 2;
k_idx >>= 1; k_idx >>= 1;
value = input[e_idx * ${dim} + d_idx] * value = input[e_idx * ${dim} + d_idx];
(kernel_size[d_idx] - is_open_spline[d_idx]); if (value >= 1) { a = 0; i = 0; break; }
value *= kernel_size[d_idx] - is_open_spline[d_idx];
frac = value - floor(value); frac = value - floor(value);
...@@ -170,22 +171,22 @@ def compute_spline_basis(input, kernel_size, is_open_spline, K, basis_kernel): ...@@ -170,22 +171,22 @@ def compute_spline_basis(input, kernel_size, is_open_spline, K, basis_kernel):
input = input.unsqueeze(1) if len(input.size()) < 2 else input input = input.unsqueeze(1) if len(input.size()) < 2 else input
num_edges, dim = input.size() num_edges, dim = input.size()
k_max = 2 ** dim k_max = 2**dim
amount = input.new(num_edges, k_max) amount = input.new(num_edges, k_max)
index = input.new(num_edges, k_max).long() index = input.new(num_edges, k_max).long()
num_threads = amount.numel() num_threads = amount.numel()
with torch.cuda.device_of(input): with torch.cuda.device_of(input):
basis_kernel(block=(cuda_num_threads, 1, 1), basis_kernel(
block=(cuda_num_threads, 1, 1),
grid=(get_blocks(num_threads), 1, 1), grid=(get_blocks(num_threads), 1, 1),
args=[ args=[
input.data_ptr(), input.data_ptr(),
amount.data_ptr(), amount.data_ptr(),
index.data_ptr(), index.data_ptr(),
kernel_size.data_ptr(), kernel_size.data_ptr(),
is_open_spline.data_ptr(), is_open_spline.data_ptr(), num_threads
num_threads
], ],
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
......
import torch import torch
from torch.autograd import Function from torch.autograd import Function
from ....utils.cuda import (cuda_num_threads, Stream, Dtype, load_kernel, from ....utils.cuda import (cuda_num_threads, Stream, load_kernel, kernel_loop,
kernel_loop, get_blocks) get_blocks)
_edgewise_spline_weighting_forward_kernel = kernel_loop + ''' _edgewise_spline_weighting_forward_kernel = kernel_loop + '''
extern "C" extern "C"
...@@ -144,18 +144,17 @@ class EdgewiseSplineWeightingGPU(Function): ...@@ -144,18 +144,17 @@ class EdgewiseSplineWeightingGPU(Function):
num_threads = output.numel() num_threads = output.numel()
with torch.cuda.device_of(input): with torch.cuda.device_of(input):
self.f_fw(block=(cuda_num_threads, 1, 1), self.f_fw(
block=(cuda_num_threads, 1, 1),
grid=(get_blocks(num_threads), 1, 1), grid=(get_blocks(num_threads), 1, 1),
args=[ args=[
input.data_ptr(), input.data_ptr(),
weight.data_ptr(), weight.data_ptr(),
output.data_ptr(), output.data_ptr(),
self.amount.data_ptr(), self.amount.data_ptr(),
self.index.data_ptr(), self.index.data_ptr(), num_threads
num_threads
], ],
stream=Stream( stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
ptr=torch.cuda.current_stream().cuda_stream))
return output return output
...@@ -168,7 +167,8 @@ class EdgewiseSplineWeightingGPU(Function): ...@@ -168,7 +167,8 @@ class EdgewiseSplineWeightingGPU(Function):
num_threads = grad_output.numel() num_threads = grad_output.numel()
with torch.cuda.device_of(grad_output): with torch.cuda.device_of(grad_output):
self.f_bw(block=(cuda_num_threads, 1, 1), self.f_bw(
block=(cuda_num_threads, 1, 1),
grid=(get_blocks(num_threads), 1, 1), grid=(get_blocks(num_threads), 1, 1),
args=[ args=[
grad_output.data_ptr(), grad_output.data_ptr(),
...@@ -177,10 +177,8 @@ class EdgewiseSplineWeightingGPU(Function): ...@@ -177,10 +177,8 @@ class EdgewiseSplineWeightingGPU(Function):
input.data_ptr(), input.data_ptr(),
weight.data_ptr(), weight.data_ptr(),
self.amount.data_ptr(), self.amount.data_ptr(),
self.index.data_ptr(), self.index.data_ptr(), num_threads
num_threads
], ],
stream=Stream( stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
ptr=torch.cuda.current_stream().cuda_stream))
return grad_input, grad_weight return grad_input, grad_weight
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment