Commit ff285368 authored by rusty1s's avatar rusty1s
Browse files

remove @torch.jit.script annotations (move jit compatibility to the test suite)

parent 3341dbeb
...@@ -18,3 +18,6 @@ def test_logsumexp(): ...@@ -18,3 +18,6 @@ def test_logsumexp():
assert out.tolist() == torch.logsumexp(src, dim=0).tolist() assert out.tolist() == torch.logsumexp(src, dim=0).tolist()
outputs.backward(torch.randn_like(outputs)) outputs.backward(torch.randn_like(outputs))
jit = torch.jit.script(scatter_logsumexp)
assert jit(inputs, index).tolist() == outputs.tolist()
...@@ -22,6 +22,9 @@ def test_softmax(): ...@@ -22,6 +22,9 @@ def test_softmax():
out.backward(torch.randn_like(out)) out.backward(torch.randn_like(out))
jit = torch.jit.script(scatter_softmax)
assert jit(src, index).tolist() == out.tolist()
def test_log_softmax(): def test_log_softmax():
src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')]) src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')])
...@@ -42,3 +45,6 @@ def test_log_softmax(): ...@@ -42,3 +45,6 @@ def test_log_softmax():
assert torch.allclose(out, expected) assert torch.allclose(out, expected)
out.backward(torch.randn_like(out)) out.backward(torch.randn_like(out))
jit = torch.jit.script(scatter_log_softmax)
assert jit(src, index).tolist() == out.tolist()
...@@ -13,3 +13,6 @@ def test_std(): ...@@ -13,3 +13,6 @@ def test_std():
assert torch.allclose(out, expected) assert torch.allclose(out, expected)
out.backward(torch.randn_like(out)) out.backward(torch.randn_like(out))
jit = torch.jit.script(scatter_std)
assert jit(src, index, dim=-1, unbiased=True).tolist() == out.tolist()
...@@ -99,12 +99,18 @@ def test_forward(test, reduce, dtype, device): ...@@ -99,12 +99,18 @@ def test_forward(test, reduce, dtype, device):
dim = test['dim'] dim = test['dim']
expected = tensor(test[reduce], dtype, device) expected = tensor(test[reduce], dtype, device)
out = getattr(torch_scatter, 'scatter_' + reduce)(src, index, dim) fn = getattr(torch_scatter, 'scatter_' + reduce)
if isinstance(out, tuple): jit = torch.jit.script(fn)
out, arg_out = out out1 = fn(src, index, dim)
out2 = jit(src, index, dim)
if isinstance(out1, tuple):
out1, arg_out1 = out1
out2, arg_out2 = out2
arg_expected = tensor(test['arg_' + reduce], torch.long, device) arg_expected = tensor(test['arg_' + reduce], torch.long, device)
assert torch.all(arg_out == arg_expected) assert torch.all(arg_out1 == arg_expected)
assert torch.all(out == expected) assert arg_out1.tolist() == arg_out1.tolist()
assert torch.all(out1 == expected)
assert out1.tolist() == out2.tolist()
@pytest.mark.parametrize('test,reduce,device', @pytest.mark.parametrize('test,reduce,device',
......
...@@ -91,19 +91,31 @@ def test_forward(test, reduce, dtype, device): ...@@ -91,19 +91,31 @@ def test_forward(test, reduce, dtype, device):
indptr = tensor(test['indptr'], torch.long, device) indptr = tensor(test['indptr'], torch.long, device)
expected = tensor(test[reduce], dtype, device) expected = tensor(test[reduce], dtype, device)
out = getattr(torch_scatter, 'segment_' + reduce + '_csr')(src, indptr) fn = getattr(torch_scatter, 'segment_' + reduce + '_csr')
if isinstance(out, tuple): jit = torch.jit.script(fn)
out, arg_out = out out1 = fn(src, indptr)
out2 = jit(src, indptr)
if isinstance(out1, tuple):
out1, arg_out1 = out1
out2, arg_out2 = out2
arg_expected = tensor(test['arg_' + reduce], torch.long, device) arg_expected = tensor(test['arg_' + reduce], torch.long, device)
assert torch.all(arg_out == arg_expected) assert torch.all(arg_out1 == arg_expected)
assert torch.all(out == expected) assert arg_out1.tolist() == arg_out2.tolist()
assert torch.all(out1 == expected)
out = getattr(torch_scatter, 'segment_' + reduce + '_coo')(src, index) assert out1.tolist() == out2.tolist()
if isinstance(out, tuple):
out, arg_out = out fn = getattr(torch_scatter, 'segment_' + reduce + '_coo')
jit = torch.jit.script(fn)
out1 = fn(src, index)
out2 = jit(src, index)
if isinstance(out1, tuple):
out1, arg_out1 = out1
out2, arg_out2 = out2
arg_expected = tensor(test['arg_' + reduce], torch.long, device) arg_expected = tensor(test['arg_' + reduce], torch.long, device)
assert torch.all(arg_out == arg_expected) assert torch.all(arg_out1 == arg_expected)
assert torch.all(out == expected) assert arg_out1.tolist() == arg_out2.tolist()
assert torch.all(out1 == expected)
assert out1.tolist() == out2.tolist()
@pytest.mark.parametrize('test,reduce,device', @pytest.mark.parametrize('test,reduce,device',
......
...@@ -6,7 +6,6 @@ from torch_scatter import scatter_sum, scatter_max ...@@ -6,7 +6,6 @@ from torch_scatter import scatter_sum, scatter_max
from torch_scatter.utils import broadcast from torch_scatter.utils import broadcast
@torch.jit.script
def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1, def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None, dim_size: Optional[int] = None,
......
...@@ -4,7 +4,6 @@ from torch_scatter import scatter_sum, scatter_max ...@@ -4,7 +4,6 @@ from torch_scatter import scatter_sum, scatter_max
from torch_scatter.utils import broadcast from torch_scatter.utils import broadcast
@torch.jit.script
def scatter_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1, def scatter_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
eps: float = 1e-12) -> torch.Tensor: eps: float = 1e-12) -> torch.Tensor:
if not torch.is_floating_point(src): if not torch.is_floating_point(src):
...@@ -25,7 +24,6 @@ def scatter_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1, ...@@ -25,7 +24,6 @@ def scatter_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
return recentered_scores_exp.div(normalizing_constants) return recentered_scores_exp.div(normalizing_constants)
@torch.jit.script
def scatter_log_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1, def scatter_log_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
eps: float = 1e-12) -> torch.Tensor: eps: float = 1e-12) -> torch.Tensor:
if not torch.is_floating_point(src): if not torch.is_floating_point(src):
......
...@@ -5,7 +5,6 @@ from torch_scatter import scatter_sum ...@@ -5,7 +5,6 @@ from torch_scatter import scatter_sum
from torch_scatter.utils import broadcast from torch_scatter.utils import broadcast
@torch.jit.script
def scatter_std(src: torch.Tensor, index: torch.Tensor, dim: int = -1, def scatter_std(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None, dim_size: Optional[int] = None,
......
...@@ -5,7 +5,6 @@ import torch ...@@ -5,7 +5,6 @@ import torch
from .utils import broadcast from .utils import broadcast
@torch.jit.script
def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1, def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor: dim_size: Optional[int] = None) -> torch.Tensor:
...@@ -24,21 +23,18 @@ def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1, ...@@ -24,21 +23,18 @@ def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
return out.scatter_add_(dim, index, src) return out.scatter_add_(dim, index, src)
@torch.jit.script
def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1, def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor: dim_size: Optional[int] = None) -> torch.Tensor:
return scatter_sum(src, index, dim, out, dim_size) return scatter_sum(src, index, dim, out, dim_size)
@torch.jit.script
def scatter_mul(src: torch.Tensor, index: torch.Tensor, dim: int = -1, def scatter_mul(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor: dim_size: Optional[int] = None) -> torch.Tensor:
return torch.ops.torch_scatter.scatter_mul(src, index, dim, out, dim_size) return torch.ops.torch_scatter.scatter_mul(src, index, dim, out, dim_size)
@torch.jit.script
def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1, def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor: dim_size: Optional[int] = None) -> torch.Tensor:
...@@ -63,7 +59,6 @@ def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1, ...@@ -63,7 +59,6 @@ def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
return out return out
@torch.jit.script
def scatter_min( def scatter_min(
src: torch.Tensor, index: torch.Tensor, dim: int = -1, src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
...@@ -71,7 +66,6 @@ def scatter_min( ...@@ -71,7 +66,6 @@ def scatter_min(
return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size) return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size)
@torch.jit.script
def scatter_max( def scatter_max(
src: torch.Tensor, index: torch.Tensor, dim: int = -1, src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
......
...@@ -3,40 +3,35 @@ from typing import Optional, Tuple ...@@ -3,40 +3,35 @@ from typing import Optional, Tuple
import torch import torch
@torch.jit.script
def segment_sum_coo(src: torch.Tensor, index: torch.Tensor, def segment_sum_coo(src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor: dim_size: Optional[int] = None) -> torch.Tensor:
return torch.ops.torch_scatter.segment_sum_coo(src, index, out, dim_size) return torch.ops.torch_scatter.segment_sum_coo(src, index, out, dim_size)
@torch.jit.script
def segment_add_coo(src: torch.Tensor, index: torch.Tensor, def segment_add_coo(src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor: dim_size: Optional[int] = None) -> torch.Tensor:
return torch.ops.torch_scatter.segment_sum_coo(src, index, out, dim_size) return torch.ops.torch_scatter.segment_sum_coo(src, index, out, dim_size)
@torch.jit.script
def segment_mean_coo(src: torch.Tensor, index: torch.Tensor, def segment_mean_coo(src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor: dim_size: Optional[int] = None) -> torch.Tensor:
return torch.ops.torch_scatter.segment_mean_coo(src, index, out, dim_size) return torch.ops.torch_scatter.segment_mean_coo(src, index, out, dim_size)
@torch.jit.script def segment_min_coo(
def segment_min_coo(src: torch.Tensor, index: torch.Tensor, src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.torch_scatter.segment_min_coo(src, index, out, dim_size) return torch.ops.torch_scatter.segment_min_coo(src, index, out, dim_size)
@torch.jit.script def segment_max_coo(
def segment_max_coo(src: torch.Tensor, index: torch.Tensor, src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.torch_scatter.segment_max_coo(src, index, out, dim_size) return torch.ops.torch_scatter.segment_max_coo(src, index, out, dim_size)
...@@ -137,7 +132,6 @@ def segment_coo(src: torch.Tensor, index: torch.Tensor, ...@@ -137,7 +132,6 @@ def segment_coo(src: torch.Tensor, index: torch.Tensor,
raise ValueError raise ValueError
@torch.jit.script
def gather_coo(src: torch.Tensor, index: torch.Tensor, def gather_coo(src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor] = None) -> torch.Tensor: out: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.ops.torch_scatter.gather_coo(src, index, out) return torch.ops.torch_scatter.gather_coo(src, index, out)
...@@ -3,25 +3,21 @@ from typing import Optional, Tuple ...@@ -3,25 +3,21 @@ from typing import Optional, Tuple
import torch import torch
@torch.jit.script
def segment_sum_csr(src: torch.Tensor, indptr: torch.Tensor, def segment_sum_csr(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None) -> torch.Tensor: out: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.ops.torch_scatter.segment_sum_csr(src, indptr, out) return torch.ops.torch_scatter.segment_sum_csr(src, indptr, out)
@torch.jit.script
def segment_add_csr(src: torch.Tensor, indptr: torch.Tensor, def segment_add_csr(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None) -> torch.Tensor: out: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.ops.torch_scatter.segment_sum_csr(src, indptr, out) return torch.ops.torch_scatter.segment_sum_csr(src, indptr, out)
@torch.jit.script
def segment_mean_csr(src: torch.Tensor, indptr: torch.Tensor, def segment_mean_csr(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None) -> torch.Tensor: out: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.ops.torch_scatter.segment_mean_csr(src, indptr, out) return torch.ops.torch_scatter.segment_mean_csr(src, indptr, out)
@torch.jit.script
def segment_min_csr( def segment_min_csr(
src: torch.Tensor, indptr: torch.Tensor, src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None out: Optional[torch.Tensor] = None
...@@ -29,7 +25,6 @@ def segment_min_csr( ...@@ -29,7 +25,6 @@ def segment_min_csr(
return torch.ops.torch_scatter.segment_min_csr(src, indptr, out) return torch.ops.torch_scatter.segment_min_csr(src, indptr, out)
@torch.jit.script
def segment_max_csr( def segment_max_csr(
src: torch.Tensor, indptr: torch.Tensor, src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None out: Optional[torch.Tensor] = None
...@@ -114,7 +109,6 @@ def segment_csr(src: torch.Tensor, indptr: torch.Tensor, ...@@ -114,7 +109,6 @@ def segment_csr(src: torch.Tensor, indptr: torch.Tensor,
raise ValueError raise ValueError
@torch.jit.script
def gather_csr(src: torch.Tensor, indptr: torch.Tensor, def gather_csr(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None) -> torch.Tensor: out: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.ops.torch_scatter.gather_csr(src, indptr, out) return torch.ops.torch_scatter.gather_csr(src, indptr, out)
import torch import torch
@torch.jit.script
def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
if dim < 0: if dim < 0:
dim = other.dim() + dim dim = other.dim() + dim
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment