Unverified Commit 86a876d8 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Optimize topk operation in llama4 (#5128)

parent 92823069
......@@ -48,7 +48,7 @@ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
from sglang.srt.utils import add_prefix, get_compiler_backend, make_layers
from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers
logger = logging.getLogger(__name__)
......@@ -63,7 +63,7 @@ class Llama4MoE(nn.Module):
topk: int,
renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
router_scores_aK, router_indices_aK = torch.topk(gating_output, topk, dim=-1)
router_scores_aK, router_indices_aK = fast_topk(gating_output, topk, dim=-1)
router_scores_aK = torch.sigmoid(router_scores_aK.float()).to(
hidden_states.dtype
)
......
......@@ -19,7 +19,7 @@ from sglang.srt.managers.schedule_batch import (
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
from sglang.srt.utils import is_cuda_available, is_hip, next_power_of_2
from sglang.srt.utils import fast_topk, is_cuda_available, is_hip, next_power_of_2
if is_cuda_available():
from sgl_kernel import (
......@@ -772,16 +772,6 @@ def select_top_k_tokens(
return input_ids, hidden_states, scores, tree_info
def fast_topk(values, topk, dim):
if topk == 1:
# Use max along the specified dimension to get both value and index
max_value, max_index = torch.max(values, dim=dim)
return max_value.unsqueeze(1), max_index.unsqueeze(1)
else:
# Use topk for efficiency with larger k values
return torch.topk(values, topk, dim=dim)
def _generate_simulated_accept_index(
accept_index,
predict,
......
......@@ -31,11 +31,15 @@ from sglang.srt.speculative.eagle_utils import (
EagleVerifyInput,
EagleVerifyOutput,
assign_draft_cache_locs,
fast_topk,
select_top_k_tokens,
)
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import empty_context, get_available_gpu_memory, is_cuda_available
from sglang.srt.utils import (
empty_context,
fast_topk,
get_available_gpu_memory,
is_cuda_available,
)
if is_cuda_available():
from sgl_kernel import segment_packbits
......
......@@ -1819,3 +1819,12 @@ class DeepEPMode(Enum):
return DeepEPMode.low_latency
else:
return DeepEPMode.normal
def fast_topk(values, topk, dim):
if topk == 1:
# Use max along the specified dimension to get both value and index
return torch.max(values, dim=dim, keepdim=True)
else:
# Use topk for efficiency with larger k values
return torch.topk(values, topk, dim=dim)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment