05_shape_mismatch.py 436 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
"""Reproduce: shape constant/symbol mismatch on A.
"""
import torch
from common import build_matmul_kernel


def main():
    M = N = K = 128
    fn = build_matmul_kernel(M, N, K, target="cuda")

    # A's second dimension is wrong (K+1 instead of K)
    a = torch.empty((M, K + 1), device="cuda", dtype=torch.float16)
    b = torch.empty((K, N), device="cuda", dtype=torch.float16)

    fn(a, b)


if __name__ == "__main__":
    main()