embedding.py 3.78 KB
Newer Older
1
2
3
4
5
6
import os
import sys

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

import torch
7
8
9
10
11
12
13
14
15

from framework import (
    BaseOperatorTest,
    GenericTestRunner,
    TensorInitializer,
    TensorSpec,
    TestCase,
    convert_infinicore_to_torch,
)
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116

import infinicore

# ==============================================================================
# Operator-specific configuration
# ==============================================================================

# Test cases format: (x_shape, weight_shape)
#  weight (Tensor) – the weights of the module of shape (num_embeddings, embedding_dim).
_TEST_CASES_DATA = [
    # Basic cases
    ((1, 5), (32000, 2048)),
    ((2, 5), (32000, 2048)),
    ((2, 2, 5), (32000, 2048)),
]

# Tolerance configuration
_TOLERANCE_MAP = {
    infinicore.float16: {"atol": 0, "rtol": 1e-2},
    infinicore.float32: {"atol": 0, "rtol": 1e-3},
    infinicore.bfloat16: {"atol": 0, "rtol": 5e-2},
}

# Data types to test
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]


def parse_test_cases():
    """
    Parse test case data and return list of TestCase objects for all operation types.
    Each test case contains all necessary information for execution and validation.
    """
    test_cases = []

    for x_shape, weight_shape in _TEST_CASES_DATA:
        strides = None

        # Generate test cases for all data types
        for dtype in _TENSOR_DTYPES:
            tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 1e-3})

            # Create typed tensor specs
            x_spec = TensorSpec.from_tensor(
                x_shape,
                strides,
                infinicore.int64,
                init_mode=TensorInitializer.RANDINT,
                low=1,
                high=10000,
                name="x",
            )

            weight_spec = TensorSpec.from_tensor(
                weight_shape, strides, dtype, name="weight"
            )

            # Test Case 1: Out-of-place (return value)
            test_cases.append(
                TestCase(
                    inputs=[x_spec, weight_spec],
                    kwargs={},
                    output_spec=None,
                    comparison_target=None,
                    tolerance=tolerance,
                    description=f"nn.Embedding - OUT_OF_PLACE",
                )
            )

    return test_cases


class OpTest(BaseOperatorTest):
    """nn.Embedding test with simplified implementation"""

    def __init__(self):
        super().__init__("nn.Embedding")

    def get_test_cases(self):
        return parse_test_cases()

    def torch_operator(self, x, weight):
        """PyTorch nn.Embedding implementation"""

        num_embeddings, embedding_dim = weight.shape

        model = torch.nn.Embedding(
            num_embeddings,
            embedding_dim,
            device=weight.device,
            dtype=weight.dtype,
        )

        params_dict = {"weight": weight}
        model.load_state_dict(params_dict)

        with torch.no_grad():
            y = model(x)
        return y

    def infinicore_operator(self, x, weight):
        """InfiniCore nn.Embedding implementation"""
117
118
119
        # Note: embedding now supports device-side input for graph recording
        # No need to convert to CPU anymore - the implementation handles both CPU and device inputs
        
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
        num_embeddings, embedding_dim = weight.shape

        model = infinicore.nn.Embedding(
            num_embeddings,
            embedding_dim,
            device=weight.device,
            dtype=weight.dtype,
        )

        params_dict = {"weight": weight}
        model.load_state_dict(params_dict)

        y = model(x)
        return y


def main():
    """Main entry point"""
    runner = GenericTestRunner(OpTest)
    runner.run_and_exit()


if __name__ == "__main__":
    main()