causal_softmax.py 4.48 KB
Newer Older
xgqdut2016's avatar
xgqdut2016 committed
1
import torch
PanZezhongQY's avatar
PanZezhongQY committed
2
import ctypes
xgqdut2016's avatar
xgqdut2016 committed
3
4
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
from libinfiniop import (
PanZezhongQY's avatar
PanZezhongQY committed
5
6
    infiniopHandle_t,
    infiniopTensorDescriptor_t,
xgqdut2016's avatar
xgqdut2016 committed
7
8
9
    open_lib,
    to_tensor,
    get_test_devices,
PanZezhongQY's avatar
PanZezhongQY committed
10
    check_error,
xgqdut2016's avatar
xgqdut2016 committed
11
    rearrange_if_needed,
PanZezhongQY's avatar
PanZezhongQY committed
12
    create_workspace,
xgqdut2016's avatar
xgqdut2016 committed
13
14
15
16
17
    test_operator,
    get_args,
    debug,
    get_tolerance,
    profile_operation,
PanZezhongQY's avatar
PanZezhongQY committed
18
19
)

xgqdut2016's avatar
xgqdut2016 committed
20
21
22
23
# ==============================================================================
#  Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
PanZezhongQY's avatar
PanZezhongQY committed
24

xgqdut2016's avatar
xgqdut2016 committed
25
_TEST_CASES = [
xgqdut2016's avatar
xgqdut2016 committed
26
27
28
29
30
31
32
33
    # x_shape, x_stride
    ((32, 512), None),
    ((32, 512), (1024, 1)),
    ((32, 5, 5), None),
    ((32, 20, 512), None),
    ((32, 20, 512), (20480, 512, 1)),  # Ascend 暂不支持非连续
]

xgqdut2016's avatar
xgqdut2016 committed
34
# Data types used for testing
xgqdut2016's avatar
xgqdut2016 committed
35
_TENSOR_DTYPES = [torch.float16]
xgqdut2016's avatar
xgqdut2016 committed
36
37
38

# Tolerance map for different data types
_TOLERANCE_MAP = {
xgqdut2016's avatar
xgqdut2016 committed
39
    torch.float16: {"atol": 0, "rtol": 1e-2},
xgqdut2016's avatar
xgqdut2016 committed
40
41
42
43
44
45
}

DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
PanZezhongQY's avatar
PanZezhongQY committed
46

xgqdut2016's avatar
xgqdut2016 committed
47

PanZezhongQY's avatar
PanZezhongQY committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class CausalSoftmaxDescriptor(Structure):
    _fields_ = [("device", c_int32)]


infiniopCausalSoftmaxDescriptor_t = POINTER(CausalSoftmaxDescriptor)


def causal_softmax(x):
    type = x.dtype
    mask = torch.tril(torch.ones_like(x), diagonal=-1).flip(dims=[-2, -1])
    y = x.clone()
    masked = torch.where(mask == 1, -torch.inf, y.to(torch.float32))
    return torch.nn.functional.softmax(masked, dim=-1).to(type)


xgqdut2016's avatar
xgqdut2016 committed
63
def test(lib, handle, torch_device, x_shape, x_stride=None, dtype=torch.float16):
PanZezhongQY's avatar
PanZezhongQY committed
64
    print(
xgqdut2016's avatar
xgqdut2016 committed
65
        f"Testing CausalSoftmax on {torch_device} with x_shape:{x_shape} x_stride:{x_stride} dtype:{dtype}"
PanZezhongQY's avatar
PanZezhongQY committed
66
    )
xgqdut2016's avatar
xgqdut2016 committed
67

xgqdut2016's avatar
xgqdut2016 committed
68
69
    x = torch.rand(x_shape, dtype=dtype).to(torch_device)

PanZezhongQY's avatar
PanZezhongQY committed
70
    ans = causal_softmax(x)
xgqdut2016's avatar
xgqdut2016 committed
71
72

    x = rearrange_if_needed(x, x_stride)
xgqdut2016's avatar
xgqdut2016 committed
73

PanZezhongQY's avatar
PanZezhongQY committed
74
    x_tensor = to_tensor(x, lib)
xgqdut2016's avatar
xgqdut2016 committed
75

PanZezhongQY's avatar
PanZezhongQY committed
76
77
78
    descriptor = infiniopCausalSoftmaxDescriptor_t()
    check_error(
        lib.infiniopCreateCausalSoftmaxDescriptor(
xgqdut2016's avatar
xgqdut2016 committed
79
            handle, ctypes.byref(descriptor), x_tensor.descriptor
PanZezhongQY's avatar
PanZezhongQY committed
80
81
        )
    )
xgqdut2016's avatar
xgqdut2016 committed
82
83
84
85

    # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
    x_tensor.descriptor.contents.invalidate()

PanZezhongQY's avatar
PanZezhongQY committed
86
87
88
89
90
91
92
    workspace_size = c_uint64(0)
    check_error(
        lib.infiniopGetCausalSoftmaxWorkspaceSize(
            descriptor, ctypes.byref(workspace_size)
        )
    )
    workspace = create_workspace(workspace_size.value, x.device)
xgqdut2016's avatar
xgqdut2016 committed
93

xgqdut2016's avatar
xgqdut2016 committed
94
95
96
97
98
99
100
101
102
    def lib_causal_softmax():
        check_error(
            lib.infiniopCausalSoftmax(
                descriptor,
                workspace.data_ptr() if workspace is not None else None,
                workspace_size.value,
                x_tensor.data,
                None,
            )
PanZezhongQY's avatar
PanZezhongQY committed
103
        )
xgqdut2016's avatar
xgqdut2016 committed
104

xgqdut2016's avatar
xgqdut2016 committed
105
    lib_causal_softmax()
xgqdut2016's avatar
xgqdut2016 committed
106

xgqdut2016's avatar
xgqdut2016 committed
107
108
109
110
111
112
113
114
115
116
117
    atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
    if DEBUG:
        debug(x, ans, atol=atol, rtol=rtol)
    assert torch.allclose(x, ans, atol=atol, rtol=rtol)

    # Profiling workflow
    if PROFILE:
        # fmt: off
        profile_operation("PyTorch", lambda: causal_softmax(x), torch_device, NUM_PRERUN, NUM_ITERATIONS)
        profile_operation("    lib", lambda: lib_causal_softmax(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
        # fmt: on
PanZezhongQY's avatar
PanZezhongQY committed
118

xgqdut2016's avatar
xgqdut2016 committed
119
    check_error(lib.infiniopDestroyCausalSoftmaxDescriptor(descriptor))
120

PanZezhongQY's avatar
PanZezhongQY committed
121
122
123
124

if __name__ == "__main__":
    args = get_args()
    lib = open_lib()
xgqdut2016's avatar
xgqdut2016 committed
125

PanZezhongQY's avatar
PanZezhongQY committed
126
127
128
129
130
131
    lib.infiniopCreateCausalSoftmaxDescriptor.restype = c_int32
    lib.infiniopCreateCausalSoftmaxDescriptor.argtypes = [
        infiniopHandle_t,
        POINTER(infiniopCausalSoftmaxDescriptor_t),
        infiniopTensorDescriptor_t,
    ]
xgqdut2016's avatar
xgqdut2016 committed
132

PanZezhongQY's avatar
PanZezhongQY committed
133
134
135
136
137
    lib.infiniopGetCausalSoftmaxWorkspaceSize.restype = c_int32
    lib.infiniopGetCausalSoftmaxWorkspaceSize.argtypes = [
        infiniopCausalSoftmaxDescriptor_t,
        POINTER(c_uint64),
    ]
xgqdut2016's avatar
xgqdut2016 committed
138

PanZezhongQY's avatar
PanZezhongQY committed
139
140
141
142
143
144
145
146
    lib.infiniopCausalSoftmax.restype = c_int32
    lib.infiniopCausalSoftmax.argtypes = [
        infiniopCausalSoftmaxDescriptor_t,
        c_void_p,
        c_uint64,
        c_void_p,
        c_void_p,
    ]
xgqdut2016's avatar
xgqdut2016 committed
147

PanZezhongQY's avatar
PanZezhongQY committed
148
149
150
151
    lib.infiniopDestroyCausalSoftmaxDescriptor.restype = c_int32
    lib.infiniopDestroyCausalSoftmaxDescriptor.argtypes = [
        infiniopCausalSoftmaxDescriptor_t,
    ]
xgqdut2016's avatar
xgqdut2016 committed
152

xgqdut2016's avatar
xgqdut2016 committed
153
154
155
156
157
    # Configure testing options
    DEBUG = args.debug
    PROFILE = args.profile
    NUM_PRERUN = args.num_prerun
    NUM_ITERATIONS = args.num_iterations
xgqdut2016's avatar
xgqdut2016 committed
158

xgqdut2016's avatar
xgqdut2016 committed
159
160
    for device in get_test_devices(args):
        test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
PanZezhongQY's avatar
PanZezhongQY committed
161
162

    print("\033[92mTest passed!\033[0m")