Commit 0d9699b5 authored by rusty1s's avatar rusty1s
Browse files

doctests

parent ea62ccd2
...@@ -2,14 +2,16 @@ import os ...@@ -2,14 +2,16 @@ import os
import sys import sys
import datetime import datetime
import sphinx_rtd_theme import sphinx_rtd_theme
import doctest
sys.path.insert(0, os.path.abspath('../..')) sys.path.insert(0, os.path.abspath('../..'))
from torch_scatter import __version__ from torch_scatter import __version__ # noqa
extensions = [ extensions = [
'sphinx.ext.autodoc', 'sphinx.ext.autodoc',
'sphinx.ext.doctest', 'sphinx.ext.doctest',
'sphinx.ext.intersphinx',
'sphinx.ext.mathjax', 'sphinx.ext.mathjax',
'sphinx.ext.napoleon', 'sphinx.ext.napoleon',
'sphinx.ext.viewcode', 'sphinx.ext.viewcode',
...@@ -19,12 +21,13 @@ extensions = [ ...@@ -19,12 +21,13 @@ extensions = [
source_suffix = '.rst' source_suffix = '.rst'
master_doc = 'index' master_doc = 'index'
project = 'pytorch_scatter'
copyright = '{}, Matthias Fey'.format(datetime.datetime.now().year)
author = 'Matthias Fey' author = 'Matthias Fey'
project = 'pytorch_scatter'
version = __version__ copyright = '{}, {}'.format(datetime.datetime.now().year, author)
release = __version__ version = release = __version__
html_theme = 'sphinx_rtd_theme' html_theme = 'sphinx_rtd_theme'
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
doctest_default_flags = doctest.NORMALIZE_WHITESPACE
intersphinx_mapping = {'python': ('https://docs.python.org/', None)}
from os import path as osp from os import path as osp
from setuptools import setup, find_packages from setuptools import setup, find_packages
import build # noqa
from torch_scatter import __version__ from torch_scatter import __version__
import build # noqa
install_requires = ['cffi'] install_requires = ['cffi']
setup_requires = ['pytest-runner', 'cffi'] setup_requires = ['pytest-runner', 'cffi']
tests_require = ['pytest'] tests_require = ['pytest']
......
...@@ -30,18 +30,32 @@ def scatter_add_(output, index, input, dim=0): ...@@ -30,18 +30,32 @@ def scatter_add_(output, index, input, dim=0):
input (Tensor): The source tensor input (Tensor): The source tensor
dim (int, optional): The axis along which to index dim (int, optional): The axis along which to index
Example:: .. testsetup::
>> input = torch.Tensor([[2, 0, 1, 4, 3], [0,2, 1, 3, 4]]) import torch
>> index = torch.LongTensor([[4, 5, 2, 3], [0, 0, 2, 2, 1]]) from torch_scatter import scatter_add_
>> output = torch.zeros(2, 6)
>> scatter_add_(output, index, input, dim=1) .. testcode::
input = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.LongTensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
output = torch.zeros(2, 6)
scatter_add_(output, index, input, dim=1)
print(output)
.. testoutput::
0 0 4 3 3 0 0 0 4 3 3 0
2 4 4 0 0 0 2 4 4 0 0 0
[torch.FloatTensor of size 2x6] [torch.FloatTensor of size 2x6]
""" """
return output.scatter_add_(dim, index, input) return output.scatter_add_(dim, index, input)
# .. testoutput::
# 0 0 4 3 3 0
# 2 4 4 0 0 0
# [torch.FloatTensor of size 2x6]
def scatter_add(index, input, dim=0, size=None, fill_value=0): def scatter_add(index, input, dim=0, size=None, fill_value=0):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment