Unverified Commit fad315cb authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

fix EAGLE 2 non greedy case (#3407)


Co-authored-by: default avatarYing Sheng <sqy1415@gmail.com>
parent f90db8bc
......@@ -54,7 +54,9 @@ def get_model_config(model_name: str, tp_size: int):
):
block_shape = config.quantization_config["weight_block_size"]
assert len(block_shape) == 2
assert vllm_version_num >= 66, "Block-wise quantized fp8 fused_moe is only supported for VLLM>=0.6.6.post1"
assert (
vllm_version_num >= 66
), "Block-wise quantized fp8 fused_moe is only supported for VLLM>=0.6.6.post1"
shape_configs = {
"num_experts": E,
......
......@@ -462,8 +462,11 @@ class CudaGraphRunner:
),
positions=None,
retrive_index=None,
retrive_next_token=None,
retrive_next_sibling=None,
retrive_cum_len=None,
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
spec_steps=self.model_runner.server_args.speculative_num_steps,
capture_hidden_mode=CaptureHiddenMode.FULL,
)
......
......@@ -4,6 +4,7 @@ import dataclasses
from typing import TYPE_CHECKING, List
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
......@@ -11,7 +12,14 @@ from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton,
)
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel
from sglang.srt.speculative.build_eagle_tree import (
build_tree_kernel,
build_tree_kernel_efficient,
)
from sglang.srt.utils import is_cuda_available
if is_cuda_available():
from sgl_kernel import tree_speculative_sampling_target_only
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch
......@@ -160,8 +168,11 @@ class EagleVerifyInput:
custom_mask: torch.Tensor
positions: torch.Tensor
retrive_index: torch.Tensor
retrive_next_token: torch.Tensor
retrive_next_sibling: torch.Tensor
retrive_cum_len: torch.Tensor
draft_token_num: int
spec_steps: int
capture_hidden_mode: CaptureHiddenMode
@classmethod
......@@ -175,10 +186,45 @@ class EagleVerifyInput:
seq_lens_sum: int,
topk: int,
spec_steps: int,
num_verify_token: int,
num_verify_tokens: int,
is_all_greedy: bool,
):
tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
build_tree_kernel(
if is_all_greedy:
tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
build_tree_kernel(
verified_id,
score_list, # b, n, topk; n= 1 + (num_steps-1) * self.topk
token_list,
parents_list,
seq_lens,
seq_lens_sum,
topk,
spec_steps,
num_verify_tokens,
)
)
return cls(
draft_tokens,
tree_mask,
position,
retrive_index,
None,
None,
retrive_cum_len,
num_verify_tokens,
spec_steps,
CaptureHiddenMode.FULL,
)
else:
(
tree_mask,
position,
retrive_index,
retrive_next_token,
retrive_next_sibling,
draft_tokens,
) = build_tree_kernel_efficient(
verified_id,
score_list,
token_list,
......@@ -187,18 +233,21 @@ class EagleVerifyInput:
seq_lens_sum,
topk,
spec_steps,
num_verify_token,
num_verify_tokens,
)
return cls(
draft_tokens,
tree_mask,
position,
retrive_index,
retrive_next_token,
retrive_next_sibling,
None,
num_verify_tokens,
spec_steps,
CaptureHiddenMode.FULL,
)
)
return cls(
draft_tokens,
tree_mask,
position,
retrive_index,
retrive_cum_len,
num_verify_token,
CaptureHiddenMode.FULL,
)
def prepare_for_verify(self, batch: ScheduleBatch):
batch.input_ids = self.draft_token
......@@ -313,12 +362,6 @@ class EagleVerifyInput:
uniform_samples=coins,
target_probs=target_probs,
draft_probs=draft_probs,
threshold_single=global_server_args_dict[
"speculative_accept_threshold_single"
],
threshold_acc=global_server_args_dict[
"speculative_accept_threshold_acc"
],
deterministic=True,
)
......
......@@ -185,6 +185,7 @@ class EAGLEWorker(TpModelWorker):
self.topk,
self.speculative_num_steps,
self.server_args.speculative_num_draft_tokens,
batch.sampling_info.is_all_greedy,
)
# Free cache locations
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment