tensor_checks.md 14.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
# Tensor Checks (Host-Side Auto-Validation)

This page explains the host-side checks that TileLang automatically inserts into the generated host stub for kernels. When you pass `torch.Tensor` or any DLPack-compatible object to a TileLang kernel, the host stub validates argument count, pointer kinds, dtype, shape, strides, device, and more — so you don’t need to handwrite Python checks. This keeps the ABI stable and significantly reduces Python overhead compared to doing equivalent checks in Python or via pybind.

## Why Host-Side Checks
- ABI stability: the entry is based on TVM FFI + DLPack, consistently accepting tensors and scalars.
- Lower overhead: shifting checks from Python into C reduces interpreter/property-access costs; the call overhead is lower than pybind-based approaches.
- Focused error reporting: assertions are raised close to the call site with precise “which field failed” messages.

## How To Inspect Host Source
You can inspect the auto-generated host source (with all checks and the final device-kernel call) for debugging:

```python
print(matmul_relu_kernel.get_host_source())
```

---

## What The Host Checks

### 1) Argument count and pointer kind
- `num_args` must match the number of formal parameters; otherwise the kernel returns `-1` with an error message.
- Each argument’s FFI type must be a pointer kind (for DLTensor/handle) or a valid scalar type; otherwise you’ll see errors like `Expect arg[i] to be pointer` or a scalar type error.

### 2) Tensor checks (per tensor, after nullability decision)
- Nullability
  - If the tensor is “statically reachable/used” by the function body, the handle must be non-NULL; otherwise: `xxx is expected to have non-NULL pointer`.
  - If an input tensor is not used by the function (statically unreachable), NULL is allowed; other field checks are executed only when `handle != NULL`.
- Rank (`ndim`)
  - Runtime `ndim` must equal the compile-time rank.
- Data type (`dtype`)
  - Match the triple `(code, bits, lanes)` with tolerance:
    - `float8_e4m3`: accept `e4m3`, `e4m3fn`, `e4m3fnuz`.
    - `float8_e5m2`: accept `e5m2`, `e5m2fnuz`.
    - `bool`: accept `int8/uint8` with `bits=8` (same lanes), `kDLBool(code=6, bits=1 or 8)`, and any `bitwidth=1` (lanes must match).
  - For packed-bit dtypes (e.g., `Int(1)`, `Int(4)`, `UInt(4)`), strict dtype checking is skipped.
- Shape
  - Each runtime dimension is bound to the compile-time shape (constants or symbols) and checked for consistency.
  - Linear equations among symbolic dims can be solved on the fly (when there’s only one unknown at a given check point), enabling cross-tensor constraints.
- Strides
  - If `buffer_type = AutoBroadcast`: allow `strides == NULL` and derive strides from `shape`. If explicit `strides` is present, bind to compile-time constraints and check for equality.
  - Otherwise: check per-dimension; if `strides == NULL`, derive from `shape` and compare (e.g., contiguous: `strides[-1] == 1`, `strides[-2] == shape[-1]`).
- `byte_offset`
  - Must be 0 (non-zero raises an error) to keep addressing simple and aligned.
- Device info
  - Assert `device_type == target backend` (CUDA/ROCM/Metal/OneAPI/WebGPU/CPU, etc.). Error messages include a DLPack code legend.
  - When multiple tensors participate, assert that `device_id` matches across them.
- Data pointer
  - Must be non-NULL when the tensor is required to be non-null by the nullability rule.

### 3) Scalar checks
- `T.int*` family: require integer; error: `Expect arg[i] to be int`.
- `T.bool`: require boolean; error: `Expect arg[i] to be boolean`.

---

## Shapes and Symbolic Equations: Linear Solving
When shapes are symbolic, the host binds and (when possible) solves linear relations at runtime (only one unknown per check point). Example:

```python
@T.prim_func
def main(
    A: T.Tensor((m,), dtype),
    B: T.Tensor((m + n,), dtype),
    C: T.Tensor((n * k,), dtype),
):
    ...
```

This enables enforcing cross-tensor relationships like `len(B) == m + n` and `len(C) == n * k` at runtime.

---

## Nullability Rules and Examples
Which tensors may be NULL?

- Rule: If an input tensor is not used by the function under static analysis (i.e., the access is statically unreachable), it is considered Nullable; otherwise it must be non-NULL.
- Examples:

1) Must be non-NULL (used)
```python
@T.prim_func
def main(A: T.Tensor((M, K), dtype)):
    A[0] = 1
```
Passing `None` raises: `main.A_handle is expected to have non-NULL pointer`.

2) Still must be non-NULL (constant-true branch)
```python
some_cond: bool = True
@T.prim_func
def main(A: T.Tensor((M, K), dtype)):
    if some_cond:
        A[0] = 1
```

3) Nullable (constant-false branch, statically unreachable)
```python
some_cond: bool = False
@T.prim_func
def main(A: T.Tensor((M, K), dtype)):
    if some_cond:
        A[0] = 1
```

4) Must be non-NULL (runtime condition)
```python
@T.prim_func
def main(A: T.Tensor((M, K), dtype), some_cond: T.bool):
    if some_cond:
        A[0] = 1
```
Since `some_cond` is only known at runtime, static analysis cannot prove `A` is unused; `A` is thus non-nullable.

---

## Device Type Codes (DLPack)
Supported and referenced device codes in error messages: `1=CPU, 2=CUDA, 7=Vulkan, 8=Metal, 10=ROCM, 14=OneAPI, 15=WebGPU`.
Kernels assert that `device_type` matches the target backend, and require `device_id` consistency across tensors.

---

## Common Error Examples (What you’ll see)
- Argument count mismatch (num_args)
  - Trigger: missing/extra argument
  - Error: `<kernel>: num_args should be N; expected: <num_args>, got: N`

- Pointer-typed argument expected
  - Trigger: scalar passed where a tensor is expected
  - Error: `<kernel>: Expect arg[i] to be pointer`

- Rank (ndim) mismatch
  - Trigger: runtime rank differs from compile-time rank
  - Error: `<kernel>.<name>.ndim is expected to equal R, but got mismatched ndim`

- Dtype mismatch
  - Trigger: dtype not equal to the compiled dtype and not within the tolerance set
  - Error: `<kernel>.<name>.dtype is expected to be <dtype>, but got incompatible dtype`

- Shape constraint violation
  - Trigger: a dimension doesn’t match a constant/symbol binding
  - Error: `Argument <kernel>.<name>.shape[i] has an unsatisfied constraint: ... == <expected>`

- Strides check failed (e.g., non-contiguous layout)
  - Trigger: transposed/sliced tensors that violate expected strides
  - Error: `Argument <kernel>.<name>.strides[j] has an unsatisfied constraint: ... == <expected>`

- Device type mismatch
  - Trigger: calling a CUDA kernel with CPU tensors, etc.
  - Error: `<kernel>.<name>.device_type mismatch [expected: <code> (<name>)] ...`

- Device id mismatch
  - Trigger: mixing tensors from different GPUs
  - Error: `Argument <kernel>.<name>.device_id has an unsatisfied constraint: ... == ...`

- NULL data pointer
  - Trigger: tensor required to be non-null has a NULL data pointer
  - Error: `<kernel>.<name> is expected to have non-NULL data pointer, but got NULL`

- Scalar type mismatch
  - Trigger: passing float to `T.int32`, or non-boolean to `T.bool`
  - Error: `<kernel>: Expect arg[i] to be int/boolean`

---

## Troubleshooting Tips
- Print the host source: `print(fn.get_host_source())` to see the exact assertion and expected vs. actual fields.
- Fix strides: call `.contiguous()` for non-contiguous tensors, or avoid generating transposed/sliced layouts that break assumptions.
- Align devices: ensure all participating tensors share the same `device_type` and `device_id`.
- Align dtype: use `.to(<dtype>)` or construct tensors with the correct dtype; pay attention to `float8` and `bool` tolerance.
- Dynamic shapes: ensure cross-tensor linear relations can be uniquely determined at the check point (only one unknown at a time).

---

## FAQ
- Can I disable the checks?
  - Not recommended and usually not supported. Checks are done on the host to preserve ABI stability and fail early close to the device call.
- Is the overhead noticeable?
  - The checks are lightweight (branches and field reads). Compared to Python-side checks, it’s faster; the dominating cost remains the Python→C boundary. Overall it’s cheaper than equivalent checks in Python.

---

## Reference Example (Matmul + ReLU)

```python
@T.prim_func
def matmul_relu_kernel(
    A: T.Tensor((M, K), dtype),
    B: T.Tensor((K, N), dtype),
    C: T.Tensor((M, N), dtype),
):
    # Initialize Kernel Context
    with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
        A_shared = T.alloc_shared((block_M, block_K), dtype)
        B_shared = T.alloc_shared((block_K, block_N), dtype)
        C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
        T.clear(C_local)
        for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
            T.copy(A[by * block_M, ko * block_K], A_shared)
            T.copy(B[ko * block_K, bx * block_N], B_shared)
            T.gemm(A_shared, B_shared, C_local)
        T.copy(C_local, C[by * block_M, bx * block_N])

# For debugging, print the host source
print(matmul_relu_kernel.get_host_source())
```

The host will insert all checks described above for this example.

---

## Quick Error Reference (Short List)
- Argument count
  - Trigger: missing/extra args; Error: `num_args should be N; expected: <num_args>, got: N`.
- Pointer kind
  - Trigger: scalar passed to tensor arg; Error: `Expect arg[i] to be pointer`.
- Rank (ndim)
  - Trigger: runtime rank != compile-time; Error: `ndim ... expected to equal R`.
- Dtype
  - Trigger: mismatch and not tolerated; Error: `dtype ... expected to be <dtype>`.
- Shape
  - Trigger: constant/symbol binding violated; Error: `shape[i] ... == <expected>`.
- Strides
  - Trigger: layout mismatch; Error: `strides[j] ... == <expected>`.
- Device type
  - Trigger: wrong backend device; Error: `device_type mismatch [expected: ...]`.
- Device id
  - Trigger: tensors on different GPUs; Error: `device_id ... == ...`.
- Data pointer
  - Trigger: required non-NULL but NULL; Error: `non-NULL data pointer`.
- Scalar types
  - Trigger: wrong scalar type; Error: `Expect arg[i] to be int/boolean`.

---

## Host Error Troubleshooting (Minimal Repros)

Below are minimal repro snippets for common host-side errors, assuming a CUDA-targeted kernel like `matmul_relu_kernel` with:

```python
# Convention:
# A: float16 [M, K]
# B: float16 [K, N]
# C: float16 [M, N]
# Target: CUDA (device_type=2)
fn = matmul_relu_kernel  # your compiled function
M = N = K = 1024
```

Adjust dtype/device if your kernel differs.

### 0. Tip: print the host source
```python
print(fn.get_host_source())
```

### 1. num_args mismatch
```python
import torch

A = torch.empty((M, K), device='cuda', dtype=torch.float16)
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
# Missing C
fn(A, B)
```
Expected: `<kernel>: num_args should be 3; expected: <num_args>, got: 3`.

Fix: pass all arguments per the signature.

### 2. Expect pointer (tensor) but got scalar
```python
import torch

B = torch.empty((K, N), device='cuda', dtype=torch.float16)
C = torch.empty((M, N), device='cuda', dtype=torch.float16)
fn(1, B, C)
```
Expected: `<kernel>: Expect arg[0] to be pointer`.

Fix: pass a DLPack-compatible tensor (e.g., torch.Tensor).

### 3. ndim mismatch
```python
import torch

A = torch.empty((M, K, 1), device='cuda', dtype=torch.float16)  # rank=3
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
C = torch.empty((M, N), device='cuda', dtype=torch.float16)
fn(A, B, C)
```
Expected: `<kernel>.A_handle.ndim is expected to equal 2, but got mismatched ndim`.

Fix: ensure runtime rank equals compiled rank.

### 4. dtype mismatch
```python
import torch

A = torch.empty((M, K), device='cuda', dtype=torch.float32)  # should be float16
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
C = torch.empty((M, N), device='cuda', dtype=torch.float16)
fn(A, B, C)
```
Expected: `<kernel>.A_handle.dtype is expected to be float16, but got incompatible dtype`.

Fix: `A = A.to(torch.float16)` or create with the correct dtype.

### 5. Shape constant/symbol mismatch
```python
import torch

A = torch.empty((M, K + 1), device='cuda', dtype=torch.float16)  # K mismatched
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
C = torch.empty((M, N), device='cuda', dtype=torch.float16)
fn(A, B, C)
```
Expected: `Argument <kernel>.A_handle.shape[i] has an unsatisfied constraint: ... == <expected>`.

Fix: satisfy linear constraints and constants across tensors.

### 6. Strides check failure (non-contiguous)
```python
import torch

A = torch.empty((M, K), device='cuda', dtype=torch.float16)
A_nc = A.t()  # transpose -> non-contiguous
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
C = torch.empty((M, N), device='cuda', dtype=torch.float16)
fn(A_nc, B, C)
```
Expected: `Argument <kernel>.A_handle.strides[1] has an unsatisfied constraint: ... == 1`.

Fix: pass `A_nc.contiguous()` or align the layout expectation in the kernel.

### 7. device_type mismatch
```python
import torch

A = torch.empty((M, K), device='cpu', dtype=torch.float16)
B = torch.empty((K, N), device='cpu', dtype=torch.float16)
C = torch.empty((M, N), device='cpu', dtype=torch.float16)
fn(A, B, C)  # CUDA-targeted kernel
```
Expected: `<kernel>.A_handle.device_type mismatch [expected: 2 (cuda)] ...`.

Fix: move tensors to the CUDA device.

### 8. device_id mismatch (multi-GPU)
```python
import torch

A = torch.empty((M, K), device='cuda:0', dtype=torch.float16)
B = torch.empty((K, N), device='cuda:1', dtype=torch.float16)
C = torch.empty((M, N), device='cuda:0', dtype=torch.float16)
fn(A, B, C)
```
Expected: `Argument <kernel>.B_handle.device_id has an unsatisfied constraint: ... == ...`.

Fix: place all tensors on the same GPU (e.g., `cuda:0`).

### 9. NULL data pointer (advanced)
This usually comes from hand-constructed DLTensor/NDArray, or external frameworks passing unallocated/freed storage. Regular `torch.Tensor` allocations rarely hit this.

Expected: `<kernel>.<name> is expected to have non-NULL data pointer, but got NULL`.

Fix: ensure valid underlying storage; in PyTorch scenarios, avoid constructing tensors from invalid external handles.

### 10. Scalar type mismatch (int / bool)
```python
import tilelang.language as T

@T.prim_func
def scalar_check(x: T.int32, flag: T.bool()):
    T.evaluate(0)

scalar_check(1.0, True)  # x is float -> Expect arg[0] to be int
scalar_check(1, 2.5)     # flag is float -> Expect arg[1] to be boolean
```

Fix: pass correct scalar types, e.g., `scalar_check(1, True)`.

---

## Closing Notes
- Cross-check “shape / strides / device / dtype” against the kernel signature to localize issues efficiently.
- For complex symbolic relations, print the host source to confirm binding/solving order, then adjust runtime shapes/layouts accordingly.