Commit 62815576 authored by rusty1s's avatar rusty1s
Browse files

moved extensions to torch.ops

parent 0a221ab8
import torch import torch
from torch_scatter import segment_cpu, gather_cpu
from torch_scatter.helpers import min_value, max_value from torch_scatter.helpers import min_value, max_value
if torch.cuda.is_available():
from torch_scatter import segment_cuda, gather_cuda
def seg(is_cuda):
return segment_cuda if is_cuda else segment_cpu
def gat(is_cuda):
return gather_cuda if is_cuda else gather_cpu
class SegmentCOO(torch.autograd.Function): class SegmentCOO(torch.autograd.Function):
@staticmethod @staticmethod
...@@ -37,7 +24,12 @@ class SegmentCOO(torch.autograd.Function): ...@@ -37,7 +24,12 @@ class SegmentCOO(torch.autograd.Function):
out = src.new_full(size, fill_value) out = src.new_full(size, fill_value)
out, arg_out = seg(src.is_cuda).segment_coo(src, index, out, reduce) if src.is_cuda:
out, arg_out = torch.ops.torch_scatter_cuda.segment_coo(
src, index, out, reduce)
else:
out, arg_out = torch.ops.torch_scatter_cpu.segment_coo(
src, index, out, reduce)
if fill_value != 0: if fill_value != 0:
out.masked_fill_(out == fill_value, 0) out.masked_fill_(out == fill_value, 0)
...@@ -56,25 +48,39 @@ class SegmentCOO(torch.autograd.Function): ...@@ -56,25 +48,39 @@ class SegmentCOO(torch.autograd.Function):
grad_src = None grad_src = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
if ctx.reduce == 'sum' or ctx.reduce == 'add': if ctx.reduce == 'sum' or ctx.reduce == 'add':
grad_src = gat(grad_out.is_cuda).gather_coo( if grad_out.is_cuda:
grad_out, index, grad_out.new_empty(src_size)) grad_src = torch.ops.torch_scatter_cuda.gather_coo(
grad_out, index, grad_out.new_empty(src_size))
else:
grad_src = torch.ops.torch_scatter_cpu.gather_coo(
grad_out, index, grad_out.new_empty(src_size))
elif ctx.reduce == 'mean': elif ctx.reduce == 'mean':
grad_src = gat(grad_out.is_cuda).gather_coo( if grad_out.is_cuda:
grad_out, index, grad_out.new_empty(src_size)) grad_src = torch.ops.torch_scatter_cuda.gather_coo(
grad_out, index, grad_out.new_empty(src_size))
else:
grad_src = torch.ops.torch_scatter_cpu.gather_coo(
grad_out, index, grad_out.new_empty(src_size))
count = arg_out # Gets pre-computed on GPU but not on CPU. count = arg_out # Gets pre-computed on GPU but not on CPU.
if count is None: if count is None:
size = list(index.size()) size = list(index.size())
size[-1] = grad_out.size(index.dim() - 1) size[-1] = grad_out.size(index.dim() - 1)
count = segment_cpu.segment_coo( count = torch.ops.torch_scatter_cpu.segment_coo(
torch.ones_like(index, dtype=grad_out.dtype), index, torch.ones_like(index, dtype=grad_out.dtype), index,
grad_out.new_zeros(size), 'sum')[0].clamp_(min=1) grad_out.new_zeros(size), 'sum')[0].clamp_(min=1)
count = gat(grad_out.is_cuda).gather_coo( if grad_out.is_cuda:
count, index, count.new_empty(src_size[:index.dim()])) count = torch.ops.torch_scatter_cuda.gather_coo(
count, index, count.new_empty(src_size[:index.dim()]))
else:
count = torch.ops.torch_scatter_cpu.gather_coo(
count, index, count.new_empty(src_size[:index.dim()]))
for _ in range(grad_out.dim() - index.dim()): for _ in range(grad_out.dim() - index.dim()):
count = count.unsqueeze(-1) count = count.unsqueeze(-1)
grad_src.div_(count) grad_src.div_(count)
elif ctx.reduce == 'min' or ctx.reduce == 'max': elif ctx.reduce == 'min' or ctx.reduce == 'max':
src_size[index.dim() - 1] += 1 src_size[index.dim() - 1] += 1
grad_src = grad_out.new_zeros(src_size).scatter_( grad_src = grad_out.new_zeros(src_size).scatter_(
...@@ -95,7 +101,13 @@ class SegmentCSR(torch.autograd.Function): ...@@ -95,7 +101,13 @@ class SegmentCSR(torch.autograd.Function):
ctx.reduce = reduce ctx.reduce = reduce
ctx.src_size = list(src.size()) ctx.src_size = list(src.size())
out, arg_out = seg(src.is_cuda).segment_csr(src, indptr, out, reduce) if src.is_cuda:
out, arg_out = torch.ops.torch_scatter_cuda.segment_csr(
src, indptr, out, reduce)
else:
out, arg_out = torch.ops.torch_scatter_cpu.segment_csr(
src, indptr, out, reduce)
ctx.save_for_backward(indptr, arg_out) ctx.save_for_backward(indptr, arg_out)
return out if arg_out is None else (out, arg_out) return out if arg_out is None else (out, arg_out)
...@@ -106,16 +118,31 @@ class SegmentCSR(torch.autograd.Function): ...@@ -106,16 +118,31 @@ class SegmentCSR(torch.autograd.Function):
grad_src = None grad_src = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
if ctx.reduce == 'sum' or ctx.reduce == 'add': if ctx.reduce == 'sum' or ctx.reduce == 'add':
grad_src = gat(grad_out.is_cuda).gather_csr( if grad_out.is_cuda:
grad_out, indptr, grad_out.new_empty(src_size)) grad_src = torch.ops.torch_scatter_cuda.gather_csr(
grad_out, indptr, grad_out.new_empty(src_size))
else:
grad_src = torch.ops.torch_scatter_cpu.gather_csr(
grad_out, indptr, grad_out.new_empty(src_size))
elif ctx.reduce == 'mean': elif ctx.reduce == 'mean':
grad_src = gat(grad_out.is_cuda).gather_csr( if grad_out.is_cuda:
grad_out, indptr, grad_out.new_empty(src_size)) grad_src = torch.ops.torch_scatter_cuda.gather_csr(
grad_out, indptr, grad_out.new_empty(src_size))
else:
grad_src = torch.ops.torch_scatter_cpu.gather_csr(
grad_out, indptr, grad_out.new_empty(src_size))
indptr1 = indptr.narrow(-1, 0, indptr.size(-1) - 1) indptr1 = indptr.narrow(-1, 0, indptr.size(-1) - 1)
indptr2 = indptr.narrow(-1, 1, indptr.size(-1) - 1) indptr2 = indptr.narrow(-1, 1, indptr.size(-1) - 1)
count = (indptr2 - indptr1).to(grad_src.dtype) count = (indptr2 - indptr1).to(grad_src.dtype)
count = gat(grad_out.is_cuda).gather_csr( if grad_out.is_cuda:
count, indptr, count.new_empty(src_size[:indptr.dim()])) count = torch.ops.torch_scatter_cuda.gather_csr(
count, indptr,
count.new_empty(src_size[:indptr.dim()]))
else:
count = torch.ops.torch_scatter_cpu.gather_csr(
count, indptr,
count.new_empty(src_size[:indptr.dim()]))
for _ in range(grad_out.dim() - indptr.dim()): for _ in range(grad_out.dim() - indptr.dim()):
count = count.unsqueeze(-1) count = count.unsqueeze(-1)
grad_src.div_(count) grad_src.div_(count)
......
import torch
import torch_scatter.scatter_cpu
if torch.cuda.is_available():
import torch_scatter.scatter_cuda
def get_func(name, tensor):
if tensor.is_cuda:
module = torch_scatter.scatter_cuda
else:
module = torch_scatter.scatter_cpu
return getattr(module, name)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment