Unverified Commit b8574f69 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up eagle code (#2756)

parent 2855caa4
...@@ -74,11 +74,6 @@ class LogitsMetadata: ...@@ -74,11 +74,6 @@ class LogitsMetadata:
@classmethod @classmethod
def from_forward_batch(cls, forward_batch: ForwardBatch): def from_forward_batch(cls, forward_batch: ForwardBatch):
if forward_batch.spec_info:
capture_hidden_mode = forward_batch.spec_info.capture_hidden_mode
else:
capture_hidden_mode = CaptureHiddenMode.NULL
if forward_batch.forward_mode.is_extend() and forward_batch.return_logprob: if forward_batch.forward_mode.is_extend() and forward_batch.return_logprob:
extend_return_logprob = True extend_return_logprob = True
extend_return_top_logprob = any( extend_return_top_logprob = any(
...@@ -98,7 +93,7 @@ class LogitsMetadata: ...@@ -98,7 +93,7 @@ class LogitsMetadata:
return cls( return cls(
forward_mode=forward_batch.forward_mode, forward_mode=forward_batch.forward_mode,
capture_hidden_mode=capture_hidden_mode, capture_hidden_mode=forward_batch.capture_hidden_mode,
extend_return_logprob=extend_return_logprob, extend_return_logprob=extend_return_logprob,
extend_return_top_logprob=extend_return_top_logprob, extend_return_top_logprob=extend_return_top_logprob,
extend_seq_lens=forward_batch.extend_seq_lens, extend_seq_lens=forward_batch.extend_seq_lens,
......
...@@ -44,7 +44,7 @@ from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject ...@@ -44,7 +44,7 @@ from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
...@@ -1163,6 +1163,11 @@ class ScheduleBatch: ...@@ -1163,6 +1163,11 @@ class ScheduleBatch:
input_embeds=self.input_embeds, input_embeds=self.input_embeds,
spec_algorithm=self.spec_algorithm, spec_algorithm=self.spec_algorithm,
spec_info=self.spec_info, spec_info=self.spec_info,
capture_hidden_mode=(
getattr(self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL)
if self.spec_info
else CaptureHiddenMode.NULL
),
) )
def copy(self): def copy(self):
...@@ -1237,6 +1242,7 @@ class ModelWorkerBatch: ...@@ -1237,6 +1242,7 @@ class ModelWorkerBatch:
# Speculative decoding # Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[SpecInfo] = None spec_info: Optional[SpecInfo] = None
capture_hidden_mode: CaptureHiddenMode = None
@triton.jit @triton.jit
......
...@@ -962,10 +962,13 @@ class Scheduler: ...@@ -962,10 +962,13 @@ class Scheduler:
self.tp_worker.forward_batch_generation(model_worker_batch) self.tp_worker.forward_batch_generation(model_worker_batch)
) )
else: else:
logits_output, next_token_ids, model_worker_batch, spec_info = ( (
self.draft_worker.forward_batch_speculative_generation(batch) logits_output,
) next_token_ids,
batch.spec_info = spec_info model_worker_batch,
num_accepted_tokens,
) = self.draft_worker.forward_batch_speculative_generation(batch)
self.num_generated_tokens += num_accepted_tokens
elif batch.forward_mode.is_idle(): elif batch.forward_mode.is_idle():
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
self.tp_worker.forward_batch_idle(model_worker_batch) self.tp_worker.forward_batch_idle(model_worker_batch)
......
...@@ -322,6 +322,8 @@ class CudaGraphRunner: ...@@ -322,6 +322,8 @@ class CudaGraphRunner:
global_num_tokens = None global_num_tokens = None
gathered_buffer = None gathered_buffer = None
spec_info = self.get_spec_info(num_tokens, positions)
forward_batch = ForwardBatch( forward_batch = ForwardBatch(
forward_mode=self.capture_forward_mode, forward_mode=self.capture_forward_mode,
batch_size=bs, batch_size=bs,
...@@ -341,7 +343,10 @@ class CudaGraphRunner: ...@@ -341,7 +343,10 @@ class CudaGraphRunner:
mrope_positions=mrope_positions, mrope_positions=mrope_positions,
gathered_buffer=gathered_buffer, gathered_buffer=gathered_buffer,
spec_algorithm=self.model_runner.spec_algorithm, spec_algorithm=self.model_runner.spec_algorithm,
spec_info=self.get_spec_info(num_tokens, positions), spec_info=spec_info,
capture_hidden_mode=(
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
),
) )
# Attention backend # Attention backend
...@@ -446,10 +451,10 @@ class CudaGraphRunner: ...@@ -446,10 +451,10 @@ class CudaGraphRunner:
if self.model_runner.is_draft_worker: if self.model_runner.is_draft_worker:
spec_info = EAGLEDraftInput() spec_info = EAGLEDraftInput()
spec_info.load_server_args(self.model_runner.server_args)
spec_info.hidden_states = self.hidden_states[:num_tokens] spec_info.hidden_states = self.hidden_states[:num_tokens]
spec_info.positions = positions spec_info.positions = positions
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
spec_info.init(self.model_runner.server_args)
else: else:
spec_info = EagleVerifyInput( spec_info = EagleVerifyInput(
None, None,
......
...@@ -107,6 +107,21 @@ class ForwardMode(IntEnum): ...@@ -107,6 +107,21 @@ class ForwardMode(IntEnum):
return self == ForwardMode.DUMMY_FIRST return self == ForwardMode.DUMMY_FIRST
class CaptureHiddenMode(IntEnum):
NULL = auto()
FULL = auto()
LAST = auto()
def need_capture(self):
return self != CaptureHiddenMode.NULL
def is_full(self):
return self == CaptureHiddenMode.FULL
def is_last(self):
return self == CaptureHiddenMode.LAST
@dataclass @dataclass
class ForwardBatch: class ForwardBatch:
"""Store all inputs of a forward pass.""" """Store all inputs of a forward pass."""
...@@ -174,6 +189,7 @@ class ForwardBatch: ...@@ -174,6 +189,7 @@ class ForwardBatch:
# Speculative decoding # Speculative decoding
spec_info: SpecInfo = None spec_info: SpecInfo = None
spec_algorithm: SpeculativeAlgorithm = None spec_algorithm: SpeculativeAlgorithm = None
capture_hidden_mode: CaptureHiddenMode = None
# For Qwen2-VL # For Qwen2-VL
mrope_positions: torch.Tensor = None mrope_positions: torch.Tensor = None
...@@ -265,6 +281,7 @@ class ForwardBatch: ...@@ -265,6 +281,7 @@ class ForwardBatch:
sampling_info=batch.sampling_info, sampling_info=batch.sampling_info,
spec_algorithm=batch.spec_algorithm, spec_algorithm=batch.spec_algorithm,
spec_info=batch.spec_info, spec_info=batch.spec_info,
capture_hidden_mode=batch.capture_hidden_mode,
input_embeds=batch.input_embeds, input_embeds=batch.input_embeds,
) )
...@@ -400,18 +417,3 @@ def compute_position_torch( ...@@ -400,18 +417,3 @@ def compute_position_torch(
@maybe_torch_compile(dynamic=True) @maybe_torch_compile(dynamic=True)
def clamp_position(seq_lens): def clamp_position(seq_lens):
return torch.clamp((seq_lens - 1), min=0).to(torch.int64) return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
class CaptureHiddenMode(IntEnum):
NULL = auto()
FULL = auto()
LAST = auto()
def need_capture(self):
return self != CaptureHiddenMode.NULL
def is_full(self):
return self == CaptureHiddenMode.FULL
def is_last(self):
return self == CaptureHiddenMode.LAST
...@@ -9,12 +9,11 @@ import triton.language as tl ...@@ -9,12 +9,11 @@ import triton.language as tl
from sglang.srt.layers.attention.flashinfer_backend import ( from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton, create_flashinfer_kv_indices_triton,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel from sglang.srt.speculative.build_eagle_tree import build_tree_kernel
from sglang.srt.speculative.spec_info import SpecInfo from sglang.srt.speculative.spec_info import SpecInfo
if TYPE_CHECKING: if TYPE_CHECKING:
from python.sglang.srt.layers.sampler import SampleOutput
from python.sglang.srt.managers.schedule_batch import ScheduleBatch from python.sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
...@@ -179,19 +178,9 @@ def generate_draft_decode_kv_indices( ...@@ -179,19 +178,9 @@ def generate_draft_decode_kv_indices(
class EAGLEDraftInput(SpecInfo): class EAGLEDraftInput(SpecInfo):
hidden_states: torch.Tensor = None def __init__(self):
verified_id: torch.Tensor = None
positions: torch.Tensor = None
accept_length: torch.Tensor = None
has_finished: bool = False
unfinished_index: List[int] = None
def init(self, server_args: ServerArgs):
self.prev_mode = ForwardMode.DECODE self.prev_mode = ForwardMode.DECODE
self.sample_output = None self.sample_output = None
self.topk: int = server_args.speculative_eagle_topk
self.num_verify_token: int = server_args.speculative_num_draft_tokens
self.spec_steps = server_args.speculative_num_steps
self.scores: torch.Tensor = None self.scores: torch.Tensor = None
self.score_list: List[torch.Tensor] = [] self.score_list: List[torch.Tensor] = []
...@@ -200,11 +189,20 @@ class EAGLEDraftInput(SpecInfo): ...@@ -200,11 +189,20 @@ class EAGLEDraftInput(SpecInfo):
self.parents_list: List[torch.Tensor] = [] self.parents_list: List[torch.Tensor] = []
self.cache_list: List[torch.Tenor] = [] self.cache_list: List[torch.Tenor] = []
self.iter = 0 self.iter = 0
self.root_token: int = None
assert self.topk <= 10, "topk should <= 10" self.hidden_states: torch.Tensor = None
self.verified_id: torch.Tensor = None
self.positions: torch.Tensor = None
self.accept_length: torch.Tensor = None
self.has_finished: bool = False
self.unfinished_index: List[int] = None
def load_server_args(self, server_args: ServerArgs):
self.topk: int = server_args.speculative_eagle_topk
self.num_verify_token: int = server_args.speculative_num_draft_tokens
self.spec_steps = server_args.speculative_num_steps
def prepare_for_extend(self, batch: ForwardBatch): def prepare_for_extend(self, batch: ScheduleBatch):
req_pool_indices = batch.alloc_req_slots(len(batch.reqs)) req_pool_indices = batch.alloc_req_slots(len(batch.reqs))
out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
batch.out_cache_loc = out_cache_loc batch.out_cache_loc = out_cache_loc
...@@ -226,81 +224,72 @@ class EAGLEDraftInput(SpecInfo): ...@@ -226,81 +224,72 @@ class EAGLEDraftInput(SpecInfo):
pt += req.extend_input_len pt += req.extend_input_len
seq_lens = [0] + batch.extend_lens # TODO: support batching inputs
input_ids = batch.input_ids.tolist() assert len(batch.extend_lens) == 1
verified_id = batch.spec_info.verified_id.tolist() batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
model_input_ids = []
for i in range(len(seq_lens) - 1):
model_input_ids.extend(
input_ids[seq_lens[i] + 1 : seq_lens[i + 1]] + [verified_id[i]]
)
batch.input_ids = torch.tensor(
model_input_ids, dtype=torch.int32, device="cuda"
)
def capture_for_decode(
self,
sample_output: SampleOutput,
hidden_states: torch.Tensor,
prev_mode: ForwardMode,
):
self.sample_output = sample_output
self.prev_mode = prev_mode
self.hidden_states = hidden_states
def prepare_for_decode(self, batch: ScheduleBatch): def prepare_for_decode(self, batch: ScheduleBatch):
prob = self.sample_output # b * (1/topk), vocab prob = self.sample_output # shape: (b * top_k, vocab) or (b, vocab)
top = torch.topk(prob, self.topk, dim=-1) top = torch.topk(prob, self.topk, dim=-1)
topk_index, topk_p = top.indices, top.values # b * (1/topk), topk topk_index, topk_p = (
if self.prev_mode == ForwardMode.DECODE: top.indices,
top.values,
) # shape: (b * top_k, top_k) or (b, top_k)
if self.prev_mode.is_decode():
scores = torch.mul( scores = torch.mul(
self.scores.unsqueeze(2), topk_p.reshape(-1, self.topk, self.topk) self.scores.unsqueeze(2), topk_p.reshape(-1, self.topk, self.topk)
) # (b, topk) mul (b * topk ,topk) -> b, topk, topk ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
topk_cs = torch.topk( topk_cs = torch.topk(
scores.flatten(start_dim=1), self.topk, dim=-1 scores.flatten(start_dim=1), self.topk, dim=-1
) # (b, topk) ) # (b, topk)
topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values
self.scores = topk_cs_p
selected_input_index = topk_cs_index.flatten() // self.topk # b* topk
selected_input_index = (
topk_cs_index.flatten() // self.topk
) # shape: (b * topk)
batch.spec_info.hidden_states = batch.spec_info.hidden_states[ batch.spec_info.hidden_states = batch.spec_info.hidden_states[
selected_input_index, : selected_input_index, :
] ]
topk_index = topk_index.reshape(-1, self.topk**2) topk_index = topk_index.reshape(-1, self.topk**2)
batch.input_ids = torch.gather( batch.input_ids = torch.gather(
topk_index, index=topk_cs_index, dim=1 topk_index, index=topk_cs_index, dim=1
).flatten() ).flatten()
batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
self.score_list.append(scores) # b, topk, topk
self.token_list.append(topk_index) # b, topk*topk self.scores = topk_cs_p
self.score_list.append(scores) # (b, topk, topk)
self.token_list.append(topk_index) # (b, topk * topk)
self.origin_score_list.append(topk_p.reshape(topk_index.shape)) self.origin_score_list.append(topk_p.reshape(topk_index.shape))
self.parents_list.append( self.parents_list.append(
topk_cs_index + (self.topk**2 * (self.iter - 1) + self.topk) topk_cs_index + (self.topk**2 * (self.iter - 1) + self.topk)
) # b, topk ) # shape: (b, topk)
else:
elif self.prev_mode in (ForwardMode.EXTEND, ForwardMode.DRAFT_EXTEND): # ForwardMode.EXTEND or ForwardMode.DRAFT_EXTEND
self.scores = topk_p # b, top_k
self.score_list.append(topk_p.unsqueeze(1))
self.token_list.append(topk_index)
self.origin_score_list.append(topk_p)
batch.spec_info.hidden_states = ( batch.spec_info.hidden_states = (
batch.spec_info.hidden_states.repeat_interleave(self.topk, 0) batch.spec_info.hidden_states.repeat_interleave(self.topk, dim=0)
) )
batch.input_ids = topk_index.flatten() batch.input_ids = topk_index.flatten()
batch.out_cache_loc = batch.alloc_token_slots(topk_index.numel()) batch.out_cache_loc = batch.alloc_token_slots(topk_index.numel())
self.scores = topk_p # shape: (b, topk)
self.score_list.append(topk_p.unsqueeze(1)) # shape: (b, 1, topk)
self.token_list.append(topk_index) # shape: (b, topk)
self.origin_score_list.append(topk_p)
self.parents_list.append( self.parents_list.append(
torch.arange(-1, self.topk, dtype=torch.long, device="cuda") torch.arange(-1, self.topk, dtype=torch.long, device="cuda")
.unsqueeze(0) .unsqueeze(0)
.repeat(self.scores.shape[0], 1) .repeat(self.scores.shape[0], 1)
) # b, topk+1 ) # shape: (b, topk + 1)
self.cache_list.append(batch.out_cache_loc) self.cache_list.append(batch.out_cache_loc)
self.positions = ( self.positions = (
batch.seq_lens[:, None] batch.seq_lens[:, None]
+ torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter + torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter
).flatten() ).flatten()
bs = batch.seq_lens.numel() bs = len(batch.seq_lens)
assign_req_to_token_pool[(bs,)]( assign_req_to_token_pool[(bs,)](
batch.req_pool_indices, batch.req_pool_indices,
batch.req_to_token_pool.req_to_token, batch.req_to_token_pool.req_to_token,
...@@ -419,11 +408,6 @@ class EAGLEDraftInput(SpecInfo): ...@@ -419,11 +408,6 @@ class EAGLEDraftInput(SpecInfo):
) )
return bs, kv_indices, cum_kv_seq_len return bs, kv_indices, cum_kv_seq_len
def clear(self):
self.iter = 0
self.score_list.clear()
self.positions = None
def clear_draft_cache(self, batch): def clear_draft_cache(self, batch):
draft_cache = torch.cat(self.cache_list, dim=0) draft_cache = torch.cat(self.cache_list, dim=0)
batch.token_to_kv_pool.free(draft_cache) batch.token_to_kv_pool.free(draft_cache)
...@@ -460,7 +444,6 @@ class EAGLEDraftInput(SpecInfo): ...@@ -460,7 +444,6 @@ class EAGLEDraftInput(SpecInfo):
[self.hidden_states, spec_info.hidden_states], axis=0 [self.hidden_states, spec_info.hidden_states], axis=0
) )
self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0) self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
# self.positions = torch.cat([self.positions, spec_info.positions], axis=0)
self.sample_output = torch.cat([self.sample_output, spec_info.sample_output]) self.sample_output = torch.cat([self.sample_output, spec_info.sample_output])
...@@ -568,9 +551,6 @@ class EagleVerifyInput(SpecInfo): ...@@ -568,9 +551,6 @@ class EagleVerifyInput(SpecInfo):
) )
accept_index = accept_index[accept_index != -1] accept_index = accept_index[accept_index != -1]
# extract_index = extract_index[extract_index != 0]
draft_input = EAGLEDraftInput()
accept_length_cpu = accept_length.tolist() accept_length_cpu = accept_length.tolist()
verified_id = predict[accept_index] verified_id = predict[accept_index]
...@@ -596,6 +576,7 @@ class EagleVerifyInput(SpecInfo): ...@@ -596,6 +576,7 @@ class EagleVerifyInput(SpecInfo):
# retracted_reqs, new_token_ratio = batch.retract_decode() # retracted_reqs, new_token_ratio = batch.retract_decode()
low = 0 low = 0
draft_input = EAGLEDraftInput()
for i, (req, verified_len) in enumerate(zip(batch.reqs, accept_length_cpu)): for i, (req, verified_len) in enumerate(zip(batch.reqs, accept_length_cpu)):
req.output_ids.extend(verified_id_cpu[low : low + verified_len + 1]) req.output_ids.extend(verified_id_cpu[low : low + verified_len + 1])
req.check_finished() req.check_finished()
...@@ -615,4 +596,10 @@ class EagleVerifyInput(SpecInfo): ...@@ -615,4 +596,10 @@ class EagleVerifyInput(SpecInfo):
draft_input.unfinished_index = unfinished_index draft_input.unfinished_index = unfinished_index
logits_output.next_token_logits = logits_output.next_token_logits[accept_index] logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
return draft_input, logits_output, verified_id, finished_extend_len return (
draft_input,
logits_output,
verified_id,
finished_extend_len,
accept_length_cpu,
)
...@@ -51,63 +51,72 @@ class EAGLEWorker(TpModelWorker): ...@@ -51,63 +51,72 @@ class EAGLEWorker(TpModelWorker):
batch.spec_info.prepare_for_decode(batch) batch.spec_info.prepare_for_decode(batch)
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
forward_batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
logits_output = self.model_runner.forward(forward_batch) logits_output = self.model_runner.forward(forward_batch)
self.capture_for_decode(logits_output, forward_batch) self.capture_for_decode(logits_output, forward_batch)
def forward_draft_extend(self, batch: ScheduleBatch): def forward_draft_extend(self, batch: ScheduleBatch):
self._swap_mem_pool(batch, self.model_runner) self._set_mem_pool(batch, self.model_runner)
batch.spec_info.prepare_for_extend(batch) batch.spec_info.prepare_for_extend(batch)
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
forward_batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
logits_output = self.model_runner.forward(forward_batch) logits_output = self.model_runner.forward(forward_batch)
self.capture_for_decode(logits_output, forward_batch) self.capture_for_decode(logits_output, forward_batch)
self._swap_mem_pool(batch, self.target_worker.model_runner) self._set_mem_pool(batch, self.target_worker.model_runner)
def forward_batch_speculative_generation(self, batch: ScheduleBatch): def forward_batch_speculative_generation(self, batch: ScheduleBatch):
if batch.forward_mode.is_decode(): if batch.forward_mode.is_decode():
prev_spec_info = batch.spec_info # Draft
self._swap_mem_pool(batch, self.model_runner) self._set_mem_pool(batch, self.model_runner)
for i in range(self.server_args.speculative_num_steps): for i in range(self.server_args.speculative_num_steps):
self.forward_draft_decode(batch) self.forward_draft_decode(batch)
batch.spec_info.clear_draft_cache(batch) batch.spec_info.clear_draft_cache(batch)
self._swap_mem_pool(batch, self.target_worker.model_runner) self._set_mem_pool(batch, self.target_worker.model_runner)
# Verify
( (
next_draft_input, next_draft_input,
logits_output, logits_output,
verified_id, verified_id,
self.finish_extend_len, self.finish_extend_len,
accept_length_cpu,
model_worker_batch, model_worker_batch,
) = self.verify(batch) ) = self.verify(batch)
next_draft_input.init(self.server_args) next_draft_input.load_server_args(self.server_args)
batch.spec_info = next_draft_input batch.spec_info = next_draft_input
# if it is None, means all requsets are finished # if it is None, means all requsets are finished
if batch.spec_info.verified_id is not None: if batch.spec_info.verified_id is not None:
self.forward_extend_after_decode(batch) self.forward_draft_extend_after_decode(batch)
batch.spec_info = prev_spec_info return (
return logits_output, verified_id, model_worker_batch, next_draft_input logits_output,
verified_id,
model_worker_batch,
sum(accept_length_cpu),
)
else: else:
spec_info = EAGLEDraftInput() # Forward with the target model and get hidden states.
spec_info.init(self.server_args) # We need the full hidden states to prefill the KV cache of the draft model.
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.spec_info = spec_info model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
logits_output, next_token_ids = self.target_worker.forward_batch_generation( logits_output, next_token_ids = self.target_worker.forward_batch_generation(
model_worker_batch model_worker_batch
) )
model_worker_batch.spec_info.verified_id = next_token_ids
model_worker_batch.spec_info.hidden_states = logits_output.hidden_states # Forward with the draft model.
spec_info = EAGLEDraftInput()
spec_info.load_server_args(self.server_args)
spec_info.hidden_states = logits_output.hidden_states
spec_info.verified_id = next_token_ids
batch.spec_info = spec_info batch.spec_info = spec_info
self.forward_draft_extend(batch) self.forward_draft_extend(batch)
batch.spec_info = None return logits_output, next_token_ids, model_worker_batch, 0
return logits_output, next_token_ids, model_worker_batch, spec_info
def verify(self, batch: ScheduleBatch): def verify(self, batch: ScheduleBatch):
verify_input = batch.spec_info.prepare_for_verify(batch) verify_input = batch.spec_info.prepare_for_verify(batch)
batch.forward_mode = ForwardMode.TARGET_VERIFY
verify_input.prepare_for_verify(batch) verify_input.prepare_for_verify(batch)
batch.forward_mode = ForwardMode.TARGET_VERIFY
batch.spec_info = verify_input batch.spec_info = verify_input
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.FULL batch.spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
...@@ -119,38 +128,41 @@ class EAGLEWorker(TpModelWorker): ...@@ -119,38 +128,41 @@ class EAGLEWorker(TpModelWorker):
batch.forward_mode = ForwardMode.DECODE batch.forward_mode = ForwardMode.DECODE
return res + (model_worker_batch,) return res + (model_worker_batch,)
def _swap_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner): def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner):
batch.token_to_kv_pool = runner.token_to_kv_pool batch.token_to_kv_pool = runner.token_to_kv_pool
batch.req_to_token_pool = runner.req_to_token_pool batch.req_to_token_pool = runner.req_to_token_pool
def forward_extend_after_decode(self, batch: ScheduleBatch): def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
self._swap_mem_pool(batch, self.model_runner) self._set_mem_pool(batch, self.model_runner)
batch.forward_mode = ForwardMode.DRAFT_EXTEND batch.forward_mode = ForwardMode.DRAFT_EXTEND
if batch.spec_info.has_finished: if batch.spec_info.has_finished:
index = batch.spec_info.unfinished_index index = batch.spec_info.unfinished_index
seq_lens = batch.seq_lens seq_lens = batch.seq_lens
batch.seq_lens = batch.seq_lens[index] batch.seq_lens = batch.seq_lens[index]
batch.spec_info.prepare_extend_after_decode(batch) batch.spec_info.prepare_extend_after_decode(batch)
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
forward_batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
logits_output = self.model_runner.forward(forward_batch) logits_output = self.model_runner.forward(forward_batch)
batch.spec_info.hidden_states = logits_output.hidden_states batch.spec_info.hidden_states = logits_output.hidden_states
self.capture_for_decode(logits_output, forward_batch) self.capture_for_decode(logits_output, forward_batch)
batch.forward_mode = ForwardMode.DECODE batch.forward_mode = ForwardMode.DECODE
if batch.spec_info.has_finished: if batch.spec_info.has_finished:
batch.seq_lens = seq_lens batch.seq_lens = seq_lens
self._swap_mem_pool(batch, self.target_worker.model_runner) self._set_mem_pool(batch, self.target_worker.model_runner)
def capture_for_decode(self, logits_output, forward_batch): def capture_for_decode(
if isinstance(logits_output, LogitsProcessorOutput): self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
logits = logits_output.next_token_logits ):
sample_output = torch.softmax( sample_output = torch.softmax(
logits, dim=-1 logits_output.next_token_logits, dim=-1
) # TODO: Support more sampling method @kavioyu ) # TODO(kavioyu): Support more sampling methods
forward_batch.spec_info.capture_for_decode( spec_info = forward_batch.spec_info
sample_output, logits_output.hidden_states, forward_batch.forward_mode spec_info.sample_output = sample_output
) spec_info.hidden_states = logits_output.hidden_states
spec_info.prev_mode = forward_batch.forward_mode
# Don't support prefix share now. # Don't support prefix share now.
def finish_request(self, reqs: Union[Req, List[Req]]): def finish_request(self, reqs: Union[Req, List[Req]]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment