01_num_args_mismatch.py 556 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
"""Reproduce: Argument count mismatch.

Note: The adapter-level wrapper expects only inputs (A, B) because C is marked as output.
Calling with the wrong number of inputs raises a ValueError before host entry.
"""
import torch
from common import build_matmul_kernel


def main():
    M = N = K = 256
    fn = build_matmul_kernel(M, N, K, target="cuda")

    a = torch.empty((M, K), device="cuda", dtype=torch.float16)
    # Missing b
    # Expected: ValueError with message about expected vs. actual inputs
    fn(a)


if __name__ == "__main__":
    main()