"vscode:/vscode.git/clone" did not exist on "f06d522c750e2bc09ee1a5393736f03d26a82480"
Commit 2ec4b08e authored by Daniel Povey's avatar Daniel Povey
Browse files

Testing of the dim parameter and reshaping capabilities

parent c288e95b
......@@ -68,7 +68,7 @@ setup(
install_requires=requirements,
python_requires='>=3.6',
packages=find_packages(),
author='Anton Obukhov',
author='Dan Povey',
license='BSD',
url='https://www.github.com/toshas/torch-discounted-cumsum',
ext_modules=configure_extensions(),
......
......@@ -11,6 +11,7 @@ def test_learned_nonlin_basic():
x = -2.0 + 0.4 * torch.arange(10, dtype=dtype)
x = x.reshape(1, 1, 10).repeat(B, C, 1)
K = 4
N = K * 2
params = torch.arange(N + 1, dtype=dtype).unsqueeze(0) + torch.arange(C, dtype=dtype).unsqueeze(1) - 3
......@@ -19,8 +20,20 @@ def test_learned_nonlin_basic():
print("x = ", x)
print("params = ", params)
print("x.shape = ", x.shape)
y = learned_nonlin(x, params, dim = 1)
if True:
# Check
x2 = x.reshape(B, C, 5, 2)
assert torch.allclose(learned_nonlin(x, params, dim = 1), learned_nonlin(x2, params, dim = 1).reshape(x.shape))
x2 = x.reshape(B, 1, C, 10)
assert torch.allclose(learned_nonlin(x, params, dim = 1), learned_nonlin(x2, params, dim = 2).reshape(x.shape))
print("y = ", y)
y.sum().backward()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment