Unverified Commit 4ef94f22 authored by Kurisu's avatar Kurisu Committed by GitHub
Browse files

[Fix] fix type imcompatible error in #1115 (#1180)

* Fix incompatible floordiv in packed api

* fix lint
parent 5f202fe5
......@@ -433,7 +433,7 @@ PrimFunc MakePackedAPI(PrimFunc func) {
auto shape_vectorize_expr = [&]() -> PrimExpr {
PrimExpr result = IntImm(kv.second->DefaultIndexType(), 1);
result = result * vectorize_dim;
result = FloorMod(result, dynamic_alignment);
result = FloorMod(result, IntImm(result->dtype, dynamic_alignment));
return result;
}();
shape_checks.emplace_back(AssertStmt(
......
import torch
import tilelang
import tilelang.language as T
def test_int64_address():
@tilelang.jit
def set_cache_kernel(
S,
D,
pos_ty='int64',
dtype="float32",
):
@T.prim_func
def main(
pos: T
.Tensor(
[
S,
], pos_ty
), # type: ignore `TypeError: Check failed: (a.dtype() == b.dtype()) is false: mismatched types. int64 vs. int32`
value: T.Tensor([S, D], dtype), # type: ignore
cache: T.Tensor([S, D], dtype), # type: ignore
):
with T.Kernel(S, threads=128) as bx:
slot = pos[bx]
for i in T.Parallel(D):
cache[slot, i] = value[bx, i]
return main
D = 2
S = 10
cache = torch.rand((S, D), device="cuda", dtype=torch.float32)
value = torch.rand((S, D), device='cuda', dtype=torch.float32)
pos_int64 = torch.arange(S, device='cuda', dtype=torch.int64)
pos_int32 = torch.arange(S, device='cuda', dtype=torch.int32)
kernel_int64 = set_cache_kernel(S, D, 'int64')
kernel_int32 = set_cache_kernel(S, D, 'int32')
kernel_int64(pos_int64, value, cache)
torch.testing.assert_close(cache, value)
kernel_int32(pos_int32, value, cache)
torch.testing.assert_close(cache, value)
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment