"src/git@developer.sourcefind.cn:yangql/composable_kernel.git" did not exist on "8732ea04fba4672a2ab4098289935f0a1bec7fc7"
Unverified Commit 1e92d11c authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Improve assertion handling in CodeGenCHost and ArgBinder (#1352)

* [Refactor] Improve assertion handling in CodeGenCHost and ArgBinder

This commit refines the assertion message generation in CodeGenCHost by optimizing the handling of equality checks and reducing buffer size for error messages. Additionally, it enhances the ArgBinder by introducing a nullable guard mechanism for assertions, allowing for more precise error handling when binding arguments. The changes improve the clarity and efficiency of assertion handling across the codebase.

* [Enhancement] Update matmul kernel and optimize argument binding

This commit enhances the matmul kernel by introducing additional tensor parameters and refining the pipeline stages for improved performance. It also updates the argument binding mechanism to include a flag indicating whether buffers are used, enhancing the efficiency of buffer management. Furthermore, the optimization phase in the engine is improved by adding a simplification step, ensuring better performance and clarity in the generated code.

* lint fix

* [Enhancement] Add tensor checks documentation and improve argument binding assertions

This commit introduces a new documentation page for host-side tensor checks, detailing the automatic validations performed by TileLang on kernel arguments. It enhances the ArgBinder by adding assertions for non-null pointers when arguments are used, improving error handling. Additionally, the optimization phase in the engine is updated to include a simplification step, ensuring better performance and clarity in the generated code.

* [Enhancement] Update .gitignore and refine matmul kernel for improved performance

This commit adds host checks logs to the .gitignore file to prevent unnecessary log files from being tracked. Additionally, it refines the matmul kernel by adjusting pipeline stages, updating tensor parameters, and enhancing argument handling for better performance. The changes also include improved error messages in the argument binding process, ensuring clearer diagnostics for users.

* lint fix

* lint fix

* [Refactor] Simplify tensor_null_test function and remove ptr_null_test

This commit refactors the tensor_null_test function by adding a with_bias parameter and removing the ptr_null_test function, which was previously unused. The run_test function is updated to reflect these changes, streamlining the testing process for tensor operations.

* lint fix

* fix
parent b8240b7a
......@@ -108,3 +108,6 @@ cmake-build-*/
# pre-commit cache
.pre-commit-cache/*
# host checks logs
maint/host_checks/logs/*
# Tensor Checks (Host-Side Auto-Validation)
This page explains the host-side checks that TileLang automatically inserts into the generated host stub for kernels. When you pass `torch.Tensor` or any DLPack-compatible object to a TileLang kernel, the host stub validates argument count, pointer kinds, dtype, shape, strides, device, and more — so you don’t need to handwrite Python checks. This keeps the ABI stable and significantly reduces Python overhead compared to doing equivalent checks in Python or via pybind.
## Why Host-Side Checks
- ABI stability: the entry is based on TVM FFI + DLPack, consistently accepting tensors and scalars.
- Lower overhead: shifting checks from Python into C reduces interpreter/property-access costs; the call overhead is lower than pybind-based approaches.
- Focused error reporting: assertions are raised close to the call site with precise “which field failed” messages.
## How To Inspect Host Source
You can inspect the auto-generated host source (with all checks and the final device-kernel call) for debugging:
```python
print(matmul_relu_kernel.get_host_source())
```
---
## What The Host Checks
### 1) Argument count and pointer kind
- `num_args` must match the number of formal parameters; otherwise the kernel returns `-1` with an error message.
- Each argument’s FFI type must be a pointer kind (for DLTensor/handle) or a valid scalar type; otherwise you’ll see errors like `Expect arg[i] to be pointer` or a scalar type error.
### 2) Tensor checks (per tensor, after nullability decision)
- Nullability
- If the tensor is “statically reachable/used” by the function body, the handle must be non-NULL; otherwise: `xxx is expected to have non-NULL pointer`.
- If an input tensor is not used by the function (statically unreachable), NULL is allowed; other field checks are executed only when `handle != NULL`.
- Rank (`ndim`)
- Runtime `ndim` must equal the compile-time rank.
- Data type (`dtype`)
- Match the triple `(code, bits, lanes)` with tolerance:
- `float8_e4m3`: accept `e4m3`, `e4m3fn`, `e4m3fnuz`.
- `float8_e5m2`: accept `e5m2`, `e5m2fnuz`.
- `bool`: accept `int8/uint8` with `bits=8` (same lanes), `kDLBool(code=6, bits=1 or 8)`, and any `bitwidth=1` (lanes must match).
- For packed-bit dtypes (e.g., `Int(1)`, `Int(4)`, `UInt(4)`), strict dtype checking is skipped.
- Shape
- Each runtime dimension is bound to the compile-time shape (constants or symbols) and checked for consistency.
- Linear equations among symbolic dims can be solved on the fly (when there’s only one unknown at a given check point), enabling cross-tensor constraints.
- Strides
- If `buffer_type = AutoBroadcast`: allow `strides == NULL` and derive strides from `shape`. If explicit `strides` is present, bind to compile-time constraints and check for equality.
- Otherwise: check per-dimension; if `strides == NULL`, derive from `shape` and compare (e.g., contiguous: `strides[-1] == 1`, `strides[-2] == shape[-1]`).
- `byte_offset`
- Must be 0 (non-zero raises an error) to keep addressing simple and aligned.
- Device info
- Assert `device_type == target backend` (CUDA/ROCM/Metal/OneAPI/WebGPU/CPU, etc.). Error messages include a DLPack code legend.
- When multiple tensors participate, assert that `device_id` matches across them.
- Data pointer
- Must be non-NULL when the tensor is required to be non-null by the nullability rule.
### 3) Scalar checks
- `T.int*` family: require integer; error: `Expect arg[i] to be int`.
- `T.bool`: require boolean; error: `Expect arg[i] to be boolean`.
---
## Shapes and Symbolic Equations: Linear Solving
When shapes are symbolic, the host binds and (when possible) solves linear relations at runtime (only one unknown per check point). Example:
```python
@T.prim_func
def main(
A: T.Tensor((m,), dtype),
B: T.Tensor((m + n,), dtype),
C: T.Tensor((n * k,), dtype),
):
...
```
This enables enforcing cross-tensor relationships like `len(B) == m + n` and `len(C) == n * k` at runtime.
---
## Nullability Rules and Examples
Which tensors may be NULL?
- Rule: If an input tensor is not used by the function under static analysis (i.e., the access is statically unreachable), it is considered Nullable; otherwise it must be non-NULL.
- Examples:
1) Must be non-NULL (used)
```python
@T.prim_func
def main(A: T.Tensor((M, K), dtype)):
A[0] = 1
```
Passing `None` raises: `main.A_handle is expected to have non-NULL pointer`.
2) Still must be non-NULL (constant-true branch)
```python
some_cond: bool = True
@T.prim_func
def main(A: T.Tensor((M, K), dtype)):
if some_cond:
A[0] = 1
```
3) Nullable (constant-false branch, statically unreachable)
```python
some_cond: bool = False
@T.prim_func
def main(A: T.Tensor((M, K), dtype)):
if some_cond:
A[0] = 1
```
4) Must be non-NULL (runtime condition)
```python
@T.prim_func
def main(A: T.Tensor((M, K), dtype), some_cond: T.bool):
if some_cond:
A[0] = 1
```
Since `some_cond` is only known at runtime, static analysis cannot prove `A` is unused; `A` is thus non-nullable.
---
## Device Type Codes (DLPack)
Supported and referenced device codes in error messages: `1=CPU, 2=CUDA, 7=Vulkan, 8=Metal, 10=ROCM, 14=OneAPI, 15=WebGPU`.
Kernels assert that `device_type` matches the target backend, and require `device_id` consistency across tensors.
---
## Common Error Examples (What you’ll see)
- Argument count mismatch (num_args)
- Trigger: missing/extra argument
- Error: `<kernel>: num_args should be N; expected: <num_args>, got: N`
- Pointer-typed argument expected
- Trigger: scalar passed where a tensor is expected
- Error: `<kernel>: Expect arg[i] to be pointer`
- Rank (ndim) mismatch
- Trigger: runtime rank differs from compile-time rank
- Error: `<kernel>.<name>.ndim is expected to equal R, but got mismatched ndim`
- Dtype mismatch
- Trigger: dtype not equal to the compiled dtype and not within the tolerance set
- Error: `<kernel>.<name>.dtype is expected to be <dtype>, but got incompatible dtype`
- Shape constraint violation
- Trigger: a dimension doesn’t match a constant/symbol binding
- Error: `Argument <kernel>.<name>.shape[i] has an unsatisfied constraint: ... == <expected>`
- Strides check failed (e.g., non-contiguous layout)
- Trigger: transposed/sliced tensors that violate expected strides
- Error: `Argument <kernel>.<name>.strides[j] has an unsatisfied constraint: ... == <expected>`
- Device type mismatch
- Trigger: calling a CUDA kernel with CPU tensors, etc.
- Error: `<kernel>.<name>.device_type mismatch [expected: <code> (<name>)] ...`
- Device id mismatch
- Trigger: mixing tensors from different GPUs
- Error: `Argument <kernel>.<name>.device_id has an unsatisfied constraint: ... == ...`
- NULL data pointer
- Trigger: tensor required to be non-null has a NULL data pointer
- Error: `<kernel>.<name> is expected to have non-NULL data pointer, but got NULL`
- Scalar type mismatch
- Trigger: passing float to `T.int32`, or non-boolean to `T.bool`
- Error: `<kernel>: Expect arg[i] to be int/boolean`
---
## Troubleshooting Tips
- Print the host source: `print(fn.get_host_source())` to see the exact assertion and expected vs. actual fields.
- Fix strides: call `.contiguous()` for non-contiguous tensors, or avoid generating transposed/sliced layouts that break assumptions.
- Align devices: ensure all participating tensors share the same `device_type` and `device_id`.
- Align dtype: use `.to(<dtype>)` or construct tensors with the correct dtype; pay attention to `float8` and `bool` tolerance.
- Dynamic shapes: ensure cross-tensor linear relations can be uniquely determined at the check point (only one unknown at a time).
---
## FAQ
- Can I disable the checks?
- Not recommended and usually not supported. Checks are done on the host to preserve ABI stability and fail early close to the device call.
- Is the overhead noticeable?
- The checks are lightweight (branches and field reads). Compared to Python-side checks, it’s faster; the dominating cost remains the Python→C boundary. Overall it’s cheaper than equivalent checks in Python.
---
## Reference Example (Matmul + ReLU)
```python
@T.prim_func
def matmul_relu_kernel(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[ko * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])
# For debugging, print the host source
print(matmul_relu_kernel.get_host_source())
```
The host will insert all checks described above for this example.
---
## Quick Error Reference (Short List)
- Argument count
- Trigger: missing/extra args; Error: `num_args should be N; expected: <num_args>, got: N`.
- Pointer kind
- Trigger: scalar passed to tensor arg; Error: `Expect arg[i] to be pointer`.
- Rank (ndim)
- Trigger: runtime rank != compile-time; Error: `ndim ... expected to equal R`.
- Dtype
- Trigger: mismatch and not tolerated; Error: `dtype ... expected to be <dtype>`.
- Shape
- Trigger: constant/symbol binding violated; Error: `shape[i] ... == <expected>`.
- Strides
- Trigger: layout mismatch; Error: `strides[j] ... == <expected>`.
- Device type
- Trigger: wrong backend device; Error: `device_type mismatch [expected: ...]`.
- Device id
- Trigger: tensors on different GPUs; Error: `device_id ... == ...`.
- Data pointer
- Trigger: required non-NULL but NULL; Error: `non-NULL data pointer`.
- Scalar types
- Trigger: wrong scalar type; Error: `Expect arg[i] to be int/boolean`.
---
## Host Error Troubleshooting (Minimal Repros)
Below are minimal repro snippets for common host-side errors, assuming a CUDA-targeted kernel like `matmul_relu_kernel` with:
```python
# Convention:
# A: float16 [M, K]
# B: float16 [K, N]
# C: float16 [M, N]
# Target: CUDA (device_type=2)
fn = matmul_relu_kernel # your compiled function
M = N = K = 1024
```
Adjust dtype/device if your kernel differs.
### 0. Tip: print the host source
```python
print(fn.get_host_source())
```
### 1. num_args mismatch
```python
import torch
A = torch.empty((M, K), device='cuda', dtype=torch.float16)
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
# Missing C
fn(A, B)
```
Expected: `<kernel>: num_args should be 3; expected: <num_args>, got: 3`.
Fix: pass all arguments per the signature.
### 2. Expect pointer (tensor) but got scalar
```python
import torch
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
C = torch.empty((M, N), device='cuda', dtype=torch.float16)
fn(1, B, C)
```
Expected: `<kernel>: Expect arg[0] to be pointer`.
Fix: pass a DLPack-compatible tensor (e.g., torch.Tensor).
### 3. ndim mismatch
```python
import torch
A = torch.empty((M, K, 1), device='cuda', dtype=torch.float16) # rank=3
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
C = torch.empty((M, N), device='cuda', dtype=torch.float16)
fn(A, B, C)
```
Expected: `<kernel>.A_handle.ndim is expected to equal 2, but got mismatched ndim`.
Fix: ensure runtime rank equals compiled rank.
### 4. dtype mismatch
```python
import torch
A = torch.empty((M, K), device='cuda', dtype=torch.float32) # should be float16
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
C = torch.empty((M, N), device='cuda', dtype=torch.float16)
fn(A, B, C)
```
Expected: `<kernel>.A_handle.dtype is expected to be float16, but got incompatible dtype`.
Fix: `A = A.to(torch.float16)` or create with the correct dtype.
### 5. Shape constant/symbol mismatch
```python
import torch
A = torch.empty((M, K + 1), device='cuda', dtype=torch.float16) # K mismatched
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
C = torch.empty((M, N), device='cuda', dtype=torch.float16)
fn(A, B, C)
```
Expected: `Argument <kernel>.A_handle.shape[i] has an unsatisfied constraint: ... == <expected>`.
Fix: satisfy linear constraints and constants across tensors.
### 6. Strides check failure (non-contiguous)
```python
import torch
A = torch.empty((M, K), device='cuda', dtype=torch.float16)
A_nc = A.t() # transpose -> non-contiguous
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
C = torch.empty((M, N), device='cuda', dtype=torch.float16)
fn(A_nc, B, C)
```
Expected: `Argument <kernel>.A_handle.strides[1] has an unsatisfied constraint: ... == 1`.
Fix: pass `A_nc.contiguous()` or align the layout expectation in the kernel.
### 7. device_type mismatch
```python
import torch
A = torch.empty((M, K), device='cpu', dtype=torch.float16)
B = torch.empty((K, N), device='cpu', dtype=torch.float16)
C = torch.empty((M, N), device='cpu', dtype=torch.float16)
fn(A, B, C) # CUDA-targeted kernel
```
Expected: `<kernel>.A_handle.device_type mismatch [expected: 2 (cuda)] ...`.
Fix: move tensors to the CUDA device.
### 8. device_id mismatch (multi-GPU)
```python
import torch
A = torch.empty((M, K), device='cuda:0', dtype=torch.float16)
B = torch.empty((K, N), device='cuda:1', dtype=torch.float16)
C = torch.empty((M, N), device='cuda:0', dtype=torch.float16)
fn(A, B, C)
```
Expected: `Argument <kernel>.B_handle.device_id has an unsatisfied constraint: ... == ...`.
Fix: place all tensors on the same GPU (e.g., `cuda:0`).
### 9. NULL data pointer (advanced)
This usually comes from hand-constructed DLTensor/NDArray, or external frameworks passing unallocated/freed storage. Regular `torch.Tensor` allocations rarely hit this.
Expected: `<kernel>.<name> is expected to have non-NULL data pointer, but got NULL`.
Fix: ensure valid underlying storage; in PyTorch scenarios, avoid constructing tensors from invalid external handles.
### 10. Scalar type mismatch (int / bool)
```python
import tilelang.language as T
@T.prim_func
def scalar_check(x: T.int32, flag: T.bool()):
T.evaluate(0)
scalar_check(1.0, True) # x is float -> Expect arg[0] to be int
scalar_check(1, 2.5) # flag is float -> Expect arg[1] to be boolean
```
Fix: pass correct scalar types, e.g., `scalar_check(1, True)`.
---
## Closing Notes
- Cross-check “shape / strides / device / dtype” against the kernel signature to localize issues efficiently.
- For complex symbolic relations, print the host source to confirm binding/solving order, then adjust runtime shapes/layouts accordingly.
......@@ -42,6 +42,7 @@ deeplearning_operators/deepseek_mla
compiler_internals/letstmt_inline
compiler_internals/inject_fence_proxy
compiler_internals/tensor_checks
:::
:::{toctree}
......
......@@ -77,7 +77,7 @@ torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# 4. Retrieve and inspect the generated CUDA source (optional)
# cuda_source = jit_kernel.get_kernel_source()
# cuda_source = matmul_relu_kernel.get_kernel_source()
# print("Generated CUDA kernel:\n", cuda_source)
# 5.Profile latency with kernel
......
"""Reproduce: Argument count mismatch.
Note: The adapter-level wrapper expects only inputs (A, B) because C is marked as output.
Calling with the wrong number of inputs raises a ValueError before host entry.
"""
import torch
from common import build_matmul_kernel
def main():
M = N = K = 256
fn = build_matmul_kernel(M, N, K, target="cuda")
a = torch.empty((M, K), device="cuda", dtype=torch.float16)
# Missing b
# Expected: ValueError with message about expected vs. actual inputs
fn(a)
if __name__ == "__main__":
main()
"""Reproduce: Pointer-type argument expected but scalar provided.
We pass an integer for A; wrapper forwards it to the host where a pointer is expected.
Expected: error like "Expect buffer A_handle to be pointer or tensor" (exact name depends on kernel param).
"""
import torch
from common import build_matmul_kernel
def main():
M = N = K = 256
fn = build_matmul_kernel(M, N, K, target="cuda")
# Wrong type for A (int instead of tensor)
a = 1
b = torch.empty((K, N), device="cuda", dtype=torch.float16)
fn(a, b)
if __name__ == "__main__":
main()
"""Reproduce: ndim (rank) mismatch for A.
"""
import torch
from common import build_matmul_kernel
def main():
M = N = K = 128
fn = build_matmul_kernel(M, N, K, target="cuda")
# A has rank 3 instead of 2
a = torch.empty((M, K, 1), device="cuda", dtype=torch.float16)
b = torch.empty((K, N), device="cuda", dtype=torch.float16)
fn(a, b)
if __name__ == "__main__":
main()
"""Reproduce: dtype mismatch for A (float32 vs expected float16).
"""
import torch
from common import build_matmul_kernel
def main():
M = N = K = 128
fn = build_matmul_kernel(M, N, K, target="cuda")
print(fn.get_host_source())
a = torch.empty((M, K), device="cuda", dtype=torch.float32) # should be float16
b = torch.empty((K, N), device="cuda", dtype=torch.float16)
fn(a, b)
if __name__ == "__main__":
main()
"""Reproduce: shape constant/symbol mismatch on A.
"""
import torch
from common import build_matmul_kernel
def main():
M = N = K = 128
fn = build_matmul_kernel(M, N, K, target="cuda")
# A's second dimension is wrong (K+1 instead of K)
a = torch.empty((M, K + 1), device="cuda", dtype=torch.float16)
b = torch.empty((K, N), device="cuda", dtype=torch.float16)
fn(a, b)
if __name__ == "__main__":
main()
"""Reproduce: strides check failure (non-contiguous A via transpose).
"""
import torch
from common import build_matmul_kernel
def main():
M = N = K = 128
fn = build_matmul_kernel(M, N, K, target="cuda")
a = torch.empty((M, K), device="cuda", dtype=torch.float16)
a_nc = a.t() # non-contiguous after transpose
b = torch.empty((K, N), device="cuda", dtype=torch.float16)
fn(a_nc, b)
if __name__ == "__main__":
main()
"""Reproduce: device_type mismatch by passing CPU tensors to a CUDA kernel.
"""
import torch
from common import build_matmul_kernel
def main():
M = N = K = 64
fn = build_matmul_kernel(M, N, K, target="cuda")
a = torch.empty((M, K), device="cpu", dtype=torch.float16)
b = torch.empty((K, N), device="cpu", dtype=torch.float16)
fn(a, b)
if __name__ == "__main__":
main()
"""Reproduce: device_id mismatch (requires >=2 CUDA devices).
"""
import torch
from common import build_matmul_kernel
def main():
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
if torch.cuda.device_count() < 2:
print("[SKIP] Need at least 2 CUDA devices to reproduce device_id mismatch.")
return
M = N = K = 64
fn = build_matmul_kernel(M, N, K, target="cuda")
a = torch.empty((M, K), device="cuda:0", dtype=torch.float16)
b = torch.empty((K, N), device="cuda:1", dtype=torch.float16)
# Output device is derived by the adapter; mismatch occurs in host checks
fn(a, b)
if __name__ == "__main__":
main()
"""Reproduce: NULL data pointer (advanced).
Passing None for a tensor argument will be forwarded through the adapter. Depending on
FFI handling, this commonly triggers a pointer-type assertion (e.g., "Expect buffer <name> to be pointer or tensor")
or a host-side non-NULL pointer check.
Note: Constructing a true DLTensor with NULL data in PyTorch is not typical; this script
demonstrates passing None, which still reproduces the intended class of failure.
"""
import torch
from common import build_matmul_kernel
def main():
M = N = K = 64
fn = build_matmul_kernel(M, N, K, target="cuda")
a = None # attempt to pass a null-like pointer
b = torch.empty((K, N), device="cuda", dtype=torch.float16)
fn(a, b)
if __name__ == "__main__":
main()
"""Reproduce: scalar parameter type mismatch (int/bool).
"""
from common import build_scalar_check_kernel
def main():
fn = build_scalar_check_kernel(target="cuda")
# Wrong types
fn(1.0, True) # x should be int -> Expect arg[0] to be int
fn(1, 2.5) # flag should be bool -> Expect arg[1] to be boolean
if __name__ == "__main__":
main()
# Host-Side Check Repro Scripts
This folder contains standalone scripts that deliberately trigger host-side (and adapter-side) validation errors described in `docs/compiler_internals/tensor_checks.md`. Each script can be run directly and will reproduce the corresponding error with a minimal example.
Prerequisites
- CUDA-capable environment (most scripts compile a CUDA-targeted kernel)
- Python packages: torch, tilelang
Usage
- Run any script, e.g.:
- `python 01_num_args_mismatch.py`
- `python 02_pointer_type_error.py`
- ... up to `10_scalar_type_mismatch.py`
- Or run all at once with a summary:
- `python run_all.py`
- Logs per test are saved under `logs/` as `<script>.out` / `<script>.err`.
Notes
- Scripts assume at least one CUDA device. For the device-id mismatch case (08), two GPUs are required; the script will skip with a note if only one is available.
- The adapter raises some errors before the host stub (e.g., wrong input count). The messages are aligned with the host checks as far as possible.
import tilelang
import tilelang.language as T
import torch
def make_matmul_prim(M,
N,
K,
block_M=128,
block_N=128,
block_K=32,
dtype="float16",
accum_dtype="float"):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[ko * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def build_matmul_kernel(M=1024, N=1024, K=1024, target="cuda"):
"""Compile and return a callable kernel that takes (A, B) and returns C."""
if target.startswith("cuda") and not torch.cuda.is_available():
raise RuntimeError("CUDA is not available; cannot build CUDA kernel for host-check repros.")
prim = make_matmul_prim(M, N, K)
# out_idx=[2] means the 3rd param C is treated as output; wrapper takes (A,B)
return tilelang.compile(prim, out_idx=[2], target=target)
def build_scalar_check_kernel(target="cuda"):
@T.prim_func
def scalar_check(x: T.int32, flag: T.bool()):
T.evaluate(0)
return tilelang.compile(scalar_check, target=target)
import sys
import subprocess
from pathlib import Path
def main():
root = Path(__file__).resolve().parent
scripts = [
"01_num_args_mismatch.py",
"02_pointer_type_error.py",
"03_ndim_mismatch.py",
"04_dtype_mismatch.py",
"05_shape_mismatch.py",
"06_strides_mismatch.py",
"07_device_type_mismatch.py",
"08_device_id_mismatch.py",
"09_null_data_pointer.py",
"10_scalar_type_mismatch.py",
]
logs_dir = root / "logs"
logs_dir.mkdir(exist_ok=True)
results = []
for name in scripts:
script_path = root / name
if not script_path.exists():
results.append((name, "MISSING", 0))
print(f"[MISSING] {name}")
continue
print(f"\n=== Running {name} ===")
proc = subprocess.run(
[sys.executable, str(script_path)],
cwd=str(root),
capture_output=True,
text=True,
)
# Save logs
(logs_dir / f"{name}.out").write_text(proc.stdout)
(logs_dir / f"{name}.err").write_text(proc.stderr)
out = (proc.stdout or "") + (proc.stderr or "")
if "[SKIP]" in out:
status = "SKIP"
elif proc.returncode != 0:
status = "PASS" # error reproduced as expected
else:
status = "FAIL" # no error observed
results.append((name, status, proc.returncode))
print(f"[{status}] {name} (rc={proc.returncode})")
# Summary
print("\n=== Summary ===")
counts = {"PASS": 0, "FAIL": 0, "SKIP": 0, "MISSING": 0}
for name, status, _ in results:
counts[status] = counts.get(status, 0) + 1
print(f"{status:7} {name}")
print("\nTotals:")
for k in ("PASS", "FAIL", "SKIP", "MISSING"):
print(f" {k:7}: {counts.get(k, 0)}")
# Exit non-zero if any FAIL
sys.exit(1 if counts.get("FAIL", 0) else 0)
if __name__ == "__main__":
main()
/*
* Helper functions for nicer runtime error messages.
*/
#include <tvm/ffi/c_api.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/data_type.h>
#include <sstream>
#include <string>
namespace tvm {
namespace tl {
// Return non-zero so that tvm_call_packed sites treat it as failure and return
// -1.
static int DTypeMismatch(const tvm::ffi::String &kernel_name,
const tvm::ffi::String &buffer_name,
int64_t actual_code, int64_t actual_bits,
int64_t actual_lanes, int64_t expect_code,
int64_t expect_bits, int64_t expect_lanes) {
tvm::runtime::DataType actual(static_cast<int>(actual_code),
static_cast<int>(actual_bits),
static_cast<int>(actual_lanes));
tvm::runtime::DataType expect(static_cast<int>(expect_code),
static_cast<int>(expect_bits),
static_cast<int>(expect_lanes));
std::ostringstream os;
os << std::string(kernel_name) << ": dtype of " << std::string(buffer_name)
<< " is expected to be " << expect << ", but got " << actual;
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
return -1;
}
// Variant without names, to avoid passing extra raw strings through packed
// args.
static int DTypeMismatchNoNames(int64_t actual_code, int64_t actual_bits,
int64_t actual_lanes, int64_t expect_code,
int64_t expect_bits, int64_t expect_lanes) {
tvm::runtime::DataType actual(static_cast<int>(actual_code),
static_cast<int>(actual_bits),
static_cast<int>(actual_lanes));
tvm::runtime::DataType expect(static_cast<int>(expect_code),
static_cast<int>(expect_bits),
static_cast<int>(expect_lanes));
std::ostringstream os;
os << "dtype mismatch: expected " << expect << ", but got " << actual;
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
return -1;
}
} // namespace tl
} // namespace tvm
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tilelang_error_dtype_mismatch",
&tvm::tl::DTypeMismatch);
refl::GlobalDef().def("tilelang_error_dtype_mismatch2",
&tvm::tl::DTypeMismatchNoNames);
}
......@@ -348,7 +348,6 @@ void CodeGenCHost::VisitExpr_(const tvm::tir::CallNode *op,
}
void CodeGenCHost::VisitStmt_(const tvm::tir::AssertStmtNode *op) { // NOLINT(*)
using namespace tvm::tir;
if (emit_asserts_) {
std::string cond = PrintExpr(op->condition);
PrintIndent();
......@@ -356,88 +355,28 @@ void CodeGenCHost::VisitStmt_(const tvm::tir::AssertStmtNode *op) { // NOLINT(*)
int assert_if_scope = this->BeginScope();
{
// Prepare the base error message
const auto *msg_node = op->message.as<StringImmNode>();
const auto *msg_node = op->message.as<tvm::tir::StringImmNode>();
ICHECK(msg_node != nullptr) << "Assert message expected to be StringImm";
const std::string &raw_msg = msg_node->value;
const std::string esc_msg = tvm::support::StrEscape(
raw_msg.c_str(), raw_msg.length(), /*use_octal_escape=*/true,
/*escape_whitespace_special_chars=*/true);
// If the assertion condition contains any equality checks anywhere
// in a composite boolean expression, append the actual LHS/RHS values
// Collect all EQ nodes within the condition (including inside And/Or/Not)
std::vector<const EQNode *> eq_nodes;
{
std::vector<PrimExpr> stk;
stk.push_back(op->condition);
while (!stk.empty()) {
PrimExpr cur = stk.back();
stk.pop_back();
if (const auto *eq = cur.as<EQNode>()) {
eq_nodes.push_back(eq);
continue;
}
if (const auto *an = cur.as<AndNode>()) {
stk.push_back(an->a);
stk.push_back(an->b);
continue;
}
if (const auto *on = cur.as<OrNode>()) {
stk.push_back(on->a);
stk.push_back(on->b);
continue;
}
if (const auto *nn = cur.as<NotNode>()) {
stk.push_back(nn->a);
continue;
}
}
}
if (!eq_nodes.empty()) {
// Build a single detailed message that includes all LHS/RHS pairs
// If the assertion is an equality check, append the actual LHS/RHS values
if (const auto *eq = op->condition.as<tvm::tir::EQNode>()) {
std::string lhs = PrintExpr(eq->a);
std::string rhs = PrintExpr(eq->b);
PrintIndent();
stream << "char __tvm_assert_msg_buf[1024];\n";
stream << "char __tvm_assert_msg_buf[512];\n";
PrintIndent();
stream << "int __tvm_assert_msg_len = snprintf(__tvm_assert_msg_buf, "
"sizeof(__tvm_assert_msg_buf), \"%s\", \""
<< esc_msg << "\");\n";
auto escape_for_printf_literal = [&](const std::string &s) {
std::string out;
out.reserve(s.size());
for (char c : s) {
if (c == '%') {
out += "%%";
} else if (c == '"') {
out += "\\\"";
} else if (c == '\\') {
out += "\\\\";
} else {
out.push_back(c);
}
}
return out;
};
for (const auto *eq : eq_nodes) {
std::string lhs = PrintExpr(eq->a);
std::string rhs = PrintExpr(eq->b);
std::string lhs_disp = escape_for_printf_literal(lhs);
std::string rhs_disp = escape_for_printf_literal(rhs);
PrintIndent();
stream << "__tvm_assert_msg_len += snprintf(__tvm_assert_msg_buf + "
"__tvm_assert_msg_len, "
"sizeof(__tvm_assert_msg_buf) - __tvm_assert_msg_len, \"; ("
<< lhs_disp << " == " << rhs_disp
<< ") got: %lld, expected: %lld\", (long long)(" << lhs
<< "), (long long)(" << rhs << "));\n";
}
stream << "snprintf(__tvm_assert_msg_buf, 512, \"%s; expected: %lld, "
"got: %lld\", \""
<< esc_msg << "\", (long long)(" << lhs << "), (long long)("
<< rhs << "));\n";
PrintIndent();
stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", "
"__tvm_assert_msg_buf);\n";
} else {
// Fallback: just emit the base message
PrintIndent();
stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", \"" << esc_msg
<< "\");\n";
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file arg_binder.cc
* \brief Helper utility to match and bind arguments.
......@@ -24,6 +5,7 @@
#include "arg_binder.h"
#include <tvm/runtime/device_api.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
......@@ -44,16 +26,32 @@ namespace tl {
using namespace tir;
void BinderAddAssert(arith::Analyzer *ana, PrimExpr cond,
const std::string &arg_name, std::vector<Stmt> *asserts) {
const std::string &arg_name, std::vector<Stmt> *asserts,
PrimExpr nullable_guard = PrimExpr()) {
PrimExpr scond = ana->Simplify(cond);
if (is_zero(scond)) {
LOG(FATAL) << "Bind have an unmet assertion: " << cond << ", "
<< " on argument " << arg_name;
}
if (!is_one(scond)) {
std::ostringstream os;
os << "Argument " << arg_name << " has an unsatisfied constraint: " << cond;
asserts->emplace_back(AssertStmt(scond, StringImm(os.str()), Evaluate(0)));
// Check if the condition is of the form "is_null || actual_cond"
// If so, generate "if !is_null: assert actual_cond" instead of "assert
// is_null || actual_cond"
if (nullable_guard.defined()) {
// Pattern: nullable_guard || actual_condition
// We want to transform this into: if !nullable_guard: assert
// actual_condition
Stmt check = AssertStmt(scond, StringImm(os.str()), Evaluate(0));
check = IfThenElse(Not(nullable_guard), check);
asserts->emplace_back(SeqStmt({check, Evaluate(0)}));
} else {
asserts->emplace_back(
AssertStmt(scond, StringImm(os.str()), Evaluate(0)));
}
}
}
......@@ -106,8 +104,8 @@ bool ArgBinder::BindNullable(const PrimExpr &arg, const PrimExpr &value,
return true;
} else {
// Second or later binding: add is_null short-circuit
PrimExpr cond = MakeGuarded(it->second == value);
BinderAddAssert(&analyzer_, cond, arg_name, &asserts_);
PrimExpr cond = value == it->second;
BinderAddAssert(&analyzer_, cond, arg_name, &asserts_, nullable_guard);
}
} else {
// 2. complex binding expr = value
......@@ -129,7 +127,7 @@ bool ArgBinder::BindNullable(const PrimExpr &arg, const PrimExpr &value,
auto value_opt = sol->src_to_dst.Get(v);
ICHECK(value_opt->defined())
<< "Unable to solve variable `" << v << "` from expression `"
<< (arg == value) << "`";
<< (value == arg) << "`";
auto value = ffi::GetRef<PrimExpr>(sol->src_to_dst.Get(v)->get());
BindVar(v.as<VarNode>(), value);
}
......@@ -138,9 +136,10 @@ bool ArgBinder::BindNullable(const PrimExpr &arg, const PrimExpr &value,
// because the solved expression may contain floordiv (e.g. 3 * m == n
// ==> m = n // 3) we re-compute the constraint to verify the solution
// is correct
PrimExpr cond = MakeGuarded(arg == value);
BinderAddAssert(&analyzer_, cond, arg_name, &asserts_);
PrimExpr cond = value == arg;
BinderAddAssert(&analyzer_, cond, arg_name, &asserts_, nullable_guard);
}
// ICHECK(false);
return false;
}
......@@ -160,10 +159,10 @@ bool ArgBinder::Bind_(const PrimExpr &arg, const PrimExpr &value,
}
return true;
} else {
BinderAddAssert(&analyzer_, it->second == value, arg_name, &asserts_);
BinderAddAssert(&analyzer_, value == it->second, arg_name, &asserts_);
}
} else {
BinderAddAssert(&analyzer_, arg == value, arg_name, &asserts_);
BinderAddAssert(&analyzer_, value == arg, arg_name, &asserts_);
}
return false;
}
......@@ -236,7 +235,7 @@ void ArgBinder::BindBuffer(const Buffer &arg, const Buffer &value,
PrimExpr offset = value->elem_offset;
PrimExpr factor = make_const(offset.dtype(), arg->offset_factor);
PrimExpr zero = make_zero(offset.dtype());
BinderAddAssert(&analyzer_, truncmod(offset, factor) == zero,
BinderAddAssert(&analyzer_, zero == truncmod(offset, factor),
arg_name + ".elem_offset", &asserts_);
}
}
......@@ -277,7 +276,7 @@ inline PrimExpr TVMArrayGet(DataType t, Var arr,
void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
const PrimExpr &device_id, const Var &handle,
const std::string &arg_name) {
const std::string &arg_name, bool is_used) {
const DataType tvm_shape_type = DataType::ShapeIndex();
const DataType tvm_ndim_type = DataType::Int(32);
const Stmt nop = Evaluate(0);
......@@ -286,11 +285,18 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
// avoid dereferencing it by using expression-level conditionals and
// short-circuiting guards in asserts. Cache the null check in a Let-bound
// boolean so codegen does not repeat `(handle == NULL)` everywhere.
Var is_null_var(arg_name + "_is_null", DataType::Bool());
init_nest_.emplace_back(
LetStmt(is_null_var,
Call(DataType::Bool(), builtin::isnullptr(), {handle}), nop));
const PrimExpr &is_null = is_null_var;
const PrimExpr &is_null = is_used ? const_false() : is_null_var;
if (is_used) {
init_nest_.emplace_back(AssertStmt(
!is_null_var,
tvm::tir::StringImm(arg_name + " is expected to have non-NULL pointer"),
nop));
}
// dimension checks
PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, builtin::kArrNDim);
......@@ -318,9 +324,10 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
ndim_err_msg << arg_name << ".ndim is expected to equal "
<< buffer->shape.size() << ", but got mismatched ndim";
auto msg = StringImm(ndim_err_msg.str());
// Only check ndim when handle is non-NULL (using short-circuit OR)
v_ndim = tvm::if_then_else(Not(is_null), v_ndim, make_zero(tvm_ndim_type));
init_nest_.emplace_back(AssertStmt(Or(is_null, a_ndim == v_ndim), msg, nop));
// Only check ndim when handle is non-NULL (using if statement)
Stmt ndim_check = AssertStmt(a_ndim == v_ndim, msg, nop);
ndim_check = IfThenElse(Not(is_null), ndim_check);
init_nest_.emplace_back(SeqStmt({ndim_check, nop}));
// type checks
std::ostringstream type_err_msg;
// Avoid dumping TIR expressions in error text; just state mismatch.
......@@ -396,8 +403,10 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
buffer->dtype == DataType::Int(4) ||
buffer->dtype == DataType::UInt(4))) {
auto type_msg = StringImm(type_err_msg.str());
// Only check dtype when handle is non-NULL (short-circuit)
asserts_.emplace_back(AssertStmt(Or(is_null, cond), type_msg, nop));
// Only check dtype when handle is non-NULL (using if statement)
Stmt dtype_check = AssertStmt(cond, type_msg, nop);
dtype_check = IfThenElse(Not(is_null), dtype_check);
asserts_.emplace_back(SeqStmt({dtype_check, nop}));
}
// shape field
......@@ -427,31 +436,16 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
}
// The "real" runtime shape value read from DLTensor
PrimExpr raw_shape_val =
PrimExpr shape_val =
cast(buffer->shape[k].dtype(),
BufferLoad(buf_shape,
{IntImm(DataType::Int(32), static_cast<int>(k))}));
// Bind to the value of the symbolic dimension (e.g., m) in TIR, with an
// is_null guard:
// handle is NULL → use 0, placeholder but no dereference
// handle non-NULL → actually read from DLTensor's shape array
PrimExpr bound_shape_val = tvm::if_then_else(
is_null, make_zero(buffer->shape[k].dtype()), raw_shape_val);
// When first encountering a Var (e.g., m), this will generate:
// Let(m, bound_shape_val, ...)
// Constant dimensions will only generate consistency assertions.
BindNullable(buffer->shape[k], bound_shape_val, shape_element_name(k), true,
BindNullable(buffer->shape[k], shape_val, shape_element_name(k), true,
is_null);
// Keep an explicit "consistency check": when non-NULL, the symbolic
// dimension must equal the DLTensor's shape.
Stmt shape_check = AssertStmt(
Or(is_null, buffer->shape[k] == raw_shape_val),
StringImm(shape_element_name(k) + " mismatch with DLTensor shape"),
Evaluate(0));
asserts_.emplace_back(shape_check);
}
// strides field
......@@ -499,7 +493,7 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
asserts_.emplace_back(SeqStmt({check, Evaluate(0)}));
}
} else if (buffer->buffer_type == kAutoBroadcast) {
PrimExpr stride_from_shape = make_const(buffer->DefaultIndexType(), 1);
PrimExpr stride_from_shape = 1;
for (size_t i = buffer->shape.size(); i != 0; --i) {
size_t k = i - 1;
DataType stride_dtype = buffer->strides[k].dtype();
......@@ -507,31 +501,15 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
cast(stride_dtype,
BufferLoad(buf_strides,
{IntImm(DataType::Int(32), static_cast<int>(k))}));
PrimExpr stride_from_shape_cast = cast(stride_dtype, stride_from_shape);
PrimExpr core_value = tvm::if_then_else(
v_strides_is_null, stride_from_shape_cast, explicit_stride);
core_value = tvm::if_then_else(buffer->shape[k] == 1,
make_zero(stride_dtype), core_value);
// Bind like shape: define var when needed, and only assert when non-NULL
PrimExpr bound_stride_val =
tvm::if_then_else(is_null, make_zero(stride_dtype), core_value);
BindNullable(buffer->strides[k], bound_stride_val, stride_element_name(k),
true, is_null);
PrimExpr stride_val = tvm::if_then_else(
v_strides_is_null, stride_from_shape, explicit_stride);
Stmt stride_check = AssertStmt(
Or(is_null, buffer->strides[k] == core_value),
StringImm(stride_element_name(k) + " mismatch with DLTensor strides"),
Evaluate(0));
asserts_.emplace_back(stride_check);
PrimExpr shape_extent = cast(stride_dtype, buffer->shape[k]);
stride_from_shape =
analyzer_.Simplify(stride_from_shape_cast * shape_extent);
BindNullable(buffer->strides[k], stride_val, stride_element_name(k), true,
is_null);
}
} else {
PrimExpr stride_from_shape = make_const(buffer->DefaultIndexType(), 1);
PrimExpr stride_from_shape = 1;
for (int k = static_cast<int>(buffer->strides.size()) - 1; k >= 0; --k) {
DataType stride_dtype = buffer->strides[k].dtype();
......@@ -540,24 +518,12 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)}));
PrimExpr shape_stride = cast(
stride_dtype, BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)}));
PrimExpr stride_from_shape_cast = cast(stride_dtype, stride_from_shape);
PrimExpr core_value = tvm::if_then_else(
v_strides_is_null, stride_from_shape_cast, explicit_stride);
PrimExpr bound_stride_val =
tvm::if_then_else(is_null, make_zero(stride_dtype), core_value);
BindNullable(buffer->strides[k], bound_stride_val, stride_element_name(k),
true, is_null);
Stmt stride_check = AssertStmt(
Or(is_null, buffer->strides[k] == core_value),
StringImm(stride_element_name(k) + " mismatch with DLTensor strides"),
Evaluate(0));
asserts_.emplace_back(stride_check);
PrimExpr stride_val = tvm::if_then_else(
v_strides_is_null, stride_from_shape, explicit_stride);
stride_from_shape =
analyzer_.Simplify(stride_from_shape_cast * shape_stride);
BindNullable(buffer->strides[k], stride_val, stride_element_name(k), true,
is_null);
}
}
......@@ -574,9 +540,10 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
PrimExpr expect_byte_offset =
make_const(DataType::UInt(64), const_offset->value * data_bytes);
Stmt byte_off_check =
AssertStmt(Or(is_null, expect_byte_offset == actual_byte_offset),
AssertStmt(expect_byte_offset == actual_byte_offset,
StringImm(arg_name + ".byte_offset mismatch"), nop);
asserts_.emplace_back(byte_off_check);
byte_off_check = IfThenElse(Not(is_null), byte_off_check);
asserts_.emplace_back(SeqStmt({byte_off_check, nop}));
} else {
PrimExpr actual_byte_offset = tvm::if_then_else(
Not(is_null),
......@@ -586,28 +553,15 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
cast(buffer->elem_offset.dtype(),
(actual_byte_offset / make_const(DataType::UInt(64), data_bytes)));
// Like shape/stride, do NULL-safe binding for elem_offset:
// handle is NULL → 0
// handle non-NULL → actual_byte_offset / data_bytes
PrimExpr bound_elem_off = tvm::if_then_else(
is_null, make_zero(buffer->elem_offset.dtype()), expect_elem_off);
BindNullable(buffer->elem_offset, bound_elem_off, arg_name + ".elem_offset",
true, is_null);
// Strict consistency check for non-NULL case
Stmt elem_off_check =
AssertStmt(Or(is_null, buffer->elem_offset == expect_elem_off),
StringImm(arg_name + ".elem_offset mismatch"), nop);
asserts_.emplace_back(elem_off_check);
BindNullable(buffer->elem_offset, expect_elem_off,
arg_name + ".elem_offset", true, is_null);
if (buffer->offset_factor > 1) {
PrimExpr offset = buffer->elem_offset;
PrimExpr factor = make_const(offset.dtype(), buffer->offset_factor);
PrimExpr zero = make_zero(offset.dtype());
Stmt off_factor_check =
AssertStmt(Or(is_null, truncmod(offset, factor) == zero),
StringImm(arg_name + ".elem_offset factor mismatch"), nop);
asserts_.emplace_back(off_factor_check);
BindNullable(offset, truncmod(offset, factor), arg_name + ".elem_offset",
true, is_null);
}
}
......@@ -621,14 +575,29 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
Not(is_null),
TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceId),
make_zero(DataType::Int(32)));
// Bind device_id to a safe expression (0 when NULL handle)
BindNullable(device_id, actual_dev_id, arg_name + ".device_id", true,
is_null);
// Check device_type consistency (device_id equality is implicitly ensured by
// binding above)
init_nest_.emplace_back(
AssertStmt(Or(is_null, device_type == actual_dev_type),
StringImm(arg_name + ".device_type mismatch"), nop));
{
std::ostringstream dev_msg;
dev_msg << arg_name << ".device_type mismatch";
if (const auto *imm = device_type.as<IntImmNode>()) {
dev_msg << " [expected: " << imm->value << " ("
<< tvm::runtime::DLDeviceType2Str(static_cast<int>(imm->value))
<< ")]";
}
// Give a short legend so users can interpret numeric codes in the
// appended "got/expected" part printed by the runtime.
dev_msg << "; DLPack codes: 1=CPU, 2=CUDA, 7=Vulkan, 8=Metal, 10=ROCM, "
"14=OneAPI, 15=WebGPU";
auto device_type_check =
IfThenElse(Not(is_null), AssertStmt(device_type == actual_dev_type,
StringImm(dev_msg.str()), nop));
asserts_.emplace_back(SeqStmt({device_type_check, Evaluate(0)}));
}
// Data field. Because the validation of the data field may depend
// on a dynamic size defined by the other DLTensor* parameters, this
......@@ -650,12 +619,14 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
product *= dim;
return product;
}();
asserts_.emplace_back(AssertStmt(
Or(is_null, (alloc_size == 0) ||
!Call(DataType::Bool(), builtin::isnullptr(), {vptr})),
Stmt data_null_check = AssertStmt(
(alloc_size == 0) ||
!Call(DataType::Bool(), builtin::isnullptr(), {vptr}),
StringImm(arg_name +
" is expected to have non-NULL data pointer, but got NULL"),
nop));
nop);
data_null_check = IfThenElse(Not(is_null), data_null_check);
asserts_.emplace_back(SeqStmt({data_null_check, nop}));
// mark alignment of external bufs
init_nest_.emplace_back(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment