Unverified Commit cde5a6e3 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Abstraction for spec worker and code cleanup (#11643)

parent 3e4c7da2
...@@ -1061,38 +1061,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1061,38 +1061,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
) )
return req_pool_indices return req_pool_indices
def allocate_for_eagle_v2(self):
from sglang.srt.speculative.eagle_info import EagleDraftInput
from sglang.srt.speculative.spec_utils import assign_req_to_token_pool
bs = self.batch_size()
assert self.spec_info.is_draft_input()
draft_input: EagleDraftInput = self.spec_info
# FIXME(lsyin): now implementation does not enable over-allocation
# Now seq_lens and allocate_lens are correct
self.maybe_wait_verify_done()
new_allocate_lens = self.seq_lens + EagleDraftInput.ALLOC_LEN_PER_DECODE
num_needed_tokens = (new_allocate_lens - draft_input.allocate_lens).sum().item()
out_cache_loc = alloc_token_slots(self.tree_cache, num_needed_tokens)
assign_req_to_token_pool[(bs,)](
self.req_pool_indices,
self.req_to_token_pool.req_to_token,
draft_input.allocate_lens,
new_allocate_lens,
out_cache_loc,
self.req_to_token_pool.req_to_token.shape[1],
next_power_of_2(bs),
)
draft_input.allocate_lens = new_allocate_lens
# FIXME(lsyin): remove seq_lens_sum calculation
self.seq_lens_cpu = self.seq_lens.cpu()
self.seq_lens_sum = self.seq_lens_cpu.sum().item()
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]): def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
self.encoder_lens_cpu = [] self.encoder_lens_cpu = []
self.encoder_cached = [] self.encoder_cached = []
...@@ -1522,8 +1490,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1522,8 +1490,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
bs = len(self.reqs) bs = len(self.reqs)
if self.is_v2_eagle: if self.is_v2_eagle:
# FIXME(lsyin): make this sync optional # TODO(spec-v2): all v2 spec should go through this path
self.allocate_for_eagle_v2() from sglang.srt.speculative.eagle_info import EagleDraftInput
draft_input: EagleDraftInput = self.spec_info
draft_input.prepare_for_decode(self)
if not self.spec_algorithm.is_none(): if not self.spec_algorithm.is_none():
# if spec decoding is used, the decode batch is prepared inside # if spec decoding is used, the decode batch is prepared inside
......
...@@ -215,10 +215,10 @@ class GenerationBatchResult: ...@@ -215,10 +215,10 @@ class GenerationBatchResult:
delay_sample_func: Optional[callable] = None delay_sample_func: Optional[callable] = None
future_indices: Optional[FutureIndices] = None future_indices: Optional[FutureIndices] = None
# FIXME(lsyin): maybe move to <BetterPlace> ? # FIXME(lsyin): maybe move to a better place?
# sync path: forward stream -> output processor # sync path: forward stream -> output processor
accept_lens: Optional[torch.Tensor] = None accept_lens: Optional[torch.Tensor] = None
last_batch_allocate_lens: Optional[torch.Tensor] = None allocate_lens: Optional[torch.Tensor] = None
# relay path: forward stream -> next step forward # relay path: forward stream -> next step forward
next_draft_input: Optional[EagleDraftInput] = None next_draft_input: Optional[EagleDraftInput] = None
...@@ -246,10 +246,8 @@ class GenerationBatchResult: ...@@ -246,10 +246,8 @@ class GenerationBatchResult:
if self.accept_lens is not None: if self.accept_lens is not None:
self.accept_lens = self.accept_lens.to("cpu", non_blocking=True) self.accept_lens = self.accept_lens.to("cpu", non_blocking=True)
if self.last_batch_allocate_lens is not None: if self.allocate_lens is not None:
self.last_batch_allocate_lens = self.last_batch_allocate_lens.to( self.allocate_lens = self.allocate_lens.to("cpu", non_blocking=True)
"cpu", non_blocking=True
)
self.copy_done.record() self.copy_done.record()
......
...@@ -42,23 +42,21 @@ class SchedulerOutputProcessorMixin: ...@@ -42,23 +42,21 @@ class SchedulerOutputProcessorMixin:
skip_stream_req = None skip_stream_req = None
if self.is_generation: if self.is_generation:
if result.copy_done is not None:
result.copy_done.synchronize()
( (
logits_output, logits_output,
next_token_ids, next_token_ids,
extend_input_len_per_req, extend_input_len_per_req,
extend_logprob_start_len_per_req, extend_logprob_start_len_per_req,
copy_done,
) = ( ) = (
result.logits_output, result.logits_output,
result.next_token_ids, result.next_token_ids,
result.extend_input_len_per_req, result.extend_input_len_per_req,
result.extend_logprob_start_len_per_req, result.extend_logprob_start_len_per_req,
result.copy_done,
) )
if copy_done is not None:
copy_done.synchronize()
# Move next_token_ids and logprobs to cpu # Move next_token_ids and logprobs to cpu
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
if batch.return_logprob: if batch.return_logprob:
...@@ -199,57 +197,52 @@ class SchedulerOutputProcessorMixin: ...@@ -199,57 +197,52 @@ class SchedulerOutputProcessorMixin:
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req) self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
def hacky_process_eagle_overlap_result( def _resolve_spec_overlap_token_ids(
self: Scheduler, result: GenerationBatchResult, batch: ScheduleBatch self: Scheduler, result: GenerationBatchResult, batch: ScheduleBatch
): ) -> List[List[int]]:
# TODO(lsyin): try use a copy stream to share SMs with forward """Resolve the padding next token ids for speculative decoding with overlap."""
# FIXME(lsyin): better organize this token free logic in eagle-overlap assert result.next_token_ids.is_cpu
last_batch_allocate_lens_cpu = result.last_batch_allocate_lens.tolist() assert result.accept_lens.is_cpu
accept_lens_cpu = result.accept_lens.tolist() assert result.allocate_lens.is_cpu
next_token_ids = result.next_token_ids.tolist() next_token_ids = result.next_token_ids.tolist()
accept_lens = result.accept_lens.tolist()
result.num_accepted_tokens = sum(accept_lens)
predict_tokens = [] predict_tokens = []
num_draft_tokens = self.draft_worker.speculative_num_draft_tokens stride = self.draft_worker.speculative_num_draft_tokens
for i, req in enumerate(batch.reqs): for i, req in enumerate(batch.reqs):
predict_tokens.append( predict_tokens.append(
next_token_ids[ next_token_ids[i * stride : i * stride + accept_lens[i]]
i * num_draft_tokens : i * num_draft_tokens + accept_lens_cpu[i]
]
) )
# FIXME(lsyin): move this update elsewhere
req.spec_verify_ct += 1 req.spec_verify_ct += 1
return last_batch_allocate_lens_cpu, accept_lens_cpu, predict_tokens return predict_tokens
def process_batch_result_decode( def process_batch_result_decode(
self: Scheduler, self: Scheduler,
batch: ScheduleBatch, batch: ScheduleBatch,
result: GenerationBatchResult, result: GenerationBatchResult,
): ):
logits_output, next_token_ids, can_run_cuda_graph, copy_done = ( if result.copy_done is not None:
result.copy_done.synchronize()
logits_output, next_token_ids, can_run_cuda_graph = (
result.logits_output, result.logits_output,
result.next_token_ids, result.next_token_ids,
result.can_run_cuda_graph, result.can_run_cuda_graph,
result.copy_done,
) )
self.num_generated_tokens += len(batch.reqs)
if copy_done is not None:
copy_done.synchronize()
if batch.spec_algorithm.is_none(): if batch.spec_algorithm.is_none():
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
if batch.return_logprob: if batch.return_logprob:
next_token_logprobs = logits_output.next_token_logprobs.tolist() next_token_logprobs = logits_output.next_token_logprobs.tolist()
elif batch.is_v2_eagle: elif batch.is_v2_eagle:
( next_token_ids = self._resolve_spec_overlap_token_ids(result, batch)
last_batch_allocate_lens_cpu, allocate_lens_list = result.allocate_lens.tolist()
accept_lens_cpu, accept_lens_list = result.accept_lens.tolist()
next_token_ids,
) = self.hacky_process_eagle_overlap_result(result, batch)
result.num_accepted_tokens = sum(accept_lens_cpu)
# FIXME(lsyin): we suppose we have already got the num_accepted_tokens in result self.num_generated_tokens += len(batch.reqs)
if not self.spec_algorithm.is_none(): if not self.spec_algorithm.is_none():
self.update_spec_metrics(batch.batch_size(), result.num_accepted_tokens) self.update_spec_metrics(batch.batch_size(), result.num_accepted_tokens)
...@@ -264,43 +257,38 @@ class SchedulerOutputProcessorMixin: ...@@ -264,43 +257,38 @@ class SchedulerOutputProcessorMixin:
continue continue
if self.enable_overlap and req.finished(): if self.enable_overlap and req.finished():
indices_to_free = None
if self.page_size == 1: if self.page_size == 1:
if batch.spec_algorithm.is_eagle(): if batch.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_worker_v2 import ( from sglang.srt.speculative.eagle_info import EagleDraftInput
free_spec_dec_tokens_page_size_1,
)
free_spec_dec_tokens_page_size_1( end_p = allocate_lens_list[i]
self.req_to_token_pool, start_p = end_p - EagleDraftInput.ALLOC_LEN_PER_DECODE
self.token_to_kv_pool_allocator, indices_to_free = self.req_to_token_pool.req_to_token[
req, req.req_pool_idx
last_batch_allocate_lens_cpu[i], ][start_p:end_p]
None,
)
else: else:
# Free the one extra delayed token # Free the one extra delayed token
self.token_to_kv_pool_allocator.free( indices_to_free = batch.out_cache_loc[i : i + 1]
batch.out_cache_loc[i : i + 1]
)
else: else:
if batch.spec_algorithm.is_eagle(): if batch.spec_algorithm.is_eagle():
# TODO(lsyin): support eagle with page_size > 1 # TODO(spec-v2): support eagle with page_size > 1
raise NotImplementedError() raise NotImplementedError()
else: else:
if ( if (
len(req.origin_input_ids) + len(req.output_ids) - 1 len(req.origin_input_ids) + len(req.output_ids) - 1
) % self.page_size == 0: ) % self.page_size == 0:
# Only free when the extra token is in a new page # Only free when the extra token is in a new page
self.token_to_kv_pool_allocator.free( indices_to_free = batch.out_cache_loc[i : i + 1]
batch.out_cache_loc[i : i + 1]
) if indices_to_free is not None:
self.token_to_kv_pool_allocator.free(indices_to_free)
continue continue
if batch.spec_algorithm.is_none(): if batch.spec_algorithm.is_none():
req.output_ids.append(next_token_id) req.output_ids.append(next_token_id)
elif batch.is_v2_eagle: elif batch.is_v2_eagle:
# FIXME(lsyin): non-overlap spec worker will solve the output_ids in speculative decoding # Only v2 eagle's output_ids are updated here.
# !!!unify the logic here!!!
req.output_ids.extend(next_token_id) req.output_ids.extend(next_token_id)
req.check_finished() req.check_finished()
...@@ -308,24 +296,13 @@ class SchedulerOutputProcessorMixin: ...@@ -308,24 +296,13 @@ class SchedulerOutputProcessorMixin:
if batch.is_v2_eagle and self.cur_batch.forward_mode.is_extend(): if batch.is_v2_eagle and self.cur_batch.forward_mode.is_extend():
# FIXME(lsyin): fix the messy logic here # FIXME(lsyin): fix the messy logic here
# 1) when not overlap (v2 impl), we free the extra tokens in the req # 1) when not overlap (v2 impl), we free the extra tokens in the req
# 2) when overlap and current batch is extend, we free the extra tokens in the req of the previous batch # 2) overlap eagle and the current batch is prefill. This seq will not run extra iteration.
from sglang.srt.speculative.eagle_worker_v2 import ( start_p = batch.seq_lens_cpu[i] + accept_lens_list[i]
free_spec_dec_tokens_page_size_1, end_p = allocate_lens_list[i]
) indices_to_free = self.req_to_token_pool.req_to_token[
req.req_pool_idx
new_seq_len = len(req.origin_input_ids) + len(req.output_ids) - 1 ][start_p:end_p]
# FIXME(lsyin): remove this assert self.token_to_kv_pool_allocator.free(indices_to_free)
assert new_seq_len == int(
batch.seq_lens_cpu[i] + accept_lens_cpu[i]
), f"{new_seq_len=} vs {batch.seq_lens_cpu[i] + accept_lens_cpu[i]=}"
free_spec_dec_tokens_page_size_1(
self.req_to_token_pool,
self.token_to_kv_pool_allocator,
req,
last_batch_allocate_lens_cpu[i],
new_seq_len,
)
if self.server_args.disaggregation_decode_enable_offload_kvcache: if self.server_args.disaggregation_decode_enable_offload_kvcache:
# Asynchronously offload KV cache; cache_finished_req will be called after Device->Host transfer completes # Asynchronously offload KV cache; cache_finished_req will be called after Device->Host transfer completes
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
import torch import torch
...@@ -54,7 +55,140 @@ if TYPE_CHECKING: ...@@ -54,7 +55,140 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TpModelWorker: class BaseTpWorker(ABC):
@abstractmethod
def forward_batch_generation(self, forward_batch: ForwardBatch):
pass
@property
@abstractmethod
def model_runner(self) -> ModelRunner:
pass
@property
def sliding_window_size(self) -> Optional[int]:
return self.model_runner.sliding_window_size
@property
def is_hybrid(self) -> bool:
return self.model_runner.is_hybrid is not None
def get_tokens_per_layer_info(self):
return (
self.model_runner.full_max_total_num_tokens,
self.model_runner.swa_max_total_num_tokens,
)
def get_pad_input_ids_func(self):
return getattr(self.model_runner.model, "pad_input_ids", None)
def get_tp_group(self):
return self.model_runner.tp_group
def get_attention_tp_group(self):
return self.model_runner.attention_tp_group
def get_attention_tp_cpu_group(self):
return getattr(self.model_runner.attention_tp_group, "cpu_group", None)
def get_memory_pool(self):
return (
self.model_runner.req_to_token_pool,
self.model_runner.token_to_kv_pool_allocator,
)
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
success, message = self.model_runner.update_weights_from_disk(
recv_req.model_path, recv_req.load_format
)
return success, message
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
success, message = self.model_runner.init_weights_update_group(
recv_req.master_address,
recv_req.master_port,
recv_req.rank_offset,
recv_req.world_size,
recv_req.group_name,
recv_req.backend,
)
return success, message
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
success, message = self.model_runner.destroy_weights_update_group(
recv_req.group_name,
)
return success, message
def init_weights_send_group_for_remote_instance(
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
):
success, message = (
self.model_runner.init_weights_send_group_for_remote_instance(
recv_req.master_address,
recv_req.ports,
recv_req.group_rank,
recv_req.world_size,
recv_req.group_name,
recv_req.backend,
)
)
return success, message
def send_weights_to_remote_instance(
self, recv_req: SendWeightsToRemoteInstanceReqInput
):
success, message = self.model_runner.send_weights_to_remote_instance(
recv_req.master_address,
recv_req.ports,
recv_req.group_name,
)
return success, message
def update_weights_from_distributed(
self, recv_req: UpdateWeightsFromDistributedReqInput
):
success, message = self.model_runner.update_weights_from_distributed(
recv_req.names, recv_req.dtypes, recv_req.shapes, recv_req.group_name
)
return success, message
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
monkey_patch_torch_reductions()
success, message = self.model_runner.update_weights_from_tensor(
named_tensors=MultiprocessingSerializer.deserialize(
recv_req.serialized_named_tensors[self.tp_rank]
),
load_format=recv_req.load_format,
)
return success, message
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.model_runner.get_weights_by_name(
recv_req.name, recv_req.truncate_size
)
return parameter
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
result = self.model_runner.load_lora_adapter(recv_req.to_ref())
return result
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
return result
def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
return self.model_runner.lora_manager.validate_lora_batch(lora_ids)
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output, _ = self.model_runner.forward(forward_batch)
embeddings = logits_output.embeddings
return embeddings
class TpModelWorker(BaseTpWorker):
"""A tensor parallel model worker.""" """A tensor parallel model worker."""
def __init__( def __init__(
...@@ -92,7 +226,7 @@ class TpModelWorker: ...@@ -92,7 +226,7 @@ class TpModelWorker:
is_draft_model=is_draft_worker, is_draft_model=is_draft_worker,
) )
self.model_runner = ModelRunner( self._model_runner = ModelRunner(
model_config=self.model_config, model_config=self.model_config,
mem_fraction_static=server_args.mem_fraction_static, mem_fraction_static=server_args.mem_fraction_static,
gpu_id=gpu_id, gpu_id=gpu_id,
...@@ -171,6 +305,10 @@ class TpModelWorker: ...@@ -171,6 +305,10 @@ class TpModelWorker:
self.enable_overlap = not server_args.disable_overlap_schedule self.enable_overlap = not server_args.disable_overlap_schedule
self.hicache_layer_transfer_counter = None self.hicache_layer_transfer_counter = None
@property
def model_runner(self) -> ModelRunner:
return self._model_runner
def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter): def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
self.hicache_layer_transfer_counter = counter self.hicache_layer_transfer_counter = counter
...@@ -193,38 +331,6 @@ class TpModelWorker: ...@@ -193,38 +331,6 @@ class TpModelWorker:
self.model_runner.token_to_kv_pool.size, self.model_runner.token_to_kv_pool.size,
) )
@property
def sliding_window_size(self) -> Optional[int]:
return self.model_runner.sliding_window_size
@property
def is_hybrid(self) -> bool:
return self.model_runner.is_hybrid is not None
def get_tokens_per_layer_info(self):
return (
self.model_runner.full_max_total_num_tokens,
self.model_runner.swa_max_total_num_tokens,
)
def get_pad_input_ids_func(self):
return getattr(self.model_runner.model, "pad_input_ids", None)
def get_tp_group(self):
return self.model_runner.tp_group
def get_attention_tp_group(self):
return self.model_runner.attention_tp_group
def get_attention_tp_cpu_group(self):
return getattr(self.model_runner.attention_tp_group, "cpu_group", None)
def get_memory_pool(self):
return (
self.model_runner.req_to_token_pool,
self.model_runner.token_to_kv_pool_allocator,
)
def forward_batch_generation( def forward_batch_generation(
self, self,
model_worker_batch: ModelWorkerBatch, model_worker_batch: ModelWorkerBatch,
...@@ -313,93 +419,3 @@ class TpModelWorker: ...@@ -313,93 +419,3 @@ class TpModelWorker:
pp_hidden_states_proxy_tensors=pp_proxy_tensors, pp_hidden_states_proxy_tensors=pp_proxy_tensors,
can_run_cuda_graph=can_run_cuda_graph, can_run_cuda_graph=can_run_cuda_graph,
) )
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output, _ = self.model_runner.forward(forward_batch)
embeddings = logits_output.embeddings
return embeddings
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
success, message = self.model_runner.update_weights_from_disk(
recv_req.model_path, recv_req.load_format
)
return success, message
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
success, message = self.model_runner.init_weights_update_group(
recv_req.master_address,
recv_req.master_port,
recv_req.rank_offset,
recv_req.world_size,
recv_req.group_name,
recv_req.backend,
)
return success, message
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
success, message = self.model_runner.destroy_weights_update_group(
recv_req.group_name,
)
return success, message
def init_weights_send_group_for_remote_instance(
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
):
success, message = (
self.model_runner.init_weights_send_group_for_remote_instance(
recv_req.master_address,
recv_req.ports,
recv_req.group_rank,
recv_req.world_size,
recv_req.group_name,
recv_req.backend,
)
)
return success, message
def send_weights_to_remote_instance(
self, recv_req: SendWeightsToRemoteInstanceReqInput
):
success, message = self.model_runner.send_weights_to_remote_instance(
recv_req.master_address,
recv_req.ports,
recv_req.group_name,
)
return success, message
def update_weights_from_distributed(
self, recv_req: UpdateWeightsFromDistributedReqInput
):
success, message = self.model_runner.update_weights_from_distributed(
recv_req.names, recv_req.dtypes, recv_req.shapes, recv_req.group_name
)
return success, message
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
monkey_patch_torch_reductions()
success, message = self.model_runner.update_weights_from_tensor(
named_tensors=MultiprocessingSerializer.deserialize(
recv_req.serialized_named_tensors[self.tp_rank]
),
load_format=recv_req.load_format,
)
return success, message
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.model_runner.get_weights_by_name(
recv_req.name, recv_req.truncate_size
)
return parameter
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
result = self.model_runner.load_lora_adapter(recv_req.to_ref())
return result
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
return result
def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
return self.model_runner.lora_manager.validate_lora_batch(lora_ids)
...@@ -53,7 +53,6 @@ from sglang.srt.utils import ( ...@@ -53,7 +53,6 @@ from sglang.srt.utils import (
empty_context, empty_context,
get_available_gpu_memory, get_available_gpu_memory,
get_bool_env_var, get_bool_env_var,
get_device_memory_capacity,
is_hip, is_hip,
log_info_on_rank0, log_info_on_rank0,
require_attn_tp_gather, require_attn_tp_gather,
...@@ -274,7 +273,6 @@ class CudaGraphRunner: ...@@ -274,7 +273,6 @@ class CudaGraphRunner:
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
) )
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
self.encoder_len_fill_value = 0 self.encoder_len_fill_value = 0
self.seq_lens_cpu = torch.full( self.seq_lens_cpu = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
......
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from sglang.srt.managers.tp_worker import TpModelWorker
class BaseDraftWorker(ABC):
@abstractmethod
def draft():
pass
@abstractmethod
def draft_extend():
pass
class BaseSpecWorker(ABC):
@property
@abstractmethod
def target_worker(self) -> TpModelWorker:
pass
@property
@abstractmethod
def draft_worker(self) -> BaseDraftWorker:
pass
...@@ -40,7 +40,11 @@ class EAGLEDraftCudaGraphRunner: ...@@ -40,7 +40,11 @@ class EAGLEDraftCudaGraphRunner:
def __init__(self, eagle_worker: EAGLEWorker): def __init__(self, eagle_worker: EAGLEWorker):
# Parse args # Parse args
self.eagle_worker = eagle_worker self.eagle_worker = eagle_worker
self.model_runner = model_runner = eagle_worker.model_runner if not hasattr(eagle_worker, "model_runner"):
# V2: EagleDraftWorker
self.model_runner = model_runner = eagle_worker.draft_runner
else:
self.model_runner = model_runner = eagle_worker.model_runner
self.graphs = {} self.graphs = {}
self.output_buffers = {} self.output_buffers = {}
self.enable_torch_compile = model_runner.server_args.enable_torch_compile self.enable_torch_compile = model_runner.server_args.enable_torch_compile
......
...@@ -38,7 +38,12 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -38,7 +38,12 @@ class EAGLEDraftExtendCudaGraphRunner:
def __init__(self, eagle_worker: EAGLEWorker): def __init__(self, eagle_worker: EAGLEWorker):
# Parse args # Parse args
self.eagle_worker = eagle_worker self.eagle_worker = eagle_worker
self.model_runner = model_runner = eagle_worker.model_runner if not hasattr(eagle_worker, "model_runner"):
# V2: EagleDraftWorker
self.model_runner = model_runner = eagle_worker.draft_runner
else:
self.model_runner = model_runner = eagle_worker.model_runner
self.graphs = {} self.graphs = {}
self.output_buffers = {} self.output_buffers = {}
self.enable_torch_compile = model_runner.server_args.enable_torch_compile self.enable_torch_compile = model_runner.server_args.enable_torch_compile
...@@ -285,7 +290,7 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -285,7 +290,7 @@ class EAGLEDraftExtendCudaGraphRunner:
output_cache_loc_backup = forward_batch.out_cache_loc output_cache_loc_backup = forward_batch.out_cache_loc
hidden_states_backup = forward_batch.spec_info.hidden_states hidden_states_backup = forward_batch.spec_info.hidden_states
ret = self.eagle_worker.draft_model_runner.model.forward( ret = self.model_runner.model.forward(
forward_batch.input_ids, forward_batch.input_ids,
forward_batch.positions, forward_batch.positions,
forward_batch, forward_batch,
......
...@@ -574,6 +574,9 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): ...@@ -574,6 +574,9 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
@dataclass @dataclass
class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
# Constant: alloc length per decode step
ALLOC_LEN_PER_DECODE: ClassVar[int] = None
# The inputs for decode # The inputs for decode
# shape: (b, topk) # shape: (b, topk)
topk_p: torch.Tensor = None topk_p: torch.Tensor = None
...@@ -609,9 +612,6 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): ...@@ -609,9 +612,6 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
new_seq_lens: Optional[torch.Tensor] = None new_seq_lens: Optional[torch.Tensor] = None
verify_done: Optional[torch.cuda.Event] = None verify_done: Optional[torch.cuda.Event] = None
# FIXME(lsyin): remove this hack
ALLOC_LEN_PER_DECODE: ClassVar[int] = None
def __post_init__(self): def __post_init__(self):
super().__init__(SpecInputType.EAGLE_DRAFT) super().__init__(SpecInputType.EAGLE_DRAFT)
......
...@@ -9,7 +9,8 @@ import triton ...@@ -9,7 +9,8 @@ import triton
import triton.language as tl import triton.language as tl
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch
from sglang.srt.mem_cache.common import alloc_token_slots
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode, CaptureHiddenMode,
...@@ -72,6 +73,34 @@ def assign_draft_cache_locs_page_size_1( ...@@ -72,6 +73,34 @@ def assign_draft_cache_locs_page_size_1(
@dataclass @dataclass
class EagleDraftInputV2Mixin: class EagleDraftInputV2Mixin:
def prepare_for_decode(self: EagleDraftInput, batch: ScheduleBatch):
from sglang.srt.speculative.spec_utils import assign_req_to_token_pool
bs = batch.batch_size()
# TODO(lsyin): implement over-allocation
# Now seq_lens and allocate_lens are correct
batch.maybe_wait_verify_done()
new_allocate_lens = batch.seq_lens + self.ALLOC_LEN_PER_DECODE
num_needed_tokens = (new_allocate_lens - self.allocate_lens).sum().item()
out_cache_loc = alloc_token_slots(batch.tree_cache, num_needed_tokens)
assign_req_to_token_pool[(bs,)](
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
self.allocate_lens,
new_allocate_lens,
out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1],
next_power_of_2(bs),
)
self.allocate_lens = new_allocate_lens
# FIXME(lsyin): make this sync optional
batch.seq_lens_cpu = batch.seq_lens.cpu()
batch.seq_lens_sum = batch.seq_lens_cpu.sum().item()
def prepare_for_v2_draft( def prepare_for_v2_draft(
self: EagleDraftInput, self: EagleDraftInput,
req_to_token_pool: ReqToTokenPool, req_to_token_pool: ReqToTokenPool,
......
import logging import logging
import os
import time import time
from contextlib import contextmanager
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from huggingface_hub import snapshot_download
from sglang.srt.distributed import ( from sglang.srt.distributed import get_tp_group
GroupCoordinator,
get_tp_group,
patch_tensor_parallel_group,
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.schedule_batch import ScheduleBatch
...@@ -47,15 +40,17 @@ from sglang.srt.speculative.eagle_utils import ( ...@@ -47,15 +40,17 @@ from sglang.srt.speculative.eagle_utils import (
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.speculative.spec_utils import ( from sglang.srt.speculative.spec_utils import (
assign_draft_cache_locs, assign_draft_cache_locs,
detect_nan,
draft_tp_context,
fast_topk, fast_topk,
generate_token_bitmask, generate_token_bitmask,
load_token_map,
select_top_k_tokens, select_top_k_tokens,
) )
from sglang.srt.utils import ( from sglang.srt.utils import (
empty_context, empty_context,
get_available_gpu_memory, get_available_gpu_memory,
get_bool_env_var, get_bool_env_var,
is_blackwell,
is_cuda, is_cuda,
next_power_of_2, next_power_of_2,
) )
...@@ -67,14 +62,6 @@ logger = logging.getLogger(__name__) ...@@ -67,14 +62,6 @@ logger = logging.getLogger(__name__)
SGLANG_RETURN_ORIGINAL_LOGPROB = get_bool_env_var("SGLANG_RETURN_ORIGINAL_LOGPROB") SGLANG_RETURN_ORIGINAL_LOGPROB = get_bool_env_var("SGLANG_RETURN_ORIGINAL_LOGPROB")
@contextmanager
def draft_tp_context(tp_group: GroupCoordinator):
# Draft model doesn't use dp and has its own tp group.
# We disable mscclpp now because it doesn't support 2 comm groups.
with patch_tensor_parallel_group(tp_group):
yield
class EAGLEWorker(TpModelWorker): class EAGLEWorker(TpModelWorker):
def __init__( def __init__(
...@@ -100,7 +87,6 @@ class EAGLEWorker(TpModelWorker): ...@@ -100,7 +87,6 @@ class EAGLEWorker(TpModelWorker):
self.speculative_algorithm = SpeculativeAlgorithm.from_string( self.speculative_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm server_args.speculative_algorithm
) )
self.padded_static_len = -1
# Override the context length of the draft model to be the same as the target model. # Override the context length of the draft model to be the same as the target model.
server_args.context_length = target_worker.model_runner.model_config.context_len server_args.context_length = target_worker.model_runner.model_config.context_len
...@@ -612,7 +598,8 @@ class EAGLEWorker(TpModelWorker): ...@@ -612,7 +598,8 @@ class EAGLEWorker(TpModelWorker):
logits_output, _ = self.draft_model_runner.forward( logits_output, _ = self.draft_model_runner.forward(
forward_batch, skip_attn_backend_init=True forward_batch, skip_attn_backend_init=True
) )
self._detect_nan_if_needed(logits_output) if self.server_args.enable_nan_detection:
detect_nan(logits_output)
probs = torch.softmax(logits_output.next_token_logits, dim=-1) probs = torch.softmax(logits_output.next_token_logits, dim=-1)
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
if self.hot_token_id is not None: if self.hot_token_id is not None:
...@@ -680,7 +667,9 @@ class EAGLEWorker(TpModelWorker): ...@@ -680,7 +667,9 @@ class EAGLEWorker(TpModelWorker):
# and will be applied to produce wrong results # and will be applied to produce wrong results
batch.sampling_info.vocab_mask = None batch.sampling_info.vocab_mask = None
self._detect_nan_if_needed(logits_output) if self.enable_nan_detection:
detect_nan(logits_output)
spec_info.hidden_states = logits_output.hidden_states spec_info.hidden_states = logits_output.hidden_states
res: EagleVerifyOutput = spec_info.verify( res: EagleVerifyOutput = spec_info.verify(
batch, batch,
...@@ -833,7 +822,8 @@ class EAGLEWorker(TpModelWorker): ...@@ -833,7 +822,8 @@ class EAGLEWorker(TpModelWorker):
) )
forward_batch.return_logprob = False forward_batch.return_logprob = False
logits_output, _ = self.draft_model_runner.forward(forward_batch) logits_output, _ = self.draft_model_runner.forward(forward_batch)
self._detect_nan_if_needed(logits_output) if self.enable_nan_detection:
detect_nan(logits_output)
assert isinstance(forward_batch.spec_info, EagleDraftInput) assert isinstance(forward_batch.spec_info, EagleDraftInput)
assert forward_batch.spec_info is batch.spec_info assert forward_batch.spec_info is batch.spec_info
self.capture_for_decode(logits_output, forward_batch.spec_info) self.capture_for_decode(logits_output, forward_batch.spec_info)
...@@ -928,7 +918,8 @@ class EAGLEWorker(TpModelWorker): ...@@ -928,7 +918,8 @@ class EAGLEWorker(TpModelWorker):
) )
self.capture_for_decode(logits_output, forward_batch.spec_info) self.capture_for_decode(logits_output, forward_batch.spec_info)
self._detect_nan_if_needed(logits_output) if self.enable_nan_detection:
detect_nan(logits_output)
# Restore backup. # Restore backup.
# This is because `seq_lens` can be modified in `prepare_extend_after_decode` # This is because `seq_lens` can be modified in `prepare_extend_after_decode`
...@@ -948,24 +939,6 @@ class EAGLEWorker(TpModelWorker): ...@@ -948,24 +939,6 @@ class EAGLEWorker(TpModelWorker):
draft_input.topk_p, draft_input.topk_index = fast_topk(probs, self.topk, dim=-1) draft_input.topk_p, draft_input.topk_index = fast_topk(probs, self.topk, dim=-1)
draft_input.hidden_states = logits_output.hidden_states draft_input.hidden_states = logits_output.hidden_states
def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
if self.enable_nan_detection:
logits = logits_output.next_token_logits
if torch.any(torch.isnan(logits)):
logger.error("Detected errors during sampling! NaN in the logits.")
raise ValueError("Detected errors during sampling! NaN in the logits.")
def load_token_map(token_map_path: str) -> List[int]:
if not os.path.exists(token_map_path):
cache_dir = snapshot_download(
os.path.dirname(token_map_path),
ignore_patterns=["*.bin", "*.safetensors"],
)
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
hot_token_id = torch.load(token_map_path, weights_only=True)
return torch.tensor(hot_token_id, dtype=torch.int64)
@torch.compile(dynamic=True) @torch.compile(dynamic=True)
def get_last_loc_large_page_size_top_k_1( def get_last_loc_large_page_size_top_k_1(
......
from __future__ import annotations from __future__ import annotations
import logging import logging
import os
import time import time
from contextlib import contextmanager
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, List
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from huggingface_hub import snapshot_download
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.distributed.parallel_state import (
GroupCoordinator,
patch_tensor_parallel_group,
)
from sglang.srt.environ import envs from sglang.srt.environ import envs
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.schedule_batch import Req
from sglang.srt.utils import is_cuda, is_hip from sglang.srt.utils import is_cuda, is_hip
if TYPE_CHECKING:
from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.speculative.eagle_info import EagleVerifyInput
if is_cuda(): if is_cuda():
from sgl_kernel import fast_topk from sgl_kernel import fast_topk
elif is_hip(): elif is_hip():
from sgl_kernel import fast_topk from sgl_kernel import fast_topk
if TYPE_CHECKING:
from sglang.srt.speculative.eagle_info import EagleVerifyInput
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -603,3 +615,29 @@ def generate_token_bitmask( ...@@ -603,3 +615,29 @@ def generate_token_bitmask(
verify_input.grammar = grammar verify_input.grammar = grammar
return allocate_token_bitmask return allocate_token_bitmask
def load_token_map(token_map_path: str) -> List[int]:
if not os.path.exists(token_map_path):
cache_dir = snapshot_download(
os.path.dirname(token_map_path),
ignore_patterns=["*.bin", "*.safetensors"],
)
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
hot_token_id = torch.load(token_map_path, weights_only=True)
return torch.tensor(hot_token_id, dtype=torch.int64)
@contextmanager
def draft_tp_context(tp_group: GroupCoordinator):
# Draft model doesn't use dp and has its own tp group.
# We disable mscclpp now because it doesn't support 2 comm groups.
with patch_tensor_parallel_group(tp_group):
yield
def detect_nan(logits_output: LogitsProcessorOutput):
logits = logits_output.next_token_logits
if torch.any(torch.isnan(logits)):
logger.error("Detected errors during sampling! NaN in the logits.")
raise ValueError("Detected errors during sampling! NaN in the logits.")
import logging import logging
from contextlib import contextmanager
from typing import Optional from typing import Optional
import torch import torch
from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.eagle_worker import EAGLEWorker, load_token_map from sglang.srt.speculative.eagle_worker import EAGLEWorker
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.speculative.spec_utils import draft_tp_context, load_token_map
from sglang.srt.utils import empty_context, get_bool_env_var, is_cuda from sglang.srt.utils import empty_context, get_bool_env_var, is_cuda
if is_cuda(): if is_cuda():
...@@ -18,14 +17,6 @@ logger = logging.getLogger(__name__) ...@@ -18,14 +17,6 @@ logger = logging.getLogger(__name__)
SGLANG_RETURN_ORIGINAL_LOGPROB = get_bool_env_var("SGLANG_RETURN_ORIGINAL_LOGPROB") SGLANG_RETURN_ORIGINAL_LOGPROB = get_bool_env_var("SGLANG_RETURN_ORIGINAL_LOGPROB")
@contextmanager
def draft_tp_context(tp_group: GroupCoordinator):
# Draft model doesn't use dp and has its own tp group.
# We disable mscclpp now because it doesn't support 2 comm groups.
with patch_tensor_parallel_group(tp_group):
yield
class StandaloneWorker(EAGLEWorker): class StandaloneWorker(EAGLEWorker):
def __init__( def __init__(
...@@ -51,7 +42,6 @@ class StandaloneWorker(EAGLEWorker): ...@@ -51,7 +42,6 @@ class StandaloneWorker(EAGLEWorker):
self.speculative_algorithm = SpeculativeAlgorithm.from_string( self.speculative_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm server_args.speculative_algorithm
) )
self.padded_static_len = -1
# Override the context length of the draft model to be the same as the target model. # Override the context length of the draft model to be the same as the target model.
server_args.context_length = target_worker.model_runner.model_config.context_len server_args.context_length = target_worker.model_runner.model_config.context_len
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment