conv2d.py 2.93 KB
Newer Older
1
2
3
4
5
6
7
import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

import torch
import infinicore
8
9
10
11
12
13
14
from framework import (
    BaseOperatorTest,
    TensorSpec,
    TestCase,
    GenericTestRunner,
    is_broadcast,
)
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104

# Test cases: (in_shape, in_strides_or_None, weight_shape, bias_shape_or_None, stride, padding, dilation, groups)

_TEST_CASES_DATA = [
    ((2, 4, 16, 16), None, (8, 4, 3, 3), None, (1, 1), (0, 0), (1, 1), 1),
    ((1, 6, 15, 17), (1530, 255, 17, 1), (4, 6, 5, 3), (4,), (2, 2), (2, 1), (1, 1), 1),
    ((2, 8, 32, 32), None, (8, 8, 1, 1), None, (1, 1), (0, 0), (1, 2), 1),
    ((3, 3, 7, 9), (189, 63, 9, 1), (6, 3, 3, 3), None, 1, (1, 1), (1, 1), 1),
    ((2, 2, 31, 29), None, (4, 2, 4, 3), (4,), (2, 1), (1, 0), (1, 1), 1),
    ((1, 8, 9, 11), (792, 99, 11, 1), (8, 8, 3, 3), None, (1, 1), (1, 1), (1, 1), 1),
]

_TOLERANCE_MAP = {
    infinicore.float16: {"atol": 1e-2, "rtol": 1e-2},
    infinicore.float32: {"atol": 1e-4, "rtol": 1e-4},
}

_TENSOR_DTYPES = [infinicore.float16, infinicore.float32]


def parse_test_cases():
    tests = []
    for (
        in_shape,
        in_strides,
        w_shape,
        b_shape,
        stride,
        padding,
        dilation,
        groups,
    ) in _TEST_CASES_DATA:
        for dtype in _TENSOR_DTYPES:
            tol = _TOLERANCE_MAP.get(dtype, {"atol": 1e-5, "rtol": 1e-3})
            in_spec = TensorSpec.from_tensor(in_shape, in_strides, dtype)
            weight_spec = TensorSpec.from_tensor(w_shape, None, dtype)
            if b_shape is not None:
                bias_spec = TensorSpec.from_tensor(b_shape, None, dtype)
            else:
                bias_spec = None

            kwargs = {
                "stride": stride,
                "padding": padding,
                "dilation": dilation,
                "groups": groups,
            }
            inputs = [in_spec, weight_spec]
            if bias_spec is not None:
                inputs.append(bias_spec)

            tests.append(
                TestCase(
                    inputs=inputs,
                    kwargs=kwargs,
                    output_spec=None,
                    comparison_target=None,
                    tolerance=tol,
                    description="Conv2d - OUT_OF_PLACE",
                )
            )

    return tests


class OpTest(BaseOperatorTest):
    """Conv2d operator test with simplified implementation"""

    def __init__(self):
        super().__init__("Conv2d")

    def get_test_cases(self):
        return parse_test_cases()

    def torch_operator(self, *args, **kwargs):
        return torch.nn.functional.conv2d(*args, **kwargs)

    # def infinicore_operator(self, *args, **kwargs):
    #     """InfiniCore implementation (operator not yet available)."""
    #     return infinicore.nn.functional.conv2d(*args, **kwargs)


def main():
    """Main entry point"""
    runner = GenericTestRunner(OpTest)
    runner.run_and_exit()


if __name__ == "__main__":
    main()