Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
5cb5c068
Unverified
Commit
5cb5c068
authored
Oct 22, 2025
by
Yu Cheng
Committed by
GitHub
Oct 22, 2025
Browse files
[Bugfix] Fix missing host cuTensorMapEncodeIm2col call (#1094)
parent
bddb125e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
113 additions
and
40 deletions
+113
-40
examples/convolution/example_convolution.py
examples/convolution/example_convolution.py
+1
-0
src/transform/inject_tma_barrier.cc
src/transform/inject_tma_barrier.cc
+5
-4
tilelang/jit/adapter/wrapper.py
tilelang/jit/adapter/wrapper.py
+107
-36
No files found.
examples/convolution/example_convolution.py
View file @
5cb5c068
...
...
@@ -122,6 +122,7 @@ def main(argv=None):
out_c
=
kernel
(
a
,
b
)
ref_c
=
ref_program
(
S
,
P
,
D
)(
a
,
b
)
torch
.
testing
.
assert_close
(
out_c
,
ref_c
,
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
"All checks passed.✅"
)
if
__name__
==
"__main__"
:
...
...
src/transform/inject_tma_barrier.cc
View file @
5cb5c068
...
...
@@ -163,7 +163,7 @@ private:
}
PrimExpr
VisitExpr_
(
const
CallNode
*
op
)
{
if
(
op
->
op
.
same_as
(
tma_load
()))
{
if
(
op
->
op
.
same_as
(
tma_load
())
||
op
->
op
.
same_as
(
tma_load_im2col
())
)
{
auto
arg0
=
op
->
args
[
0
].
as
<
Call
>
();
bool
is_1d_tma_load
=
arg0
&&
!
arg0
.
value
()
->
op
.
same_as
(
create_tma_descriptor
())
&&
...
...
@@ -203,7 +203,7 @@ private:
void
VisitStmt_
(
const
EvaluateNode
*
op
)
final
{
if
(
const
auto
*
call
=
op
->
value
.
as
<
CallNode
>
())
{
if
(
call
->
op
.
same_as
(
tma_load
()))
{
if
(
call
->
op
.
same_as
(
tma_load
())
||
call
->
op
.
same_as
(
tma_load_im2col
())
)
{
pending_tma_ops_
.
push_back
(
GetRef
<
Call
>
(
call
));
}
else
if
(
call
->
op
.
same_as
(
mbarrier_expect_tx
()))
{
pending_tma_ops_
.
push_back
(
GetRef
<
Call
>
(
call
));
...
...
@@ -451,7 +451,7 @@ private:
}
PrimExpr
VisitExpr_
(
const
CallNode
*
op
)
{
if
(
op
->
op
.
same_as
(
tma_load
()))
{
if
(
op
->
op
.
same_as
(
tma_load
())
||
op
->
op
.
same_as
(
tma_load_im2col
())
)
{
// check this must be in the tma_op_to_barrier_id_
ICHECK
(
tma_op_to_barrier_id_
.
count
(
GetRef
<
Call
>
(
op
)))
<<
"tma_load must be in the tma_op_to_barrier_id_"
;
...
...
@@ -459,7 +459,8 @@ private:
auto
new_args
=
op
->
args
;
auto
arg0
=
op
->
args
[
0
].
as
<
Call
>
();
auto
is_1d_tma_load
=
arg0
&&
!
arg0
.
value
()
->
op
.
same_as
(
create_tma_descriptor
());
arg0
&&
!
arg0
.
value
()
->
op
.
same_as
(
create_tma_descriptor
())
&&
!
arg0
.
value
()
->
op
.
same_as
(
create_tma_im2col_descriptor
());
if
(
is_1d_tma_load
)
{
new_args
.
Set
(
2
,
barrier_id
);
}
else
{
...
...
tilelang/jit/adapter/wrapper.py
View file @
5cb5c068
...
...
@@ -106,6 +106,35 @@ TMA_DESC_INIT_FUNC = """
\t
}}
"""
TMA_IM2COL_DESC_INIT_FUNC
=
"""
\t
CUtensorMap {0};
\t
CUtensorMapDataType {0}_type= (CUtensorMapDataType){1};
\t
cuuint32_t {0}_tensorRank= {2};
\t
void *{0}_globalAddress= {3};
\t
cuuint64_t {0}_globalDim[{2}]= {{{4}}};
\t
cuuint64_t {0}_globalStride[{2}]= {{{5}}};
\t
cuuint32_t {0}_elementStrides[{2}]= {{{6}}};
\t
int {0}_lowerCorner[{2} - 2]= {{{7}}};
\t
int {0}_upperCorner[{2} - 2]= {{{8}}};
\t
cuuint32_t {0}_channelsPerPixel= {9};
\t
cuuint32_t {0}_pixelsPerColumn= {10};
\t
CUtensorMapInterleave {0}_interleave= (CUtensorMapInterleave){11};
\t
CUtensorMapSwizzle {0}_swizzle= (CUtensorMapSwizzle){12};
\t
CUtensorMapL2promotion {0}_l2Promotion= (CUtensorMapL2promotion){13};
\t
CUtensorMapFloatOOBfill {0}_oobFill= (CUtensorMapFloatOOBfill){14};
\t
CUresult {0}_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeIm2col)(
&{0}, {0}_type, {0}_tensorRank, {0}_globalAddress, {0}_globalDim, {0}_globalStride + 1,
{0}_lowerCorner, {0}_upperCorner, {0}_channelsPerPixel, {0}_pixelsPerColumn, {0}_elementStrides, {0}_interleave, {0}_swizzle, {0}_l2Promotion, {0}_oobFill);
\t
if ({0}_result != CUDA_SUCCESS) {{
\t\t
std::stringstream ss;
\t\t
ss << "Error: Failed to initialize the TMA descriptor {0}";
\t\t
snprintf(error_buf, ERROR_BUF_SIZE, "%s", ss.str().c_str());
\t\t
return -1;
\t
}}
"""
TMA_DESC_INIT_FUNC_PY
=
"""
\t
{0}_type = cuda.bindings.driver.CUtensorMapDataType({1})
\t
{0}_tensorRank = {2}
...
...
@@ -401,7 +430,10 @@ class TLCUDASourceWrapper(object):
if
len
(
args
)
<
3
:
raise
ValueError
(
f
"TMA descriptor args too short:
{
len
(
args
)
}
elements, expected at least 3"
)
_
,
dtype
,
tensor_rank
,
globalAddress
,
*
remaining_args
=
args
[
1
:]
tma_create_str
,
_
,
dtype
,
tensor_rank
,
globalAddress
,
*
remaining_args
=
args
is_img2col
=
(
tma_create_str
.
value
==
"__tvm_tensormap_create_im2col"
)
dtype
=
self
.
_pythonic_expr
(
dtype
)
tensor_rank
=
int
(
self
.
_pythonic_expr
(
tensor_rank
))
...
...
@@ -409,42 +441,81 @@ class TLCUDASourceWrapper(object):
if
not
isinstance
(
tensor_rank
,
int
)
or
tensor_rank
<=
0
:
raise
ValueError
(
f
"Invalid tensor_rank:
{
tensor_rank
}
. Must be a positive integer"
)
# Calculate required length for remaining_args
expected_args_len
=
4
*
tensor_rank
+
4
# 4 groups of tensor_rank size + 4 parameters
if
len
(
remaining_args
)
<
expected_args_len
:
raise
ValueError
(
f
"Insufficient remaining args: got
{
len
(
remaining_args
)
}
, "
f
"expected
{
expected_args_len
}
for tensor_rank
{
tensor_rank
}
"
)
# Extract dimensions and strides using list slicing
global_dim
=
remaining_args
[:
tensor_rank
]
global_stride
=
remaining_args
[
tensor_rank
:
2
*
tensor_rank
]
box_dim
=
remaining_args
[
2
*
tensor_rank
:
3
*
tensor_rank
]
element_strides
=
remaining_args
[
3
*
tensor_rank
:
4
*
tensor_rank
]
global_dim
=
[
self
.
_pythonic_expr
(
i
)
for
i
in
global_dim
]
global_stride
=
[
self
.
_pythonic_expr
(
i
)
for
i
in
global_stride
]
box_dim
=
[
self
.
_pythonic_expr
(
i
)
for
i
in
box_dim
]
element_strides
=
[
self
.
_pythonic_expr
(
i
)
for
i
in
element_strides
]
# Extract remaining parameters
try
:
interleave
,
swizzle
,
l2Promotion
,
oobFill
=
remaining_args
[
4
*
tensor_rank
:
4
*
tensor_rank
+
4
]
interleave
=
self
.
_pythonic_expr
(
interleave
)
swizzle
=
self
.
_pythonic_expr
(
swizzle
)
l2Promotion
=
self
.
_pythonic_expr
(
l2Promotion
)
oobFill
=
self
.
_pythonic_expr
(
oobFill
)
except
ValueError
as
e
:
raise
ValueError
(
"Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
)
from
e
if
not
is_img2col
:
# Calculate required length for remaining_args
expected_args_len
=
4
*
tensor_rank
+
4
# 4 groups of tensor_rank size + 4 parameters
if
len
(
remaining_args
)
<
expected_args_len
:
raise
ValueError
(
f
"Insufficient remaining args: got
{
len
(
remaining_args
)
}
, "
f
"expected
{
expected_args_len
}
for tensor_rank
{
tensor_rank
}
"
)
# Extract dimensions and strides using list slicing
global_dim
=
remaining_args
[:
tensor_rank
]
global_stride
=
remaining_args
[
tensor_rank
:
2
*
tensor_rank
]
box_dim
=
remaining_args
[
2
*
tensor_rank
:
3
*
tensor_rank
]
element_strides
=
remaining_args
[
3
*
tensor_rank
:
4
*
tensor_rank
]
global_dim
=
[
self
.
_pythonic_expr
(
i
)
for
i
in
global_dim
]
global_stride
=
[
self
.
_pythonic_expr
(
i
)
for
i
in
global_stride
]
box_dim
=
[
self
.
_pythonic_expr
(
i
)
for
i
in
box_dim
]
element_strides
=
[
self
.
_pythonic_expr
(
i
)
for
i
in
element_strides
]
# Extract remaining parameters
try
:
interleave
,
swizzle
,
l2Promotion
,
oobFill
=
remaining_args
[
4
*
tensor_rank
:
4
*
tensor_rank
+
4
]
interleave
=
self
.
_pythonic_expr
(
interleave
)
swizzle
=
self
.
_pythonic_expr
(
swizzle
)
l2Promotion
=
self
.
_pythonic_expr
(
l2Promotion
)
oobFill
=
self
.
_pythonic_expr
(
oobFill
)
except
ValueError
as
e
:
raise
ValueError
(
"Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
)
from
e
tma_descripter_init
+=
TMA_DESC_INIT_FUNC
.
format
(
handle_name
,
dtype
,
tensor_rank
,
globalAddress
,
","
.
join
(
global_dim
),
","
.
join
(
global_stride
),
","
.
join
(
box_dim
),
","
.
join
(
element_strides
),
interleave
,
swizzle
,
l2Promotion
,
oobFill
)
else
:
# Calculate required length for remaining_args
expected_args_len
=
5
*
tensor_rank
+
2
if
len
(
remaining_args
)
<
expected_args_len
:
raise
ValueError
(
f
"Insufficient remaining args: got
{
len
(
remaining_args
)
}
, "
f
"expected
{
expected_args_len
}
for tensor_rank
{
tensor_rank
}
"
)
# Extract dimensions and strides using list slicing
global_dim
=
remaining_args
[:
tensor_rank
]
global_stride
=
remaining_args
[
tensor_rank
:
2
*
tensor_rank
]
element_strides
=
remaining_args
[
2
*
tensor_rank
:
3
*
tensor_rank
]
lower_corner
=
remaining_args
[
3
*
tensor_rank
:
4
*
tensor_rank
-
2
]
upper_corner
=
remaining_args
[
4
*
tensor_rank
-
2
:
5
*
tensor_rank
-
4
]
global_dim
=
[
self
.
_pythonic_expr
(
i
)
for
i
in
global_dim
]
global_stride
=
[
self
.
_pythonic_expr
(
i
)
for
i
in
global_stride
]
element_strides
=
[
self
.
_pythonic_expr
(
i
)
for
i
in
element_strides
]
lower_corner
=
[
self
.
_pythonic_expr
(
i
)
for
i
in
lower_corner
]
upper_corner
=
[
self
.
_pythonic_expr
(
i
)
for
i
in
upper_corner
]
# Extract remaining parameters
try
:
smem_box_pixel
,
smem_box_channel
,
interleave
,
swizzle
,
l2Promotion
,
oobFill
=
remaining_args
[
5
*
tensor_rank
-
4
:
5
*
tensor_rank
+
2
]
smem_box_pixel
=
self
.
_pythonic_expr
(
smem_box_pixel
)
smem_box_channel
=
self
.
_pythonic_expr
(
smem_box_channel
)
interleave
=
self
.
_pythonic_expr
(
interleave
)
swizzle
=
self
.
_pythonic_expr
(
swizzle
)
l2Promotion
=
self
.
_pythonic_expr
(
l2Promotion
)
oobFill
=
self
.
_pythonic_expr
(
oobFill
)
except
ValueError
as
e
:
raise
ValueError
(
"Failed to unpack the final 6 TMA parameters (smem_box_pixel, smem_box_channel, interleave, swizzle, l2Promotion, oobFill)"
)
from
e
tma_descripter_init
+=
TMA_IM2COL_DESC_INIT_FUNC
.
format
(
handle_name
,
dtype
,
tensor_rank
,
globalAddress
,
","
.
join
(
global_dim
),
","
.
join
(
global_stride
),
","
.
join
(
element_strides
),
","
.
join
(
lower_corner
),
","
.
join
(
upper_corner
),
smem_box_channel
,
smem_box_pixel
,
interleave
,
swizzle
,
l2Promotion
,
oobFill
)
tma_descripter_init
+=
TMA_DESC_INIT_FUNC
.
format
(
handle_name
,
dtype
,
tensor_rank
,
globalAddress
,
","
.
join
(
global_dim
),
","
.
join
(
global_stride
),
","
.
join
(
box_dim
),
","
.
join
(
element_strides
),
interleave
,
swizzle
,
l2Promotion
,
oobFill
)
return
tma_descripter_init
def
parse_source_information
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment