"openwebtext/README.md" did not exist on "0399d32c75b4719c89b91c18a173d05936112036"
integrated_conv_test.py 769 Bytes
Newer Older
Daniel Povey's avatar
Daniel Povey committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
from torch_integrated_conv import integrated_conv


def test_integrated_conv_zeros():
    N = 1
    C = 2
    H = 3
    W = 4
    for device in [ torch.device('cpu'), torch.device('cuda:0') ]:
        for dtype in [torch.float32, torch.float64]:
            print("device=", device, ", dtype=", dtype)
            input = torch.zeros(N, 2 * C, H, W, device=device, dtype=dtype)
            kH = 5
            kW = 5
            pos_add = torch.zeros(C, kH, kW, device=device, dtype=dtype)
            pos_mul = torch.zeros(C, kH, kW, device=device, dtype=dtype)

Daniel Povey's avatar
Daniel Povey committed
19
            output_ref = torch.zeros(N, C, H, W, device=device, dtype=dtype)
Daniel Povey's avatar
Daniel Povey committed
20
            output = integrated_conv(input, pos_add, pos_mul)
Daniel Povey's avatar
Daniel Povey committed
21
            assert torch.allclose(output, output_ref)