Unverified Commit 86a876d8 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Optimize topk operation in llama4 (#5128)

parent 92823069
...@@ -48,7 +48,7 @@ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding ...@@ -48,7 +48,7 @@ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
from sglang.srt.utils import add_prefix, get_compiler_backend, make_layers from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -63,7 +63,7 @@ class Llama4MoE(nn.Module): ...@@ -63,7 +63,7 @@ class Llama4MoE(nn.Module):
topk: int, topk: int,
renormalize: bool, renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
router_scores_aK, router_indices_aK = torch.topk(gating_output, topk, dim=-1) router_scores_aK, router_indices_aK = fast_topk(gating_output, topk, dim=-1)
router_scores_aK = torch.sigmoid(router_scores_aK.float()).to( router_scores_aK = torch.sigmoid(router_scores_aK.float()).to(
hidden_states.dtype hidden_states.dtype
) )
......
...@@ -19,7 +19,7 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -19,7 +19,7 @@ from sglang.srt.managers.schedule_batch import (
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
from sglang.srt.utils import is_cuda_available, is_hip, next_power_of_2 from sglang.srt.utils import fast_topk, is_cuda_available, is_hip, next_power_of_2
if is_cuda_available(): if is_cuda_available():
from sgl_kernel import ( from sgl_kernel import (
...@@ -772,16 +772,6 @@ def select_top_k_tokens( ...@@ -772,16 +772,6 @@ def select_top_k_tokens(
return input_ids, hidden_states, scores, tree_info return input_ids, hidden_states, scores, tree_info
def fast_topk(values, topk, dim):
if topk == 1:
# Use max along the specified dimension to get both value and index
max_value, max_index = torch.max(values, dim=dim)
return max_value.unsqueeze(1), max_index.unsqueeze(1)
else:
# Use topk for efficiency with larger k values
return torch.topk(values, topk, dim=dim)
def _generate_simulated_accept_index( def _generate_simulated_accept_index(
accept_index, accept_index,
predict, predict,
......
...@@ -31,11 +31,15 @@ from sglang.srt.speculative.eagle_utils import ( ...@@ -31,11 +31,15 @@ from sglang.srt.speculative.eagle_utils import (
EagleVerifyInput, EagleVerifyInput,
EagleVerifyOutput, EagleVerifyOutput,
assign_draft_cache_locs, assign_draft_cache_locs,
fast_topk,
select_top_k_tokens, select_top_k_tokens,
) )
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import empty_context, get_available_gpu_memory, is_cuda_available from sglang.srt.utils import (
empty_context,
fast_topk,
get_available_gpu_memory,
is_cuda_available,
)
if is_cuda_available(): if is_cuda_available():
from sgl_kernel import segment_packbits from sgl_kernel import segment_packbits
......
...@@ -1819,3 +1819,12 @@ class DeepEPMode(Enum): ...@@ -1819,3 +1819,12 @@ class DeepEPMode(Enum):
return DeepEPMode.low_latency return DeepEPMode.low_latency
else: else:
return DeepEPMode.normal return DeepEPMode.normal
def fast_topk(values, topk, dim):
if topk == 1:
# Use max along the specified dimension to get both value and index
return torch.max(values, dim=dim, keepdim=True)
else:
# Use topk for efficiency with larger k values
return torch.topk(values, topk, dim=dim)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment