test_jit.py 741 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from typing import Optional

import torch
import torch_scatter


@torch.jit.script
def segment_csr(src: torch.Tensor, indptr: torch.Tensor,
                out: Optional[torch.Tensor] = None, reduce: str = "sum"):
    return torch.ops.torch_scatter_cpu.segment_sum_csr(src, indptr, out)


def test_jit():
    # op = torch.ops.torch_scatter_cpu.segment_sum_csr

    src = torch.randn(8, 4)
    src.requires_grad_()
    indptr = torch.tensor([0, 2, 4, 6, 8])

    out = segment_csr(src, indptr)
    print(out)

    print(src.grad)
    out.backward(torch.randn_like(out))
    print(src.grad)

    # op = torch.ops.torch_scatter_cpu.segment_csr
    # out = op(src, indptr, None, "sum")
    # print(out)

    # traced_cell = torch.jit.script(op)