Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
ebea77d9
Unverified
Commit
ebea77d9
authored
Sep 18, 2025
by
Lei Wang
Committed by
GitHub
Sep 18, 2025
Browse files
[CI] Test Fix: Handle BufferLoad nodes when T.gemm input has a stride (#843)
* bugfix * fix * test fix
parent
232782dd
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
87 additions
and
8 deletions
+87
-8
tilelang/language/gemm.py
tilelang/language/gemm.py
+87
-8
No files found.
tilelang/language/gemm.py
View file @
ebea77d9
...
...
@@ -4,6 +4,7 @@ from tilelang.primitives.gemm.base import GemmWarpPolicy
import
tilelang.language
as
T
from
tvm
import
tir
from
typing
import
Union
,
List
from
tilelang.utils.language
import
get_buffer_region_from_load
def
gemm
(
...
...
@@ -66,8 +67,15 @@ def gemm(
for
r
in
region
:
shape
.
append
(
r
.
extent
)
return
shape
elif
isinstance
(
object
,
tir
.
BufferLoad
):
region
=
get_buffer_region_from_load
(
object
).
region
shape
=
[]
for
r
in
region
:
shape
.
append
(
r
.
extent
)
return
shape
else
:
raise
ValueError
(
f
"Unsupported argument type:
{
type
(
object
)
}
for buffer
{
object
}
"
)
raise
ValueError
(
f
"Unsupported retrieve_shape argument type:
{
type
(
object
)
}
for buffer
{
object
}
"
)
def
retrieve_stride
(
object
:
Union
[
tir
.
Buffer
,
tir
.
BufferRegion
])
->
List
[
int
]:
if
isinstance
(
object
,
tir
.
Buffer
):
...
...
@@ -85,8 +93,17 @@ def gemm(
strides
.
insert
(
0
,
stride
)
stride
*=
s
return
strides
elif
isinstance
(
object
,
tir
.
BufferLoad
):
buffer
=
object
.
buffer
strides
=
[]
stride
=
1
for
s
in
reversed
(
buffer
.
shape
):
strides
.
insert
(
0
,
stride
)
stride
*=
s
return
strides
else
:
raise
ValueError
(
f
"Unsupported argument type:
{
type
(
object
)
}
for buffer
{
object
}
"
)
raise
ValueError
(
f
"Unsupported retrieve_stride argument type:
{
type
(
object
)
}
for buffer
{
object
}
"
)
A_shape
=
retrieve_shape
(
A
)
B_shape
=
retrieve_shape
(
B
)
...
...
@@ -134,8 +151,24 @@ def gemm(
for
i
in
range
(
len
(
indices
)
-
2
):
offset
+=
indices
[
i
]
*
strides
[
i
]
return
buffer
.
access_ptr
(
access_mask
=
access_type
,
offset
=
offset
)
elif
isinstance
(
object
,
tir
.
BufferLoad
):
buffer
=
object
.
buffer
region
=
get_buffer_region_from_load
(
object
).
region
indices
=
[]
for
r
in
region
:
indices
.
append
(
r
.
min
)
strides
=
[]
stride
=
1
for
s
in
reversed
(
buffer
.
shape
):
strides
.
insert
(
0
,
stride
)
stride
*=
s
offset
=
0
for
i
in
range
(
len
(
indices
)
-
2
):
offset
+=
indices
[
i
]
*
strides
[
i
]
return
buffer
.
access_ptr
(
access_mask
=
access_type
,
offset
=
offset
)
else
:
raise
ValueError
(
f
"Unsupported argument type:
{
type
(
object
)
}
for buffer
{
object
}
"
)
raise
ValueError
(
f
"Unsupported retrieve_ptr argument type:
{
type
(
object
)
}
for buffer
{
object
}
"
)
def
retrieve_offset
(
object
:
Union
[
tir
.
Buffer
,
tir
.
BufferRegion
])
->
tir
.
PrimExpr
:
"""Retrieve the offset of the buffer or buffer region."""
...
...
@@ -147,8 +180,15 @@ def gemm(
for
r
in
region
:
indices
.
append
(
r
.
min
)
return
indices
elif
isinstance
(
object
,
tir
.
BufferLoad
):
region
=
get_buffer_region_from_load
(
object
).
region
indices
=
[]
for
r
in
region
:
indices
.
append
(
r
.
min
)
return
indices
else
:
raise
ValueError
(
f
"Unsupported argument type:
{
type
(
object
)
}
for buffer
{
object
}
"
)
raise
ValueError
(
f
"Unsupported retrieve_offset argument type:
{
type
(
object
)
}
for buffer
{
object
}
"
)
A_offset
=
retrieve_offset
(
A
)
B_offset
=
retrieve_offset
(
B
)
...
...
@@ -243,8 +283,15 @@ def gemm_v2(
for
r
in
region
:
shape
.
append
(
r
.
extent
)
return
shape
elif
isinstance
(
object
,
tir
.
BufferLoad
):
region
=
get_buffer_region_from_load
(
object
).
region
shape
=
[]
for
r
in
region
:
shape
.
append
(
r
.
extent
)
return
shape
else
:
raise
ValueError
(
f
"Unsupported argument type:
{
type
(
object
)
}
for buffer
{
object
}
"
)
raise
ValueError
(
f
"Unsupported retrieve_shape argument type:
{
type
(
object
)
}
for buffer
{
object
}
"
)
def
retrieve_stride
(
object
:
Union
[
tir
.
Buffer
,
tir
.
BufferRegion
])
->
List
[
int
]:
if
isinstance
(
object
,
tir
.
Buffer
):
...
...
@@ -262,8 +309,17 @@ def gemm_v2(
strides
.
insert
(
0
,
stride
)
stride
*=
s
return
strides
elif
isinstance
(
object
,
tir
.
BufferLoad
):
buffer
=
object
.
buffer
strides
=
[]
stride
=
1
for
s
in
reversed
(
buffer
.
shape
):
strides
.
insert
(
0
,
stride
)
stride
*=
s
return
strides
else
:
raise
ValueError
(
f
"Unsupported argument type:
{
type
(
object
)
}
for buffer
{
object
}
"
)
raise
ValueError
(
f
"Unsupported retrieve_stride argument type:
{
type
(
object
)
}
for buffer
{
object
}
"
)
A_shape
=
retrieve_shape
(
A
)
B_shape
=
retrieve_shape
(
B
)
...
...
@@ -311,8 +367,24 @@ def gemm_v2(
for
i
in
range
(
len
(
indices
)
-
2
):
offset
+=
indices
[
i
]
*
strides
[
i
]
return
buffer
.
access_ptr
(
access_mask
=
access_type
,
offset
=
offset
)
elif
isinstance
(
object
,
tir
.
BufferLoad
):
buffer
=
object
.
buffer
region
=
get_buffer_region_from_load
(
object
).
region
indices
=
[]
for
r
in
region
:
indices
.
append
(
r
.
min
)
strides
=
[]
stride
=
1
for
s
in
reversed
(
buffer
.
shape
):
strides
.
insert
(
0
,
stride
)
stride
*=
s
offset
=
0
for
i
in
range
(
len
(
indices
)
-
2
):
offset
+=
indices
[
i
]
*
strides
[
i
]
return
buffer
.
access_ptr
(
access_mask
=
access_type
,
offset
=
offset
)
else
:
raise
ValueError
(
f
"Unsupported argument type:
{
type
(
object
)
}
for buffer
{
object
}
"
)
raise
ValueError
(
f
"Unsupported retrieve_ptr argument type:
{
type
(
object
)
}
for buffer
{
object
}
"
)
def
retrieve_offset
(
object
:
Union
[
tir
.
Buffer
,
tir
.
BufferRegion
])
->
tir
.
PrimExpr
:
"""Retrieve the offset of the buffer or buffer region."""
...
...
@@ -324,8 +396,15 @@ def gemm_v2(
for
r
in
region
:
indices
.
append
(
r
.
min
)
return
indices
elif
isinstance
(
object
,
tir
.
BufferLoad
):
region
=
get_buffer_region_from_load
(
object
).
region
indices
=
[]
for
r
in
region
:
indices
.
append
(
r
.
min
)
return
indices
else
:
raise
ValueError
(
f
"Unsupported argument type:
{
type
(
object
)
}
for buffer
{
object
}
"
)
raise
ValueError
(
f
"Unsupported retrieve_offset argument type:
{
type
(
object
)
}
for buffer
{
object
}
"
)
A_offset
=
retrieve_offset
(
A
)
B_offset
=
retrieve_offset
(
B
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment