Commit 92d409f8 authored by rusty1s's avatar rusty1s
Browse files

fill value

parent bed12976
...@@ -13,16 +13,34 @@ devices = [torch.device('cuda')] ...@@ -13,16 +13,34 @@ devices = [torch.device('cuda')]
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available') @pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_forward(dtype, device): def test_forward(dtype, device):
src = tensor([1, 2, 3, 4], dtype, device)
src = tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype, device) src = tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype, device)
src = tensor([1, 2, 3, 4], dtype, device)
src.requires_grad_()
indptr = tensor([0, 2, 5, 5, 6], torch.long, device) indptr = tensor([0, 2, 5, 5, 6], torch.long, device)
out = gather_csr(src, indptr)
print('CSR', out)
index = tensor([0, 0, 1, 1, 1, 3], torch.long, device) index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
out = gather_coo(src, index)
print('COO', out)
out = src.index_select(0, index) out = src.index_select(0, index)
print('Expected', out) grad_out = torch.randn_like(out)
out.backward(grad_out)
print('EXPECTED')
print(out)
print(src.grad)
src.grad = None
out = gather_csr(src, indptr)
out.backward(grad_out)
print('CSR')
print(out)
print(src.grad)
# print('CSR', out)
# out = gather_coo(src, index)
# print('COO', out)
# print('Expected', out)
src.grad = None
out = gather_coo(src, index)
out.backward(grad_out)
print('COO')
print(out)
print(src.grad)
...@@ -14,17 +14,14 @@ devices = [torch.device('cuda')] ...@@ -14,17 +14,14 @@ devices = [torch.device('cuda')]
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available') @pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_forward(dtype, device): def test_forward(dtype, device):
# src = tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], dtype, src = tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], dtype,
# device) device)
src = tensor([1, 2, 3, 4, 5, 6], dtype, device) src = tensor([1, 2, 3, 4, 5, 6], dtype, device)
src.requires_grad_() src.requires_grad_()
indptr = tensor([0, 2, 5, 5, 6], torch.long, device) indptr = tensor([0, 2, 5, 5, 6], torch.long, device)
index = tensor([0, 0, 1, 1, 1, 3], torch.long, device) index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
out = tensor([0, 0, 0, 0], dtype, device)
out.scatter_(0, index, src)
out = scatter_min(src, index, dim=0)[0] out = scatter_min(src, index, dim=0)[0]
grad_out = torch.randn_like(out) grad_out = torch.randn_like(out)
print(grad_out) print(grad_out)
......
import torch import torch
if torch.cuda.is_available(): if torch.cuda.is_available():
from torch_scatter import gather_cuda from torch_scatter import gather_cuda, segment_cuda
class GatherCOO(torch.autograd.Function):
@staticmethod
def forward(ctx, src, index, out):
if out is not None:
ctx.mark_dirty(out)
ctx.src_size = list(src.size())
ctx.save_for_backward(index)
return gather_cuda.gather_coo(src, index, out)
@staticmethod
def backward(ctx, grad_out):
(index, ), src_size = ctx.saved_tensors, ctx.src_size
grad_src = None
if ctx.needs_input_grad[0]:
grad_src, _ = segment_cuda.segment_coo(
grad_out, index, grad_out.new_zeros(src_size), 'add')
return grad_src, None, None
class GatherCSR(torch.autograd.Function):
@staticmethod
def forward(ctx, src, indptr, out):
if out is not None:
ctx.mark_dirty(out)
ctx.src_size = list(src.size())
ctx.save_for_backward(indptr)
return gather_cuda.gather_csr(src, indptr, out)
@staticmethod
def backward(ctx, grad_out):
(indptr, ), src_size = ctx.saved_tensors, ctx.src_size
grad_src = None
if ctx.needs_input_grad[0]:
grad_src, _ = segment_cuda.segment_csr(
grad_out, indptr, grad_out.new_empty(src_size), 'add')
return grad_src, None, None
def gather_coo(src, index, out=None): def gather_coo(src, index, out=None):
return gather_cuda.gather_coo(src, index, out) return GatherCOO.apply(src, index, out)
def gather_csr(src, indptr, out=None): def gather_csr(src, indptr, out=None):
return gather_cuda.gather_csr(src, indptr, out) return GatherCSR.apply(src, indptr, out)
import torch import torch
from torch_scatter.utils import min_value, max_value
if torch.cuda.is_available(): if torch.cuda.is_available():
from torch_scatter import segment_cuda, gather_cuda from torch_scatter import segment_cuda, gather_cuda
...@@ -48,12 +50,24 @@ class SegmentCSR(torch.autograd.Function): ...@@ -48,12 +50,24 @@ class SegmentCSR(torch.autograd.Function):
def segment_coo(src, index, out=None, dim_size=None, reduce='add'): def segment_coo(src, index, out=None, dim_size=None, reduce='add'):
assert reduce in ['any', 'add', 'mean', 'min', 'max'] assert reduce in ['any', 'add', 'mean', 'min', 'max']
fill_value = 0
if out is None: if out is None:
dim_size = index.max().item() + 1 if dim_size is None else dim_size dim_size = index.max().item() + 1 if dim_size is None else dim_size
size = list(src.size()) size = list(src.size())
size[index.dim() - 1] = dim_size size[index.dim() - 1] = dim_size
out = src.new_zeros(size) # TODO: DEPENDS ON REDUCE
if reduce == 'min':
fill_value = max_value(src.dtype)
elif reduce == 'max':
fill_value = min_value(src.dtype)
out = src.new_full(size, fill_value)
out, arg_out = segment_cuda.segment_coo(src, index, out, reduce) out, arg_out = segment_cuda.segment_coo(src, index, out, reduce)
if fill_value != 0:
out.masked_fill_(out == fill_value, 0)
return out if arg_out is None else (out, arg_out) return out if arg_out is None else (out, arg_out)
......
import torch
def min_value(dtype):
try:
return torch.finfo(dtype).min
except AttributeError:
return torch.info(dtype).min
def max_value(dtype):
try:
return torch.finfo(dtype).max
except AttributeError:
return torch.info(dtype).max
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment