Unverified Commit 53a525bf authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Eagle] Fix kernel call after updating speculative sampling kernels (#7231)

parent 7ddf8e83
......@@ -20,7 +20,7 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
RUN pip3 install torch==2.7.1 torchvision==0.22.1 torchaudio==2.7.1 --index-url https://download.pytorch.org/whl/cu128 --break-system-packages
RUN pip3 install https://github.com/sgl-project/whl/releases/download/v0.1.8.post2/sgl_kernel-0.1.8.post2+cu128-cp39-abi3-manylinux2014_x86_64.whl --break-system-packages \
RUN pip3 install https://github.com/sgl-project/whl/releases/download/v0.1.9/sgl_kernel-0.1.9+cu128-cp39-abi3-manylinux2014_x86_64.whl --break-system-packages \
&& pip3 install setuptools==75.0.0 wheel scikit-build-core --break-system-packages
RUN git clone --depth=1 https://github.com/sgl-project/sglang.git \
......
......@@ -49,7 +49,7 @@ runtime_common = [
srt = [
"sglang[runtime_common]",
"sgl-kernel==0.1.8.post2",
"sgl-kernel==0.1.9",
"flashinfer_python==0.2.6.post1",
"torch==2.7.1",
"torchaudio==2.7.1",
......
......@@ -605,7 +605,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if _is_cuda:
assert_pkg_version(
"sgl-kernel",
"0.1.8.post2",
"0.1.9",
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
)
......
......@@ -92,7 +92,7 @@ def build_tree_kernel_efficient(
sgl_build_tree_kernel_efficient(
parent_list,
top_scores_index,
seq_lens.to(torch.int32),
seq_lens,
tree_mask,
positions,
retrive_index,
......
......@@ -23,7 +23,7 @@ from sglang.srt.managers.schedule_batch import (
)
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
if is_cuda():
from sgl_kernel import (
......@@ -32,6 +32,7 @@ if is_cuda():
tree_speculative_sampling_target_only,
verify_tree_greedy,
)
from sgl_kernel.top_k import fast_topk
elif is_hip():
from sgl_kernel import verify_tree_greedy
......@@ -327,11 +328,11 @@ class EagleVerifyInput:
predicts=predict, # mutable
accept_index=accept_index, # mutable
accept_token_num=accept_length, # mutable
candidates=candidates.to(torch.int32),
retrive_index=self.retrive_index.to(torch.int32),
retrive_next_token=self.retrive_next_token.to(torch.int32),
retrive_next_sibling=self.retrive_next_sibling.to(torch.int32),
target_predict=target_predict.to(torch.int32),
candidates=candidates,
retrive_index=self.retrive_index,
retrive_next_token=self.retrive_next_token,
retrive_next_sibling=self.retrive_next_sibling,
target_predict=target_predict,
)
else:
# apply temperature and get target probs
......@@ -370,12 +371,12 @@ class EagleVerifyInput:
predicts=predict, # mutable
accept_index=accept_index, # mutable
accept_token_num=accept_length, # mutable
candidates=candidates.to(torch.int32),
retrive_index=self.retrive_index.to(torch.int32),
retrive_next_token=self.retrive_next_token.to(torch.int32),
retrive_next_sibling=self.retrive_next_sibling.to(torch.int32),
candidates=candidates,
retrive_index=self.retrive_index,
retrive_next_token=self.retrive_next_token,
retrive_next_sibling=self.retrive_next_sibling,
uniform_samples=coins,
# uniform_samples_for_final_sampling=coins_for_final_sampling,
uniform_samples_for_final_sampling=coins_for_final_sampling,
target_probs=target_probs,
draft_probs=draft_probs,
threshold_single=global_server_args_dict[
......@@ -1005,16 +1006,6 @@ def select_top_k_tokens(
return input_ids, hidden_states, scores, tree_info
def fast_topk_torch(values, topk, dim):
if topk == 1:
# Use max along the specified dimension to get both value and index
max_value, max_index = torch.max(values, dim=dim)
return max_value.unsqueeze(1), max_index.unsqueeze(1)
else:
# Use topk for efficiency with larger k values
return torch.topk(values, topk, dim=dim)
def _generate_simulated_accept_index(
accept_index,
predict,
......
......@@ -828,7 +828,7 @@ def load_token_map(token_map_path: str) -> List[int]:
)
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
hot_token_id = torch.load(token_map_path, weights_only=True)
return torch.tensor(hot_token_id, dtype=torch.int32)
return torch.tensor(hot_token_id, dtype=torch.int64)
@torch.compile(dynamic=True)
......
......@@ -143,7 +143,7 @@ class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest):
args.extend(
[
"--cuda-graph-max-bs",
"2",
"4",
"--speculative-algorithm",
"EAGLE3",
"--speculative-draft",
......@@ -169,7 +169,7 @@ class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest):
model = DEFAULT_MODEL_NAME_FOR_TEST
accuracy_threshold = 0.65
speculative_decode = True
spec_decode_threshold = 1.5
spec_decode_threshold = 1.6
@classmethod
def get_server_args(cls):
......@@ -177,7 +177,7 @@ class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest):
args.extend(
[
"--cuda-graph-max-bs",
"2",
"4",
"--speculative-algorithm",
"EAGLE3",
"--speculative-draft",
......@@ -201,7 +201,7 @@ class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest):
model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
accuracy_threshold = 0.60
speculative_decode = True
spec_decode_threshold = 1.5
spec_decode_threshold = 2.5
@classmethod
def get_server_args(cls):
......@@ -209,7 +209,7 @@ class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest):
args.extend(
[
"--cuda-graph-max-bs",
"2",
"4",
"--speculative-algorithm",
"EAGLE",
"--speculative-draft",
......@@ -233,7 +233,7 @@ class TestFlashAttention3MLASpeculativeDecodeTopk(BaseFlashAttentionTest):
model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
accuracy_threshold = 0.60
speculative_decode = True
spec_decode_threshold = 1.5
spec_decode_threshold = 2.95
@classmethod
def get_server_args(cls):
......@@ -241,7 +241,7 @@ class TestFlashAttention3MLASpeculativeDecodeTopk(BaseFlashAttentionTest):
args.extend(
[
"--cuda-graph-max-bs",
"2",
"4",
"--speculative-algorithm",
"EAGLE",
"--speculative-draft",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment