Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
d4b6d094
Unverified
Commit
d4b6d094
authored
Nov 20, 2025
by
Lei Wang
Committed by
GitHub
Nov 20, 2025
Browse files
[Enhancement] Shared Memory Size Can be Dynamic (#1294)
* bugfix * lint fix * test * lint fix * increate procs * recover
parent
dd7fdb8e
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
58 additions
and
8 deletions
+58
-8
.github/workflows/ci.yml
.github/workflows/ci.yml
+1
-1
3rdparty/tvm
3rdparty/tvm
+1
-1
src/tl_templates/cuda/atomic.h
src/tl_templates/cuda/atomic.h
+1
-2
testing/python/language/test_tilelang_language_atomic_add.py
testing/python/language/test_tilelang_language_atomic_add.py
+3
-4
testing/python/runtime/test_tilelang_runtime_dynamic_shared_memory.py
...on/runtime/test_tilelang_runtime_dynamic_shared_memory.py
+52
-0
No files found.
.github/workflows/ci.yml
View file @
d4b6d094
...
...
@@ -352,7 +352,7 @@ jobs:
uv run --no-project -m --
pytest --verbose --color=yes --durations=0 --showlocals --cache-clear
)
"${PYTEST[@]}" --maxfail=3 --numprocesses=
1
\
"${PYTEST[@]}" --maxfail=3 --numprocesses=
4
\
../examples
# NVIDIA CUDA tests
...
...
tvm
@
713e6ade
Compare
f4affc7f
...
713e6ade
Subproject commit
f4affc7f31e36e7f88c0fe1c715b03215c6a0c62
Subproject commit
713e6ade56eaa72cc85d58d9228dd9f34cc2d03e
src/tl_templates/cuda/atomic.h
View file @
d4b6d094
...
...
@@ -131,8 +131,7 @@ TL_DEVICE void AtomicMin(T1 &ref, T2 val,
}
else
{
#if CUDART_VERSION >= 11080
cuda
::
atomic_ref
<
NT1
,
cuda
::
thread_scope_device
>
aref
(
*
address
);
return
static_cast
<
T1
>
(
aref
.
fetch_min
(
cuda_cast
<
NT1
>
(
val
),
cuda
::
memory_order
(
memory_order
)));
aref
.
fetch_min
(
cuda_cast
<
NT1
>
(
val
),
cuda
::
memory_order
(
memory_order
));
#else
TL_NOT_IMPLEMENTED
();
#endif
...
...
testing/python/language/test_tilelang_language_atomic_add.py
View file @
d4b6d094
...
...
@@ -374,10 +374,9 @@ def test_atomic_return_prev():
run_atomic_return_prev
(
32
,
32
,
8
,
8
)
# TODO(lei): test failed and this is experimental
# CC @dyq
# def test_tile_atomic_add():
# run_tile_atomic_add(8, 128, 128, 32, 32)
def
test_tile_atomic_add
():
run_tile_atomic_add
(
8
,
128
,
128
,
32
,
32
)
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
testing/python/runtime/test_tilelang_runtime_dynamic_shared_memory.py
0 → 100644
View file @
d4b6d094
import
pytest
import
torch
import
tilelang
import
tilelang.language
as
T
import
tilelang.testing
@
tilelang
.
jit
def
dynamic_smem_kernel
():
# Symbolic length to drive dynamic shared memory allocation
length
=
T
.
symbolic
(
"len"
,
dtype
=
"int32"
)
# noqa: F821
@
T
.
prim_func
def
main
(
global_tensor
:
T
.
Tensor
[(
length
,),
"int32"
]):
# noqa: F821
# Launch a simple kernel that copies from global memory into shared memory
# using a dynamically-sized allocation. No writes back to global_tensor.
with
T
.
Kernel
(
1
,
threads
=
32
)
as
_
:
buffer_shared
=
T
.
alloc_shared
((
length
,),
dtype
=
"int32"
)
# noqa: F821
T
.
copy
(
buffer_shared
,
global_tensor
)
return
main
def
_require_cuda_tensor
(
shape
,
dtype
):
if
not
torch
.
cuda
.
is_available
():
pytest
.
skip
(
"CUDA not available"
)
try
:
return
torch
.
randint
(
0
,
100
,
shape
,
dtype
=
dtype
,
device
=
"cuda"
)
except
RuntimeError
as
err
:
pytest
.
skip
(
f
"CUDA runtime unavailable:
{
err
}
"
)
def
_run_and_check
(
kernel
,
n
):
a
=
_require_cuda_tensor
((
n
,),
torch
.
int32
)
kernel
(
a
)
torch
.
cuda
.
synchronize
()
def
test_dynamic_shared_memory_varies_across_calls
():
kernel
=
dynamic_smem_kernel
()
# Run with different dynamic shared memory sizes across invocations
_run_and_check
(
kernel
,
100
)
_run_and_check
(
kernel
,
200
)
# Repeat sizes to exercise attribute caching path
_run_and_check
(
kernel
,
200
)
_run_and_check
(
kernel
,
100
)
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment