Commit a95ac7f5 authored by rusty1s's avatar rusty1s
Browse files

rename

parent 8ef7174b
......@@ -12,7 +12,7 @@ extra_objects = []
with_cuda = False
ffi = create_extension(
name='torch_scatter._ext.scatter',
name='torch_scatter._ext.ffi',
package=True,
verbose=True,
headers=headers,
......
from nose.tools import assert_equal
import torch
from torch_scatter._ext import scatter
from torch_scatter._ext import ffi
def test_scatter_add():
......@@ -12,15 +12,5 @@ def test_scatter_add():
output = input.new(2, 6).fill_(0)
expected_output = [[0, 0, 4, 3, 3, 0], [2, 4, 4, 0, 0, 0]]
scatter.scatter_add_Float(output, index, input, 1)
ffi.scatter_add_Float(output, index, input, 1)
assert_equal(output.tolist(), expected_output)
n = 10000
input = torch.rand(torch.Size([n]))
index = (torch.rand(torch.Size([n])) * n).long()
output = input.new(n).fill_(0)
expected_output = input.new(n).fill_(0)
scatter.scatter_add_Float(output, index, input, 0)
expected_output.scatter_add_(0, index, input)
assert_equal(output.tolist(), expected_output.tolist())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment