cross_entropy.py 3.27 KB
Newer Older
PanZezhong1725's avatar
PanZezhong1725 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
import ctypes
from ctypes import c_uint64
from libinfiniop import (
    LIBINFINIOP,
    TestTensor,
    get_test_devices,
    check_error,
    test_operator,
    get_args,
    get_tolerance,
    profile_operation,
    TestWorkspace,
    InfiniDtype,
    InfiniDtypeNames,
    InfiniDeviceNames,
    infiniopOperatorDescriptor_t,
)

# ------------------------------------------------------------
# 用例配置
# ------------------------------------------------------------
_TEST_CASES_ = [
    ((2, 4, 10), None, None),        # logits shape, x_stride, y_stride
    ((1, 128, 32000), None, None),
    ((4, 512, 1000), None, None),
]

_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32]
_TOLERANCE_MAP = {
    InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2},
    InfiniDtype.BF16: {"atol": 1e-2, "rtol": 2e-2},
    InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5},
}

# ------------------------------------------------------------
# PyTorch 参考实现
# ------------------------------------------------------------
def cross_entropy_ref(logits, target):
    vocab = logits.shape[-1]
    logits_flat = logits.reshape(-1, vocab).float()
    target_flat = target.reshape(-1).long()
    loss = torch.nn.functional.cross_entropy(logits_flat, target_flat, reduction="none")
    return loss.view(target.shape).to(logits.dtype)


def test(handle, device, shape, x_stride=None, y_stride=None, dtype=InfiniDtype.F16, sync=None):
    logits_shape = shape
    label_shape = shape[:-1]
    vocab = shape[-1]

    print(f"Testing CrossEntropy on {InfiniDeviceNames[device]} logits:{logits_shape} dtype:{InfiniDtypeNames[dtype]}")

    x = TestTensor(logits_shape, x_stride, dtype, device)
    target = TestTensor(label_shape, None, InfiniDtype.I64, device)

    # 生成有效标签
    tgt = target.torch_tensor()
    tgt.copy_(torch.randint(0, vocab, label_shape, dtype=torch.int64, device=tgt.device))
    target.actual_tensor().copy_(tgt)

    reference = cross_entropy_ref(x.torch_tensor(), target.torch_tensor())
    y = TestTensor(label_shape, y_stride, dtype, device)

    descriptor = infiniopOperatorDescriptor_t()
    check_error(
        LIBINFINIOP.infiniopCreateCrossEntropyDescriptor(
            handle, ctypes.byref(descriptor), y.descriptor, x.descriptor, target.descriptor
        )
    )

    for tensor in [x, y, target]:
        tensor.destroy_desc()

    workspace_size = c_uint64(0)
    check_error(LIBINFINIOP.infiniopGetCrossEntropyWorkspaceSize(descriptor, ctypes.byref(workspace_size)))
    workspace = TestWorkspace(workspace_size.value, x.device)

    def run():
        check_error(
            LIBINFINIOP.infiniopCrossEntropy(
                descriptor,
                workspace.data(),
                workspace.size(),
                y.data(),
                x.data(),
                target.data(),
                None,
            )
        )

    run()
    if sync:
        sync()

    atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
    assert torch.allclose(y.actual_tensor(), reference, atol=atol, rtol=rtol)

    check_error(LIBINFINIOP.infiniopDestroyCrossEntropyDescriptor(descriptor))


if __name__ == "__main__":
    args = get_args()
    for device in get_test_devices(args):
        test_operator(device, test, _TEST_CASES_, _TENSOR_DTYPES)
    print("\033[92mTest passed!\033[0m")