causal_softmax.py 5.26 KB
Newer Older
xgqdut2016's avatar
xgqdut2016 committed
1
import torch
PanZezhongQY's avatar
PanZezhongQY committed
2
import ctypes
xgqdut2016's avatar
xgqdut2016 committed
3
4
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
from libinfiniop import (
PanZezhongQY's avatar
PanZezhongQY committed
5
6
    infiniopHandle_t,
    infiniopTensorDescriptor_t,
xgqdut2016's avatar
xgqdut2016 committed
7
8
9
    open_lib,
    to_tensor,
    get_test_devices,
PanZezhongQY's avatar
PanZezhongQY committed
10
    check_error,
xgqdut2016's avatar
xgqdut2016 committed
11
    rearrange_if_needed,
PanZezhongQY's avatar
PanZezhongQY committed
12
    create_workspace,
xgqdut2016's avatar
xgqdut2016 committed
13
14
15
16
17
    test_operator,
    get_args,
    debug,
    get_tolerance,
    profile_operation,
PanZezhongQY's avatar
PanZezhongQY committed
18
)
19
from enum import Enum, auto
PanZezhongQY's avatar
PanZezhongQY committed
20

xgqdut2016's avatar
xgqdut2016 committed
21
22
23
24
# ==============================================================================
#  Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
25
26
27
28
29
30
31
_TEST_CASES_ = [
    # shape, x_stride, y_stride
    ((3, 3), None, None),
    ((32, 512), None, None),
    ((32, 512), (1024, 1), (1024, 1)),
    ((32, 5, 5), None, None),
    ((32, 20, 512), None, None),
32
    ((32, 20, 512), (20480, 512, 1), None),
xgqdut2016's avatar
xgqdut2016 committed
33
34
]

xgqdut2016's avatar
xgqdut2016 committed
35
# Data types used for testing
xgqdut2016's avatar
xgqdut2016 committed
36
_TENSOR_DTYPES = [torch.float16]
xgqdut2016's avatar
xgqdut2016 committed
37
38
39

# Tolerance map for different data types
_TOLERANCE_MAP = {
40
    torch.float16: {"atol": 1e-3, "rtol": 1e-2},
xgqdut2016's avatar
xgqdut2016 committed
41
42
}

43
44
45
46
47
48
49
50

class Inplace(Enum):
    OUT_OF_PLACE = auto()
    INPLACE_X = auto()


_INPLACE = [
    Inplace.INPLACE_X,
51
    Inplace.OUT_OF_PLACE,
52
53
54
55
56
57
58
59
]

_TEST_CASES = [
    test_case + (inplace_item,)
    for test_case in _TEST_CASES_
    for inplace_item in _INPLACE
]

xgqdut2016's avatar
xgqdut2016 committed
60
61
62
63
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
PanZezhongQY's avatar
PanZezhongQY committed
64

xgqdut2016's avatar
xgqdut2016 committed
65

PanZezhongQY's avatar
PanZezhongQY committed
66
67
68
69
70
71
72
73
74
75
class CausalSoftmaxDescriptor(Structure):
    _fields_ = [("device", c_int32)]


infiniopCausalSoftmaxDescriptor_t = POINTER(CausalSoftmaxDescriptor)


def causal_softmax(x):
    type = x.dtype
    mask = torch.tril(torch.ones_like(x), diagonal=-1).flip(dims=[-2, -1])
76
77
    masked = torch.where(mask == 1, -torch.inf, x.to(torch.float32))
    return torch.nn.functional.softmax(masked, dim=-1, dtype=type)
PanZezhongQY's avatar
PanZezhongQY committed
78
79


80
81
82
83
84
85
86
87
88
def test(
    lib,
    handle,
    torch_device,
    shape,
    x_stride=None,
    y_stride=None,
    inplace=Inplace.OUT_OF_PLACE,
    dtype=torch.float16,
89
    sync=None
90
):
PanZezhongQY's avatar
PanZezhongQY committed
91
    print(
92
        f"Testing CausalSoftmax on {torch_device} with shape:{shape} x_stride:{x_stride} y_stride:{y_stride} dtype:{dtype} inplace:{inplace}"
PanZezhongQY's avatar
PanZezhongQY committed
93
    )
xgqdut2016's avatar
xgqdut2016 committed
94

95
    x = torch.rand(shape, dtype=dtype).to(torch_device)
xgqdut2016's avatar
xgqdut2016 committed
96

PanZezhongQY's avatar
PanZezhongQY committed
97
    ans = causal_softmax(x)
xgqdut2016's avatar
xgqdut2016 committed
98
99

    x = rearrange_if_needed(x, x_stride)
xgqdut2016's avatar
xgqdut2016 committed
100

PanZezhongQY's avatar
PanZezhongQY committed
101
    x_tensor = to_tensor(x, lib)
xgqdut2016's avatar
xgqdut2016 committed
102

103
104
105
106
107
108
109
    if inplace == Inplace.INPLACE_X:
        y = x
        y_tensor = x_tensor
    else:
        y = torch.zeros(shape, dtype=dtype).to(torch_device)
        y = rearrange_if_needed(y, y_stride)
        y_tensor = to_tensor(y, lib)
110
111
112
        
    if sync is not None:
        sync()
113

PanZezhongQY's avatar
PanZezhongQY committed
114
115
116
    descriptor = infiniopCausalSoftmaxDescriptor_t()
    check_error(
        lib.infiniopCreateCausalSoftmaxDescriptor(
117
            handle, ctypes.byref(descriptor), y_tensor.descriptor, x_tensor.descriptor
PanZezhongQY's avatar
PanZezhongQY committed
118
119
        )
    )
xgqdut2016's avatar
xgqdut2016 committed
120
121

    # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
122
    x_tensor.destroyDesc(lib)
xgqdut2016's avatar
xgqdut2016 committed
123

PanZezhongQY's avatar
PanZezhongQY committed
124
125
126
127
128
129
130
    workspace_size = c_uint64(0)
    check_error(
        lib.infiniopGetCausalSoftmaxWorkspaceSize(
            descriptor, ctypes.byref(workspace_size)
        )
    )
    workspace = create_workspace(workspace_size.value, x.device)
xgqdut2016's avatar
xgqdut2016 committed
131

xgqdut2016's avatar
xgqdut2016 committed
132
133
134
135
136
137
    def lib_causal_softmax():
        check_error(
            lib.infiniopCausalSoftmax(
                descriptor,
                workspace.data_ptr() if workspace is not None else None,
                workspace_size.value,
138
                y_tensor.data,
xgqdut2016's avatar
xgqdut2016 committed
139
140
141
                x_tensor.data,
                None,
            )
PanZezhongQY's avatar
PanZezhongQY committed
142
        )
xgqdut2016's avatar
xgqdut2016 committed
143

xgqdut2016's avatar
xgqdut2016 committed
144
    lib_causal_softmax()
145
146
147
    
    if sync is not None:
        sync() 
xgqdut2016's avatar
xgqdut2016 committed
148

xgqdut2016's avatar
xgqdut2016 committed
149
150
    atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
    if DEBUG:
151
152
        debug(y, ans, atol=atol, rtol=rtol)
    assert torch.allclose(y, ans, atol=atol, rtol=rtol)
xgqdut2016's avatar
xgqdut2016 committed
153
154
155
156
157
158
159

    # Profiling workflow
    if PROFILE:
        # fmt: off
        profile_operation("PyTorch", lambda: causal_softmax(x), torch_device, NUM_PRERUN, NUM_ITERATIONS)
        profile_operation("    lib", lambda: lib_causal_softmax(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
        # fmt: on
PanZezhongQY's avatar
PanZezhongQY committed
160

xgqdut2016's avatar
xgqdut2016 committed
161
    check_error(lib.infiniopDestroyCausalSoftmaxDescriptor(descriptor))
162

PanZezhongQY's avatar
PanZezhongQY committed
163
164
165
166

if __name__ == "__main__":
    args = get_args()
    lib = open_lib()
xgqdut2016's avatar
xgqdut2016 committed
167

PanZezhongQY's avatar
PanZezhongQY committed
168
169
170
171
172
173
    lib.infiniopCreateCausalSoftmaxDescriptor.restype = c_int32
    lib.infiniopCreateCausalSoftmaxDescriptor.argtypes = [
        infiniopHandle_t,
        POINTER(infiniopCausalSoftmaxDescriptor_t),
        infiniopTensorDescriptor_t,
    ]
xgqdut2016's avatar
xgqdut2016 committed
174

PanZezhongQY's avatar
PanZezhongQY committed
175
176
177
178
179
    lib.infiniopGetCausalSoftmaxWorkspaceSize.restype = c_int32
    lib.infiniopGetCausalSoftmaxWorkspaceSize.argtypes = [
        infiniopCausalSoftmaxDescriptor_t,
        POINTER(c_uint64),
    ]
xgqdut2016's avatar
xgqdut2016 committed
180

PanZezhongQY's avatar
PanZezhongQY committed
181
182
183
184
185
186
187
188
    lib.infiniopCausalSoftmax.restype = c_int32
    lib.infiniopCausalSoftmax.argtypes = [
        infiniopCausalSoftmaxDescriptor_t,
        c_void_p,
        c_uint64,
        c_void_p,
        c_void_p,
    ]
xgqdut2016's avatar
xgqdut2016 committed
189

PanZezhongQY's avatar
PanZezhongQY committed
190
191
192
193
    lib.infiniopDestroyCausalSoftmaxDescriptor.restype = c_int32
    lib.infiniopDestroyCausalSoftmaxDescriptor.argtypes = [
        infiniopCausalSoftmaxDescriptor_t,
    ]
xgqdut2016's avatar
xgqdut2016 committed
194

xgqdut2016's avatar
xgqdut2016 committed
195
196
197
198
199
    # Configure testing options
    DEBUG = args.debug
    PROFILE = args.profile
    NUM_PRERUN = args.num_prerun
    NUM_ITERATIONS = args.num_iterations
xgqdut2016's avatar
xgqdut2016 committed
200

xgqdut2016's avatar
xgqdut2016 committed
201
202
    for device in get_test_devices(args):
        test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
PanZezhongQY's avatar
PanZezhongQY committed
203
204

    print("\033[92mTest passed!\033[0m")