Unverified Commit cde5a6e3 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Abstraction for spec worker and code cleanup (#11643)

parent 3e4c7da2
...@@ -1061,38 +1061,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1061,38 +1061,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
) )
return req_pool_indices return req_pool_indices
def allocate_for_eagle_v2(self):
from sglang.srt.speculative.eagle_info import EagleDraftInput
from sglang.srt.speculative.spec_utils import assign_req_to_token_pool
bs = self.batch_size()
assert self.spec_info.is_draft_input()
draft_input: EagleDraftInput = self.spec_info
# FIXME(lsyin): now implementation does not enable over-allocation
# Now seq_lens and allocate_lens are correct
self.maybe_wait_verify_done()
new_allocate_lens = self.seq_lens + EagleDraftInput.ALLOC_LEN_PER_DECODE
num_needed_tokens = (new_allocate_lens - draft_input.allocate_lens).sum().item()
out_cache_loc = alloc_token_slots(self.tree_cache, num_needed_tokens)
assign_req_to_token_pool[(bs,)](
self.req_pool_indices,
self.req_to_token_pool.req_to_token,
draft_input.allocate_lens,
new_allocate_lens,
out_cache_loc,
self.req_to_token_pool.req_to_token.shape[1],
next_power_of_2(bs),
)
draft_input.allocate_lens = new_allocate_lens
# FIXME(lsyin): remove seq_lens_sum calculation
self.seq_lens_cpu = self.seq_lens.cpu()
self.seq_lens_sum = self.seq_lens_cpu.sum().item()
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]): def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
self.encoder_lens_cpu = [] self.encoder_lens_cpu = []
self.encoder_cached = [] self.encoder_cached = []
...@@ -1522,8 +1490,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1522,8 +1490,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
bs = len(self.reqs) bs = len(self.reqs)
if self.is_v2_eagle: if self.is_v2_eagle:
# FIXME(lsyin): make this sync optional # TODO(spec-v2): all v2 spec should go through this path
self.allocate_for_eagle_v2() from sglang.srt.speculative.eagle_info import EagleDraftInput
draft_input: EagleDraftInput = self.spec_info
draft_input.prepare_for_decode(self)
if not self.spec_algorithm.is_none(): if not self.spec_algorithm.is_none():
# if spec decoding is used, the decode batch is prepared inside # if spec decoding is used, the decode batch is prepared inside
......
...@@ -215,10 +215,10 @@ class GenerationBatchResult: ...@@ -215,10 +215,10 @@ class GenerationBatchResult:
delay_sample_func: Optional[callable] = None delay_sample_func: Optional[callable] = None
future_indices: Optional[FutureIndices] = None future_indices: Optional[FutureIndices] = None
# FIXME(lsyin): maybe move to <BetterPlace> ? # FIXME(lsyin): maybe move to a better place?
# sync path: forward stream -> output processor # sync path: forward stream -> output processor
accept_lens: Optional[torch.Tensor] = None accept_lens: Optional[torch.Tensor] = None
last_batch_allocate_lens: Optional[torch.Tensor] = None allocate_lens: Optional[torch.Tensor] = None
# relay path: forward stream -> next step forward # relay path: forward stream -> next step forward
next_draft_input: Optional[EagleDraftInput] = None next_draft_input: Optional[EagleDraftInput] = None
...@@ -246,10 +246,8 @@ class GenerationBatchResult: ...@@ -246,10 +246,8 @@ class GenerationBatchResult:
if self.accept_lens is not None: if self.accept_lens is not None:
self.accept_lens = self.accept_lens.to("cpu", non_blocking=True) self.accept_lens = self.accept_lens.to("cpu", non_blocking=True)
if self.last_batch_allocate_lens is not None: if self.allocate_lens is not None:
self.last_batch_allocate_lens = self.last_batch_allocate_lens.to( self.allocate_lens = self.allocate_lens.to("cpu", non_blocking=True)
"cpu", non_blocking=True
)
self.copy_done.record() self.copy_done.record()
......
...@@ -42,23 +42,21 @@ class SchedulerOutputProcessorMixin: ...@@ -42,23 +42,21 @@ class SchedulerOutputProcessorMixin:
skip_stream_req = None skip_stream_req = None
if self.is_generation: if self.is_generation:
if result.copy_done is not None:
result.copy_done.synchronize()
( (
logits_output, logits_output,
next_token_ids, next_token_ids,
extend_input_len_per_req, extend_input_len_per_req,
extend_logprob_start_len_per_req, extend_logprob_start_len_per_req,
copy_done,
) = ( ) = (
result.logits_output, result.logits_output,
result.next_token_ids, result.next_token_ids,
result.extend_input_len_per_req, result.extend_input_len_per_req,
result.extend_logprob_start_len_per_req, result.extend_logprob_start_len_per_req,
result.copy_done,
) )
if copy_done is not None:
copy_done.synchronize()
# Move next_token_ids and logprobs to cpu # Move next_token_ids and logprobs to cpu
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
if batch.return_logprob: if batch.return_logprob:
...@@ -199,57 +197,52 @@ class SchedulerOutputProcessorMixin: ...@@ -199,57 +197,52 @@ class SchedulerOutputProcessorMixin:
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req) self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
def hacky_process_eagle_overlap_result( def _resolve_spec_overlap_token_ids(
self: Scheduler, result: GenerationBatchResult, batch: ScheduleBatch self: Scheduler, result: GenerationBatchResult, batch: ScheduleBatch
): ) -> List[List[int]]:
# TODO(lsyin): try use a copy stream to share SMs with forward """Resolve the padding next token ids for speculative decoding with overlap."""
# FIXME(lsyin): better organize this token free logic in eagle-overlap assert result.next_token_ids.is_cpu
last_batch_allocate_lens_cpu = result.last_batch_allocate_lens.tolist() assert result.accept_lens.is_cpu
accept_lens_cpu = result.accept_lens.tolist() assert result.allocate_lens.is_cpu
next_token_ids = result.next_token_ids.tolist() next_token_ids = result.next_token_ids.tolist()
accept_lens = result.accept_lens.tolist()
result.num_accepted_tokens = sum(accept_lens)
predict_tokens = [] predict_tokens = []
num_draft_tokens = self.draft_worker.speculative_num_draft_tokens stride = self.draft_worker.speculative_num_draft_tokens
for i, req in enumerate(batch.reqs): for i, req in enumerate(batch.reqs):
predict_tokens.append( predict_tokens.append(
next_token_ids[ next_token_ids[i * stride : i * stride + accept_lens[i]]
i * num_draft_tokens : i * num_draft_tokens + accept_lens_cpu[i]
]
) )
# FIXME(lsyin): move this update elsewhere
req.spec_verify_ct += 1 req.spec_verify_ct += 1
return last_batch_allocate_lens_cpu, accept_lens_cpu, predict_tokens return predict_tokens
def process_batch_result_decode( def process_batch_result_decode(
self: Scheduler, self: Scheduler,
batch: ScheduleBatch, batch: ScheduleBatch,
result: GenerationBatchResult, result: GenerationBatchResult,
): ):
logits_output, next_token_ids, can_run_cuda_graph, copy_done = ( if result.copy_done is not None:
result.copy_done.synchronize()
logits_output, next_token_ids, can_run_cuda_graph = (
result.logits_output, result.logits_output,
result.next_token_ids, result.next_token_ids,
result.can_run_cuda_graph, result.can_run_cuda_graph,
result.copy_done,
) )
self.num_generated_tokens += len(batch.reqs)
if copy_done is not None:
copy_done.synchronize()
if batch.spec_algorithm.is_none(): if batch.spec_algorithm.is_none():
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
if batch.return_logprob: if batch.return_logprob:
next_token_logprobs = logits_output.next_token_logprobs.tolist() next_token_logprobs = logits_output.next_token_logprobs.tolist()
elif batch.is_v2_eagle: elif batch.is_v2_eagle:
( next_token_ids = self._resolve_spec_overlap_token_ids(result, batch)
last_batch_allocate_lens_cpu, allocate_lens_list = result.allocate_lens.tolist()
accept_lens_cpu, accept_lens_list = result.accept_lens.tolist()
next_token_ids,
) = self.hacky_process_eagle_overlap_result(result, batch)
result.num_accepted_tokens = sum(accept_lens_cpu)
# FIXME(lsyin): we suppose we have already got the num_accepted_tokens in result self.num_generated_tokens += len(batch.reqs)
if not self.spec_algorithm.is_none(): if not self.spec_algorithm.is_none():
self.update_spec_metrics(batch.batch_size(), result.num_accepted_tokens) self.update_spec_metrics(batch.batch_size(), result.num_accepted_tokens)
...@@ -264,43 +257,38 @@ class SchedulerOutputProcessorMixin: ...@@ -264,43 +257,38 @@ class SchedulerOutputProcessorMixin:
continue continue
if self.enable_overlap and req.finished(): if self.enable_overlap and req.finished():
indices_to_free = None
if self.page_size == 1: if self.page_size == 1:
if batch.spec_algorithm.is_eagle(): if batch.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_worker_v2 import ( from sglang.srt.speculative.eagle_info import EagleDraftInput
free_spec_dec_tokens_page_size_1,
)
free_spec_dec_tokens_page_size_1( end_p = allocate_lens_list[i]
self.req_to_token_pool, start_p = end_p - EagleDraftInput.ALLOC_LEN_PER_DECODE
self.token_to_kv_pool_allocator, indices_to_free = self.req_to_token_pool.req_to_token[
req, req.req_pool_idx
last_batch_allocate_lens_cpu[i], ][start_p:end_p]
None,
)
else: else:
# Free the one extra delayed token # Free the one extra delayed token
self.token_to_kv_pool_allocator.free( indices_to_free = batch.out_cache_loc[i : i + 1]
batch.out_cache_loc[i : i + 1]
)
else: else:
if batch.spec_algorithm.is_eagle(): if batch.spec_algorithm.is_eagle():
# TODO(lsyin): support eagle with page_size > 1 # TODO(spec-v2): support eagle with page_size > 1
raise NotImplementedError() raise NotImplementedError()
else: else:
if ( if (
len(req.origin_input_ids) + len(req.output_ids) - 1 len(req.origin_input_ids) + len(req.output_ids) - 1
) % self.page_size == 0: ) % self.page_size == 0:
# Only free when the extra token is in a new page # Only free when the extra token is in a new page
self.token_to_kv_pool_allocator.free( indices_to_free = batch.out_cache_loc[i : i + 1]
batch.out_cache_loc[i : i + 1]
) if indices_to_free is not None:
self.token_to_kv_pool_allocator.free(indices_to_free)
continue continue
if batch.spec_algorithm.is_none(): if batch.spec_algorithm.is_none():
req.output_ids.append(next_token_id) req.output_ids.append(next_token_id)
elif batch.is_v2_eagle: elif batch.is_v2_eagle:
# FIXME(lsyin): non-overlap spec worker will solve the output_ids in speculative decoding # Only v2 eagle's output_ids are updated here.
# !!!unify the logic here!!!
req.output_ids.extend(next_token_id) req.output_ids.extend(next_token_id)
req.check_finished() req.check_finished()
...@@ -308,24 +296,13 @@ class SchedulerOutputProcessorMixin: ...@@ -308,24 +296,13 @@ class SchedulerOutputProcessorMixin:
if batch.is_v2_eagle and self.cur_batch.forward_mode.is_extend(): if batch.is_v2_eagle and self.cur_batch.forward_mode.is_extend():
# FIXME(lsyin): fix the messy logic here # FIXME(lsyin): fix the messy logic here
# 1) when not overlap (v2 impl), we free the extra tokens in the req # 1) when not overlap (v2 impl), we free the extra tokens in the req
# 2) when overlap and current batch is extend, we free the extra tokens in the req of the previous batch # 2) overlap eagle and the current batch is prefill. This seq will not run extra iteration.
from sglang.srt.speculative.eagle_worker_v2 import ( start_p = batch.seq_lens_cpu[i] + accept_lens_list[i]
free_spec_dec_tokens_page_size_1, end_p = allocate_lens_list[i]
) indices_to_free = self.req_to_token_pool.req_to_token[
req.req_pool_idx
new_seq_len = len(req.origin_input_ids) + len(req.output_ids) - 1 ][start_p:end_p]
# FIXME(lsyin): remove this assert self.token_to_kv_pool_allocator.free(indices_to_free)
assert new_seq_len == int(
batch.seq_lens_cpu[i] + accept_lens_cpu[i]
), f"{new_seq_len=} vs {batch.seq_lens_cpu[i] + accept_lens_cpu[i]=}"
free_spec_dec_tokens_page_size_1(
self.req_to_token_pool,
self.token_to_kv_pool_allocator,
req,
last_batch_allocate_lens_cpu[i],
new_seq_len,
)
if self.server_args.disaggregation_decode_enable_offload_kvcache: if self.server_args.disaggregation_decode_enable_offload_kvcache:
# Asynchronously offload KV cache; cache_finished_req will be called after Device->Host transfer completes # Asynchronously offload KV cache; cache_finished_req will be called after Device->Host transfer completes
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
import torch import torch
...@@ -54,7 +55,140 @@ if TYPE_CHECKING: ...@@ -54,7 +55,140 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TpModelWorker: class BaseTpWorker(ABC):
@abstractmethod
def forward_batch_generation(self, forward_batch: ForwardBatch):
pass
@property
@abstractmethod
def model_runner(self) -> ModelRunner:
pass
@property
def sliding_window_size(self) -> Optional[int]:
return self.model_runner.sliding_window_size
@property
def is_hybrid(self) -> bool:
return self.model_runner.is_hybrid is not None
def get_tokens_per_layer_info(self):
return (
self.model_runner.full_max_total_num_tokens,
self.model_runner.swa_max_total_num_tokens,
)
def get_pad_input_ids_func(self):
return getattr(self.model_runner.model, "pad_input_ids", None)
def get_tp_group(self):
return self.model_runner.tp_group
def get_attention_tp_group(self):
return self.model_runner.attention_tp_group
def get_attention_tp_cpu_group(self):
return getattr(self.model_runner.attention_tp_group, "cpu_group", None)
def get_memory_pool(self):
return (
self.model_runner.req_to_token_pool,
self.model_runner.token_to_kv_pool_allocator,
)
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
success, message = self.model_runner.update_weights_from_disk(
recv_req.model_path, recv_req.load_format
)
return success, message
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
success, message = self.model_runner.init_weights_update_group(
recv_req.master_address,
recv_req.master_port,
recv_req.rank_offset,
recv_req.world_size,
recv_req.group_name,
recv_req.backend,
)
return success, message
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
success, message = self.model_runner.destroy_weights_update_group(
recv_req.group_name,
)
return success, message
def init_weights_send_group_for_remote_instance(
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
):
success, message = (
self.model_runner.init_weights_send_group_for_remote_instance(
recv_req.master_address,
recv_req.ports,
recv_req.group_rank,
recv_req.world_size,
recv_req.group_name,
recv_req.backend,
)
)
return success, message
def send_weights_to_remote_instance(
self, recv_req: SendWeightsToRemoteInstanceReqInput
):
success, message = self.model_runner.send_weights_to_remote_instance(
recv_req.master_address,
recv_req.ports,
recv_req.group_name,
)
return success, message
def update_weights_from_distributed(
self, recv_req: UpdateWeightsFromDistributedReqInput
):
success, message = self.model_runner.update_weights_from_distributed(
recv_req.names, recv_req.dtypes, recv_req.shapes, recv_req.group_name
)
return success, message
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
monkey_patch_torch_reductions()
success, message = self.model_runner.update_weights_from_tensor(
named_tensors=MultiprocessingSerializer.deserialize(
recv_req.serialized_named_tensors[self.tp_rank]
),
load_format=recv_req.load_format,
)
return success, message
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.model_runner.get_weights_by_name(
recv_req.name, recv_req.truncate_size
)
return parameter
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
result = self.model_runner.load_lora_adapter(recv_req.to_ref())
return result
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
return result
def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
return self.model_runner.lora_manager.validate_lora_batch(lora_ids)
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output, _ = self.model_runner.forward(forward_batch)
embeddings = logits_output.embeddings
return embeddings
class TpModelWorker(BaseTpWorker):
"""A tensor parallel model worker.""" """A tensor parallel model worker."""
def __init__( def __init__(
...@@ -92,7 +226,7 @@ class TpModelWorker: ...@@ -92,7 +226,7 @@ class TpModelWorker:
is_draft_model=is_draft_worker, is_draft_model=is_draft_worker,
) )
self.model_runner = ModelRunner( self._model_runner = ModelRunner(
model_config=self.model_config, model_config=self.model_config,
mem_fraction_static=server_args.mem_fraction_static, mem_fraction_static=server_args.mem_fraction_static,
gpu_id=gpu_id, gpu_id=gpu_id,
...@@ -171,6 +305,10 @@ class TpModelWorker: ...@@ -171,6 +305,10 @@ class TpModelWorker:
self.enable_overlap = not server_args.disable_overlap_schedule self.enable_overlap = not server_args.disable_overlap_schedule
self.hicache_layer_transfer_counter = None self.hicache_layer_transfer_counter = None
@property
def model_runner(self) -> ModelRunner:
return self._model_runner
def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter): def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
self.hicache_layer_transfer_counter = counter self.hicache_layer_transfer_counter = counter
...@@ -193,38 +331,6 @@ class TpModelWorker: ...@@ -193,38 +331,6 @@ class TpModelWorker:
self.model_runner.token_to_kv_pool.size, self.model_runner.token_to_kv_pool.size,
) )
@property
def sliding_window_size(self) -> Optional[int]:
return self.model_runner.sliding_window_size
@property
def is_hybrid(self) -> bool:
return self.model_runner.is_hybrid is not None
def get_tokens_per_layer_info(self):
return (
self.model_runner.full_max_total_num_tokens,
self.model_runner.swa_max_total_num_tokens,
)
def get_pad_input_ids_func(self):
return getattr(self.model_runner.model, "pad_input_ids", None)
def get_tp_group(self):
return self.model_runner.tp_group
def get_attention_tp_group(self):
return self.model_runner.attention_tp_group
def get_attention_tp_cpu_group(self):
return getattr(self.model_runner.attention_tp_group, "cpu_group", None)
def get_memory_pool(self):
return (
self.model_runner.req_to_token_pool,
self.model_runner.token_to_kv_pool_allocator,
)
def forward_batch_generation( def forward_batch_generation(
self, self,
model_worker_batch: ModelWorkerBatch, model_worker_batch: ModelWorkerBatch,
...@@ -313,93 +419,3 @@ class TpModelWorker: ...@@ -313,93 +419,3 @@ class TpModelWorker:
pp_hidden_states_proxy_tensors=pp_proxy_tensors, pp_hidden_states_proxy_tensors=pp_proxy_tensors,
can_run_cuda_graph=can_run_cuda_graph, can_run_cuda_graph=can_run_cuda_graph,
) )
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output, _ = self.model_runner.forward(forward_batch)
embeddings = logits_output.embeddings
return embeddings
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
success, message = self.model_runner.update_weights_from_disk(
recv_req.model_path, recv_req.load_format
)
return success, message
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
success, message = self.model_runner.init_weights_update_group(
recv_req.master_address,
recv_req.master_port,
recv_req.rank_offset,
recv_req.world_size,
recv_req.group_name,
recv_req.backend,
)
return success, message
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
success, message = self.model_runner.destroy_weights_update_group(
recv_req.group_name,
)
return success, message
def init_weights_send_group_for_remote_instance(
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
):
success, message = (
self.model_runner.init_weights_send_group_for_remote_instance(
recv_req.master_address,
recv_req.ports,
recv_req.group_rank,
recv_req.world_size,
recv_req.group_name,
recv_req.backend,
)
)
return success, message
def send_weights_to_remote_instance(
self, recv_req: SendWeightsToRemoteInstanceReqInput
):
success, message = self.model_runner.send_weights_to_remote_instance(
recv_req.master_address,
recv_req.ports,
recv_req.group_name,
)
return success, message
def update_weights_from_distributed(
self, recv_req: UpdateWeightsFromDistributedReqInput
):
success, message = self.model_runner.update_weights_from_distributed(
recv_req.names, recv_req.dtypes, recv_req.shapes, recv_req.group_name
)
return success, message
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
monkey_patch_torch_reductions()
success, message = self.model_runner.update_weights_from_tensor(
named_tensors=MultiprocessingSerializer.deserialize(
recv_req.serialized_named_tensors[self.tp_rank]
),
load_format=recv_req.load_format,
)
return success, message
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.model_runner.get_weights_by_name(
recv_req.name, recv_req.truncate_size
)
return parameter
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
result = self.model_runner.load_lora_adapter(recv_req.to_ref())
return result
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
return result
def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
return self.model_runner.lora_manager.validate_lora_batch(lora_ids)
...@@ -53,7 +53,6 @@ from sglang.srt.utils import ( ...@@ -53,7 +53,6 @@ from sglang.srt.utils import (
empty_context, empty_context,
get_available_gpu_memory, get_available_gpu_memory,
get_bool_env_var, get_bool_env_var,
get_device_memory_capacity,
is_hip, is_hip,
log_info_on_rank0, log_info_on_rank0,
require_attn_tp_gather, require_attn_tp_gather,
...@@ -274,7 +273,6 @@ class CudaGraphRunner: ...@@ -274,7 +273,6 @@ class CudaGraphRunner:
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
) )
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
self.encoder_len_fill_value = 0 self.encoder_len_fill_value = 0
self.seq_lens_cpu = torch.full( self.seq_lens_cpu = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
......
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from sglang.srt.managers.tp_worker import TpModelWorker
class BaseDraftWorker(ABC):
@abstractmethod
def draft():
pass
@abstractmethod
def draft_extend():
pass
class BaseSpecWorker(ABC):
@property
@abstractmethod
def target_worker(self) -> TpModelWorker:
pass
@property
@abstractmethod
def draft_worker(self) -> BaseDraftWorker:
pass
...@@ -40,7 +40,11 @@ class EAGLEDraftCudaGraphRunner: ...@@ -40,7 +40,11 @@ class EAGLEDraftCudaGraphRunner:
def __init__(self, eagle_worker: EAGLEWorker): def __init__(self, eagle_worker: EAGLEWorker):
# Parse args # Parse args
self.eagle_worker = eagle_worker self.eagle_worker = eagle_worker
self.model_runner = model_runner = eagle_worker.model_runner if not hasattr(eagle_worker, "model_runner"):
# V2: EagleDraftWorker
self.model_runner = model_runner = eagle_worker.draft_runner
else:
self.model_runner = model_runner = eagle_worker.model_runner
self.graphs = {} self.graphs = {}
self.output_buffers = {} self.output_buffers = {}
self.enable_torch_compile = model_runner.server_args.enable_torch_compile self.enable_torch_compile = model_runner.server_args.enable_torch_compile
......
...@@ -38,7 +38,12 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -38,7 +38,12 @@ class EAGLEDraftExtendCudaGraphRunner:
def __init__(self, eagle_worker: EAGLEWorker): def __init__(self, eagle_worker: EAGLEWorker):
# Parse args # Parse args
self.eagle_worker = eagle_worker self.eagle_worker = eagle_worker
self.model_runner = model_runner = eagle_worker.model_runner if not hasattr(eagle_worker, "model_runner"):
# V2: EagleDraftWorker
self.model_runner = model_runner = eagle_worker.draft_runner
else:
self.model_runner = model_runner = eagle_worker.model_runner
self.graphs = {} self.graphs = {}
self.output_buffers = {} self.output_buffers = {}
self.enable_torch_compile = model_runner.server_args.enable_torch_compile self.enable_torch_compile = model_runner.server_args.enable_torch_compile
...@@ -285,7 +290,7 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -285,7 +290,7 @@ class EAGLEDraftExtendCudaGraphRunner:
output_cache_loc_backup = forward_batch.out_cache_loc output_cache_loc_backup = forward_batch.out_cache_loc
hidden_states_backup = forward_batch.spec_info.hidden_states hidden_states_backup = forward_batch.spec_info.hidden_states
ret = self.eagle_worker.draft_model_runner.model.forward( ret = self.model_runner.model.forward(
forward_batch.input_ids, forward_batch.input_ids,
forward_batch.positions, forward_batch.positions,
forward_batch, forward_batch,
......
...@@ -574,6 +574,9 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): ...@@ -574,6 +574,9 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
@dataclass @dataclass
class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
# Constant: alloc length per decode step
ALLOC_LEN_PER_DECODE: ClassVar[int] = None
# The inputs for decode # The inputs for decode
# shape: (b, topk) # shape: (b, topk)
topk_p: torch.Tensor = None topk_p: torch.Tensor = None
...@@ -609,9 +612,6 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): ...@@ -609,9 +612,6 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
new_seq_lens: Optional[torch.Tensor] = None new_seq_lens: Optional[torch.Tensor] = None
verify_done: Optional[torch.cuda.Event] = None verify_done: Optional[torch.cuda.Event] = None
# FIXME(lsyin): remove this hack
ALLOC_LEN_PER_DECODE: ClassVar[int] = None
def __post_init__(self): def __post_init__(self):
super().__init__(SpecInputType.EAGLE_DRAFT) super().__init__(SpecInputType.EAGLE_DRAFT)
......
...@@ -9,7 +9,8 @@ import triton ...@@ -9,7 +9,8 @@ import triton
import triton.language as tl import triton.language as tl
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch
from sglang.srt.mem_cache.common import alloc_token_slots
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode, CaptureHiddenMode,
...@@ -72,6 +73,34 @@ def assign_draft_cache_locs_page_size_1( ...@@ -72,6 +73,34 @@ def assign_draft_cache_locs_page_size_1(
@dataclass @dataclass
class EagleDraftInputV2Mixin: class EagleDraftInputV2Mixin:
def prepare_for_decode(self: EagleDraftInput, batch: ScheduleBatch):
from sglang.srt.speculative.spec_utils import assign_req_to_token_pool
bs = batch.batch_size()
# TODO(lsyin): implement over-allocation
# Now seq_lens and allocate_lens are correct
batch.maybe_wait_verify_done()
new_allocate_lens = batch.seq_lens + self.ALLOC_LEN_PER_DECODE
num_needed_tokens = (new_allocate_lens - self.allocate_lens).sum().item()
out_cache_loc = alloc_token_slots(batch.tree_cache, num_needed_tokens)
assign_req_to_token_pool[(bs,)](
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
self.allocate_lens,
new_allocate_lens,
out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1],
next_power_of_2(bs),
)
self.allocate_lens = new_allocate_lens
# FIXME(lsyin): make this sync optional
batch.seq_lens_cpu = batch.seq_lens.cpu()
batch.seq_lens_sum = batch.seq_lens_cpu.sum().item()
def prepare_for_v2_draft( def prepare_for_v2_draft(
self: EagleDraftInput, self: EagleDraftInput,
req_to_token_pool: ReqToTokenPool, req_to_token_pool: ReqToTokenPool,
......
import logging import logging
import os
import time import time
from contextlib import contextmanager
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from huggingface_hub import snapshot_download
from sglang.srt.distributed import ( from sglang.srt.distributed import get_tp_group
GroupCoordinator,
get_tp_group,
patch_tensor_parallel_group,
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.schedule_batch import ScheduleBatch
...@@ -47,15 +40,17 @@ from sglang.srt.speculative.eagle_utils import ( ...@@ -47,15 +40,17 @@ from sglang.srt.speculative.eagle_utils import (
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.speculative.spec_utils import ( from sglang.srt.speculative.spec_utils import (
assign_draft_cache_locs, assign_draft_cache_locs,
detect_nan,
draft_tp_context,
fast_topk, fast_topk,
generate_token_bitmask, generate_token_bitmask,
load_token_map,
select_top_k_tokens, select_top_k_tokens,
) )
from sglang.srt.utils import ( from sglang.srt.utils import (
empty_context, empty_context,
get_available_gpu_memory, get_available_gpu_memory,
get_bool_env_var, get_bool_env_var,
is_blackwell,
is_cuda, is_cuda,
next_power_of_2, next_power_of_2,
) )
...@@ -67,14 +62,6 @@ logger = logging.getLogger(__name__) ...@@ -67,14 +62,6 @@ logger = logging.getLogger(__name__)
SGLANG_RETURN_ORIGINAL_LOGPROB = get_bool_env_var("SGLANG_RETURN_ORIGINAL_LOGPROB") SGLANG_RETURN_ORIGINAL_LOGPROB = get_bool_env_var("SGLANG_RETURN_ORIGINAL_LOGPROB")
@contextmanager
def draft_tp_context(tp_group: GroupCoordinator):
# Draft model doesn't use dp and has its own tp group.
# We disable mscclpp now because it doesn't support 2 comm groups.
with patch_tensor_parallel_group(tp_group):
yield
class EAGLEWorker(TpModelWorker): class EAGLEWorker(TpModelWorker):
def __init__( def __init__(
...@@ -100,7 +87,6 @@ class EAGLEWorker(TpModelWorker): ...@@ -100,7 +87,6 @@ class EAGLEWorker(TpModelWorker):
self.speculative_algorithm = SpeculativeAlgorithm.from_string( self.speculative_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm server_args.speculative_algorithm
) )
self.padded_static_len = -1
# Override the context length of the draft model to be the same as the target model. # Override the context length of the draft model to be the same as the target model.
server_args.context_length = target_worker.model_runner.model_config.context_len server_args.context_length = target_worker.model_runner.model_config.context_len
...@@ -612,7 +598,8 @@ class EAGLEWorker(TpModelWorker): ...@@ -612,7 +598,8 @@ class EAGLEWorker(TpModelWorker):
logits_output, _ = self.draft_model_runner.forward( logits_output, _ = self.draft_model_runner.forward(
forward_batch, skip_attn_backend_init=True forward_batch, skip_attn_backend_init=True
) )
self._detect_nan_if_needed(logits_output) if self.server_args.enable_nan_detection:
detect_nan(logits_output)
probs = torch.softmax(logits_output.next_token_logits, dim=-1) probs = torch.softmax(logits_output.next_token_logits, dim=-1)
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
if self.hot_token_id is not None: if self.hot_token_id is not None:
...@@ -680,7 +667,9 @@ class EAGLEWorker(TpModelWorker): ...@@ -680,7 +667,9 @@ class EAGLEWorker(TpModelWorker):
# and will be applied to produce wrong results # and will be applied to produce wrong results
batch.sampling_info.vocab_mask = None batch.sampling_info.vocab_mask = None
self._detect_nan_if_needed(logits_output) if self.enable_nan_detection:
detect_nan(logits_output)
spec_info.hidden_states = logits_output.hidden_states spec_info.hidden_states = logits_output.hidden_states
res: EagleVerifyOutput = spec_info.verify( res: EagleVerifyOutput = spec_info.verify(
batch, batch,
...@@ -833,7 +822,8 @@ class EAGLEWorker(TpModelWorker): ...@@ -833,7 +822,8 @@ class EAGLEWorker(TpModelWorker):
) )
forward_batch.return_logprob = False forward_batch.return_logprob = False
logits_output, _ = self.draft_model_runner.forward(forward_batch) logits_output, _ = self.draft_model_runner.forward(forward_batch)
self._detect_nan_if_needed(logits_output) if self.enable_nan_detection:
detect_nan(logits_output)
assert isinstance(forward_batch.spec_info, EagleDraftInput) assert isinstance(forward_batch.spec_info, EagleDraftInput)
assert forward_batch.spec_info is batch.spec_info assert forward_batch.spec_info is batch.spec_info
self.capture_for_decode(logits_output, forward_batch.spec_info) self.capture_for_decode(logits_output, forward_batch.spec_info)
...@@ -928,7 +918,8 @@ class EAGLEWorker(TpModelWorker): ...@@ -928,7 +918,8 @@ class EAGLEWorker(TpModelWorker):
) )
self.capture_for_decode(logits_output, forward_batch.spec_info) self.capture_for_decode(logits_output, forward_batch.spec_info)
self._detect_nan_if_needed(logits_output) if self.enable_nan_detection:
detect_nan(logits_output)
# Restore backup. # Restore backup.
# This is because `seq_lens` can be modified in `prepare_extend_after_decode` # This is because `seq_lens` can be modified in `prepare_extend_after_decode`
...@@ -948,24 +939,6 @@ class EAGLEWorker(TpModelWorker): ...@@ -948,24 +939,6 @@ class EAGLEWorker(TpModelWorker):
draft_input.topk_p, draft_input.topk_index = fast_topk(probs, self.topk, dim=-1) draft_input.topk_p, draft_input.topk_index = fast_topk(probs, self.topk, dim=-1)
draft_input.hidden_states = logits_output.hidden_states draft_input.hidden_states = logits_output.hidden_states
def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
if self.enable_nan_detection:
logits = logits_output.next_token_logits
if torch.any(torch.isnan(logits)):
logger.error("Detected errors during sampling! NaN in the logits.")
raise ValueError("Detected errors during sampling! NaN in the logits.")
def load_token_map(token_map_path: str) -> List[int]:
if not os.path.exists(token_map_path):
cache_dir = snapshot_download(
os.path.dirname(token_map_path),
ignore_patterns=["*.bin", "*.safetensors"],
)
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
hot_token_id = torch.load(token_map_path, weights_only=True)
return torch.tensor(hot_token_id, dtype=torch.int64)
@torch.compile(dynamic=True) @torch.compile(dynamic=True)
def get_last_loc_large_page_size_top_k_1( def get_last_loc_large_page_size_top_k_1(
......
import contextlib import contextlib
import logging import logging
from typing import List, Optional import time
from typing import List, Optional, Tuple
import torch import torch
from torch.cuda import Stream as CudaStream from torch.cuda import Stream as CudaStream
from sglang.srt.environ import envs from sglang.srt.environ import envs
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, Req
from sglang.srt.managers.scheduler import GenerationBatchResult from sglang.srt.managers.scheduler import GenerationBatchResult
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardBatch from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardBatch
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.base_spec_worker import BaseDraftWorker, BaseSpecWorker
from sglang.srt.speculative.draft_utils import DraftBackendFactory
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
EAGLEDraftCudaGraphRunner,
)
from sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner import (
EAGLEDraftExtendCudaGraphRunner,
)
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.eagle_info_v2 import ( from sglang.srt.speculative.eagle_info_v2 import (
assign_extend_cache_locs, assign_extend_cache_locs,
...@@ -22,69 +28,214 @@ from sglang.srt.speculative.eagle_info_v2 import ( ...@@ -22,69 +28,214 @@ from sglang.srt.speculative.eagle_info_v2 import (
select_top_k_tokens_tmp, select_top_k_tokens_tmp,
) )
from sglang.srt.speculative.eagle_utils import TreeMaskMode, build_tree_kernel_efficient from sglang.srt.speculative.eagle_utils import TreeMaskMode, build_tree_kernel_efficient
from sglang.srt.speculative.eagle_worker import EAGLEWorker from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils.common import fast_topk, next_power_of_2 from sglang.srt.speculative.spec_utils import (
detect_nan,
draft_tp_context,
load_token_map,
)
from sglang.srt.utils.common import (
empty_context,
fast_topk,
get_available_gpu_memory,
next_power_of_2,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class EAGLEWorkerV2(EAGLEWorker): def _get_plan_stream(
device: str,
) -> Tuple[Optional[CudaStream], contextlib.AbstractContextManager]:
if envs.SGLANG_ENABLE_OVERLAP_PLAN_STREAM.get():
plan_stream: CudaStream = torch.get_device_module(device).Stream()
plan_stream_ctx = torch.cuda.stream(plan_stream)
return plan_stream, plan_stream_ctx
else:
return None, contextlib.nullcontext()
class EagleDraftWorker(BaseDraftWorker):
def __init__( def __init__(
self, self,
server_args: ServerArgs, server_args: ServerArgs,
gpu_id: int, gpu_id: int,
tp_rank: int, tp_rank: int,
dp_rank: Optional[int], dp_rank: int,
moe_ep_rank: int, moe_ep_rank: int,
nccl_port: int, nccl_port: int,
target_worker: TpModelWorker, target_worker: TpModelWorker,
): ):
super().__init__( # copy args
server_args, self.server_args = server_args
gpu_id, self.gpu_id = gpu_id
tp_rank, self.tp_rank = tp_rank
dp_rank, self.dp_rank = dp_rank
moe_ep_rank, self.moe_ep_rank = moe_ep_rank
nccl_port, self.nccl_port = nccl_port
target_worker, self.target_worker = target_worker
# Args for easy access
self.device = server_args.device
self.topk = server_args.speculative_eagle_topk
self.speculative_num_steps = server_args.speculative_num_steps
self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
) )
# Set constant
EagleDraftInput.ALLOC_LEN_PER_DECODE = max( EagleDraftInput.ALLOC_LEN_PER_DECODE = max(
self.speculative_num_steps * self.topk, self.speculative_num_draft_tokens self.speculative_num_steps * self.topk, self.speculative_num_draft_tokens
) )
# Do not capture cuda graph in `TpModelWorker` init,
# will capture later with init_cuda_graphs()
backup_disable_cuda_graph = server_args.disable_cuda_graph
server_args.disable_cuda_graph = True
# Share the allocator with a target worker.
# Draft and target worker own their own KV cache pools.
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
target_worker.get_memory_pool()
)
with empty_context():
# Init draft worker
self.draft_worker = TpModelWorker(
server_args=server_args,
gpu_id=gpu_id,
tp_rank=tp_rank,
pp_rank=0, # FIXME
dp_rank=dp_rank,
moe_ep_rank=moe_ep_rank,
nccl_port=nccl_port,
is_draft_worker=True,
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
)
# Alias for better readability
self.draft_runner = self.draft_worker.model_runner
self.init_token_map()
self.init_lm_head()
# Init attention backend and cuda graphs
self.draft_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
self.draft_tp_context = (
draft_tp_context if server_args.enable_dp_attention else empty_context
)
with self.draft_tp_context(self.draft_runner.tp_group):
self.init_attention_backend()
self.init_cuda_graphs()
self.tree_mask_mode = TreeMaskMode.FULL_MASK self.tree_mask_mode = TreeMaskMode.FULL_MASK
if envs.SGLANG_ENABLE_OVERLAP_PLAN_STREAM.get(): self.plan_stream, self.plan_stream_ctx = _get_plan_stream(self.device)
self.plan_stream: CudaStream = torch.get_device_module(self.device).Stream()
self.plan_stream_ctx = torch.cuda.stream(self.plan_stream) def init_token_map(self):
# Load hot token ids
if self.speculative_algorithm.is_eagle3():
if self.server_args.speculative_token_map is not None:
logger.warning(
"Speculative token map specified, but EAGLE3 models already have this. Ignoring the specified token map."
)
self.hot_token_id = None
elif self.server_args.speculative_token_map is not None:
self.hot_token_id = load_token_map(self.server_args.speculative_token_map)
self.server_args.json_model_override_args = (
f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
)
else: else:
self.plan_stream = None self.hot_token_id = None
self.plan_stream_ctx = contextlib.nullcontext()
def init_lm_head(self):
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
if self.speculative_algorithm.is_eagle3():
# most cases EAGLE3 models don't share lm_head
# but some models (e.g. nvidia/gpt-oss-120b-Eagle3) shares
if (
hasattr(self.draft_runner.model, "load_lm_head_from_target")
and self.draft_runner.model.load_lm_head_from_target
):
self.draft_runner.model.set_embed_and_head(embed, head)
else:
self.draft_runner.model.set_embed(embed)
# grab hot token ids
if self.draft_runner.model.hot_token_id is not None:
self.hot_token_id = self.draft_runner.model.hot_token_id.to(
embed.device
)
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
if model_worker_batch.forward_mode.is_decode():
# FIXME(lsyin): why shall we use spec_info for both draft and verify?
draft_input: EagleDraftInput = model_worker_batch.spec_info
assert draft_input.is_draft_input()
verify_input: EagleVerifyInput = self.draft(model_worker_batch)
assert verify_input.is_verify_input()
model_worker_batch.spec_info = verify_input
batch_output = self.verify(model_worker_batch, draft_input.allocate_lens)
return batch_output
else: else:
# Target prefill if self.hot_token_id is not None:
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL head = head.clone()
batch_output = self.target_worker.forward_batch_generation( self.hot_token_id = self.hot_token_id.to(head.device)
model_worker_batch head.data = head.data[self.hot_token_id]
# Share the embedding and lm_head
self.draft_runner.model.set_embed_and_head(embed, head)
def init_attention_backend(self):
# Create multi-step attn backends and cuda graph runners
self.has_prefill_wrapper_verify = False
self.draft_extend_attn_backend = None
draft_backend_factory = DraftBackendFactory(
self.server_args,
self.draft_runner,
self.topk,
self.speculative_num_steps,
)
# Initialize decode attention backend
self.draft_attn_backend = draft_backend_factory.create_decode_backend()
# Initialize draft extend attention backend (respects speculative_attention_mode setting)
self.draft_extend_attn_backend = (
draft_backend_factory.create_draft_extend_backend()
)
self.draft_runner.draft_attn_backend = self.draft_attn_backend
self.tree_mask_mode = TreeMaskMode.FULL_MASK
def init_cuda_graphs(self):
"""Capture cuda graphs."""
self.cuda_graph_runner = None
self.cuda_graph_runner_for_draft_extend = None
if self.server_args.disable_cuda_graph:
return
# Capture draft
if self.speculative_num_steps > 1:
tic = time.perf_counter()
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info(
f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
)
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info(
f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
) )
# Draft prefill # Capture extend
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.LAST if self.draft_extend_attn_backend:
batch_output.next_draft_input = self.forward_draft_extend( tic = time.perf_counter()
model_worker_batch, before_mem = get_available_gpu_memory(self.device, self.gpu_id)
batch_output.logits_output.hidden_states, logger.info(
batch_output.next_token_ids, f"Capture draft extend cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
)
self.cuda_graph_runner_for_draft_extend = EAGLEDraftExtendCudaGraphRunner(
self
)
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info(
f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
) )
return batch_output
def draft(self, model_worker_batch: ModelWorkerBatch): def draft(self, model_worker_batch: ModelWorkerBatch):
draft_input: EagleDraftInput = model_worker_batch.spec_info draft_input: EagleDraftInput = model_worker_batch.spec_info
...@@ -92,7 +243,7 @@ class EAGLEWorkerV2(EAGLEWorker): ...@@ -92,7 +243,7 @@ class EAGLEWorkerV2(EAGLEWorker):
self.req_to_token_pool, self.req_to_token_pool,
model_worker_batch, model_worker_batch,
self.cuda_graph_runner, self.cuda_graph_runner,
self.draft_model_runner, self.draft_runner,
self.topk, self.topk,
self.speculative_num_steps, self.speculative_num_steps,
) )
...@@ -201,10 +352,11 @@ class EAGLEWorkerV2(EAGLEWorker): ...@@ -201,10 +352,11 @@ class EAGLEWorkerV2(EAGLEWorker):
spec_info.hidden_states = hidden_states spec_info.hidden_states = hidden_states
# Run forward # Run forward
logits_output = self.draft_model_runner.model.forward( logits_output = self.draft_runner.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch forward_batch.input_ids, forward_batch.positions, forward_batch
) )
self._detect_nan_if_needed(logits_output) if self.server_args.enable_nan_detection:
detect_nan(logits_output)
probs = torch.softmax(logits_output.next_token_logits, dim=-1) probs = torch.softmax(logits_output.next_token_logits, dim=-1)
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
if self.hot_token_id is not None: if self.hot_token_id is not None:
...@@ -233,10 +385,190 @@ class EAGLEWorkerV2(EAGLEWorker): ...@@ -233,10 +385,190 @@ class EAGLEWorkerV2(EAGLEWorker):
return parent_list, top_scores_index, draft_tokens return parent_list, top_scores_index, draft_tokens
def draft_extend(self):
pass
def _draft_extend_for_prefill(
self,
batch: ModelWorkerBatch,
target_hidden_states: torch.Tensor,
next_token_ids: torch.Tensor,
):
"""
Run draft model extend to correctly fill the KV cache.
Args:
batch: The batch to run.
target_hidden_states: Hidden states from the target model forward
next_token_ids: Next token ids generated from the target forward.
"""
# Construct input_ids
pt = 0
for i, extend_len in enumerate(batch.extend_seq_lens):
input_ids = batch.input_ids[pt : pt + extend_len]
batch.input_ids[pt : pt + extend_len] = torch.cat(
(input_ids[1:], next_token_ids[i].reshape(1))
)
pt += extend_len
# Construct spec_info
next_draft_input = EagleDraftInput(
hidden_states=target_hidden_states,
verified_id=next_token_ids,
new_seq_lens=batch.seq_lens,
allocate_lens=batch.seq_lens,
)
batch.spec_info = next_draft_input
# Run forward
forward_batch = ForwardBatch.init_new(batch, self.draft_runner)
logits_output, _ = self.draft_runner.forward(forward_batch)
# Update spec_info for the next draft step
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
next_draft_input.topk_p, next_draft_input.topk_index = fast_topk(
probs, self.topk, dim=-1
)
next_draft_input.hidden_states = logits_output.hidden_states
return next_draft_input
def _draft_extend_for_decode(
self, batch: ModelWorkerBatch, batch_result: GenerationBatchResult
):
# Batch 2: Draft extend
draft_input = EagleDraftInput(
hidden_states=batch_result.logits_output.hidden_states,
)
select_index = (
torch.arange(len(batch.seq_lens), device=self.device)
* self.speculative_num_draft_tokens
+ batch_result.accept_lens
- 1
)
# Prepare for draft extend in a separate stream
with self.plan_stream_ctx:
forward_batch = draft_input.prepare_for_extend_to_fill_draft_kvcache(
batch,
batch_result.next_token_ids,
self.speculative_num_draft_tokens,
self.draft_runner,
)
if self.plan_stream:
torch.cuda.current_stream().wait_stream(self.plan_stream)
# Run draft extend batch in the main compute stream
draft_logits_output = self.draft_runner.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
# Reorganize the spec info for the next batch
draft_logits_output.next_token_logits = draft_logits_output.next_token_logits[
select_index
]
draft_logits_output.hidden_states = draft_logits_output.hidden_states[
select_index
]
probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1)
ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1)
ret_hidden_states = draft_logits_output.hidden_states
# Construct the return values
next_draft_input = batch_result.next_draft_input
(
next_draft_input.topk_p,
next_draft_input.topk_index,
next_draft_input.hidden_states,
) = (
ret_topk_p,
ret_topk_index,
ret_hidden_states,
)
class EAGLEWorkerV2(BaseSpecWorker):
def __init__(
self,
server_args: ServerArgs,
gpu_id: int,
tp_rank: int,
dp_rank: Optional[int],
moe_ep_rank: int,
nccl_port: int,
target_worker: TpModelWorker,
):
# Parse arguments
self.server_args = server_args
self.topk = server_args.speculative_eagle_topk
self.speculative_num_steps = server_args.speculative_num_steps
self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
self.enable_nan_detection = server_args.enable_nan_detection
self.gpu_id = gpu_id
self.device = server_args.device
self._target_worker = target_worker
self.page_size = server_args.page_size
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
)
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
target_worker.get_memory_pool()
)
# Override the context length of the draft model to be the same as the target model.
server_args.context_length = target_worker.model_runner.model_config.context_len
self._draft_worker = EagleDraftWorker(
server_args, gpu_id, tp_rank, dp_rank, moe_ep_rank, nccl_port, target_worker
)
# Some dummy tensors
self.num_new_pages_per_topk = torch.empty(
(), dtype=torch.int64, device=self.device
)
self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device)
self.plan_stream, self.plan_stream_ctx = _get_plan_stream(self.device)
@property
def target_worker(self):
return self._target_worker
@property
def draft_worker(self):
return self._draft_worker
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
if model_worker_batch.forward_mode.is_decode():
draft_input: EagleDraftInput = model_worker_batch.spec_info
assert draft_input.is_draft_input()
verify_input: EagleVerifyInput = self.draft_worker.draft(model_worker_batch)
assert verify_input.is_verify_input()
model_worker_batch.spec_info = verify_input
batch_output = self.verify(model_worker_batch, draft_input.allocate_lens)
self.draft_worker._draft_extend_for_decode(model_worker_batch, batch_output)
return batch_output
else:
# Target prefill
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
batch_output = self.target_worker.forward_batch_generation(
model_worker_batch
)
# Draft prefill
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.LAST
batch_output.next_draft_input = self.draft_worker._draft_extend_for_prefill(
model_worker_batch,
batch_output.logits_output.hidden_states,
batch_output.next_token_ids,
)
return batch_output
def verify( def verify(
self, self,
batch: ModelWorkerBatch, batch: ModelWorkerBatch,
pre_draft_allocate_lens: torch.Tensor, cur_allocate_lens: torch.Tensor,
): ):
# Since batch.seq_lens is allocated in another stream, we need # Since batch.seq_lens is allocated in another stream, we need
# record_stream() to prevent pytorch gc and reuse the gpu memory # record_stream() to prevent pytorch gc and reuse the gpu memory
...@@ -284,7 +616,8 @@ class EAGLEWorkerV2(EAGLEWorker): ...@@ -284,7 +616,8 @@ class EAGLEWorkerV2(EAGLEWorker):
logits_output = forward_batch_output.logits_output logits_output = forward_batch_output.logits_output
# Sample # Sample
self._detect_nan_if_needed(logits_output) if self.enable_nan_detection:
detect_nan(logits_output)
( (
predict, predict,
accept_length, accept_length,
...@@ -303,53 +636,11 @@ class EAGLEWorkerV2(EAGLEWorker): ...@@ -303,53 +636,11 @@ class EAGLEWorkerV2(EAGLEWorker):
self.speculative_num_draft_tokens, self.speculative_num_draft_tokens,
) )
# Batch 2: Draft extend # Construct the next draft input
draft_input = EagleDraftInput(
hidden_states=logits_output.hidden_states,
)
select_index = (
torch.arange(len(batch.seq_lens), device=self.device)
* self.speculative_num_draft_tokens
+ accept_length
- 1
)
# Prepare for draft extend in a separate stream
with self.plan_stream_ctx:
forward_batch = draft_input.prepare_for_extend_to_fill_draft_kvcache(
batch,
predict,
self.speculative_num_draft_tokens,
self.draft_model_runner,
)
if self.plan_stream:
torch.cuda.current_stream().wait_stream(self.plan_stream)
# Run draft extend batch in the main compute stream
draft_logits_output = self.draft_model_runner.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
# Reorganize the spec info for the next batch
draft_logits_output.next_token_logits = draft_logits_output.next_token_logits[
select_index
]
draft_logits_output.hidden_states = draft_logits_output.hidden_states[
select_index
]
probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1)
ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1)
ret_hidden_states = draft_logits_output.hidden_states
# Construct the return values
next_draft_input = EagleDraftInput( next_draft_input = EagleDraftInput(
topk_p=ret_topk_p,
topk_index=ret_topk_index,
hidden_states=ret_hidden_states,
verified_id=verified_id, verified_id=verified_id,
new_seq_lens=new_seq_lens, new_seq_lens=new_seq_lens,
allocate_lens=pre_draft_allocate_lens, allocate_lens=cur_allocate_lens,
verify_done=verify_done, verify_done=verify_done,
) )
...@@ -359,52 +650,8 @@ class EAGLEWorkerV2(EAGLEWorker): ...@@ -359,52 +650,8 @@ class EAGLEWorkerV2(EAGLEWorker):
can_run_cuda_graph=can_run_cuda_graph, can_run_cuda_graph=can_run_cuda_graph,
next_draft_input=next_draft_input, next_draft_input=next_draft_input,
accept_lens=accept_length, accept_lens=accept_length,
last_batch_allocate_lens=pre_draft_allocate_lens, allocate_lens=cur_allocate_lens,
)
def forward_draft_extend(
self,
batch: ModelWorkerBatch,
target_hidden_states: torch.Tensor,
next_token_ids: torch.Tensor,
):
"""
Run draft model extend to correctly fill the KV cache.
Args:
batch: The batch to run.
target_hidden_states: Hidden states from the target model forward
next_token_ids: Next token ids generated from the target forward.
"""
# Construct input_ids
pt = 0
for i, extend_len in enumerate(batch.extend_seq_lens):
input_ids = batch.input_ids[pt : pt + extend_len]
batch.input_ids[pt : pt + extend_len] = torch.cat(
(input_ids[1:], next_token_ids[i].reshape(1))
)
pt += extend_len
# Construct spec_info
next_draft_input = EagleDraftInput(
hidden_states=target_hidden_states,
verified_id=next_token_ids,
new_seq_lens=batch.seq_lens,
allocate_lens=batch.seq_lens,
)
batch.spec_info = next_draft_input
# Run forward
forward_batch = ForwardBatch.init_new(batch, self.draft_model_runner)
logits_output, _ = self.draft_model_runner.forward(forward_batch)
# Update spec_info for the next draft step
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
next_draft_input.topk_p, next_draft_input.topk_index = fast_topk(
probs, self.topk, dim=-1
) )
next_draft_input.hidden_states = logits_output.hidden_states
return next_draft_input
def move_accepted_tokens_to_target_kvcache( def move_accepted_tokens_to_target_kvcache(
self, self,
...@@ -449,32 +696,3 @@ class EAGLEWorkerV2(EAGLEWorker): ...@@ -449,32 +696,3 @@ class EAGLEWorkerV2(EAGLEWorker):
self.token_to_kv_pool_allocator.get_kvcache().move_kv_cache( self.token_to_kv_pool_allocator.get_kvcache().move_kv_cache(
tgt_cache_loc, accepted_out_cache_loc tgt_cache_loc, accepted_out_cache_loc
) )
def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
if self.enable_nan_detection:
logits = logits_output.next_token_logits
if torch.any(torch.isnan(logits)):
logger.error("Detected errors during sampling! NaN in the logits.")
raise ValueError("Detected errors during sampling! NaN in the logits.")
def free_spec_dec_tokens_page_size_1(
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
req: Req,
allocate_len: int,
new_seq_len: int,
):
# FIXME(lsyin): move this function elsewhere
# free extra allocated tokens
if new_seq_len is None:
# True only for overlap eagle and the current batch is decode. This seq will be part of the decode, so the final iteration's allocation is not used (i.e. this case).
start_len = allocate_len - EagleDraftInput.ALLOC_LEN_PER_DECODE
else:
# True for 1) non-overlap; 2) overlap eagle and the current batch is prefill. This seq will not run extra iteration, so start_lens is passed in.
start_len = new_seq_len
indices_to_free = req_to_token_pool.req_to_token[req.req_pool_idx][
start_len:allocate_len
]
token_to_kv_pool_allocator.free(indices_to_free)
from __future__ import annotations from __future__ import annotations
import logging import logging
import os
import time import time
from contextlib import contextmanager
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, List
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from huggingface_hub import snapshot_download
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.distributed.parallel_state import (
GroupCoordinator,
patch_tensor_parallel_group,
)
from sglang.srt.environ import envs from sglang.srt.environ import envs
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.schedule_batch import Req
from sglang.srt.utils import is_cuda, is_hip from sglang.srt.utils import is_cuda, is_hip
if TYPE_CHECKING:
from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.speculative.eagle_info import EagleVerifyInput
if is_cuda(): if is_cuda():
from sgl_kernel import fast_topk from sgl_kernel import fast_topk
elif is_hip(): elif is_hip():
from sgl_kernel import fast_topk from sgl_kernel import fast_topk
if TYPE_CHECKING:
from sglang.srt.speculative.eagle_info import EagleVerifyInput
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -603,3 +615,29 @@ def generate_token_bitmask( ...@@ -603,3 +615,29 @@ def generate_token_bitmask(
verify_input.grammar = grammar verify_input.grammar = grammar
return allocate_token_bitmask return allocate_token_bitmask
def load_token_map(token_map_path: str) -> List[int]:
if not os.path.exists(token_map_path):
cache_dir = snapshot_download(
os.path.dirname(token_map_path),
ignore_patterns=["*.bin", "*.safetensors"],
)
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
hot_token_id = torch.load(token_map_path, weights_only=True)
return torch.tensor(hot_token_id, dtype=torch.int64)
@contextmanager
def draft_tp_context(tp_group: GroupCoordinator):
# Draft model doesn't use dp and has its own tp group.
# We disable mscclpp now because it doesn't support 2 comm groups.
with patch_tensor_parallel_group(tp_group):
yield
def detect_nan(logits_output: LogitsProcessorOutput):
logits = logits_output.next_token_logits
if torch.any(torch.isnan(logits)):
logger.error("Detected errors during sampling! NaN in the logits.")
raise ValueError("Detected errors during sampling! NaN in the logits.")
import logging import logging
from contextlib import contextmanager
from typing import Optional from typing import Optional
import torch import torch
from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.eagle_worker import EAGLEWorker, load_token_map from sglang.srt.speculative.eagle_worker import EAGLEWorker
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.speculative.spec_utils import draft_tp_context, load_token_map
from sglang.srt.utils import empty_context, get_bool_env_var, is_cuda from sglang.srt.utils import empty_context, get_bool_env_var, is_cuda
if is_cuda(): if is_cuda():
...@@ -18,14 +17,6 @@ logger = logging.getLogger(__name__) ...@@ -18,14 +17,6 @@ logger = logging.getLogger(__name__)
SGLANG_RETURN_ORIGINAL_LOGPROB = get_bool_env_var("SGLANG_RETURN_ORIGINAL_LOGPROB") SGLANG_RETURN_ORIGINAL_LOGPROB = get_bool_env_var("SGLANG_RETURN_ORIGINAL_LOGPROB")
@contextmanager
def draft_tp_context(tp_group: GroupCoordinator):
# Draft model doesn't use dp and has its own tp group.
# We disable mscclpp now because it doesn't support 2 comm groups.
with patch_tensor_parallel_group(tp_group):
yield
class StandaloneWorker(EAGLEWorker): class StandaloneWorker(EAGLEWorker):
def __init__( def __init__(
...@@ -51,7 +42,6 @@ class StandaloneWorker(EAGLEWorker): ...@@ -51,7 +42,6 @@ class StandaloneWorker(EAGLEWorker):
self.speculative_algorithm = SpeculativeAlgorithm.from_string( self.speculative_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm server_args.speculative_algorithm
) )
self.padded_static_len = -1
# Override the context length of the draft model to be the same as the target model. # Override the context length of the draft model to be the same as the target model.
server_args.context_length = target_worker.model_runner.model_config.context_len server_args.context_length = target_worker.model_runner.model_config.context_len
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment