Unverified Commit 53a525bf authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Eagle] Fix kernel call after updating speculative sampling kernels (#7231)

parent 7ddf8e83
...@@ -20,7 +20,7 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ ...@@ -20,7 +20,7 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
RUN pip3 install torch==2.7.1 torchvision==0.22.1 torchaudio==2.7.1 --index-url https://download.pytorch.org/whl/cu128 --break-system-packages RUN pip3 install torch==2.7.1 torchvision==0.22.1 torchaudio==2.7.1 --index-url https://download.pytorch.org/whl/cu128 --break-system-packages
RUN pip3 install https://github.com/sgl-project/whl/releases/download/v0.1.8.post2/sgl_kernel-0.1.8.post2+cu128-cp39-abi3-manylinux2014_x86_64.whl --break-system-packages \ RUN pip3 install https://github.com/sgl-project/whl/releases/download/v0.1.9/sgl_kernel-0.1.9+cu128-cp39-abi3-manylinux2014_x86_64.whl --break-system-packages \
&& pip3 install setuptools==75.0.0 wheel scikit-build-core --break-system-packages && pip3 install setuptools==75.0.0 wheel scikit-build-core --break-system-packages
RUN git clone --depth=1 https://github.com/sgl-project/sglang.git \ RUN git clone --depth=1 https://github.com/sgl-project/sglang.git \
......
...@@ -49,7 +49,7 @@ runtime_common = [ ...@@ -49,7 +49,7 @@ runtime_common = [
srt = [ srt = [
"sglang[runtime_common]", "sglang[runtime_common]",
"sgl-kernel==0.1.8.post2", "sgl-kernel==0.1.9",
"flashinfer_python==0.2.6.post1", "flashinfer_python==0.2.6.post1",
"torch==2.7.1", "torch==2.7.1",
"torchaudio==2.7.1", "torchaudio==2.7.1",
......
...@@ -605,7 +605,7 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -605,7 +605,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if _is_cuda: if _is_cuda:
assert_pkg_version( assert_pkg_version(
"sgl-kernel", "sgl-kernel",
"0.1.8.post2", "0.1.9",
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`", "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
) )
......
...@@ -92,7 +92,7 @@ def build_tree_kernel_efficient( ...@@ -92,7 +92,7 @@ def build_tree_kernel_efficient(
sgl_build_tree_kernel_efficient( sgl_build_tree_kernel_efficient(
parent_list, parent_list,
top_scores_index, top_scores_index,
seq_lens.to(torch.int32), seq_lens,
tree_mask, tree_mask,
positions, positions,
retrive_index, retrive_index,
......
...@@ -23,7 +23,7 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -23,7 +23,7 @@ from sglang.srt.managers.schedule_batch import (
) )
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2 from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
if is_cuda(): if is_cuda():
from sgl_kernel import ( from sgl_kernel import (
...@@ -32,6 +32,7 @@ if is_cuda(): ...@@ -32,6 +32,7 @@ if is_cuda():
tree_speculative_sampling_target_only, tree_speculative_sampling_target_only,
verify_tree_greedy, verify_tree_greedy,
) )
from sgl_kernel.top_k import fast_topk
elif is_hip(): elif is_hip():
from sgl_kernel import verify_tree_greedy from sgl_kernel import verify_tree_greedy
...@@ -327,11 +328,11 @@ class EagleVerifyInput: ...@@ -327,11 +328,11 @@ class EagleVerifyInput:
predicts=predict, # mutable predicts=predict, # mutable
accept_index=accept_index, # mutable accept_index=accept_index, # mutable
accept_token_num=accept_length, # mutable accept_token_num=accept_length, # mutable
candidates=candidates.to(torch.int32), candidates=candidates,
retrive_index=self.retrive_index.to(torch.int32), retrive_index=self.retrive_index,
retrive_next_token=self.retrive_next_token.to(torch.int32), retrive_next_token=self.retrive_next_token,
retrive_next_sibling=self.retrive_next_sibling.to(torch.int32), retrive_next_sibling=self.retrive_next_sibling,
target_predict=target_predict.to(torch.int32), target_predict=target_predict,
) )
else: else:
# apply temperature and get target probs # apply temperature and get target probs
...@@ -370,12 +371,12 @@ class EagleVerifyInput: ...@@ -370,12 +371,12 @@ class EagleVerifyInput:
predicts=predict, # mutable predicts=predict, # mutable
accept_index=accept_index, # mutable accept_index=accept_index, # mutable
accept_token_num=accept_length, # mutable accept_token_num=accept_length, # mutable
candidates=candidates.to(torch.int32), candidates=candidates,
retrive_index=self.retrive_index.to(torch.int32), retrive_index=self.retrive_index,
retrive_next_token=self.retrive_next_token.to(torch.int32), retrive_next_token=self.retrive_next_token,
retrive_next_sibling=self.retrive_next_sibling.to(torch.int32), retrive_next_sibling=self.retrive_next_sibling,
uniform_samples=coins, uniform_samples=coins,
# uniform_samples_for_final_sampling=coins_for_final_sampling, uniform_samples_for_final_sampling=coins_for_final_sampling,
target_probs=target_probs, target_probs=target_probs,
draft_probs=draft_probs, draft_probs=draft_probs,
threshold_single=global_server_args_dict[ threshold_single=global_server_args_dict[
...@@ -1005,16 +1006,6 @@ def select_top_k_tokens( ...@@ -1005,16 +1006,6 @@ def select_top_k_tokens(
return input_ids, hidden_states, scores, tree_info return input_ids, hidden_states, scores, tree_info
def fast_topk_torch(values, topk, dim):
if topk == 1:
# Use max along the specified dimension to get both value and index
max_value, max_index = torch.max(values, dim=dim)
return max_value.unsqueeze(1), max_index.unsqueeze(1)
else:
# Use topk for efficiency with larger k values
return torch.topk(values, topk, dim=dim)
def _generate_simulated_accept_index( def _generate_simulated_accept_index(
accept_index, accept_index,
predict, predict,
......
...@@ -828,7 +828,7 @@ def load_token_map(token_map_path: str) -> List[int]: ...@@ -828,7 +828,7 @@ def load_token_map(token_map_path: str) -> List[int]:
) )
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path)) token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
hot_token_id = torch.load(token_map_path, weights_only=True) hot_token_id = torch.load(token_map_path, weights_only=True)
return torch.tensor(hot_token_id, dtype=torch.int32) return torch.tensor(hot_token_id, dtype=torch.int64)
@torch.compile(dynamic=True) @torch.compile(dynamic=True)
......
...@@ -143,7 +143,7 @@ class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest): ...@@ -143,7 +143,7 @@ class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest):
args.extend( args.extend(
[ [
"--cuda-graph-max-bs", "--cuda-graph-max-bs",
"2", "4",
"--speculative-algorithm", "--speculative-algorithm",
"EAGLE3", "EAGLE3",
"--speculative-draft", "--speculative-draft",
...@@ -169,7 +169,7 @@ class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest): ...@@ -169,7 +169,7 @@ class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest):
model = DEFAULT_MODEL_NAME_FOR_TEST model = DEFAULT_MODEL_NAME_FOR_TEST
accuracy_threshold = 0.65 accuracy_threshold = 0.65
speculative_decode = True speculative_decode = True
spec_decode_threshold = 1.5 spec_decode_threshold = 1.6
@classmethod @classmethod
def get_server_args(cls): def get_server_args(cls):
...@@ -177,7 +177,7 @@ class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest): ...@@ -177,7 +177,7 @@ class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest):
args.extend( args.extend(
[ [
"--cuda-graph-max-bs", "--cuda-graph-max-bs",
"2", "4",
"--speculative-algorithm", "--speculative-algorithm",
"EAGLE3", "EAGLE3",
"--speculative-draft", "--speculative-draft",
...@@ -201,7 +201,7 @@ class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest): ...@@ -201,7 +201,7 @@ class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest):
model = DEFAULT_MODEL_NAME_FOR_TEST_MLA model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
accuracy_threshold = 0.60 accuracy_threshold = 0.60
speculative_decode = True speculative_decode = True
spec_decode_threshold = 1.5 spec_decode_threshold = 2.5
@classmethod @classmethod
def get_server_args(cls): def get_server_args(cls):
...@@ -209,7 +209,7 @@ class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest): ...@@ -209,7 +209,7 @@ class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest):
args.extend( args.extend(
[ [
"--cuda-graph-max-bs", "--cuda-graph-max-bs",
"2", "4",
"--speculative-algorithm", "--speculative-algorithm",
"EAGLE", "EAGLE",
"--speculative-draft", "--speculative-draft",
...@@ -233,7 +233,7 @@ class TestFlashAttention3MLASpeculativeDecodeTopk(BaseFlashAttentionTest): ...@@ -233,7 +233,7 @@ class TestFlashAttention3MLASpeculativeDecodeTopk(BaseFlashAttentionTest):
model = DEFAULT_MODEL_NAME_FOR_TEST_MLA model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
accuracy_threshold = 0.60 accuracy_threshold = 0.60
speculative_decode = True speculative_decode = True
spec_decode_threshold = 1.5 spec_decode_threshold = 2.95
@classmethod @classmethod
def get_server_args(cls): def get_server_args(cls):
...@@ -241,7 +241,7 @@ class TestFlashAttention3MLASpeculativeDecodeTopk(BaseFlashAttentionTest): ...@@ -241,7 +241,7 @@ class TestFlashAttention3MLASpeculativeDecodeTopk(BaseFlashAttentionTest):
args.extend( args.extend(
[ [
"--cuda-graph-max-bs", "--cuda-graph-max-bs",
"2", "4",
"--speculative-algorithm", "--speculative-algorithm",
"EAGLE", "EAGLE",
"--speculative-draft", "--speculative-draft",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment