Commit d63eb9c9 authored by Miltos Allamanis's avatar Miltos Allamanis
Browse files

Remaining flake8 formatting errors

parent 0c127881
...@@ -9,15 +9,20 @@ from .utils import devices, tensor ...@@ -9,15 +9,20 @@ from .utils import devices, tensor
SUPPORTED_FLOAT_DTYPES = {torch.float32, torch.float64} SUPPORTED_FLOAT_DTYPES = {torch.float32, torch.float64}
@pytest.mark.parametrize('dtype,device', product(SUPPORTED_FLOAT_DTYPES, devices)) @pytest.mark.parametrize('dtype,device',
product(SUPPORTED_FLOAT_DTYPES, devices))
def test_logsumexp(dtype, device): def test_logsumexp(dtype, device):
src = tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, float('-inf')], dtype, device) src = tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device) index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
out = scatter_logsumexp(src, index) out = scatter_logsumexp(src, index)
idx0 = torch.logsumexp(torch.tensor([0.5, 0.5], dtype=dtype), dim=-1).tolist() idx0 = torch.logsumexp(
idx1 = torch.logsumexp(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1).tolist() torch.tensor([0.5, 0.5], dtype=dtype),
dim=-1).tolist()
idx1 = torch.logsumexp(
torch.tensor([0, -2.1, 3.2], dtype=dtype),
dim=-1).tolist()
idx2 = 7 # Single element idx2 = 7 # Single element
idx3 = torch.finfo(dtype).min # Empty index, returns yield value idx3 = torch.finfo(dtype).min # Empty index, returns yield value
idx4 = -1 # logsumexp with -inf is the identity idx4 = -1 # logsumexp with -inf is the identity
......
...@@ -10,16 +10,20 @@ from .utils import devices, tensor ...@@ -10,16 +10,20 @@ from .utils import devices, tensor
SUPPORTED_FLOAT_DTYPES = {torch.float32, torch.float64} SUPPORTED_FLOAT_DTYPES = {torch.float32, torch.float64}
@pytest.mark.parametrize('dtype,device', product(SUPPORTED_FLOAT_DTYPES, devices)) @pytest.mark.parametrize('dtype,device',
product(SUPPORTED_FLOAT_DTYPES, devices))
def test_log_softmax(dtype, device): def test_log_softmax(dtype, device):
src = tensor([0.25, 0, 0.25, -2.1, 3.2, 7, -1, float('-inf')], dtype, device) src = tensor([0.25, 0, 0.25, -2.1, 3.2, 7, -1, float('-inf')],
dtype, device)
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device) index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
out = scatter_log_softmax(src, index) out = scatter_log_softmax(src, index)
# Expected results per index # Expected results per index
idx0 = [np.log(0.5), np.log(0.5)] idx0 = [np.log(0.5), np.log(0.5)]
idx1 = torch.log_softmax(torch.tensor([0.0, -2.1, 3.2], dtype=dtype), dim=-1).tolist() idx1 = torch.log_softmax(
torch.tensor([0.0, -2.1, 3.2], dtype=dtype),
dim=-1).tolist()
idx2 = 0.0 # Single element, has logprob=0 idx2 = 0.0 # Single element, has logprob=0
# index=3 is empty. Should not matter. # index=3 is empty. Should not matter.
idx4 = [0.0, float('-inf')] # log_softmax with -inf preserves the -inf idx4 = [0.0, float('-inf')] # log_softmax with -inf preserves the -inf
...@@ -31,16 +35,20 @@ def test_log_softmax(dtype, device): ...@@ -31,16 +35,20 @@ def test_log_softmax(dtype, device):
) )
@pytest.mark.parametrize('dtype,device', product(SUPPORTED_FLOAT_DTYPES, devices)) @pytest.mark.parametrize('dtype,device',
product(SUPPORTED_FLOAT_DTYPES, devices))
def test_softmax(dtype, device): def test_softmax(dtype, device):
src = tensor([0.25, 0, 0.25, -2.1, 3.2, 7, -1, float('-inf')], dtype, device) src = tensor([0.25, 0, 0.25, -2.1, 3.2, 7, -1, float('-inf')],
dtype, device)
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device) index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
out = scatter_softmax(src, index) out = scatter_softmax(src, index)
# Expected results per index # Expected results per index
idx0 = [0.5, 0.5] idx0 = [0.5, 0.5]
idx1 = torch.softmax(torch.tensor([0.0, -2.1, 3.2], dtype=dtype), dim=-1).tolist() idx1 = torch.softmax(
torch.tensor([0.0, -2.1, 3.2], dtype=dtype),
dim=-1).tolist()
idx2 = 1 # Single element, has prob=1 idx2 = 1 # Single element, has prob=1
# index=3 is empty. Should not matter. # index=3 is empty. Should not matter.
idx4 = [1.0, 0.0] # softmax with -inf yields zero probability idx4 = [1.0, 0.0] # softmax with -inf yields zero probability
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment