Commit 0d9699b5 authored by rusty1s's avatar rusty1s
Browse files

doctests

parent ea62ccd2
......@@ -2,14 +2,16 @@ import os
import sys
import datetime
import sphinx_rtd_theme
import doctest
sys.path.insert(0, os.path.abspath('../..'))
from torch_scatter import __version__
from torch_scatter import __version__ # noqa
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.doctest',
'sphinx.ext.intersphinx',
'sphinx.ext.mathjax',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
......@@ -19,12 +21,13 @@ extensions = [
source_suffix = '.rst'
master_doc = 'index'
project = 'pytorch_scatter'
copyright = '{}, Matthias Fey'.format(datetime.datetime.now().year)
author = 'Matthias Fey'
version = __version__
release = __version__
project = 'pytorch_scatter'
copyright = '{}, {}'.format(datetime.datetime.now().year, author)
version = release = __version__
html_theme = 'sphinx_rtd_theme'
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
doctest_default_flags = doctest.NORMALIZE_WHITESPACE
intersphinx_mapping = {'python': ('https://docs.python.org/', None)}
from os import path as osp
from setuptools import setup, find_packages
import build # noqa
from torch_scatter import __version__
import build # noqa
install_requires = ['cffi']
setup_requires = ['pytest-runner', 'cffi']
tests_require = ['pytest']
......
......@@ -30,18 +30,32 @@ def scatter_add_(output, index, input, dim=0):
input (Tensor): The source tensor
dim (int, optional): The axis along which to index
Example::
.. testsetup::
>> input = torch.Tensor([[2, 0, 1, 4, 3], [0,2, 1, 3, 4]])
>> index = torch.LongTensor([[4, 5, 2, 3], [0, 0, 2, 2, 1]])
>> output = torch.zeros(2, 6)
>> scatter_add_(output, index, input, dim=1)
import torch
from torch_scatter import scatter_add_
.. testcode::
input = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.LongTensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
output = torch.zeros(2, 6)
scatter_add_(output, index, input, dim=1)
print(output)
.. testoutput::
0 0 4 3 3 0
2 4 4 0 0 0
[torch.FloatTensor of size 2x6]
"""
return output.scatter_add_(dim, index, input)
# .. testoutput::
# 0 0 4 3 3 0
# 2 4 4 0 0 0
# [torch.FloatTensor of size 2x6]
def scatter_add(index, input, dim=0, size=None, fill_value=0):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment