Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
86a876d8
"vscode:/vscode.git/clone" did not exist on "b4b9376aab7ed20a9c515fee0559878a55d6c1d4"
Unverified
Commit
86a876d8
authored
Apr 09, 2025
by
fzyzcjy
Committed by
GitHub
Apr 09, 2025
Browse files
Optimize topk operation in llama4 (#5128)
parent
92823069
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
18 additions
and
15 deletions
+18
-15
python/sglang/srt/models/llama4.py
python/sglang/srt/models/llama4.py
+2
-2
python/sglang/srt/speculative/eagle_utils.py
python/sglang/srt/speculative/eagle_utils.py
+1
-11
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+6
-2
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+9
-0
No files found.
python/sglang/srt/models/llama4.py
View file @
86a876d8
...
@@ -48,7 +48,7 @@ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
...
@@ -48,7 +48,7 @@ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.models.llama
import
LlamaForCausalLM
,
LlamaMLP
from
sglang.srt.models.llama
import
LlamaForCausalLM
,
LlamaMLP
from
sglang.srt.utils
import
add_prefix
,
get_compiler_backend
,
make_layers
from
sglang.srt.utils
import
add_prefix
,
fast_topk
,
get_compiler_backend
,
make_layers
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -63,7 +63,7 @@ class Llama4MoE(nn.Module):
...
@@ -63,7 +63,7 @@ class Llama4MoE(nn.Module):
topk
:
int
,
topk
:
int
,
renormalize
:
bool
,
renormalize
:
bool
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
router_scores_aK
,
router_indices_aK
=
torch
.
topk
(
gating_output
,
topk
,
dim
=-
1
)
router_scores_aK
,
router_indices_aK
=
fast_
topk
(
gating_output
,
topk
,
dim
=-
1
)
router_scores_aK
=
torch
.
sigmoid
(
router_scores_aK
.
float
()).
to
(
router_scores_aK
=
torch
.
sigmoid
(
router_scores_aK
.
float
()).
to
(
hidden_states
.
dtype
hidden_states
.
dtype
)
)
...
...
python/sglang/srt/speculative/eagle_utils.py
View file @
86a876d8
...
@@ -19,7 +19,7 @@ from sglang.srt.managers.schedule_batch import (
...
@@ -19,7 +19,7 @@ from sglang.srt.managers.schedule_batch import (
from
sglang.srt.mem_cache.memory_pool
import
TokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool
import
TokenToKVPoolAllocator
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
from
sglang.srt.speculative.build_eagle_tree
import
build_tree_kernel_efficient
from
sglang.srt.speculative.build_eagle_tree
import
build_tree_kernel_efficient
from
sglang.srt.utils
import
is_cuda_available
,
is_hip
,
next_power_of_2
from
sglang.srt.utils
import
fast_topk
,
is_cuda_available
,
is_hip
,
next_power_of_2
if
is_cuda_available
():
if
is_cuda_available
():
from
sgl_kernel
import
(
from
sgl_kernel
import
(
...
@@ -772,16 +772,6 @@ def select_top_k_tokens(
...
@@ -772,16 +772,6 @@ def select_top_k_tokens(
return
input_ids
,
hidden_states
,
scores
,
tree_info
return
input_ids
,
hidden_states
,
scores
,
tree_info
def
fast_topk
(
values
,
topk
,
dim
):
if
topk
==
1
:
# Use max along the specified dimension to get both value and index
max_value
,
max_index
=
torch
.
max
(
values
,
dim
=
dim
)
return
max_value
.
unsqueeze
(
1
),
max_index
.
unsqueeze
(
1
)
else
:
# Use topk for efficiency with larger k values
return
torch
.
topk
(
values
,
topk
,
dim
=
dim
)
def
_generate_simulated_accept_index
(
def
_generate_simulated_accept_index
(
accept_index
,
accept_index
,
predict
,
predict
,
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
86a876d8
...
@@ -31,11 +31,15 @@ from sglang.srt.speculative.eagle_utils import (
...
@@ -31,11 +31,15 @@ from sglang.srt.speculative.eagle_utils import (
EagleVerifyInput
,
EagleVerifyInput
,
EagleVerifyOutput
,
EagleVerifyOutput
,
assign_draft_cache_locs
,
assign_draft_cache_locs
,
fast_topk
,
select_top_k_tokens
,
select_top_k_tokens
,
)
)
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.utils
import
empty_context
,
get_available_gpu_memory
,
is_cuda_available
from
sglang.srt.utils
import
(
empty_context
,
fast_topk
,
get_available_gpu_memory
,
is_cuda_available
,
)
if
is_cuda_available
():
if
is_cuda_available
():
from
sgl_kernel
import
segment_packbits
from
sgl_kernel
import
segment_packbits
...
...
python/sglang/srt/utils.py
View file @
86a876d8
...
@@ -1819,3 +1819,12 @@ class DeepEPMode(Enum):
...
@@ -1819,3 +1819,12 @@ class DeepEPMode(Enum):
return
DeepEPMode
.
low_latency
return
DeepEPMode
.
low_latency
else
:
else
:
return
DeepEPMode
.
normal
return
DeepEPMode
.
normal
def
fast_topk
(
values
,
topk
,
dim
):
if
topk
==
1
:
# Use max along the specified dimension to get both value and index
return
torch
.
max
(
values
,
dim
=
dim
,
keepdim
=
True
)
else
:
# Use topk for efficiency with larger k values
return
torch
.
topk
(
values
,
topk
,
dim
=
dim
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment