Commit d63eb9c9 authored by Miltos Allamanis's avatar Miltos Allamanis
Browse files

Remaining flake8 formatting errors

parent 0c127881
......@@ -9,15 +9,20 @@ from .utils import devices, tensor
SUPPORTED_FLOAT_DTYPES = {torch.float32, torch.float64}
@pytest.mark.parametrize('dtype,device', product(SUPPORTED_FLOAT_DTYPES, devices))
@pytest.mark.parametrize('dtype,device',
product(SUPPORTED_FLOAT_DTYPES, devices))
def test_logsumexp(dtype, device):
src = tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
out = scatter_logsumexp(src, index)
idx0 = torch.logsumexp(torch.tensor([0.5, 0.5], dtype=dtype), dim=-1).tolist()
idx1 = torch.logsumexp(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1).tolist()
idx0 = torch.logsumexp(
torch.tensor([0.5, 0.5], dtype=dtype),
dim=-1).tolist()
idx1 = torch.logsumexp(
torch.tensor([0, -2.1, 3.2], dtype=dtype),
dim=-1).tolist()
idx2 = 7 # Single element
idx3 = torch.finfo(dtype).min # Empty index, returns yield value
idx4 = -1 # logsumexp with -inf is the identity
......
......@@ -10,16 +10,20 @@ from .utils import devices, tensor
SUPPORTED_FLOAT_DTYPES = {torch.float32, torch.float64}
@pytest.mark.parametrize('dtype,device', product(SUPPORTED_FLOAT_DTYPES, devices))
@pytest.mark.parametrize('dtype,device',
product(SUPPORTED_FLOAT_DTYPES, devices))
def test_log_softmax(dtype, device):
src = tensor([0.25, 0, 0.25, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
src = tensor([0.25, 0, 0.25, -2.1, 3.2, 7, -1, float('-inf')],
dtype, device)
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
out = scatter_log_softmax(src, index)
# Expected results per index
idx0 = [np.log(0.5), np.log(0.5)]
idx1 = torch.log_softmax(torch.tensor([0.0, -2.1, 3.2], dtype=dtype), dim=-1).tolist()
idx1 = torch.log_softmax(
torch.tensor([0.0, -2.1, 3.2], dtype=dtype),
dim=-1).tolist()
idx2 = 0.0 # Single element, has logprob=0
# index=3 is empty. Should not matter.
idx4 = [0.0, float('-inf')] # log_softmax with -inf preserves the -inf
......@@ -31,16 +35,20 @@ def test_log_softmax(dtype, device):
)
@pytest.mark.parametrize('dtype,device', product(SUPPORTED_FLOAT_DTYPES, devices))
@pytest.mark.parametrize('dtype,device',
product(SUPPORTED_FLOAT_DTYPES, devices))
def test_softmax(dtype, device):
src = tensor([0.25, 0, 0.25, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
src = tensor([0.25, 0, 0.25, -2.1, 3.2, 7, -1, float('-inf')],
dtype, device)
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
out = scatter_softmax(src, index)
# Expected results per index
idx0 = [0.5, 0.5]
idx1 = torch.softmax(torch.tensor([0.0, -2.1, 3.2], dtype=dtype), dim=-1).tolist()
idx1 = torch.softmax(
torch.tensor([0.0, -2.1, 3.2], dtype=dtype),
dim=-1).tolist()
idx2 = 1 # Single element, has prob=1
# index=3 is empty. Should not matter.
idx4 = [1.0, 0.0] # softmax with -inf yields zero probability
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment