Unverified Commit cde5a6e3 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Abstraction for spec worker and code cleanup (#11643)

parent 3e4c7da2
......@@ -1061,38 +1061,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
)
return req_pool_indices
def allocate_for_eagle_v2(self):
from sglang.srt.speculative.eagle_info import EagleDraftInput
from sglang.srt.speculative.spec_utils import assign_req_to_token_pool
bs = self.batch_size()
assert self.spec_info.is_draft_input()
draft_input: EagleDraftInput = self.spec_info
# FIXME(lsyin): now implementation does not enable over-allocation
# Now seq_lens and allocate_lens are correct
self.maybe_wait_verify_done()
new_allocate_lens = self.seq_lens + EagleDraftInput.ALLOC_LEN_PER_DECODE
num_needed_tokens = (new_allocate_lens - draft_input.allocate_lens).sum().item()
out_cache_loc = alloc_token_slots(self.tree_cache, num_needed_tokens)
assign_req_to_token_pool[(bs,)](
self.req_pool_indices,
self.req_to_token_pool.req_to_token,
draft_input.allocate_lens,
new_allocate_lens,
out_cache_loc,
self.req_to_token_pool.req_to_token.shape[1],
next_power_of_2(bs),
)
draft_input.allocate_lens = new_allocate_lens
# FIXME(lsyin): remove seq_lens_sum calculation
self.seq_lens_cpu = self.seq_lens.cpu()
self.seq_lens_sum = self.seq_lens_cpu.sum().item()
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
self.encoder_lens_cpu = []
self.encoder_cached = []
......@@ -1522,8 +1490,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
bs = len(self.reqs)
if self.is_v2_eagle:
# FIXME(lsyin): make this sync optional
self.allocate_for_eagle_v2()
# TODO(spec-v2): all v2 spec should go through this path
from sglang.srt.speculative.eagle_info import EagleDraftInput
draft_input: EagleDraftInput = self.spec_info
draft_input.prepare_for_decode(self)
if not self.spec_algorithm.is_none():
# if spec decoding is used, the decode batch is prepared inside
......
......@@ -215,10 +215,10 @@ class GenerationBatchResult:
delay_sample_func: Optional[callable] = None
future_indices: Optional[FutureIndices] = None
# FIXME(lsyin): maybe move to <BetterPlace> ?
# FIXME(lsyin): maybe move to a better place?
# sync path: forward stream -> output processor
accept_lens: Optional[torch.Tensor] = None
last_batch_allocate_lens: Optional[torch.Tensor] = None
allocate_lens: Optional[torch.Tensor] = None
# relay path: forward stream -> next step forward
next_draft_input: Optional[EagleDraftInput] = None
......@@ -246,10 +246,8 @@ class GenerationBatchResult:
if self.accept_lens is not None:
self.accept_lens = self.accept_lens.to("cpu", non_blocking=True)
if self.last_batch_allocate_lens is not None:
self.last_batch_allocate_lens = self.last_batch_allocate_lens.to(
"cpu", non_blocking=True
)
if self.allocate_lens is not None:
self.allocate_lens = self.allocate_lens.to("cpu", non_blocking=True)
self.copy_done.record()
......
......@@ -42,23 +42,21 @@ class SchedulerOutputProcessorMixin:
skip_stream_req = None
if self.is_generation:
if result.copy_done is not None:
result.copy_done.synchronize()
(
logits_output,
next_token_ids,
extend_input_len_per_req,
extend_logprob_start_len_per_req,
copy_done,
) = (
result.logits_output,
result.next_token_ids,
result.extend_input_len_per_req,
result.extend_logprob_start_len_per_req,
result.copy_done,
)
if copy_done is not None:
copy_done.synchronize()
# Move next_token_ids and logprobs to cpu
next_token_ids = next_token_ids.tolist()
if batch.return_logprob:
......@@ -199,57 +197,52 @@ class SchedulerOutputProcessorMixin:
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
def hacky_process_eagle_overlap_result(
def _resolve_spec_overlap_token_ids(
self: Scheduler, result: GenerationBatchResult, batch: ScheduleBatch
):
# TODO(lsyin): try use a copy stream to share SMs with forward
# FIXME(lsyin): better organize this token free logic in eagle-overlap
last_batch_allocate_lens_cpu = result.last_batch_allocate_lens.tolist()
accept_lens_cpu = result.accept_lens.tolist()
) -> List[List[int]]:
"""Resolve the padding next token ids for speculative decoding with overlap."""
assert result.next_token_ids.is_cpu
assert result.accept_lens.is_cpu
assert result.allocate_lens.is_cpu
next_token_ids = result.next_token_ids.tolist()
accept_lens = result.accept_lens.tolist()
result.num_accepted_tokens = sum(accept_lens)
predict_tokens = []
num_draft_tokens = self.draft_worker.speculative_num_draft_tokens
stride = self.draft_worker.speculative_num_draft_tokens
for i, req in enumerate(batch.reqs):
predict_tokens.append(
next_token_ids[
i * num_draft_tokens : i * num_draft_tokens + accept_lens_cpu[i]
]
next_token_ids[i * stride : i * stride + accept_lens[i]]
)
# FIXME(lsyin): move this update elsewhere
req.spec_verify_ct += 1
return last_batch_allocate_lens_cpu, accept_lens_cpu, predict_tokens
return predict_tokens
def process_batch_result_decode(
self: Scheduler,
batch: ScheduleBatch,
result: GenerationBatchResult,
):
logits_output, next_token_ids, can_run_cuda_graph, copy_done = (
if result.copy_done is not None:
result.copy_done.synchronize()
logits_output, next_token_ids, can_run_cuda_graph = (
result.logits_output,
result.next_token_ids,
result.can_run_cuda_graph,
result.copy_done,
)
self.num_generated_tokens += len(batch.reqs)
if copy_done is not None:
copy_done.synchronize()
if batch.spec_algorithm.is_none():
next_token_ids = next_token_ids.tolist()
if batch.return_logprob:
next_token_logprobs = logits_output.next_token_logprobs.tolist()
elif batch.is_v2_eagle:
(
last_batch_allocate_lens_cpu,
accept_lens_cpu,
next_token_ids,
) = self.hacky_process_eagle_overlap_result(result, batch)
result.num_accepted_tokens = sum(accept_lens_cpu)
next_token_ids = self._resolve_spec_overlap_token_ids(result, batch)
allocate_lens_list = result.allocate_lens.tolist()
accept_lens_list = result.accept_lens.tolist()
# FIXME(lsyin): we suppose we have already got the num_accepted_tokens in result
self.num_generated_tokens += len(batch.reqs)
if not self.spec_algorithm.is_none():
self.update_spec_metrics(batch.batch_size(), result.num_accepted_tokens)
......@@ -264,43 +257,38 @@ class SchedulerOutputProcessorMixin:
continue
if self.enable_overlap and req.finished():
indices_to_free = None
if self.page_size == 1:
if batch.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_worker_v2 import (
free_spec_dec_tokens_page_size_1,
)
from sglang.srt.speculative.eagle_info import EagleDraftInput
free_spec_dec_tokens_page_size_1(
self.req_to_token_pool,
self.token_to_kv_pool_allocator,
req,
last_batch_allocate_lens_cpu[i],
None,
)
end_p = allocate_lens_list[i]
start_p = end_p - EagleDraftInput.ALLOC_LEN_PER_DECODE
indices_to_free = self.req_to_token_pool.req_to_token[
req.req_pool_idx
][start_p:end_p]
else:
# Free the one extra delayed token
self.token_to_kv_pool_allocator.free(
batch.out_cache_loc[i : i + 1]
)
indices_to_free = batch.out_cache_loc[i : i + 1]
else:
if batch.spec_algorithm.is_eagle():
# TODO(lsyin): support eagle with page_size > 1
# TODO(spec-v2): support eagle with page_size > 1
raise NotImplementedError()
else:
if (
len(req.origin_input_ids) + len(req.output_ids) - 1
) % self.page_size == 0:
# Only free when the extra token is in a new page
self.token_to_kv_pool_allocator.free(
batch.out_cache_loc[i : i + 1]
)
indices_to_free = batch.out_cache_loc[i : i + 1]
if indices_to_free is not None:
self.token_to_kv_pool_allocator.free(indices_to_free)
continue
if batch.spec_algorithm.is_none():
req.output_ids.append(next_token_id)
elif batch.is_v2_eagle:
# FIXME(lsyin): non-overlap spec worker will solve the output_ids in speculative decoding
# !!!unify the logic here!!!
# Only v2 eagle's output_ids are updated here.
req.output_ids.extend(next_token_id)
req.check_finished()
......@@ -308,24 +296,13 @@ class SchedulerOutputProcessorMixin:
if batch.is_v2_eagle and self.cur_batch.forward_mode.is_extend():
# FIXME(lsyin): fix the messy logic here
# 1) when not overlap (v2 impl), we free the extra tokens in the req
# 2) when overlap and current batch is extend, we free the extra tokens in the req of the previous batch
from sglang.srt.speculative.eagle_worker_v2 import (
free_spec_dec_tokens_page_size_1,
)
new_seq_len = len(req.origin_input_ids) + len(req.output_ids) - 1
# FIXME(lsyin): remove this assert
assert new_seq_len == int(
batch.seq_lens_cpu[i] + accept_lens_cpu[i]
), f"{new_seq_len=} vs {batch.seq_lens_cpu[i] + accept_lens_cpu[i]=}"
free_spec_dec_tokens_page_size_1(
self.req_to_token_pool,
self.token_to_kv_pool_allocator,
req,
last_batch_allocate_lens_cpu[i],
new_seq_len,
)
# 2) overlap eagle and the current batch is prefill. This seq will not run extra iteration.
start_p = batch.seq_lens_cpu[i] + accept_lens_list[i]
end_p = allocate_lens_list[i]
indices_to_free = self.req_to_token_pool.req_to_token[
req.req_pool_idx
][start_p:end_p]
self.token_to_kv_pool_allocator.free(indices_to_free)
if self.server_args.disaggregation_decode_enable_offload_kvcache:
# Asynchronously offload KV cache; cache_finished_req will be called after Device->Host transfer completes
......
......@@ -15,6 +15,7 @@
from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional
import torch
......@@ -54,7 +55,140 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class TpModelWorker:
class BaseTpWorker(ABC):
@abstractmethod
def forward_batch_generation(self, forward_batch: ForwardBatch):
pass
@property
@abstractmethod
def model_runner(self) -> ModelRunner:
pass
@property
def sliding_window_size(self) -> Optional[int]:
return self.model_runner.sliding_window_size
@property
def is_hybrid(self) -> bool:
return self.model_runner.is_hybrid is not None
def get_tokens_per_layer_info(self):
return (
self.model_runner.full_max_total_num_tokens,
self.model_runner.swa_max_total_num_tokens,
)
def get_pad_input_ids_func(self):
return getattr(self.model_runner.model, "pad_input_ids", None)
def get_tp_group(self):
return self.model_runner.tp_group
def get_attention_tp_group(self):
return self.model_runner.attention_tp_group
def get_attention_tp_cpu_group(self):
return getattr(self.model_runner.attention_tp_group, "cpu_group", None)
def get_memory_pool(self):
return (
self.model_runner.req_to_token_pool,
self.model_runner.token_to_kv_pool_allocator,
)
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
success, message = self.model_runner.update_weights_from_disk(
recv_req.model_path, recv_req.load_format
)
return success, message
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
success, message = self.model_runner.init_weights_update_group(
recv_req.master_address,
recv_req.master_port,
recv_req.rank_offset,
recv_req.world_size,
recv_req.group_name,
recv_req.backend,
)
return success, message
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
success, message = self.model_runner.destroy_weights_update_group(
recv_req.group_name,
)
return success, message
def init_weights_send_group_for_remote_instance(
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
):
success, message = (
self.model_runner.init_weights_send_group_for_remote_instance(
recv_req.master_address,
recv_req.ports,
recv_req.group_rank,
recv_req.world_size,
recv_req.group_name,
recv_req.backend,
)
)
return success, message
def send_weights_to_remote_instance(
self, recv_req: SendWeightsToRemoteInstanceReqInput
):
success, message = self.model_runner.send_weights_to_remote_instance(
recv_req.master_address,
recv_req.ports,
recv_req.group_name,
)
return success, message
def update_weights_from_distributed(
self, recv_req: UpdateWeightsFromDistributedReqInput
):
success, message = self.model_runner.update_weights_from_distributed(
recv_req.names, recv_req.dtypes, recv_req.shapes, recv_req.group_name
)
return success, message
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
monkey_patch_torch_reductions()
success, message = self.model_runner.update_weights_from_tensor(
named_tensors=MultiprocessingSerializer.deserialize(
recv_req.serialized_named_tensors[self.tp_rank]
),
load_format=recv_req.load_format,
)
return success, message
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.model_runner.get_weights_by_name(
recv_req.name, recv_req.truncate_size
)
return parameter
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
result = self.model_runner.load_lora_adapter(recv_req.to_ref())
return result
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
return result
def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
return self.model_runner.lora_manager.validate_lora_batch(lora_ids)
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output, _ = self.model_runner.forward(forward_batch)
embeddings = logits_output.embeddings
return embeddings
class TpModelWorker(BaseTpWorker):
"""A tensor parallel model worker."""
def __init__(
......@@ -92,7 +226,7 @@ class TpModelWorker:
is_draft_model=is_draft_worker,
)
self.model_runner = ModelRunner(
self._model_runner = ModelRunner(
model_config=self.model_config,
mem_fraction_static=server_args.mem_fraction_static,
gpu_id=gpu_id,
......@@ -171,6 +305,10 @@ class TpModelWorker:
self.enable_overlap = not server_args.disable_overlap_schedule
self.hicache_layer_transfer_counter = None
@property
def model_runner(self) -> ModelRunner:
return self._model_runner
def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
self.hicache_layer_transfer_counter = counter
......@@ -193,38 +331,6 @@ class TpModelWorker:
self.model_runner.token_to_kv_pool.size,
)
@property
def sliding_window_size(self) -> Optional[int]:
return self.model_runner.sliding_window_size
@property
def is_hybrid(self) -> bool:
return self.model_runner.is_hybrid is not None
def get_tokens_per_layer_info(self):
return (
self.model_runner.full_max_total_num_tokens,
self.model_runner.swa_max_total_num_tokens,
)
def get_pad_input_ids_func(self):
return getattr(self.model_runner.model, "pad_input_ids", None)
def get_tp_group(self):
return self.model_runner.tp_group
def get_attention_tp_group(self):
return self.model_runner.attention_tp_group
def get_attention_tp_cpu_group(self):
return getattr(self.model_runner.attention_tp_group, "cpu_group", None)
def get_memory_pool(self):
return (
self.model_runner.req_to_token_pool,
self.model_runner.token_to_kv_pool_allocator,
)
def forward_batch_generation(
self,
model_worker_batch: ModelWorkerBatch,
......@@ -313,93 +419,3 @@ class TpModelWorker:
pp_hidden_states_proxy_tensors=pp_proxy_tensors,
can_run_cuda_graph=can_run_cuda_graph,
)
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output, _ = self.model_runner.forward(forward_batch)
embeddings = logits_output.embeddings
return embeddings
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
success, message = self.model_runner.update_weights_from_disk(
recv_req.model_path, recv_req.load_format
)
return success, message
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
success, message = self.model_runner.init_weights_update_group(
recv_req.master_address,
recv_req.master_port,
recv_req.rank_offset,
recv_req.world_size,
recv_req.group_name,
recv_req.backend,
)
return success, message
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
success, message = self.model_runner.destroy_weights_update_group(
recv_req.group_name,
)
return success, message
def init_weights_send_group_for_remote_instance(
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
):
success, message = (
self.model_runner.init_weights_send_group_for_remote_instance(
recv_req.master_address,
recv_req.ports,
recv_req.group_rank,
recv_req.world_size,
recv_req.group_name,
recv_req.backend,
)
)
return success, message
def send_weights_to_remote_instance(
self, recv_req: SendWeightsToRemoteInstanceReqInput
):
success, message = self.model_runner.send_weights_to_remote_instance(
recv_req.master_address,
recv_req.ports,
recv_req.group_name,
)
return success, message
def update_weights_from_distributed(
self, recv_req: UpdateWeightsFromDistributedReqInput
):
success, message = self.model_runner.update_weights_from_distributed(
recv_req.names, recv_req.dtypes, recv_req.shapes, recv_req.group_name
)
return success, message
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
monkey_patch_torch_reductions()
success, message = self.model_runner.update_weights_from_tensor(
named_tensors=MultiprocessingSerializer.deserialize(
recv_req.serialized_named_tensors[self.tp_rank]
),
load_format=recv_req.load_format,
)
return success, message
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.model_runner.get_weights_by_name(
recv_req.name, recv_req.truncate_size
)
return parameter
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
result = self.model_runner.load_lora_adapter(recv_req.to_ref())
return result
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
return result
def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
return self.model_runner.lora_manager.validate_lora_batch(lora_ids)
......@@ -53,7 +53,6 @@ from sglang.srt.utils import (
empty_context,
get_available_gpu_memory,
get_bool_env_var,
get_device_memory_capacity,
is_hip,
log_info_on_rank0,
require_attn_tp_gather,
......@@ -274,7 +273,6 @@ class CudaGraphRunner:
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
)
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
self.encoder_len_fill_value = 0
self.seq_lens_cpu = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
......
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from sglang.srt.managers.tp_worker import TpModelWorker
class BaseDraftWorker(ABC):
@abstractmethod
def draft():
pass
@abstractmethod
def draft_extend():
pass
class BaseSpecWorker(ABC):
@property
@abstractmethod
def target_worker(self) -> TpModelWorker:
pass
@property
@abstractmethod
def draft_worker(self) -> BaseDraftWorker:
pass
......@@ -40,7 +40,11 @@ class EAGLEDraftCudaGraphRunner:
def __init__(self, eagle_worker: EAGLEWorker):
# Parse args
self.eagle_worker = eagle_worker
self.model_runner = model_runner = eagle_worker.model_runner
if not hasattr(eagle_worker, "model_runner"):
# V2: EagleDraftWorker
self.model_runner = model_runner = eagle_worker.draft_runner
else:
self.model_runner = model_runner = eagle_worker.model_runner
self.graphs = {}
self.output_buffers = {}
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
......
......@@ -38,7 +38,12 @@ class EAGLEDraftExtendCudaGraphRunner:
def __init__(self, eagle_worker: EAGLEWorker):
# Parse args
self.eagle_worker = eagle_worker
self.model_runner = model_runner = eagle_worker.model_runner
if not hasattr(eagle_worker, "model_runner"):
# V2: EagleDraftWorker
self.model_runner = model_runner = eagle_worker.draft_runner
else:
self.model_runner = model_runner = eagle_worker.model_runner
self.graphs = {}
self.output_buffers = {}
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
......@@ -285,7 +290,7 @@ class EAGLEDraftExtendCudaGraphRunner:
output_cache_loc_backup = forward_batch.out_cache_loc
hidden_states_backup = forward_batch.spec_info.hidden_states
ret = self.eagle_worker.draft_model_runner.model.forward(
ret = self.model_runner.model.forward(
forward_batch.input_ids,
forward_batch.positions,
forward_batch,
......
......@@ -574,6 +574,9 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
@dataclass
class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
# Constant: alloc length per decode step
ALLOC_LEN_PER_DECODE: ClassVar[int] = None
# The inputs for decode
# shape: (b, topk)
topk_p: torch.Tensor = None
......@@ -609,9 +612,6 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
new_seq_lens: Optional[torch.Tensor] = None
verify_done: Optional[torch.cuda.Event] = None
# FIXME(lsyin): remove this hack
ALLOC_LEN_PER_DECODE: ClassVar[int] = None
def __post_init__(self):
super().__init__(SpecInputType.EAGLE_DRAFT)
......
......@@ -9,7 +9,8 @@ import triton
import triton.language as tl
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch
from sglang.srt.mem_cache.common import alloc_token_slots
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
......@@ -72,6 +73,34 @@ def assign_draft_cache_locs_page_size_1(
@dataclass
class EagleDraftInputV2Mixin:
def prepare_for_decode(self: EagleDraftInput, batch: ScheduleBatch):
from sglang.srt.speculative.spec_utils import assign_req_to_token_pool
bs = batch.batch_size()
# TODO(lsyin): implement over-allocation
# Now seq_lens and allocate_lens are correct
batch.maybe_wait_verify_done()
new_allocate_lens = batch.seq_lens + self.ALLOC_LEN_PER_DECODE
num_needed_tokens = (new_allocate_lens - self.allocate_lens).sum().item()
out_cache_loc = alloc_token_slots(batch.tree_cache, num_needed_tokens)
assign_req_to_token_pool[(bs,)](
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
self.allocate_lens,
new_allocate_lens,
out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1],
next_power_of_2(bs),
)
self.allocate_lens = new_allocate_lens
# FIXME(lsyin): make this sync optional
batch.seq_lens_cpu = batch.seq_lens.cpu()
batch.seq_lens_sum = batch.seq_lens_cpu.sum().item()
def prepare_for_v2_draft(
self: EagleDraftInput,
req_to_token_pool: ReqToTokenPool,
......
import logging
import os
import time
from contextlib import contextmanager
from typing import List, Optional, Tuple
import torch
from huggingface_hub import snapshot_download
from sglang.srt.distributed import (
GroupCoordinator,
get_tp_group,
patch_tensor_parallel_group,
)
from sglang.srt.distributed import get_tp_group
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
from sglang.srt.managers.schedule_batch import ScheduleBatch
......@@ -47,15 +40,17 @@ from sglang.srt.speculative.eagle_utils import (
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.speculative.spec_utils import (
assign_draft_cache_locs,
detect_nan,
draft_tp_context,
fast_topk,
generate_token_bitmask,
load_token_map,
select_top_k_tokens,
)
from sglang.srt.utils import (
empty_context,
get_available_gpu_memory,
get_bool_env_var,
is_blackwell,
is_cuda,
next_power_of_2,
)
......@@ -67,14 +62,6 @@ logger = logging.getLogger(__name__)
SGLANG_RETURN_ORIGINAL_LOGPROB = get_bool_env_var("SGLANG_RETURN_ORIGINAL_LOGPROB")
@contextmanager
def draft_tp_context(tp_group: GroupCoordinator):
# Draft model doesn't use dp and has its own tp group.
# We disable mscclpp now because it doesn't support 2 comm groups.
with patch_tensor_parallel_group(tp_group):
yield
class EAGLEWorker(TpModelWorker):
def __init__(
......@@ -100,7 +87,6 @@ class EAGLEWorker(TpModelWorker):
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
)
self.padded_static_len = -1
# Override the context length of the draft model to be the same as the target model.
server_args.context_length = target_worker.model_runner.model_config.context_len
......@@ -612,7 +598,8 @@ class EAGLEWorker(TpModelWorker):
logits_output, _ = self.draft_model_runner.forward(
forward_batch, skip_attn_backend_init=True
)
self._detect_nan_if_needed(logits_output)
if self.server_args.enable_nan_detection:
detect_nan(logits_output)
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
if self.hot_token_id is not None:
......@@ -680,7 +667,9 @@ class EAGLEWorker(TpModelWorker):
# and will be applied to produce wrong results
batch.sampling_info.vocab_mask = None
self._detect_nan_if_needed(logits_output)
if self.enable_nan_detection:
detect_nan(logits_output)
spec_info.hidden_states = logits_output.hidden_states
res: EagleVerifyOutput = spec_info.verify(
batch,
......@@ -833,7 +822,8 @@ class EAGLEWorker(TpModelWorker):
)
forward_batch.return_logprob = False
logits_output, _ = self.draft_model_runner.forward(forward_batch)
self._detect_nan_if_needed(logits_output)
if self.enable_nan_detection:
detect_nan(logits_output)
assert isinstance(forward_batch.spec_info, EagleDraftInput)
assert forward_batch.spec_info is batch.spec_info
self.capture_for_decode(logits_output, forward_batch.spec_info)
......@@ -928,7 +918,8 @@ class EAGLEWorker(TpModelWorker):
)
self.capture_for_decode(logits_output, forward_batch.spec_info)
self._detect_nan_if_needed(logits_output)
if self.enable_nan_detection:
detect_nan(logits_output)
# Restore backup.
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
......@@ -948,24 +939,6 @@ class EAGLEWorker(TpModelWorker):
draft_input.topk_p, draft_input.topk_index = fast_topk(probs, self.topk, dim=-1)
draft_input.hidden_states = logits_output.hidden_states
def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
if self.enable_nan_detection:
logits = logits_output.next_token_logits
if torch.any(torch.isnan(logits)):
logger.error("Detected errors during sampling! NaN in the logits.")
raise ValueError("Detected errors during sampling! NaN in the logits.")
def load_token_map(token_map_path: str) -> List[int]:
if not os.path.exists(token_map_path):
cache_dir = snapshot_download(
os.path.dirname(token_map_path),
ignore_patterns=["*.bin", "*.safetensors"],
)
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
hot_token_id = torch.load(token_map_path, weights_only=True)
return torch.tensor(hot_token_id, dtype=torch.int64)
@torch.compile(dynamic=True)
def get_last_loc_large_page_size_top_k_1(
......
import contextlib
import logging
from typing import List, Optional
import time
from typing import List, Optional, Tuple
import torch
from torch.cuda import Stream as CudaStream
from sglang.srt.environ import envs
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, Req
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.scheduler import GenerationBatchResult
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardBatch
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.base_spec_worker import BaseDraftWorker, BaseSpecWorker
from sglang.srt.speculative.draft_utils import DraftBackendFactory
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
EAGLEDraftCudaGraphRunner,
)
from sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner import (
EAGLEDraftExtendCudaGraphRunner,
)
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.eagle_info_v2 import (
assign_extend_cache_locs,
......@@ -22,69 +28,214 @@ from sglang.srt.speculative.eagle_info_v2 import (
select_top_k_tokens_tmp,
)
from sglang.srt.speculative.eagle_utils import TreeMaskMode, build_tree_kernel_efficient
from sglang.srt.speculative.eagle_worker import EAGLEWorker
from sglang.srt.utils.common import fast_topk, next_power_of_2
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.speculative.spec_utils import (
detect_nan,
draft_tp_context,
load_token_map,
)
from sglang.srt.utils.common import (
empty_context,
fast_topk,
get_available_gpu_memory,
next_power_of_2,
)
logger = logging.getLogger(__name__)
class EAGLEWorkerV2(EAGLEWorker):
def _get_plan_stream(
device: str,
) -> Tuple[Optional[CudaStream], contextlib.AbstractContextManager]:
if envs.SGLANG_ENABLE_OVERLAP_PLAN_STREAM.get():
plan_stream: CudaStream = torch.get_device_module(device).Stream()
plan_stream_ctx = torch.cuda.stream(plan_stream)
return plan_stream, plan_stream_ctx
else:
return None, contextlib.nullcontext()
class EagleDraftWorker(BaseDraftWorker):
def __init__(
self,
server_args: ServerArgs,
gpu_id: int,
tp_rank: int,
dp_rank: Optional[int],
dp_rank: int,
moe_ep_rank: int,
nccl_port: int,
target_worker: TpModelWorker,
):
super().__init__(
server_args,
gpu_id,
tp_rank,
dp_rank,
moe_ep_rank,
nccl_port,
target_worker,
# copy args
self.server_args = server_args
self.gpu_id = gpu_id
self.tp_rank = tp_rank
self.dp_rank = dp_rank
self.moe_ep_rank = moe_ep_rank
self.nccl_port = nccl_port
self.target_worker = target_worker
# Args for easy access
self.device = server_args.device
self.topk = server_args.speculative_eagle_topk
self.speculative_num_steps = server_args.speculative_num_steps
self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
)
# Set constant
EagleDraftInput.ALLOC_LEN_PER_DECODE = max(
self.speculative_num_steps * self.topk, self.speculative_num_draft_tokens
)
# Do not capture cuda graph in `TpModelWorker` init,
# will capture later with init_cuda_graphs()
backup_disable_cuda_graph = server_args.disable_cuda_graph
server_args.disable_cuda_graph = True
# Share the allocator with a target worker.
# Draft and target worker own their own KV cache pools.
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
target_worker.get_memory_pool()
)
with empty_context():
# Init draft worker
self.draft_worker = TpModelWorker(
server_args=server_args,
gpu_id=gpu_id,
tp_rank=tp_rank,
pp_rank=0, # FIXME
dp_rank=dp_rank,
moe_ep_rank=moe_ep_rank,
nccl_port=nccl_port,
is_draft_worker=True,
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
)
# Alias for better readability
self.draft_runner = self.draft_worker.model_runner
self.init_token_map()
self.init_lm_head()
# Init attention backend and cuda graphs
self.draft_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
self.draft_tp_context = (
draft_tp_context if server_args.enable_dp_attention else empty_context
)
with self.draft_tp_context(self.draft_runner.tp_group):
self.init_attention_backend()
self.init_cuda_graphs()
self.tree_mask_mode = TreeMaskMode.FULL_MASK
if envs.SGLANG_ENABLE_OVERLAP_PLAN_STREAM.get():
self.plan_stream: CudaStream = torch.get_device_module(self.device).Stream()
self.plan_stream_ctx = torch.cuda.stream(self.plan_stream)
self.plan_stream, self.plan_stream_ctx = _get_plan_stream(self.device)
def init_token_map(self):
# Load hot token ids
if self.speculative_algorithm.is_eagle3():
if self.server_args.speculative_token_map is not None:
logger.warning(
"Speculative token map specified, but EAGLE3 models already have this. Ignoring the specified token map."
)
self.hot_token_id = None
elif self.server_args.speculative_token_map is not None:
self.hot_token_id = load_token_map(self.server_args.speculative_token_map)
self.server_args.json_model_override_args = (
f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
)
else:
self.plan_stream = None
self.plan_stream_ctx = contextlib.nullcontext()
self.hot_token_id = None
def init_lm_head(self):
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
if self.speculative_algorithm.is_eagle3():
# most cases EAGLE3 models don't share lm_head
# but some models (e.g. nvidia/gpt-oss-120b-Eagle3) shares
if (
hasattr(self.draft_runner.model, "load_lm_head_from_target")
and self.draft_runner.model.load_lm_head_from_target
):
self.draft_runner.model.set_embed_and_head(embed, head)
else:
self.draft_runner.model.set_embed(embed)
# grab hot token ids
if self.draft_runner.model.hot_token_id is not None:
self.hot_token_id = self.draft_runner.model.hot_token_id.to(
embed.device
)
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
if model_worker_batch.forward_mode.is_decode():
# FIXME(lsyin): why shall we use spec_info for both draft and verify?
draft_input: EagleDraftInput = model_worker_batch.spec_info
assert draft_input.is_draft_input()
verify_input: EagleVerifyInput = self.draft(model_worker_batch)
assert verify_input.is_verify_input()
model_worker_batch.spec_info = verify_input
batch_output = self.verify(model_worker_batch, draft_input.allocate_lens)
return batch_output
else:
# Target prefill
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
batch_output = self.target_worker.forward_batch_generation(
model_worker_batch
if self.hot_token_id is not None:
head = head.clone()
self.hot_token_id = self.hot_token_id.to(head.device)
head.data = head.data[self.hot_token_id]
# Share the embedding and lm_head
self.draft_runner.model.set_embed_and_head(embed, head)
def init_attention_backend(self):
# Create multi-step attn backends and cuda graph runners
self.has_prefill_wrapper_verify = False
self.draft_extend_attn_backend = None
draft_backend_factory = DraftBackendFactory(
self.server_args,
self.draft_runner,
self.topk,
self.speculative_num_steps,
)
# Initialize decode attention backend
self.draft_attn_backend = draft_backend_factory.create_decode_backend()
# Initialize draft extend attention backend (respects speculative_attention_mode setting)
self.draft_extend_attn_backend = (
draft_backend_factory.create_draft_extend_backend()
)
self.draft_runner.draft_attn_backend = self.draft_attn_backend
self.tree_mask_mode = TreeMaskMode.FULL_MASK
def init_cuda_graphs(self):
"""Capture cuda graphs."""
self.cuda_graph_runner = None
self.cuda_graph_runner_for_draft_extend = None
if self.server_args.disable_cuda_graph:
return
# Capture draft
if self.speculative_num_steps > 1:
tic = time.perf_counter()
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info(
f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
)
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info(
f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
)
# Draft prefill
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.LAST
batch_output.next_draft_input = self.forward_draft_extend(
model_worker_batch,
batch_output.logits_output.hidden_states,
batch_output.next_token_ids,
# Capture extend
if self.draft_extend_attn_backend:
tic = time.perf_counter()
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info(
f"Capture draft extend cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
)
self.cuda_graph_runner_for_draft_extend = EAGLEDraftExtendCudaGraphRunner(
self
)
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info(
f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
)
return batch_output
def draft(self, model_worker_batch: ModelWorkerBatch):
draft_input: EagleDraftInput = model_worker_batch.spec_info
......@@ -92,7 +243,7 @@ class EAGLEWorkerV2(EAGLEWorker):
self.req_to_token_pool,
model_worker_batch,
self.cuda_graph_runner,
self.draft_model_runner,
self.draft_runner,
self.topk,
self.speculative_num_steps,
)
......@@ -201,10 +352,11 @@ class EAGLEWorkerV2(EAGLEWorker):
spec_info.hidden_states = hidden_states
# Run forward
logits_output = self.draft_model_runner.model.forward(
logits_output = self.draft_runner.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
self._detect_nan_if_needed(logits_output)
if self.server_args.enable_nan_detection:
detect_nan(logits_output)
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
if self.hot_token_id is not None:
......@@ -233,10 +385,190 @@ class EAGLEWorkerV2(EAGLEWorker):
return parent_list, top_scores_index, draft_tokens
def draft_extend(self):
pass
def _draft_extend_for_prefill(
self,
batch: ModelWorkerBatch,
target_hidden_states: torch.Tensor,
next_token_ids: torch.Tensor,
):
"""
Run draft model extend to correctly fill the KV cache.
Args:
batch: The batch to run.
target_hidden_states: Hidden states from the target model forward
next_token_ids: Next token ids generated from the target forward.
"""
# Construct input_ids
pt = 0
for i, extend_len in enumerate(batch.extend_seq_lens):
input_ids = batch.input_ids[pt : pt + extend_len]
batch.input_ids[pt : pt + extend_len] = torch.cat(
(input_ids[1:], next_token_ids[i].reshape(1))
)
pt += extend_len
# Construct spec_info
next_draft_input = EagleDraftInput(
hidden_states=target_hidden_states,
verified_id=next_token_ids,
new_seq_lens=batch.seq_lens,
allocate_lens=batch.seq_lens,
)
batch.spec_info = next_draft_input
# Run forward
forward_batch = ForwardBatch.init_new(batch, self.draft_runner)
logits_output, _ = self.draft_runner.forward(forward_batch)
# Update spec_info for the next draft step
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
next_draft_input.topk_p, next_draft_input.topk_index = fast_topk(
probs, self.topk, dim=-1
)
next_draft_input.hidden_states = logits_output.hidden_states
return next_draft_input
def _draft_extend_for_decode(
self, batch: ModelWorkerBatch, batch_result: GenerationBatchResult
):
# Batch 2: Draft extend
draft_input = EagleDraftInput(
hidden_states=batch_result.logits_output.hidden_states,
)
select_index = (
torch.arange(len(batch.seq_lens), device=self.device)
* self.speculative_num_draft_tokens
+ batch_result.accept_lens
- 1
)
# Prepare for draft extend in a separate stream
with self.plan_stream_ctx:
forward_batch = draft_input.prepare_for_extend_to_fill_draft_kvcache(
batch,
batch_result.next_token_ids,
self.speculative_num_draft_tokens,
self.draft_runner,
)
if self.plan_stream:
torch.cuda.current_stream().wait_stream(self.plan_stream)
# Run draft extend batch in the main compute stream
draft_logits_output = self.draft_runner.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
# Reorganize the spec info for the next batch
draft_logits_output.next_token_logits = draft_logits_output.next_token_logits[
select_index
]
draft_logits_output.hidden_states = draft_logits_output.hidden_states[
select_index
]
probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1)
ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1)
ret_hidden_states = draft_logits_output.hidden_states
# Construct the return values
next_draft_input = batch_result.next_draft_input
(
next_draft_input.topk_p,
next_draft_input.topk_index,
next_draft_input.hidden_states,
) = (
ret_topk_p,
ret_topk_index,
ret_hidden_states,
)
class EAGLEWorkerV2(BaseSpecWorker):
def __init__(
self,
server_args: ServerArgs,
gpu_id: int,
tp_rank: int,
dp_rank: Optional[int],
moe_ep_rank: int,
nccl_port: int,
target_worker: TpModelWorker,
):
# Parse arguments
self.server_args = server_args
self.topk = server_args.speculative_eagle_topk
self.speculative_num_steps = server_args.speculative_num_steps
self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
self.enable_nan_detection = server_args.enable_nan_detection
self.gpu_id = gpu_id
self.device = server_args.device
self._target_worker = target_worker
self.page_size = server_args.page_size
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
)
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
target_worker.get_memory_pool()
)
# Override the context length of the draft model to be the same as the target model.
server_args.context_length = target_worker.model_runner.model_config.context_len
self._draft_worker = EagleDraftWorker(
server_args, gpu_id, tp_rank, dp_rank, moe_ep_rank, nccl_port, target_worker
)
# Some dummy tensors
self.num_new_pages_per_topk = torch.empty(
(), dtype=torch.int64, device=self.device
)
self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device)
self.plan_stream, self.plan_stream_ctx = _get_plan_stream(self.device)
@property
def target_worker(self):
return self._target_worker
@property
def draft_worker(self):
return self._draft_worker
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
if model_worker_batch.forward_mode.is_decode():
draft_input: EagleDraftInput = model_worker_batch.spec_info
assert draft_input.is_draft_input()
verify_input: EagleVerifyInput = self.draft_worker.draft(model_worker_batch)
assert verify_input.is_verify_input()
model_worker_batch.spec_info = verify_input
batch_output = self.verify(model_worker_batch, draft_input.allocate_lens)
self.draft_worker._draft_extend_for_decode(model_worker_batch, batch_output)
return batch_output
else:
# Target prefill
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
batch_output = self.target_worker.forward_batch_generation(
model_worker_batch
)
# Draft prefill
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.LAST
batch_output.next_draft_input = self.draft_worker._draft_extend_for_prefill(
model_worker_batch,
batch_output.logits_output.hidden_states,
batch_output.next_token_ids,
)
return batch_output
def verify(
self,
batch: ModelWorkerBatch,
pre_draft_allocate_lens: torch.Tensor,
cur_allocate_lens: torch.Tensor,
):
# Since batch.seq_lens is allocated in another stream, we need
# record_stream() to prevent pytorch gc and reuse the gpu memory
......@@ -284,7 +616,8 @@ class EAGLEWorkerV2(EAGLEWorker):
logits_output = forward_batch_output.logits_output
# Sample
self._detect_nan_if_needed(logits_output)
if self.enable_nan_detection:
detect_nan(logits_output)
(
predict,
accept_length,
......@@ -303,53 +636,11 @@ class EAGLEWorkerV2(EAGLEWorker):
self.speculative_num_draft_tokens,
)
# Batch 2: Draft extend
draft_input = EagleDraftInput(
hidden_states=logits_output.hidden_states,
)
select_index = (
torch.arange(len(batch.seq_lens), device=self.device)
* self.speculative_num_draft_tokens
+ accept_length
- 1
)
# Prepare for draft extend in a separate stream
with self.plan_stream_ctx:
forward_batch = draft_input.prepare_for_extend_to_fill_draft_kvcache(
batch,
predict,
self.speculative_num_draft_tokens,
self.draft_model_runner,
)
if self.plan_stream:
torch.cuda.current_stream().wait_stream(self.plan_stream)
# Run draft extend batch in the main compute stream
draft_logits_output = self.draft_model_runner.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
# Reorganize the spec info for the next batch
draft_logits_output.next_token_logits = draft_logits_output.next_token_logits[
select_index
]
draft_logits_output.hidden_states = draft_logits_output.hidden_states[
select_index
]
probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1)
ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1)
ret_hidden_states = draft_logits_output.hidden_states
# Construct the return values
# Construct the next draft input
next_draft_input = EagleDraftInput(
topk_p=ret_topk_p,
topk_index=ret_topk_index,
hidden_states=ret_hidden_states,
verified_id=verified_id,
new_seq_lens=new_seq_lens,
allocate_lens=pre_draft_allocate_lens,
allocate_lens=cur_allocate_lens,
verify_done=verify_done,
)
......@@ -359,52 +650,8 @@ class EAGLEWorkerV2(EAGLEWorker):
can_run_cuda_graph=can_run_cuda_graph,
next_draft_input=next_draft_input,
accept_lens=accept_length,
last_batch_allocate_lens=pre_draft_allocate_lens,
)
def forward_draft_extend(
self,
batch: ModelWorkerBatch,
target_hidden_states: torch.Tensor,
next_token_ids: torch.Tensor,
):
"""
Run draft model extend to correctly fill the KV cache.
Args:
batch: The batch to run.
target_hidden_states: Hidden states from the target model forward
next_token_ids: Next token ids generated from the target forward.
"""
# Construct input_ids
pt = 0
for i, extend_len in enumerate(batch.extend_seq_lens):
input_ids = batch.input_ids[pt : pt + extend_len]
batch.input_ids[pt : pt + extend_len] = torch.cat(
(input_ids[1:], next_token_ids[i].reshape(1))
)
pt += extend_len
# Construct spec_info
next_draft_input = EagleDraftInput(
hidden_states=target_hidden_states,
verified_id=next_token_ids,
new_seq_lens=batch.seq_lens,
allocate_lens=batch.seq_lens,
)
batch.spec_info = next_draft_input
# Run forward
forward_batch = ForwardBatch.init_new(batch, self.draft_model_runner)
logits_output, _ = self.draft_model_runner.forward(forward_batch)
# Update spec_info for the next draft step
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
next_draft_input.topk_p, next_draft_input.topk_index = fast_topk(
probs, self.topk, dim=-1
allocate_lens=cur_allocate_lens,
)
next_draft_input.hidden_states = logits_output.hidden_states
return next_draft_input
def move_accepted_tokens_to_target_kvcache(
self,
......@@ -449,32 +696,3 @@ class EAGLEWorkerV2(EAGLEWorker):
self.token_to_kv_pool_allocator.get_kvcache().move_kv_cache(
tgt_cache_loc, accepted_out_cache_loc
)
def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
if self.enable_nan_detection:
logits = logits_output.next_token_logits
if torch.any(torch.isnan(logits)):
logger.error("Detected errors during sampling! NaN in the logits.")
raise ValueError("Detected errors during sampling! NaN in the logits.")
def free_spec_dec_tokens_page_size_1(
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
req: Req,
allocate_len: int,
new_seq_len: int,
):
# FIXME(lsyin): move this function elsewhere
# free extra allocated tokens
if new_seq_len is None:
# True only for overlap eagle and the current batch is decode. This seq will be part of the decode, so the final iteration's allocation is not used (i.e. this case).
start_len = allocate_len - EagleDraftInput.ALLOC_LEN_PER_DECODE
else:
# True for 1) non-overlap; 2) overlap eagle and the current batch is prefill. This seq will not run extra iteration, so start_lens is passed in.
start_len = new_seq_len
indices_to_free = req_to_token_pool.req_to_token[req.req_pool_idx][
start_len:allocate_len
]
token_to_kv_pool_allocator.free(indices_to_free)
from __future__ import annotations
import logging
import os
import time
from contextlib import contextmanager
from typing import TYPE_CHECKING, List
import torch
import triton
import triton.language as tl
from huggingface_hub import snapshot_download
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.distributed.parallel_state import (
GroupCoordinator,
patch_tensor_parallel_group,
)
from sglang.srt.environ import envs
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.utils import is_cuda, is_hip
if TYPE_CHECKING:
from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.speculative.eagle_info import EagleVerifyInput
if is_cuda():
from sgl_kernel import fast_topk
elif is_hip():
from sgl_kernel import fast_topk
if TYPE_CHECKING:
from sglang.srt.speculative.eagle_info import EagleVerifyInput
logger = logging.getLogger(__name__)
......@@ -603,3 +615,29 @@ def generate_token_bitmask(
verify_input.grammar = grammar
return allocate_token_bitmask
def load_token_map(token_map_path: str) -> List[int]:
if not os.path.exists(token_map_path):
cache_dir = snapshot_download(
os.path.dirname(token_map_path),
ignore_patterns=["*.bin", "*.safetensors"],
)
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
hot_token_id = torch.load(token_map_path, weights_only=True)
return torch.tensor(hot_token_id, dtype=torch.int64)
@contextmanager
def draft_tp_context(tp_group: GroupCoordinator):
# Draft model doesn't use dp and has its own tp group.
# We disable mscclpp now because it doesn't support 2 comm groups.
with patch_tensor_parallel_group(tp_group):
yield
def detect_nan(logits_output: LogitsProcessorOutput):
logits = logits_output.next_token_logits
if torch.any(torch.isnan(logits)):
logger.error("Detected errors during sampling! NaN in the logits.")
raise ValueError("Detected errors during sampling! NaN in the logits.")
import logging
from contextlib import contextmanager
from typing import Optional
import torch
from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.eagle_worker import EAGLEWorker, load_token_map
from sglang.srt.speculative.eagle_worker import EAGLEWorker
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.speculative.spec_utils import draft_tp_context, load_token_map
from sglang.srt.utils import empty_context, get_bool_env_var, is_cuda
if is_cuda():
......@@ -18,14 +17,6 @@ logger = logging.getLogger(__name__)
SGLANG_RETURN_ORIGINAL_LOGPROB = get_bool_env_var("SGLANG_RETURN_ORIGINAL_LOGPROB")
@contextmanager
def draft_tp_context(tp_group: GroupCoordinator):
# Draft model doesn't use dp and has its own tp group.
# We disable mscclpp now because it doesn't support 2 comm groups.
with patch_tensor_parallel_group(tp_group):
yield
class StandaloneWorker(EAGLEWorker):
def __init__(
......@@ -51,7 +42,6 @@ class StandaloneWorker(EAGLEWorker):
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
)
self.padded_static_len = -1
# Override the context length of the draft model to be the same as the target model.
server_args.context_length = target_worker.model_runner.model_config.context_len
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment