Unverified Commit b8574f69 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up eagle code (#2756)

parent 2855caa4
......@@ -74,11 +74,6 @@ class LogitsMetadata:
@classmethod
def from_forward_batch(cls, forward_batch: ForwardBatch):
if forward_batch.spec_info:
capture_hidden_mode = forward_batch.spec_info.capture_hidden_mode
else:
capture_hidden_mode = CaptureHiddenMode.NULL
if forward_batch.forward_mode.is_extend() and forward_batch.return_logprob:
extend_return_logprob = True
extend_return_top_logprob = any(
......@@ -98,7 +93,7 @@ class LogitsMetadata:
return cls(
forward_mode=forward_batch.forward_mode,
capture_hidden_mode=capture_hidden_mode,
capture_hidden_mode=forward_batch.capture_hidden_mode,
extend_return_logprob=extend_return_logprob,
extend_return_top_logprob=extend_return_top_logprob,
extend_seq_lens=forward_batch.extend_seq_lens,
......
......@@ -44,7 +44,7 @@ from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
......@@ -1163,6 +1163,11 @@ class ScheduleBatch:
input_embeds=self.input_embeds,
spec_algorithm=self.spec_algorithm,
spec_info=self.spec_info,
capture_hidden_mode=(
getattr(self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL)
if self.spec_info
else CaptureHiddenMode.NULL
),
)
def copy(self):
......@@ -1237,6 +1242,7 @@ class ModelWorkerBatch:
# Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[SpecInfo] = None
capture_hidden_mode: CaptureHiddenMode = None
@triton.jit
......
......@@ -962,10 +962,13 @@ class Scheduler:
self.tp_worker.forward_batch_generation(model_worker_batch)
)
else:
logits_output, next_token_ids, model_worker_batch, spec_info = (
self.draft_worker.forward_batch_speculative_generation(batch)
)
batch.spec_info = spec_info
(
logits_output,
next_token_ids,
model_worker_batch,
num_accepted_tokens,
) = self.draft_worker.forward_batch_speculative_generation(batch)
self.num_generated_tokens += num_accepted_tokens
elif batch.forward_mode.is_idle():
model_worker_batch = batch.get_model_worker_batch()
self.tp_worker.forward_batch_idle(model_worker_batch)
......
......@@ -322,6 +322,8 @@ class CudaGraphRunner:
global_num_tokens = None
gathered_buffer = None
spec_info = self.get_spec_info(num_tokens, positions)
forward_batch = ForwardBatch(
forward_mode=self.capture_forward_mode,
batch_size=bs,
......@@ -341,7 +343,10 @@ class CudaGraphRunner:
mrope_positions=mrope_positions,
gathered_buffer=gathered_buffer,
spec_algorithm=self.model_runner.spec_algorithm,
spec_info=self.get_spec_info(num_tokens, positions),
spec_info=spec_info,
capture_hidden_mode=(
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
),
)
# Attention backend
......@@ -446,10 +451,10 @@ class CudaGraphRunner:
if self.model_runner.is_draft_worker:
spec_info = EAGLEDraftInput()
spec_info.load_server_args(self.model_runner.server_args)
spec_info.hidden_states = self.hidden_states[:num_tokens]
spec_info.positions = positions
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
spec_info.init(self.model_runner.server_args)
else:
spec_info = EagleVerifyInput(
None,
......
......@@ -107,6 +107,21 @@ class ForwardMode(IntEnum):
return self == ForwardMode.DUMMY_FIRST
class CaptureHiddenMode(IntEnum):
NULL = auto()
FULL = auto()
LAST = auto()
def need_capture(self):
return self != CaptureHiddenMode.NULL
def is_full(self):
return self == CaptureHiddenMode.FULL
def is_last(self):
return self == CaptureHiddenMode.LAST
@dataclass
class ForwardBatch:
"""Store all inputs of a forward pass."""
......@@ -174,6 +189,7 @@ class ForwardBatch:
# Speculative decoding
spec_info: SpecInfo = None
spec_algorithm: SpeculativeAlgorithm = None
capture_hidden_mode: CaptureHiddenMode = None
# For Qwen2-VL
mrope_positions: torch.Tensor = None
......@@ -265,6 +281,7 @@ class ForwardBatch:
sampling_info=batch.sampling_info,
spec_algorithm=batch.spec_algorithm,
spec_info=batch.spec_info,
capture_hidden_mode=batch.capture_hidden_mode,
input_embeds=batch.input_embeds,
)
......@@ -400,18 +417,3 @@ def compute_position_torch(
@maybe_torch_compile(dynamic=True)
def clamp_position(seq_lens):
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
class CaptureHiddenMode(IntEnum):
NULL = auto()
FULL = auto()
LAST = auto()
def need_capture(self):
return self != CaptureHiddenMode.NULL
def is_full(self):
return self == CaptureHiddenMode.FULL
def is_last(self):
return self == CaptureHiddenMode.LAST
......@@ -9,12 +9,11 @@ import triton.language as tl
from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel
from sglang.srt.speculative.spec_info import SpecInfo
if TYPE_CHECKING:
from python.sglang.srt.layers.sampler import SampleOutput
from python.sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.server_args import ServerArgs
......@@ -179,19 +178,9 @@ def generate_draft_decode_kv_indices(
class EAGLEDraftInput(SpecInfo):
hidden_states: torch.Tensor = None
verified_id: torch.Tensor = None
positions: torch.Tensor = None
accept_length: torch.Tensor = None
has_finished: bool = False
unfinished_index: List[int] = None
def init(self, server_args: ServerArgs):
def __init__(self):
self.prev_mode = ForwardMode.DECODE
self.sample_output = None
self.topk: int = server_args.speculative_eagle_topk
self.num_verify_token: int = server_args.speculative_num_draft_tokens
self.spec_steps = server_args.speculative_num_steps
self.scores: torch.Tensor = None
self.score_list: List[torch.Tensor] = []
......@@ -200,11 +189,20 @@ class EAGLEDraftInput(SpecInfo):
self.parents_list: List[torch.Tensor] = []
self.cache_list: List[torch.Tenor] = []
self.iter = 0
self.root_token: int = None
assert self.topk <= 10, "topk should <= 10"
self.hidden_states: torch.Tensor = None
self.verified_id: torch.Tensor = None
self.positions: torch.Tensor = None
self.accept_length: torch.Tensor = None
self.has_finished: bool = False
self.unfinished_index: List[int] = None
def load_server_args(self, server_args: ServerArgs):
self.topk: int = server_args.speculative_eagle_topk
self.num_verify_token: int = server_args.speculative_num_draft_tokens
self.spec_steps = server_args.speculative_num_steps
def prepare_for_extend(self, batch: ForwardBatch):
def prepare_for_extend(self, batch: ScheduleBatch):
req_pool_indices = batch.alloc_req_slots(len(batch.reqs))
out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
batch.out_cache_loc = out_cache_loc
......@@ -226,81 +224,72 @@ class EAGLEDraftInput(SpecInfo):
pt += req.extend_input_len
seq_lens = [0] + batch.extend_lens
input_ids = batch.input_ids.tolist()
verified_id = batch.spec_info.verified_id.tolist()
model_input_ids = []
for i in range(len(seq_lens) - 1):
model_input_ids.extend(
input_ids[seq_lens[i] + 1 : seq_lens[i + 1]] + [verified_id[i]]
)
batch.input_ids = torch.tensor(
model_input_ids, dtype=torch.int32, device="cuda"
)
def capture_for_decode(
self,
sample_output: SampleOutput,
hidden_states: torch.Tensor,
prev_mode: ForwardMode,
):
self.sample_output = sample_output
self.prev_mode = prev_mode
self.hidden_states = hidden_states
# TODO: support batching inputs
assert len(batch.extend_lens) == 1
batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
def prepare_for_decode(self, batch: ScheduleBatch):
prob = self.sample_output # b * (1/topk), vocab
prob = self.sample_output # shape: (b * top_k, vocab) or (b, vocab)
top = torch.topk(prob, self.topk, dim=-1)
topk_index, topk_p = top.indices, top.values # b * (1/topk), topk
if self.prev_mode == ForwardMode.DECODE:
topk_index, topk_p = (
top.indices,
top.values,
) # shape: (b * top_k, top_k) or (b, top_k)
if self.prev_mode.is_decode():
scores = torch.mul(
self.scores.unsqueeze(2), topk_p.reshape(-1, self.topk, self.topk)
) # (b, topk) mul (b * topk ,topk) -> b, topk, topk
) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
topk_cs = torch.topk(
scores.flatten(start_dim=1), self.topk, dim=-1
) # (b, topk)
topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values
self.scores = topk_cs_p
selected_input_index = topk_cs_index.flatten() // self.topk # b* topk
selected_input_index = (
topk_cs_index.flatten() // self.topk
) # shape: (b * topk)
batch.spec_info.hidden_states = batch.spec_info.hidden_states[
selected_input_index, :
]
topk_index = topk_index.reshape(-1, self.topk**2)
batch.input_ids = torch.gather(
topk_index, index=topk_cs_index, dim=1
).flatten()
batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
self.score_list.append(scores) # b, topk, topk
self.token_list.append(topk_index) # b, topk*topk
batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
self.scores = topk_cs_p
self.score_list.append(scores) # (b, topk, topk)
self.token_list.append(topk_index) # (b, topk * topk)
self.origin_score_list.append(topk_p.reshape(topk_index.shape))
self.parents_list.append(
topk_cs_index + (self.topk**2 * (self.iter - 1) + self.topk)
) # b, topk
elif self.prev_mode in (ForwardMode.EXTEND, ForwardMode.DRAFT_EXTEND):
self.scores = topk_p # b, top_k
self.score_list.append(topk_p.unsqueeze(1))
self.token_list.append(topk_index)
self.origin_score_list.append(topk_p)
) # shape: (b, topk)
else:
# ForwardMode.EXTEND or ForwardMode.DRAFT_EXTEND
batch.spec_info.hidden_states = (
batch.spec_info.hidden_states.repeat_interleave(self.topk, 0)
batch.spec_info.hidden_states.repeat_interleave(self.topk, dim=0)
)
batch.input_ids = topk_index.flatten()
batch.out_cache_loc = batch.alloc_token_slots(topk_index.numel())
self.scores = topk_p # shape: (b, topk)
self.score_list.append(topk_p.unsqueeze(1)) # shape: (b, 1, topk)
self.token_list.append(topk_index) # shape: (b, topk)
self.origin_score_list.append(topk_p)
self.parents_list.append(
torch.arange(-1, self.topk, dtype=torch.long, device="cuda")
.unsqueeze(0)
.repeat(self.scores.shape[0], 1)
) # b, topk+1
) # shape: (b, topk + 1)
self.cache_list.append(batch.out_cache_loc)
self.positions = (
batch.seq_lens[:, None]
+ torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter
).flatten()
bs = batch.seq_lens.numel()
bs = len(batch.seq_lens)
assign_req_to_token_pool[(bs,)](
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
......@@ -419,11 +408,6 @@ class EAGLEDraftInput(SpecInfo):
)
return bs, kv_indices, cum_kv_seq_len
def clear(self):
self.iter = 0
self.score_list.clear()
self.positions = None
def clear_draft_cache(self, batch):
draft_cache = torch.cat(self.cache_list, dim=0)
batch.token_to_kv_pool.free(draft_cache)
......@@ -460,7 +444,6 @@ class EAGLEDraftInput(SpecInfo):
[self.hidden_states, spec_info.hidden_states], axis=0
)
self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
# self.positions = torch.cat([self.positions, spec_info.positions], axis=0)
self.sample_output = torch.cat([self.sample_output, spec_info.sample_output])
......@@ -568,9 +551,6 @@ class EagleVerifyInput(SpecInfo):
)
accept_index = accept_index[accept_index != -1]
# extract_index = extract_index[extract_index != 0]
draft_input = EAGLEDraftInput()
accept_length_cpu = accept_length.tolist()
verified_id = predict[accept_index]
......@@ -596,6 +576,7 @@ class EagleVerifyInput(SpecInfo):
# retracted_reqs, new_token_ratio = batch.retract_decode()
low = 0
draft_input = EAGLEDraftInput()
for i, (req, verified_len) in enumerate(zip(batch.reqs, accept_length_cpu)):
req.output_ids.extend(verified_id_cpu[low : low + verified_len + 1])
req.check_finished()
......@@ -615,4 +596,10 @@ class EagleVerifyInput(SpecInfo):
draft_input.unfinished_index = unfinished_index
logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
return draft_input, logits_output, verified_id, finished_extend_len
return (
draft_input,
logits_output,
verified_id,
finished_extend_len,
accept_length_cpu,
)
......@@ -51,63 +51,72 @@ class EAGLEWorker(TpModelWorker):
batch.spec_info.prepare_for_decode(batch)
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
forward_batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
logits_output = self.model_runner.forward(forward_batch)
self.capture_for_decode(logits_output, forward_batch)
def forward_draft_extend(self, batch: ScheduleBatch):
self._swap_mem_pool(batch, self.model_runner)
self._set_mem_pool(batch, self.model_runner)
batch.spec_info.prepare_for_extend(batch)
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
forward_batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
logits_output = self.model_runner.forward(forward_batch)
self.capture_for_decode(logits_output, forward_batch)
self._swap_mem_pool(batch, self.target_worker.model_runner)
self._set_mem_pool(batch, self.target_worker.model_runner)
def forward_batch_speculative_generation(self, batch: ScheduleBatch):
if batch.forward_mode.is_decode():
prev_spec_info = batch.spec_info
self._swap_mem_pool(batch, self.model_runner)
# Draft
self._set_mem_pool(batch, self.model_runner)
for i in range(self.server_args.speculative_num_steps):
self.forward_draft_decode(batch)
batch.spec_info.clear_draft_cache(batch)
self._swap_mem_pool(batch, self.target_worker.model_runner)
self._set_mem_pool(batch, self.target_worker.model_runner)
# Verify
(
next_draft_input,
logits_output,
verified_id,
self.finish_extend_len,
accept_length_cpu,
model_worker_batch,
) = self.verify(batch)
next_draft_input.init(self.server_args)
next_draft_input.load_server_args(self.server_args)
batch.spec_info = next_draft_input
# if it is None, means all requsets are finished
if batch.spec_info.verified_id is not None:
self.forward_extend_after_decode(batch)
batch.spec_info = prev_spec_info
return logits_output, verified_id, model_worker_batch, next_draft_input
self.forward_draft_extend_after_decode(batch)
return (
logits_output,
verified_id,
model_worker_batch,
sum(accept_length_cpu),
)
else:
spec_info = EAGLEDraftInput()
spec_info.init(self.server_args)
# Forward with the target model and get hidden states.
# We need the full hidden states to prefill the KV cache of the draft model.
model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.spec_info = spec_info
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
logits_output, next_token_ids = self.target_worker.forward_batch_generation(
model_worker_batch
)
model_worker_batch.spec_info.verified_id = next_token_ids
model_worker_batch.spec_info.hidden_states = logits_output.hidden_states
# Forward with the draft model.
spec_info = EAGLEDraftInput()
spec_info.load_server_args(self.server_args)
spec_info.hidden_states = logits_output.hidden_states
spec_info.verified_id = next_token_ids
batch.spec_info = spec_info
self.forward_draft_extend(batch)
batch.spec_info = None
return logits_output, next_token_ids, model_worker_batch, spec_info
return logits_output, next_token_ids, model_worker_batch, 0
def verify(self, batch: ScheduleBatch):
verify_input = batch.spec_info.prepare_for_verify(batch)
batch.forward_mode = ForwardMode.TARGET_VERIFY
verify_input.prepare_for_verify(batch)
batch.forward_mode = ForwardMode.TARGET_VERIFY
batch.spec_info = verify_input
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
model_worker_batch = batch.get_model_worker_batch()
......@@ -119,38 +128,41 @@ class EAGLEWorker(TpModelWorker):
batch.forward_mode = ForwardMode.DECODE
return res + (model_worker_batch,)
def _swap_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner):
def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner):
batch.token_to_kv_pool = runner.token_to_kv_pool
batch.req_to_token_pool = runner.req_to_token_pool
def forward_extend_after_decode(self, batch: ScheduleBatch):
self._swap_mem_pool(batch, self.model_runner)
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
self._set_mem_pool(batch, self.model_runner)
batch.forward_mode = ForwardMode.DRAFT_EXTEND
if batch.spec_info.has_finished:
index = batch.spec_info.unfinished_index
seq_lens = batch.seq_lens
batch.seq_lens = batch.seq_lens[index]
batch.spec_info.prepare_extend_after_decode(batch)
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
forward_batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
logits_output = self.model_runner.forward(forward_batch)
batch.spec_info.hidden_states = logits_output.hidden_states
self.capture_for_decode(logits_output, forward_batch)
batch.forward_mode = ForwardMode.DECODE
if batch.spec_info.has_finished:
batch.seq_lens = seq_lens
self._swap_mem_pool(batch, self.target_worker.model_runner)
self._set_mem_pool(batch, self.target_worker.model_runner)
def capture_for_decode(self, logits_output, forward_batch):
if isinstance(logits_output, LogitsProcessorOutput):
logits = logits_output.next_token_logits
def capture_for_decode(
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
):
sample_output = torch.softmax(
logits, dim=-1
) # TODO: Support more sampling method @kavioyu
forward_batch.spec_info.capture_for_decode(
sample_output, logits_output.hidden_states, forward_batch.forward_mode
)
logits_output.next_token_logits, dim=-1
) # TODO(kavioyu): Support more sampling methods
spec_info = forward_batch.spec_info
spec_info.sample_output = sample_output
spec_info.hidden_states = logits_output.hidden_states
spec_info.prev_mode = forward_batch.forward_mode
# Don't support prefix share now.
def finish_request(self, reqs: Union[Req, List[Req]]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment