Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9fce7bee
Unverified
Commit
9fce7bee
authored
Oct 20, 2025
by
Jiangyun Zhu
Committed by
GitHub
Oct 20, 2025
Browse files
[Kernel] Accelerate solve_tril with TMA (#26746)
Signed-off-by:
zjy0516
<
riverclouds.zhu@qq.com
>
parent
b63f2143
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
412 additions
and
301 deletions
+412
-301
vllm/model_executor/layers/fla/ops/op.py
vllm/model_executor/layers/fla/ops/op.py
+32
-11
vllm/model_executor/layers/fla/ops/solve_tril.py
vllm/model_executor/layers/fla/ops/solve_tril.py
+375
-290
vllm/model_executor/layers/fla/ops/utils.py
vllm/model_executor/layers/fla/ops/utils.py
+5
-0
No files found.
vllm/model_executor/layers/fla/ops/op.py
View file @
9fce7bee
...
...
@@ -11,29 +11,50 @@ import os
from
vllm.triton_utils
import
tl
,
tldevice
,
triton
from
.utils
import
is_gather_supported
if
os
.
environ
.
get
(
"FLA_USE_FAST_OPS"
,
"0"
)
==
"1"
:
div
=
tldevice
.
fast_dividef
exp
=
tldevice
.
fast_expf
log
=
tldevice
.
fast_logf
log2
=
tldevice
.
fast_log2f
else
:
@
triton
.
jit
def
div_normal
(
x
,
y
):
return
x
/
y
div
=
div_normal
exp
=
tl
.
exp
log
=
tl
.
log
log2
=
tl
.
log2
if
not
hasattr
(
tl
,
"gather"
)
:
if
not
is_gather_supported
:
@
triton
.
jit
def
gather
(
src
,
index
,
axis
,
_builder
=
None
):
# This is a fallback implementation when tl.gather is not supported
# In order to pass triton compiler, there is no actual gather operation
return
src
"""
Gather operation that works when tl.gather is not supported.
This is a fallback implementation that returns None.
Just to make triton compiler happy.
"""
return
None
else
:
gather
=
tl
.
gather
if
hasattr
(
triton
.
language
,
"_experimental_make_tensor_descriptor"
):
# For Triton 3.3.x
make_tensor_descriptor
=
triton
.
language
.
_experimental_make_tensor_descriptor
elif
hasattr
(
triton
.
language
,
"make_tensor_descriptor"
):
# For Triton 3.4.x and later
make_tensor_descriptor
=
triton
.
language
.
make_tensor_descriptor
else
:
"""
Fallback implementation when TMA is not supported.
Returns None to indicate TMA descriptors are unavailable.
Just make triton compiler happy.
"""
@
triton
.
jit
def
make_tensor_descriptor
(
base
,
shape
,
strides
,
block_shape
,
_builder
=
None
,
):
return
None
vllm/model_executor/layers/fla/ops/solve_tril.py
View file @
9fce7bee
This diff is collapsed.
Click to expand it.
vllm/model_executor/layers/fla/ops/utils.py
View file @
9fce7bee
...
...
@@ -150,6 +150,11 @@ is_nvidia_hopper = is_nvidia and (
or
torch
.
cuda
.
get_device_capability
()[
0
]
>=
9
)
use_cuda_graph
=
is_nvidia
and
os
.
environ
.
get
(
"FLA_USE_CUDA_GRAPH"
,
"0"
)
==
"1"
is_gather_supported
=
hasattr
(
triton
.
language
,
"gather"
)
is_tma_supported
=
(
is_nvidia
and
torch
.
cuda
.
get_device_capability
(
0
)[
0
]
>=
9
)
and
(
hasattr
(
triton
.
language
,
"_experimental_make_tensor_descriptor"
)
or
hasattr
(
triton
.
language
,
"make_tensor_descriptor"
)
)
def
get_all_max_shared_mem
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment