Unverified Commit fad315cb authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

fix EAGLE 2 non greedy case (#3407)


Co-authored-by: default avatarYing Sheng <sqy1415@gmail.com>
parent f90db8bc
...@@ -54,7 +54,9 @@ def get_model_config(model_name: str, tp_size: int): ...@@ -54,7 +54,9 @@ def get_model_config(model_name: str, tp_size: int):
): ):
block_shape = config.quantization_config["weight_block_size"] block_shape = config.quantization_config["weight_block_size"]
assert len(block_shape) == 2 assert len(block_shape) == 2
assert vllm_version_num >= 66, "Block-wise quantized fp8 fused_moe is only supported for VLLM>=0.6.6.post1" assert (
vllm_version_num >= 66
), "Block-wise quantized fp8 fused_moe is only supported for VLLM>=0.6.6.post1"
shape_configs = { shape_configs = {
"num_experts": E, "num_experts": E,
......
...@@ -462,8 +462,11 @@ class CudaGraphRunner: ...@@ -462,8 +462,11 @@ class CudaGraphRunner:
), ),
positions=None, positions=None,
retrive_index=None, retrive_index=None,
retrive_next_token=None,
retrive_next_sibling=None,
retrive_cum_len=None, retrive_cum_len=None,
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens, draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
spec_steps=self.model_runner.server_args.speculative_num_steps,
capture_hidden_mode=CaptureHiddenMode.FULL, capture_hidden_mode=CaptureHiddenMode.FULL,
) )
......
...@@ -4,6 +4,7 @@ import dataclasses ...@@ -4,6 +4,7 @@ import dataclasses
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, List
import torch import torch
import torch.nn.functional as F
import triton import triton
import triton.language as tl import triton.language as tl
...@@ -11,7 +12,14 @@ from sglang.srt.layers.attention.flashinfer_backend import ( ...@@ -11,7 +12,14 @@ from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton, create_flashinfer_kv_indices_triton,
) )
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel from sglang.srt.speculative.build_eagle_tree import (
build_tree_kernel,
build_tree_kernel_efficient,
)
from sglang.srt.utils import is_cuda_available
if is_cuda_available():
from sgl_kernel import tree_speculative_sampling_target_only
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.schedule_batch import ScheduleBatch
...@@ -160,8 +168,11 @@ class EagleVerifyInput: ...@@ -160,8 +168,11 @@ class EagleVerifyInput:
custom_mask: torch.Tensor custom_mask: torch.Tensor
positions: torch.Tensor positions: torch.Tensor
retrive_index: torch.Tensor retrive_index: torch.Tensor
retrive_next_token: torch.Tensor
retrive_next_sibling: torch.Tensor
retrive_cum_len: torch.Tensor retrive_cum_len: torch.Tensor
draft_token_num: int draft_token_num: int
spec_steps: int
capture_hidden_mode: CaptureHiddenMode capture_hidden_mode: CaptureHiddenMode
@classmethod @classmethod
...@@ -175,10 +186,45 @@ class EagleVerifyInput: ...@@ -175,10 +186,45 @@ class EagleVerifyInput:
seq_lens_sum: int, seq_lens_sum: int,
topk: int, topk: int,
spec_steps: int, spec_steps: int,
num_verify_token: int, num_verify_tokens: int,
is_all_greedy: bool,
): ):
tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = ( if is_all_greedy:
build_tree_kernel( tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
build_tree_kernel(
verified_id,
score_list, # b, n, topk; n= 1 + (num_steps-1) * self.topk
token_list,
parents_list,
seq_lens,
seq_lens_sum,
topk,
spec_steps,
num_verify_tokens,
)
)
return cls(
draft_tokens,
tree_mask,
position,
retrive_index,
None,
None,
retrive_cum_len,
num_verify_tokens,
spec_steps,
CaptureHiddenMode.FULL,
)
else:
(
tree_mask,
position,
retrive_index,
retrive_next_token,
retrive_next_sibling,
draft_tokens,
) = build_tree_kernel_efficient(
verified_id, verified_id,
score_list, score_list,
token_list, token_list,
...@@ -187,18 +233,21 @@ class EagleVerifyInput: ...@@ -187,18 +233,21 @@ class EagleVerifyInput:
seq_lens_sum, seq_lens_sum,
topk, topk,
spec_steps, spec_steps,
num_verify_token, num_verify_tokens,
)
return cls(
draft_tokens,
tree_mask,
position,
retrive_index,
retrive_next_token,
retrive_next_sibling,
None,
num_verify_tokens,
spec_steps,
CaptureHiddenMode.FULL,
) )
)
return cls(
draft_tokens,
tree_mask,
position,
retrive_index,
retrive_cum_len,
num_verify_token,
CaptureHiddenMode.FULL,
)
def prepare_for_verify(self, batch: ScheduleBatch): def prepare_for_verify(self, batch: ScheduleBatch):
batch.input_ids = self.draft_token batch.input_ids = self.draft_token
...@@ -313,12 +362,6 @@ class EagleVerifyInput: ...@@ -313,12 +362,6 @@ class EagleVerifyInput:
uniform_samples=coins, uniform_samples=coins,
target_probs=target_probs, target_probs=target_probs,
draft_probs=draft_probs, draft_probs=draft_probs,
threshold_single=global_server_args_dict[
"speculative_accept_threshold_single"
],
threshold_acc=global_server_args_dict[
"speculative_accept_threshold_acc"
],
deterministic=True, deterministic=True,
) )
......
...@@ -185,6 +185,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -185,6 +185,7 @@ class EAGLEWorker(TpModelWorker):
self.topk, self.topk,
self.speculative_num_steps, self.speculative_num_steps,
self.server_args.speculative_num_draft_tokens, self.server_args.speculative_num_draft_tokens,
batch.sampling_info.is_all_greedy,
) )
# Free cache locations # Free cache locations
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment