Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
86a876d8
Unverified
Commit
86a876d8
authored
Apr 09, 2025
by
fzyzcjy
Committed by
GitHub
Apr 09, 2025
Browse files
Optimize topk operation in llama4 (#5128)
parent
92823069
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
18 additions
and
15 deletions
+18
-15
python/sglang/srt/models/llama4.py
python/sglang/srt/models/llama4.py
+2
-2
python/sglang/srt/speculative/eagle_utils.py
python/sglang/srt/speculative/eagle_utils.py
+1
-11
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+6
-2
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+9
-0
No files found.
python/sglang/srt/models/llama4.py
View file @
86a876d8
...
@@ -48,7 +48,7 @@ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
...
@@ -48,7 +48,7 @@ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.models.llama
import
LlamaForCausalLM
,
LlamaMLP
from
sglang.srt.models.llama
import
LlamaForCausalLM
,
LlamaMLP
from
sglang.srt.utils
import
add_prefix
,
get_compiler_backend
,
make_layers
from
sglang.srt.utils
import
add_prefix
,
fast_topk
,
get_compiler_backend
,
make_layers
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -63,7 +63,7 @@ class Llama4MoE(nn.Module):
...
@@ -63,7 +63,7 @@ class Llama4MoE(nn.Module):
topk
:
int
,
topk
:
int
,
renormalize
:
bool
,
renormalize
:
bool
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
router_scores_aK
,
router_indices_aK
=
torch
.
topk
(
gating_output
,
topk
,
dim
=-
1
)
router_scores_aK
,
router_indices_aK
=
fast_
topk
(
gating_output
,
topk
,
dim
=-
1
)
router_scores_aK
=
torch
.
sigmoid
(
router_scores_aK
.
float
()).
to
(
router_scores_aK
=
torch
.
sigmoid
(
router_scores_aK
.
float
()).
to
(
hidden_states
.
dtype
hidden_states
.
dtype
)
)
...
...
python/sglang/srt/speculative/eagle_utils.py
View file @
86a876d8
...
@@ -19,7 +19,7 @@ from sglang.srt.managers.schedule_batch import (
...
@@ -19,7 +19,7 @@ from sglang.srt.managers.schedule_batch import (
from
sglang.srt.mem_cache.memory_pool
import
TokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool
import
TokenToKVPoolAllocator
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
from
sglang.srt.speculative.build_eagle_tree
import
build_tree_kernel_efficient
from
sglang.srt.speculative.build_eagle_tree
import
build_tree_kernel_efficient
from
sglang.srt.utils
import
is_cuda_available
,
is_hip
,
next_power_of_2
from
sglang.srt.utils
import
fast_topk
,
is_cuda_available
,
is_hip
,
next_power_of_2
if
is_cuda_available
():
if
is_cuda_available
():
from
sgl_kernel
import
(
from
sgl_kernel
import
(
...
@@ -772,16 +772,6 @@ def select_top_k_tokens(
...
@@ -772,16 +772,6 @@ def select_top_k_tokens(
return
input_ids
,
hidden_states
,
scores
,
tree_info
return
input_ids
,
hidden_states
,
scores
,
tree_info
def
fast_topk
(
values
,
topk
,
dim
):
if
topk
==
1
:
# Use max along the specified dimension to get both value and index
max_value
,
max_index
=
torch
.
max
(
values
,
dim
=
dim
)
return
max_value
.
unsqueeze
(
1
),
max_index
.
unsqueeze
(
1
)
else
:
# Use topk for efficiency with larger k values
return
torch
.
topk
(
values
,
topk
,
dim
=
dim
)
def
_generate_simulated_accept_index
(
def
_generate_simulated_accept_index
(
accept_index
,
accept_index
,
predict
,
predict
,
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
86a876d8
...
@@ -31,11 +31,15 @@ from sglang.srt.speculative.eagle_utils import (
...
@@ -31,11 +31,15 @@ from sglang.srt.speculative.eagle_utils import (
EagleVerifyInput
,
EagleVerifyInput
,
EagleVerifyOutput
,
EagleVerifyOutput
,
assign_draft_cache_locs
,
assign_draft_cache_locs
,
fast_topk
,
select_top_k_tokens
,
select_top_k_tokens
,
)
)
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.utils
import
empty_context
,
get_available_gpu_memory
,
is_cuda_available
from
sglang.srt.utils
import
(
empty_context
,
fast_topk
,
get_available_gpu_memory
,
is_cuda_available
,
)
if
is_cuda_available
():
if
is_cuda_available
():
from
sgl_kernel
import
segment_packbits
from
sgl_kernel
import
segment_packbits
...
...
python/sglang/srt/utils.py
View file @
86a876d8
...
@@ -1819,3 +1819,12 @@ class DeepEPMode(Enum):
...
@@ -1819,3 +1819,12 @@ class DeepEPMode(Enum):
return
DeepEPMode
.
low_latency
return
DeepEPMode
.
low_latency
else
:
else
:
return
DeepEPMode
.
normal
return
DeepEPMode
.
normal
def
fast_topk
(
values
,
topk
,
dim
):
if
topk
==
1
:
# Use max along the specified dimension to get both value and index
return
torch
.
max
(
values
,
dim
=
dim
,
keepdim
=
True
)
else
:
# Use topk for efficiency with larger k values
return
torch
.
topk
(
values
,
topk
,
dim
=
dim
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment