Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
4ef94f22
Unverified
Commit
4ef94f22
authored
Nov 04, 2025
by
Kurisu
Committed by
GitHub
Nov 04, 2025
Browse files
[Fix] fix type imcompatible error in #1115 (#1180)
* Fix incompatible floordiv in packed api * fix lint
parent
5f202fe5
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
50 additions
and
1 deletion
+50
-1
src/transform/make_packed_api.cc
src/transform/make_packed_api.cc
+1
-1
testing/python/issue/test_tilelang_issue_1115.py
testing/python/issue/test_tilelang_issue_1115.py
+49
-0
No files found.
src/transform/make_packed_api.cc
View file @
4ef94f22
...
@@ -433,7 +433,7 @@ PrimFunc MakePackedAPI(PrimFunc func) {
...
@@ -433,7 +433,7 @@ PrimFunc MakePackedAPI(PrimFunc func) {
auto
shape_vectorize_expr
=
[
&
]()
->
PrimExpr
{
auto
shape_vectorize_expr
=
[
&
]()
->
PrimExpr
{
PrimExpr
result
=
IntImm
(
kv
.
second
->
DefaultIndexType
(),
1
);
PrimExpr
result
=
IntImm
(
kv
.
second
->
DefaultIndexType
(),
1
);
result
=
result
*
vectorize_dim
;
result
=
result
*
vectorize_dim
;
result
=
FloorMod
(
result
,
dynamic_alignment
);
result
=
FloorMod
(
result
,
IntImm
(
result
->
dtype
,
dynamic_alignment
)
)
;
return
result
;
return
result
;
}();
}();
shape_checks
.
emplace_back
(
AssertStmt
(
shape_checks
.
emplace_back
(
AssertStmt
(
...
...
testing/python/issue/test_tilelang_issue_1115.py
0 → 100644
View file @
4ef94f22
import
torch
import
tilelang
import
tilelang.language
as
T
def
test_int64_address
():
@
tilelang
.
jit
def
set_cache_kernel
(
S
,
D
,
pos_ty
=
'int64'
,
dtype
=
"float32"
,
):
@
T
.
prim_func
def
main
(
pos
:
T
.
Tensor
(
[
S
,
],
pos_ty
),
# type: ignore `TypeError: Check failed: (a.dtype() == b.dtype()) is false: mismatched types. int64 vs. int32`
value
:
T
.
Tensor
([
S
,
D
],
dtype
),
# type: ignore
cache
:
T
.
Tensor
([
S
,
D
],
dtype
),
# type: ignore
):
with
T
.
Kernel
(
S
,
threads
=
128
)
as
bx
:
slot
=
pos
[
bx
]
for
i
in
T
.
Parallel
(
D
):
cache
[
slot
,
i
]
=
value
[
bx
,
i
]
return
main
D
=
2
S
=
10
cache
=
torch
.
rand
((
S
,
D
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
value
=
torch
.
rand
((
S
,
D
),
device
=
'cuda'
,
dtype
=
torch
.
float32
)
pos_int64
=
torch
.
arange
(
S
,
device
=
'cuda'
,
dtype
=
torch
.
int64
)
pos_int32
=
torch
.
arange
(
S
,
device
=
'cuda'
,
dtype
=
torch
.
int32
)
kernel_int64
=
set_cache_kernel
(
S
,
D
,
'int64'
)
kernel_int32
=
set_cache_kernel
(
S
,
D
,
'int32'
)
kernel_int64
(
pos_int64
,
value
,
cache
)
torch
.
testing
.
assert_close
(
cache
,
value
)
kernel_int32
(
pos_int32
,
value
,
cache
)
torch
.
testing
.
assert_close
(
cache
,
value
)
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment