Commit 0c127881 authored by Miltos Allamanis's avatar Miltos Allamanis
Browse files

Address most flake8, pycodestyle errors.

parent 7b14c671
......@@ -2,12 +2,13 @@ from itertools import product
import torch
import pytest
from torch_scatter import scatter_max, scatter_logsumexp
from torch_scatter import scatter_logsumexp
from .utils import devices, tensor
SUPPORTED_FLOAT_DTYPES = {torch.float32, torch.float64}
@pytest.mark.parametrize('dtype,device', product(SUPPORTED_FLOAT_DTYPES, devices))
def test_logsumexp(dtype, device):
src = tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
......
......@@ -9,6 +9,7 @@ from .utils import devices, tensor
SUPPORTED_FLOAT_DTYPES = {torch.float32, torch.float64}
@pytest.mark.parametrize('dtype,device', product(SUPPORTED_FLOAT_DTYPES, devices))
def test_log_softmax(dtype, device):
src = tensor([0.25, 0, 0.25, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
......
......@@ -2,9 +2,11 @@ import torch
from torch_scatter import scatter_add, scatter_max
def scatter_log_softmax(src, index, dim=-1, dim_size=None):
r"""
Numerical safe log-softmax of all values from the :attr:`src` tensor into :attr:`out` at the
Numerical safe log-softmax of all values from
the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`.If multiple indices reference the same location, their
**contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`).
......@@ -42,9 +44,12 @@ def scatter_log_softmax(src, index, dim=-1, dim_size=None):
:rtype: :class:`Tensor`
"""
if not torch.is_floating_point(src):
raise ValueError('log_softmax can be computed only over tensors with floating point data types.')
raise ValueError('log_softmax can be computed only over '
'tensors with floating point data types.')
max_value_per_index, _ = scatter_max(src, index, dim=dim, dim_size=dim_size)
max_value_per_index, _ = scatter_max(src, index,
dim=dim,
dim_size=dim_size)
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element
......@@ -62,7 +67,8 @@ def scatter_log_softmax(src, index, dim=-1, dim_size=None):
def scatter_softmax(src, index, dim=-1, dim_size=None, epsilon=1e-16):
r"""
Numerical safe log-softmax of all values from the :attr:`src` tensor into :attr:`out` at the
Numerical safe log-softmax of all values from
the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`. If multiple indices reference the same location, their
**contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`).
......@@ -100,9 +106,12 @@ def scatter_softmax(src, index, dim=-1, dim_size=None, epsilon=1e-16):
:rtype: :class:`Tensor`
"""
if not torch.is_floating_point(src):
raise ValueError('softmax can be computed only over tensors with floating point data types.')
raise ValueError('softmax can be computed only over '
'tensors with floating point data types.')
max_value_per_index, _ = scatter_max(src, index, dim=dim, dim_size=dim_size)
max_value_per_index, _ = scatter_max(src, index,
dim=dim,
dim_size=dim_size)
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element
......
......@@ -3,9 +3,11 @@ import torch
from . import scatter_add, scatter_max
def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=None, epsilon=1e-16):
def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None,
fill_value=None, epsilon=1e-16):
r"""
Numerically safe logsumexp of all values from the :attr:`src` tensor into :attr:`out` at the
Numerically safe logsumexp of all values from
the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`. If multiple indices reference the same location, their
**contributions logsumexp** (`cf.` :meth:`~torch_scatter.scatter_add`).
......@@ -13,7 +15,8 @@ def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=No
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = \log \left( \exp(\mathrm{out}_i) + \sum_j \exp(\mathrm{src}_j) \right)
\mathrm{out}_i = \log \left( \exp(\mathrm{out}_i)
+ \sum_j \exp(\mathrm{src}_j) \right)
Compute a numerically safe logsumexp operation
from the :attr:`src` tensor into :attr:`out` at the indices
......@@ -40,13 +43,18 @@ def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=No
:rtype: :class:`Tensor`
"""
if not torch.is_floating_point(src):
raise ValueError('logsumexp can be computed over tensors with floating point data types.')
raise ValueError('logsumexp can only be computed over '
'tensors with floating point data types.')
if fill_value is None:
fill_value = torch.finfo(src.dtype).min
dim_size = out.shape[dim] if dim_size is None and out is not None else dim_size
max_value_per_index, _ = scatter_max(src, index, dim=dim, out=out, dim_size=dim_size, fill_value=fill_value)
dim_size = out.shape[dim] \
if dim_size is None and out is not None else dim_size
max_value_per_index, _ = scatter_max(src, index, dim=dim,
out=out, dim_size=dim_size,
fill_value=fill_value)
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment