Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
8a5eb569
Unverified
Commit
8a5eb569
authored
Oct 22, 2025
by
Yu Cheng
Committed by
GitHub
Oct 22, 2025
Browse files
[Refactor] Use forceinline in `ldmatrix` and update mamba scan kernel (#1104)
parent
5683e6a6
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
20 deletions
+30
-20
examples/linear_attention/example_mamba_chunk_scan.py
examples/linear_attention/example_mamba_chunk_scan.py
+18
-8
src/tl_templates/cuda/ldsm.h
src/tl_templates/cuda/ldsm.h
+12
-12
No files found.
examples/linear_attention/example_mamba_chunk_scan.py
View file @
8a5eb569
...
...
@@ -71,7 +71,12 @@ def get_configs():
@
autotune
(
configs
=
get_configs
(),
warmup
=
10
,
rep
=
10
)
@
tilelang
.
jit
(
out_idx
=
[
7
])
@
tilelang
.
jit
(
out_idx
=
[
7
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
},
)
def
chunk_scan_fwd
(
batch
,
seqlen
,
chunk_size
,
...
...
@@ -91,13 +96,16 @@ def chunk_scan_fwd(batch,
p
=
1.44269504
@
T
.
prim_func
def
main
(
cb
:
T
.
Tensor
((
batch
,
nchunks
,
ngroups
,
chunk_size
,
chunk_size
),
dtype
),
x
:
T
.
Tensor
(
(
batch
,
seqlen
,
nheads
,
headdim
),
dtype
),
dt
:
T
.
Tensor
(
(
batch
,
nheads
,
nchunks
,
chunk_size
),
dtype
),
dA_cumsum
:
T
.
Tensor
(
(
batch
,
nheads
,
nchunks
,
chunk_size
),
dtype
),
C
:
T
.
Tensor
((
batch
,
seqlen
,
ngroups
,
dstate
),
dtype
),
prev_states
:
T
.
Tensor
(
(
batch
,
nchunks
,
nheads
,
headdim
,
dstate
),
dtype
),
D
:
T
.
Tensor
(
(
nheads
),
dtype
),
Output
:
T
.
Tensor
((
batch
,
seqlen
,
nheads
,
headdim
),
dtype
)):
def
main
(
cb
:
T
.
Tensor
((
batch
,
nchunks
,
ngroups
,
chunk_size
,
chunk_size
),
dtype
),
# type: ignore
x
:
T
.
Tensor
((
batch
,
seqlen
,
nheads
,
headdim
),
dtype
),
# type: ignore
dt
:
T
.
Tensor
((
batch
,
nheads
,
nchunks
,
chunk_size
),
dtype
),
# type: ignore
dA_cumsum
:
T
.
Tensor
((
batch
,
nheads
,
nchunks
,
chunk_size
),
dtype
),
# type: ignore
C
:
T
.
Tensor
((
batch
,
seqlen
,
ngroups
,
dstate
),
dtype
),
# type: ignore
prev_states
:
T
.
Tensor
((
batch
,
nchunks
,
nheads
,
headdim
,
dstate
),
dtype
),
# type: ignore
D
:
T
.
Tensor
((
nheads
),
dtype
),
# type: ignore
Output
:
T
.
Tensor
((
batch
,
seqlen
,
nheads
,
headdim
),
dtype
)
# type: ignore
):
with
T
.
Kernel
(
nheads
,
T
.
ceildiv
(
chunk_size
,
block_M
)
*
T
.
ceildiv
(
headdim
,
block_N
),
...
...
@@ -134,6 +142,8 @@ def chunk_scan_fwd(batch,
x_residual_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
x_residual_shared
)
})
T
.
no_set_max_nreg
()
T
.
copy
(
dA_cumsum
[
batch_idx
,
bz
,
chunk_idx
,
m_idx
*
block_M
:(
m_idx
+
1
)
*
block_M
],
dA_cs_m_shared
)
T
.
copy
(
dA_cs_m_shared
,
dA_cs_m_local
)
...
...
src/tl_templates/cuda/ldsm.h
View file @
8a5eb569
...
...
@@ -4,7 +4,7 @@
namespace
tl
{
TL_DEVICE
_NOINLINE
void
ptx_ldmatrix_x1
(
void
const
*
const
smem_ptr
,
TL_DEVICE
void
ptx_ldmatrix_x1
(
void
const
*
const
smem_ptr
,
void
*
const
local_ptr
)
{
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
smem_ptr
);
int32_t
*
value
=
reinterpret_cast
<
int32_t
*>
(
local_ptr
);
...
...
@@ -13,7 +13,7 @@ TL_DEVICE_NOINLINE void ptx_ldmatrix_x1(void const *const smem_ptr,
:
"r"
(
smem_int_ptr
));
}
TL_DEVICE
_NOINLINE
void
ptx_ldmatrix_x2
(
void
const
*
const
smem_ptr
,
TL_DEVICE
void
ptx_ldmatrix_x2
(
void
const
*
const
smem_ptr
,
void
*
const
local_ptr
)
{
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
smem_ptr
);
int32_t
*
value
=
reinterpret_cast
<
int32_t
*>
(
local_ptr
);
...
...
@@ -22,7 +22,7 @@ TL_DEVICE_NOINLINE void ptx_ldmatrix_x2(void const *const smem_ptr,
:
"r"
(
smem_int_ptr
));
}
TL_DEVICE
_NOINLINE
void
ptx_ldmatrix_x4
(
void
const
*
const
smem_ptr
,
TL_DEVICE
void
ptx_ldmatrix_x4
(
void
const
*
const
smem_ptr
,
void
*
const
local_ptr
)
{
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
smem_ptr
);
int32_t
*
value
=
reinterpret_cast
<
int32_t
*>
(
local_ptr
);
...
...
@@ -32,7 +32,7 @@ TL_DEVICE_NOINLINE void ptx_ldmatrix_x4(void const *const smem_ptr,
:
"r"
(
smem_int_ptr
));
}
TL_DEVICE
_NOINLINE
void
ptx_ldmatrix_x1_trans
(
void
const
*
const
smem_ptr
,
TL_DEVICE
void
ptx_ldmatrix_x1_trans
(
void
const
*
const
smem_ptr
,
void
*
const
local_ptr
)
{
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
smem_ptr
);
int32_t
*
value
=
reinterpret_cast
<
int32_t
*>
(
local_ptr
);
...
...
@@ -41,7 +41,7 @@ TL_DEVICE_NOINLINE void ptx_ldmatrix_x1_trans(void const *const smem_ptr,
:
"r"
(
smem_int_ptr
));
}
TL_DEVICE
_NOINLINE
void
ptx_ldmatrix_x2_trans
(
void
const
*
const
smem_ptr
,
TL_DEVICE
void
ptx_ldmatrix_x2_trans
(
void
const
*
const
smem_ptr
,
void
*
const
local_ptr
)
{
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
smem_ptr
);
int32_t
*
value
=
reinterpret_cast
<
int32_t
*>
(
local_ptr
);
...
...
@@ -51,7 +51,7 @@ TL_DEVICE_NOINLINE void ptx_ldmatrix_x2_trans(void const *const smem_ptr,
:
"r"
(
smem_int_ptr
));
}
TL_DEVICE
_NOINLINE
void
ptx_ldmatrix_x4_trans
(
void
const
*
const
smem_ptr
,
TL_DEVICE
void
ptx_ldmatrix_x4_trans
(
void
const
*
const
smem_ptr
,
void
*
const
local_ptr
)
{
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
smem_ptr
);
int32_t
*
value
=
reinterpret_cast
<
int32_t
*>
(
local_ptr
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment