Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
3aecab8f
Unverified
Commit
3aecab8f
authored
Oct 05, 2025
by
Lei Wang
Committed by
GitHub
Oct 05, 2025
Browse files
[Example] Disable TMA and enable FastMath for NSA Examples (#941)
* tma disable * int64 cast fix.
parent
557589ff
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
82 additions
and
12 deletions
+82
-12
examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py
examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py
+6
-4
examples/deepseek_nsa/example_tilelang_nsa_bwd.py
examples/deepseek_nsa/example_tilelang_nsa_bwd.py
+6
-3
examples/deepseek_nsa/example_tilelang_nsa_fwd.py
examples/deepseek_nsa/example_tilelang_nsa_fwd.py
+4
-1
examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py
examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py
+6
-3
src/transform/flatten_buffer.cc
src/transform/flatten_buffer.cc
+60
-1
No files found.
examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py
View file @
3aecab8f
...
...
@@ -38,9 +38,6 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
v
+=
(
bos
*
H
+
i_h
)
*
V
block_indices
+=
(
bos
+
i_t
)
*
H
*
S
+
i_h
*
S
# if USE_BLOCK_COUNTS:
# NS = tl.load(block_counts + (bos + i_t) * H + i_h)
# else:
NS
=
S
p_q
=
tl
.
make_block_ptr
(
q
+
(
bos
+
i_t
)
*
HQ
*
K
,
(
HQ
,
K
),
(
K
,
1
),
(
i_h
*
G
,
0
),
(
G
,
BK
),
...
...
@@ -452,7 +449,12 @@ def get_configs():
@
tilelang
.
autotune
(
configs
=
get_configs
(),)
@
tilelang
.
jit
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
def
tilelang_sparse_attention
(
batch
,
heads
,
seq_len
,
...
...
examples/deepseek_nsa/example_tilelang_nsa_bwd.py
View file @
3aecab8f
...
...
@@ -17,9 +17,12 @@ from einops import rearrange
import
tilelang
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
def
tilelang_kernel_fwd
(
batch
,
heads
,
...
...
examples/deepseek_nsa/example_tilelang_nsa_fwd.py
View file @
3aecab8f
...
...
@@ -9,8 +9,11 @@ tilelang.testing.set_random_seed(0)
@
tilelang
.
jit
(
out_idx
=
[
-
1
],
pass_configs
=
{
out_idx
=
[
-
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
def
native_sparse_attention
(
batch
,
heads
,
...
...
examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py
View file @
3aecab8f
...
...
@@ -16,9 +16,12 @@ from reference import naive_nsa
from
einops
import
rearrange
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
def
native_sparse_attention_varlen
(
batch
,
heads
,
c_seq_len
,
...
...
src/transform/flatten_buffer.cc
View file @
3aecab8f
...
...
@@ -62,6 +62,43 @@ private:
using
IRMutatorWithAnalyzer
::
VisitStmt
;
using
IRMutatorWithAnalyzer
::
VisitStmt_
;
class
Int64Promoter
:
public
tir
::
IndexDataTypeRewriter
{
public:
using
Parent
=
IndexDataTypeRewriter
;
PrimExpr
VisitExpr_
(
const
VarNode
*
op
)
final
{
if
(
op
->
dtype
.
is_int
()
&&
op
->
dtype
.
bits
()
<
64
)
{
return
cast
(
DataType
::
Int
(
64
),
GetRef
<
Var
>
(
op
));
}
return
GetRef
<
PrimExpr
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
IntImmNode
*
op
)
final
{
if
(
op
->
dtype
.
is_int
()
&&
op
->
dtype
.
bits
()
<
64
)
{
return
IntImm
(
DataType
::
Int
(
64
),
op
->
value
);
}
return
GetRef
<
PrimExpr
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
CastNode
*
op
)
final
{
if
(
op
->
dtype
.
is_int
()
&&
op
->
dtype
.
bits
()
<
64
)
{
return
cast
(
DataType
::
Int
(
64
),
op
->
value
);
}
return
GetRef
<
PrimExpr
>
(
op
);
}
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
// Force indices to be int64
auto
node
=
Downcast
<
BufferStore
>
(
Parent
::
VisitStmt_
(
op
));
return
std
::
move
(
node
);
}
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
auto
node
=
Downcast
<
BufferLoad
>
(
Parent
::
VisitExpr_
(
op
));
return
std
::
move
(
node
);
}
};
explicit
BufferFlattener
(
arith
::
Analyzer
*
ana
)
:
IRMutatorWithAnalyzer
(
ana
)
{}
Stmt
VisitStmt_
(
const
BlockNode
*
op
)
final
{
...
...
@@ -244,7 +281,29 @@ private:
Array
<
PrimExpr
>
GetSimplifiedElemOffset
(
const
Buffer
&
buffer
,
const
Array
<
PrimExpr
>
&
indices
)
{
auto
flattened_indices
=
buffer
->
ElemOffset
(
indices
);
return
this
->
IterMapSimplifyWithContext
(
flattened_indices
,
false
);
Array
<
PrimExpr
>
safe_indices
;
for
(
auto
index
:
flattened_indices
)
{
auto
int_bound
=
analyzer_
->
const_int_bound
(
index
);
DataType
dtype
=
index
->
dtype
;
if
(
dtype
.
is_int
()
&&
dtype
.
bits
()
<
64
)
{
int64_t
max_value
=
int_bound
->
max_value
;
int64_t
min_value
=
int_bound
->
min_value
;
const
int64_t
type_max
=
(
1LL
<<
(
dtype
.
bits
()
-
1
));
const
int64_t
type_min
=
-
(
1LL
<<
(
dtype
.
bits
()
-
1
));
if
(
max_value
>=
(
type_max
-
1
)
||
min_value
<
type_min
)
{
Int64Promoter
promoter
;
for
(
auto
&
index
:
flattened_indices
)
{
safe_indices
.
push_back
(
promoter
(
index
));
}
}
else
{
safe_indices
.
push_back
(
index
);
}
}
else
{
safe_indices
.
push_back
(
index
);
}
}
return
this
->
IterMapSimplifyWithContext
(
safe_indices
,
false
);
}
template
<
typename
Node
>
Node
VisitBufferAccess
(
Node
node
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment