Commit 8099c537 authored by rusty1s's avatar rusty1s
Browse files

added new tests

parent f0fdfe20
...@@ -2,7 +2,7 @@ from os import path as osp ...@@ -2,7 +2,7 @@ from os import path as osp
from setuptools import setup, find_packages from setuptools import setup, find_packages
__version__ = '0.2.3' __version__ = '0.3.0'
url = 'https://github.com/rusty1s/pytorch_scatter' url = 'https://github.com/rusty1s/pytorch_scatter'
install_requires = ['cffi'] install_requires = ['cffi']
......
[ [
{ {
"name": "add", "name": "add",
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]], "index": [2, 0, 1, 1, 0],
"input": [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]], "input": [1, 2, 3, 4, 5],
"dim": 1, "dim": 0,
"fill_value": 0, "fill_value": 0,
"grad": [[10, 20, 30, 40, 50, 60], [15, 25, 35, 45, 55, 65]], "grad": [4, 8, 6],
"expected": [[50, 60, 50, 30, 40], [15, 15, 35, 35, 25]] "expected": [6, 4, 8, 8, 4]
}, },
{ {
"name": "add", "name": "sub",
"index": [[0, 0], [1, 1], [1, 1], [0, 0]], "index": [2, 0, 1, 1, 0],
"input": [[5, 2], [2, 5], [4, 3], [1, 3]], "input": [1, 2, 3, 4, 5],
"dim": 0, "dim": 0,
"fill_value": 0, "fill_value": 0,
"grad": [[10, 20], [15, 25]], "grad": [4, 8, 6],
"expected": [[10, 20], [15, 25], [15, 25], [10, 20]] "expected": [-6, -4, -8, -8, -4]
}, },
{ {
"name": "mean", "name": "mean",
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]], "index": [2, 0, 1, 1, 0],
"input": [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]], "input": [1, 2, 3, 4, 5],
"dim": 1, "dim": 0,
"fill_value": 0, "fill_value": 0,
"grad": [[10, 20, 30, 40, 50, 60], [15, 25, 35, 45, 55, 65]], "grad": [4, 8, 6],
"expected": [[50, 60, 50, 30, 40], [15, 15, 35, 35, 25]] "expected": [6, 2, 4, 4, 2]
}, },
{ {
"name": "mean", "name": "max",
"index": [[0, 0], [1, 1], [1, 1], [0, 0]], "index": [2, 0, 1, 1, 0],
"input": [[5, 2], [2, 5], [4, 3], [1, 3]], "input": [1, 2, 3, 4, 5],
"dim": 0, "dim": 0,
"fill_value": 0, "fill_value": 0,
"grad": [[10, 20], [15, 25]], "grad": [4, 8, 6],
"expected": [[10, 20], [15, 25], [15, 25], [10, 20]] "expected": [6, 0, 0, 8, 4]
}, },
{ {
"name": "max", "name": "min",
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]], "index": [2, 0, 1, 1, 0],
"input": [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]], "input": [1, 2, 3, 4, 5],
"dim": 1, "dim": 0,
"fill_value": 0, "fill_value": 3,
"grad": [[10, 20, 30, 40, 50, 60], [15, 25, 35, 45, 55, 65]], "grad": [4, 8, 6],
"expected": [[50, 60, 0, 30, 40], [0, 15, 0, 35, 25]] "expected": [6, 4, 8, 0, 0]
}, },
{ {
"name": "max", "name": "mul",
"index": [[0, 0], [1, 1], [1, 1], [0, 0]], "index": [2, 0, 1, 1, 0],
"input": [[5, 2], [2, 5], [4, 3], [1, 3]], "input": [1, 2, 3, 4, 5],
"dim": 0, "dim": 0,
"fill_value": 0, "fill_value": 2,
"grad": [[10, 20], [15, 25]], "grad": [4, 8, 6],
"expected": [[10, 0], [0, 25], [15, 0], [0, 20]] "expected": [12, 40, 64, 48, 16]
} }
] ]
...@@ -3,6 +3,8 @@ from torch._tensor_docs import tensor_classes ...@@ -3,6 +3,8 @@ from torch._tensor_docs import tensor_classes
tensors = [t[:-4] for t in tensor_classes] tensors = [t[:-4] for t in tensor_classes]
tensors.remove('ShortTensor') # TODO: PyTorch `atomicAdd` bug with short type. tensors.remove('ShortTensor') # TODO: PyTorch `atomicAdd` bug with short type.
tensors.remove('ByteTensor') # We cannot properly test unsigned values.
tensors.remove('CharTensor') # Overflow on gradient computations :(
def Tensor(str, x): def Tensor(str, x):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment