Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2163586e
Unverified
Commit
2163586e
authored
May 29, 2025
by
JieXin Liang
Committed by
GitHub
May 28, 2025
Browse files
[feat] triton kernel for get_last_loc (#6676)
parent
e06b0761
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
64 additions
and
2 deletions
+64
-2
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+64
-2
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
2163586e
...
@@ -1810,10 +1810,72 @@ def write_req_to_token_pool_triton(
...
@@ -1810,10 +1810,72 @@ def write_req_to_token_pool_triton(
)
)
@
torch
.
compile
(
dynamic
=
True
,
backend
=
get_compiler_backend
())
def
get_last_loc
(
def
get_last_loc
(
req_to_token
,
req_pool_indices_tensor
,
prefix_lens_tensor
):
req_to_token
:
torch
.
Tensor
,
req_pool_indices_tensor
:
torch
.
Tensor
,
prefix_lens_tensor
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
if
global_server_args_dict
[
"attention_backend"
]
!=
"torch_native"
:
impl
=
get_last_loc_triton
else
:
impl
=
get_last_loc_torch
return
impl
(
req_to_token
,
req_pool_indices_tensor
,
prefix_lens_tensor
)
def
get_last_loc_torch
(
req_to_token
:
torch
.
Tensor
,
req_pool_indices_tensor
:
torch
.
Tensor
,
prefix_lens_tensor
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
return
torch
.
where
(
return
torch
.
where
(
prefix_lens_tensor
>
0
,
prefix_lens_tensor
>
0
,
req_to_token
[
req_pool_indices_tensor
,
prefix_lens_tensor
-
1
],
req_to_token
[
req_pool_indices_tensor
,
prefix_lens_tensor
-
1
],
torch
.
full_like
(
prefix_lens_tensor
,
-
1
),
torch
.
full_like
(
prefix_lens_tensor
,
-
1
),
)
)
@
triton
.
jit
def
get_last_loc_kernel
(
req_to_token
,
req_pool_indices_tensor
,
prefix_lens_tensor
,
result
,
num_tokens
,
req_to_token_stride
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
+
pid
*
BLOCK_SIZE
mask
=
offset
<
num_tokens
prefix_lens
=
tl
.
load
(
prefix_lens_tensor
+
offset
,
mask
=
mask
,
other
=
0
)
req_pool_indices
=
tl
.
load
(
req_pool_indices_tensor
+
offset
,
mask
=
mask
,
other
=
0
)
token_mask
=
prefix_lens
>
0
token_index
=
req_pool_indices
*
req_to_token_stride
+
(
prefix_lens
-
1
)
tokens
=
tl
.
load
(
req_to_token
+
token_index
,
mask
=
token_mask
,
other
=-
1
)
tl
.
store
(
result
+
offset
,
tokens
,
mask
=
mask
)
def
get_last_loc_triton
(
req_to_token
:
torch
.
Tensor
,
req_pool_indices_tensor
:
torch
.
Tensor
,
prefix_lens_tensor
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
BLOCK_SIZE
=
256
num_tokens
=
prefix_lens_tensor
.
shape
[
0
]
result
=
torch
.
empty_like
(
prefix_lens_tensor
)
grid
=
(
triton
.
cdiv
(
num_tokens
,
BLOCK_SIZE
),)
get_last_loc_kernel
[
grid
](
req_to_token
,
req_pool_indices_tensor
,
prefix_lens_tensor
,
result
,
num_tokens
,
req_to_token
.
stride
(
0
),
BLOCK_SIZE
,
)
return
result
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment