Unverified Commit ad20b795 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Eagle speculative decoding part 3: small modifications to the general scheduler (#2709)


Co-authored-by: default avatarkavioyu <kavioyu@tencent.com>
parent 9183c23e
...@@ -61,10 +61,10 @@ For example, there are two H20 nodes, each with 8 GPUs. The first node's IP is ` ...@@ -61,10 +61,10 @@ For example, there are two H20 nodes, each with 8 GPUs. The first node's IP is `
```bash ```bash
# node 1 # node 1
python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --nccl-init 10.0.0.1:5000 --nnodes 2 --node-rank 0 --trust-remote-code python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 10.0.0.1:5000 --nnodes 2 --node-rank 0 --trust-remote-code
# node 2 # node 2
python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --nccl-init 10.0.0.1:5000 --nnodes 2 --node-rank 1 --trust-remote-code python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 10.0.0.1:5000 --nnodes 2 --node-rank 1 --trust-remote-code
``` ```
If you have two H100 nodes, the usage is similar to the aforementioned H20. If you have two H100 nodes, the usage is similar to the aforementioned H20.
......
...@@ -63,6 +63,7 @@ from sglang.srt.model_executor.model_runner import ModelRunner ...@@ -63,6 +63,7 @@ from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server import _set_envs_and_config from sglang.srt.server import _set_envs_and_config
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import configure_logger, kill_process_tree, suppress_other_loggers from sglang.srt.utils import configure_logger, kill_process_tree, suppress_other_loggers
...@@ -214,6 +215,7 @@ def extend(reqs, model_runner): ...@@ -214,6 +215,7 @@ def extend(reqs, model_runner):
tree_cache=None, tree_cache=None,
model_config=model_runner.model_config, model_config=model_runner.model_config,
enable_overlap=False, enable_overlap=False,
spec_algorithm=SpeculativeAlgorithm.NONE,
) )
batch.prepare_for_extend() batch.prepare_for_extend()
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
......
...@@ -26,7 +26,7 @@ class AttentionBackend(ABC): ...@@ -26,7 +26,7 @@ class AttentionBackend(ABC):
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
self, self,
bs: int, bs: int,
num_token: int, num_tokens: int,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
......
...@@ -227,7 +227,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -227,7 +227,7 @@ class FlashInferAttnBackend(AttentionBackend):
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
self, self,
bs: int, bs: int,
num_token: int, num_tokens: int,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
...@@ -243,9 +243,11 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -243,9 +243,11 @@ class FlashInferAttnBackend(AttentionBackend):
"NHD", "NHD",
use_cuda_graph=True, use_cuda_graph=True,
use_tensor_cores=self.decode_use_tensor_cores, use_tensor_cores=self.decode_use_tensor_cores,
paged_kv_indptr_buffer=self.kv_indptr[i][: num_token + 1], paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i], paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
paged_kv_last_page_len_buffer=self.kv_last_page_len[:num_token], paged_kv_last_page_len_buffer=self.kv_last_page_len[
:num_tokens
],
) )
) )
seq_lens_sum = seq_lens.sum().item() seq_lens_sum = seq_lens.sum().item()
......
...@@ -81,7 +81,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -81,7 +81,7 @@ class TritonAttnBackend(AttentionBackend):
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
self, self,
bs: int, bs: int,
num_token: int, num_tokens: int,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
......
...@@ -575,8 +575,8 @@ class ScheduleBatch: ...@@ -575,8 +575,8 @@ class ScheduleBatch:
device: str = "cuda" device: str = "cuda"
# Speculative decoding # Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[SpecInfo] = None spec_info: Optional[SpecInfo] = None
spec_algorithm: Optional[SpeculativeAlgorithm] = None
@classmethod @classmethod
def init_new( def init_new(
...@@ -587,7 +587,7 @@ class ScheduleBatch: ...@@ -587,7 +587,7 @@ class ScheduleBatch:
tree_cache: BasePrefixCache, tree_cache: BasePrefixCache,
model_config: ModelConfig, model_config: ModelConfig,
enable_overlap: bool, enable_overlap: bool,
speculative_algorithm: Optional[SpeculativeAlgorithm] = None, spec_algorithm: SpeculativeAlgorithm,
): ):
return cls( return cls(
reqs=reqs, reqs=reqs,
...@@ -600,7 +600,7 @@ class ScheduleBatch: ...@@ -600,7 +600,7 @@ class ScheduleBatch:
has_stream=any(req.stream for req in reqs), has_stream=any(req.stream for req in reqs),
has_grammar=any(req.grammar for req in reqs), has_grammar=any(req.grammar for req in reqs),
device=req_to_token_pool.device, device=req_to_token_pool.device,
spec_algorithm=speculative_algorithm, spec_algorithm=spec_algorithm,
) )
def batch_size(self): def batch_size(self):
...@@ -1010,6 +1010,8 @@ class ScheduleBatch: ...@@ -1010,6 +1010,8 @@ class ScheduleBatch:
def prepare_for_decode(self): def prepare_for_decode(self):
self.forward_mode = ForwardMode.DECODE self.forward_mode = ForwardMode.DECODE
if self.spec_algorithm.is_eagle():
return
self.input_ids = self.output_ids self.input_ids = self.output_ids
self.output_ids = None self.output_ids = None
...@@ -1172,6 +1174,7 @@ class ScheduleBatch: ...@@ -1172,6 +1174,7 @@ class ScheduleBatch:
out_cache_loc=self.out_cache_loc, out_cache_loc=self.out_cache_loc,
return_logprob=self.return_logprob, return_logprob=self.return_logprob,
decoding_reqs=self.decoding_reqs, decoding_reqs=self.decoding_reqs,
spec_algorithm=self.spec_algorithm,
) )
def __str__(self): def __str__(self):
...@@ -1232,8 +1235,8 @@ class ModelWorkerBatch: ...@@ -1232,8 +1235,8 @@ class ModelWorkerBatch:
input_embeds: Optional[torch.tensor] = None input_embeds: Optional[torch.tensor] = None
# Speculative decoding # Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[SpecInfo] = None spec_info: Optional[SpecInfo] = None
spec_algorithm: Optional[SpeculativeAlgorithm] = None
@triton.jit @triton.jit
......
...@@ -76,6 +76,7 @@ from sglang.srt.mem_cache.radix_cache import RadixCache ...@@ -76,6 +76,7 @@ from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import ( from sglang.srt.utils import (
broadcast_pyobj, broadcast_pyobj,
configure_logger, configure_logger,
...@@ -116,6 +117,14 @@ class Scheduler: ...@@ -116,6 +117,14 @@ class Scheduler:
self.enable_overlap = not server_args.disable_overlap_schedule self.enable_overlap = not server_args.disable_overlap_schedule
self.skip_tokenizer_init = server_args.skip_tokenizer_init self.skip_tokenizer_init = server_args.skip_tokenizer_init
self.enable_metrics = server_args.enable_metrics self.enable_metrics = server_args.enable_metrics
self.spec_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
)
self.decode_mem_cache_buf_multiplier = (
self.server_args.speculative_num_draft_tokens
if not self.spec_algorithm.is_none()
else 1
)
# Init inter-process communication # Init inter-process communication
context = zmq.Context(2) context = zmq.Context(2)
...@@ -199,6 +208,21 @@ class Scheduler: ...@@ -199,6 +208,21 @@ class Scheduler:
nccl_port=port_args.nccl_port, nccl_port=port_args.nccl_port,
) )
# Launch worker for speculative decoding if need
if self.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_worker import EAGLEWorker
self.draft_worker = EAGLEWorker(
gpu_id=gpu_id,
tp_rank=tp_rank,
server_args=server_args,
nccl_port=port_args.nccl_port,
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
else:
self.draft_worker = None
# Get token and memory info from the model worker # Get token and memory info from the model worker
( (
self.max_total_num_tokens, self.max_total_num_tokens,
...@@ -855,6 +879,7 @@ class Scheduler: ...@@ -855,6 +879,7 @@ class Scheduler:
self.tree_cache, self.tree_cache,
self.model_config, self.model_config,
self.enable_overlap, self.enable_overlap,
self.spec_algorithm,
) )
new_batch.prepare_for_extend() new_batch.prepare_for_extend()
...@@ -888,11 +913,15 @@ class Scheduler: ...@@ -888,11 +913,15 @@ class Scheduler:
return None return None
# Check if decode out of memory # Check if decode out of memory
if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10): if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
test_retract and batch.batch_size() > 10
):
old_ratio = self.new_token_ratio old_ratio = self.new_token_ratio
retracted_reqs, new_token_ratio = batch.retract_decode() retracted_reqs, new_token_ratio = batch.retract_decode()
self.new_token_ratio = new_token_ratio self.new_token_ratio = new_token_ratio
if self.draft_worker:
self.draft_worker.finish_request(retracted_reqs)
logger.info( logger.info(
"Decode out of memory happened. " "Decode out of memory happened. "
...@@ -926,11 +955,17 @@ class Scheduler: ...@@ -926,11 +955,17 @@ class Scheduler:
self.forward_ct += 1 self.forward_ct += 1
if self.is_generation: if self.is_generation:
model_worker_batch = batch.get_model_worker_batch()
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0: if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
logits_output, next_token_ids = self.tp_worker.forward_batch_generation( if self.spec_algorithm.is_none():
model_worker_batch model_worker_batch = batch.get_model_worker_batch()
) logits_output, next_token_ids = (
self.tp_worker.forward_batch_generation(model_worker_batch)
)
else:
logits_output, next_token_ids, model_worker_batch, spec_info = (
self.draft_worker.forward_batch_speculative_generation(batch)
)
batch.spec_info = spec_info
elif batch.forward_mode.is_idle(): elif batch.forward_mode.is_idle():
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
self.tp_worker.forward_batch_idle(model_worker_batch) self.tp_worker.forward_batch_idle(model_worker_batch)
...@@ -1077,7 +1112,10 @@ class Scheduler: ...@@ -1077,7 +1112,10 @@ class Scheduler:
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1]) self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
continue continue
req.output_ids.append(next_token_id) if batch.spec_algorithm.is_none():
# speculative worker will solve the output_ids in speculative decoding
req.output_ids.append(next_token_id)
req.check_finished() req.check_finished()
if req.finished(): if req.finished():
...@@ -1252,6 +1290,9 @@ class Scheduler: ...@@ -1252,6 +1290,9 @@ class Scheduler:
# If not stream, we still want to output some tokens to get the benefit of incremental decoding. # If not stream, we still want to output some tokens to get the benefit of incremental decoding.
or (not req.stream and len(req.output_ids) % 50 == 0) or (not req.stream and len(req.output_ids) % 50 == 0)
): ):
if self.draft_worker and req.finished():
self.draft_worker.finish_request(req)
rids.append(req.rid) rids.append(req.rid)
finished_reasons.append( finished_reasons.append(
req.finished_reason.to_json() if req.finished_reason else None req.finished_reason.to_json() if req.finished_reason else None
...@@ -1383,6 +1424,7 @@ class Scheduler: ...@@ -1383,6 +1424,7 @@ class Scheduler:
self.tree_cache, self.tree_cache,
self.model_config, self.model_config,
self.enable_overlap, self.enable_overlap,
self.spec_algorithm,
) )
idle_batch.prepare_for_idle() idle_batch.prepare_for_idle()
return idle_batch return idle_batch
......
...@@ -45,13 +45,18 @@ class TpModelWorker: ...@@ -45,13 +45,18 @@ class TpModelWorker:
tp_rank: int, tp_rank: int,
dp_rank: Optional[int], dp_rank: Optional[int],
nccl_port: int, nccl_port: int,
is_draft_worker: bool = False,
): ):
# Parse args # Parse args
self.tp_rank = tp_rank self.tp_rank = tp_rank
# Init model and tokenizer # Init model and tokenizer
self.model_config = ModelConfig( self.model_config = ModelConfig(
server_args.model_path, (
server_args.model_path
if not is_draft_worker
else server_args.speculative_draft_model_path
),
trust_remote_code=server_args.trust_remote_code, trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision, revision=server_args.revision,
context_length=server_args.context_length, context_length=server_args.context_length,
...@@ -68,6 +73,7 @@ class TpModelWorker: ...@@ -68,6 +73,7 @@ class TpModelWorker:
tp_size=server_args.tp_size, tp_size=server_args.tp_size,
nccl_port=nccl_port, nccl_port=nccl_port,
server_args=server_args, server_args=server_args,
is_draft_worker=is_draft_worker,
) )
if server_args.skip_tokenizer_init: if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None self.tokenizer = self.processor = None
......
...@@ -33,7 +33,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -33,7 +33,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch, ForwardBatch,
ForwardMode, ForwardMode,
) )
from sglang.srt.utils import maybe_torch_compile, monkey_patch_vllm_all_gather from sglang.srt.utils import monkey_patch_vllm_all_gather
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
...@@ -106,11 +106,6 @@ def set_torch_compile_config(): ...@@ -106,11 +106,6 @@ def set_torch_compile_config():
torch._dynamo.config.cache_size_limit = 1024 torch._dynamo.config.cache_size_limit = 1024
@maybe_torch_compile(dynamic=True)
def clamp_position(seq_lens):
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
class CudaGraphRunner: class CudaGraphRunner:
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile.""" """A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
...@@ -157,6 +152,17 @@ class CudaGraphRunner: ...@@ -157,6 +152,17 @@ class CudaGraphRunner:
self.capture_forward_mode = ForwardMode.DECODE self.capture_forward_mode = ForwardMode.DECODE
self.num_tokens_per_bs = 1 self.num_tokens_per_bs = 1
if model_runner.spec_algorithm.is_eagle():
if self.model_runner.is_draft_worker:
self.num_tokens_per_bs = (
self.model_runner.server_args.speculative_eagle_topk
)
else:
self.capture_forward_mode = ForwardMode.TARGET_VERIFY
self.num_tokens_per_bs = (
self.model_runner.server_args.speculative_num_draft_tokens
)
self.compile_bs = ( self.compile_bs = (
[ [
bs bs
...@@ -192,6 +198,13 @@ class CudaGraphRunner: ...@@ -192,6 +198,13 @@ class CudaGraphRunner:
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32) self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32)
# Speculative_inference
if model_runner.spec_algorithm.is_eagle():
self.hidden_states = torch.zeros(
(self.max_num_token, self.model_runner.model_config.hidden_size),
dtype=self.model_runner.dtype,
)
if self.is_encoder_decoder: if self.is_encoder_decoder:
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch # NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
self.encoder_lens = torch.full( self.encoder_lens = torch.full(
...@@ -234,9 +247,6 @@ class CudaGraphRunner: ...@@ -234,9 +247,6 @@ class CudaGraphRunner:
self.model_runner.model.capture_mode = False self.model_runner.model.capture_mode = False
def can_run(self, forward_batch: ForwardBatch): def can_run(self, forward_batch: ForwardBatch):
if not forward_batch.forward_mode.is_cuda_graph():
return False
if self.enable_dp_attention: if self.enable_dp_attention:
min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max( min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max(
forward_batch.global_num_tokens forward_batch.global_num_tokens
...@@ -291,21 +301,18 @@ class CudaGraphRunner: ...@@ -291,21 +301,18 @@ class CudaGraphRunner:
def capture_one_batch_size(self, bs: int, forward: Callable): def capture_one_batch_size(self, bs: int, forward: Callable):
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
stream = self.stream stream = self.stream
num_token = bs * self.num_tokens_per_bs num_tokens = bs * self.num_tokens_per_bs
# Common inputs # Common inputs
input_ids = self.input_ids[:num_token] input_ids = self.input_ids[:num_tokens]
req_pool_indices = self.req_pool_indices[:bs] req_pool_indices = self.req_pool_indices[:bs]
seq_lens = self.seq_lens[:bs] seq_lens = self.seq_lens[:bs]
out_cache_loc = self.out_cache_loc[:num_token] out_cache_loc = self.out_cache_loc[:num_tokens]
positions = self.positions[:num_token] positions = self.positions[:num_tokens]
if self.is_encoder_decoder: if self.is_encoder_decoder:
encoder_lens = self.encoder_lens[:bs] encoder_lens = self.encoder_lens[:bs]
else: else:
encoder_lens = None encoder_lens = None
seq_lens_sum = seq_lens.sum().item()
mrope_positions = self.mrope_positions[:, :bs] mrope_positions = self.mrope_positions[:, :bs]
if self.enable_dp_attention: if self.enable_dp_attention:
...@@ -325,20 +332,22 @@ class CudaGraphRunner: ...@@ -325,20 +332,22 @@ class CudaGraphRunner:
token_to_kv_pool=self.model_runner.token_to_kv_pool, token_to_kv_pool=self.model_runner.token_to_kv_pool,
attn_backend=self.model_runner.attn_backend, attn_backend=self.model_runner.attn_backend,
out_cache_loc=out_cache_loc, out_cache_loc=out_cache_loc,
seq_lens_sum=seq_lens_sum, seq_lens_sum=seq_lens.sum(),
encoder_lens=encoder_lens, encoder_lens=encoder_lens,
return_logprob=False, return_logprob=False,
top_logprobs_nums=[0] * num_token, top_logprobs_nums=[0] * bs,
positions=positions, positions=positions,
global_num_tokens=global_num_tokens, global_num_tokens=global_num_tokens,
mrope_positions=mrope_positions, mrope_positions=mrope_positions,
gathered_buffer=gathered_buffer, gathered_buffer=gathered_buffer,
spec_algorithm=self.model_runner.spec_algorithm,
spec_info=self.get_spec_info(num_tokens, positions),
) )
# Attention backend # Attention backend
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph( self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
bs, bs,
num_token, num_tokens,
req_pool_indices, req_pool_indices,
seq_lens, seq_lens,
encoder_lens, encoder_lens,
...@@ -394,14 +403,16 @@ class CudaGraphRunner: ...@@ -394,14 +403,16 @@ class CudaGraphRunner:
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc) self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
positions = clamp_position(forward_batch.seq_lens) self.positions[:raw_num_token].copy_(forward_batch.positions)
self.positions[:raw_num_token].copy_(positions)
if self.is_encoder_decoder: if self.is_encoder_decoder:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
if forward_batch.mrope_positions is not None: if forward_batch.mrope_positions is not None:
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions) self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
if hasattr(forward_batch.spec_info, "hidden_states"):
self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states
# Attention backend # Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
bs, bs,
...@@ -424,3 +435,36 @@ class CudaGraphRunner: ...@@ -424,3 +435,36 @@ class CudaGraphRunner:
), ),
) )
return logits_output return logits_output
def get_spec_info(self, num_tokens: int, positions: torch.Tensor):
spec_info = None
if self.model_runner.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_utils import (
EAGLEDraftInput,
EagleVerifyInput,
)
if self.model_runner.is_draft_worker:
spec_info = EAGLEDraftInput()
spec_info.hidden_states = self.hidden_states[:num_tokens]
spec_info.positions = positions
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
spec_info.init(self.model_runner.server_args)
else:
spec_info = EagleVerifyInput(
None,
None,
None,
None,
None,
None,
self.model_runner.server_args.speculative_num_draft_tokens,
)
spec_info.custom_mask = torch.zeros(
(num_tokens * self.model_runner.model_config.context_len),
dtype=torch.bool,
device="cuda",
)
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
return spec_info
...@@ -38,6 +38,7 @@ import triton ...@@ -38,6 +38,7 @@ import triton
import triton.language as tl import triton.language as tl
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import maybe_torch_compile
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.attention import AttentionBackend from sglang.srt.layers.attention import AttentionBackend
...@@ -276,10 +277,21 @@ class ForwardBatch: ...@@ -276,10 +277,21 @@ class ForwardBatch:
) )
if ret.forward_mode.is_idle(): if ret.forward_mode.is_idle():
ret.positions = torch.empty((0,), device=device)
return ret return ret
# Override the positions with spec_info
if (
ret.spec_info is not None
and getattr(ret.spec_info, "positions", None) is not None
):
ret.positions = ret.spec_info.positions
# Init position information # Init position information
if not ret.forward_mode.is_decode(): if ret.forward_mode.is_decode():
if ret.positions is None:
ret.positions = clamp_position(batch.seq_lens)
else:
ret.extend_seq_lens = torch.tensor( ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32 batch.extend_seq_lens, dtype=torch.int32
).to(device, non_blocking=True) ).to(device, non_blocking=True)
...@@ -288,13 +300,15 @@ class ForwardBatch: ...@@ -288,13 +300,15 @@ class ForwardBatch:
).to(device, non_blocking=True) ).to(device, non_blocking=True)
if model_runner.server_args.attention_backend != "torch_native": if model_runner.server_args.attention_backend != "torch_native":
ret.extend_num_tokens = batch.extend_num_tokens ret.extend_num_tokens = batch.extend_num_tokens
ret.positions, ret.extend_start_loc = compute_position_triton( positions, ret.extend_start_loc = compute_position_triton(
ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
) )
else: else:
ret.positions, ret.extend_start_loc = compute_position_torch( positions, ret.extend_start_loc = compute_position_torch(
ret.extend_prefix_lens, ret.extend_seq_lens ret.extend_prefix_lens, ret.extend_seq_lens
) )
if ret.positions is None:
ret.positions = positions
ret.extend_prefix_lens_cpu = batch.extend_prefix_lens ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
ret.extend_seq_lens_cpu = batch.extend_seq_lens ret.extend_seq_lens_cpu = batch.extend_seq_lens
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
...@@ -383,6 +397,11 @@ def compute_position_torch( ...@@ -383,6 +397,11 @@ def compute_position_torch(
return positions.to(torch.int64), extend_start_loc return positions.to(torch.int64), extend_start_loc
@maybe_torch_compile(dynamic=True)
def clamp_position(seq_lens):
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
class CaptureHiddenMode(IntEnum): class CaptureHiddenMode(IntEnum):
NULL = auto() NULL = auto()
FULL = auto() FULL = auto()
......
...@@ -49,6 +49,7 @@ from sglang.srt.mem_cache.memory_pool import ( ...@@ -49,6 +49,7 @@ from sglang.srt.mem_cache.memory_pool import (
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader import get_model from sglang.srt.model_loader import get_model
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import ( from sglang.srt.utils import (
enable_show_time_cost, enable_show_time_cost,
get_available_gpu_memory, get_available_gpu_memory,
...@@ -74,6 +75,7 @@ class ModelRunner: ...@@ -74,6 +75,7 @@ class ModelRunner:
tp_size: int, tp_size: int,
nccl_port: int, nccl_port: int,
server_args: ServerArgs, server_args: ServerArgs,
is_draft_worker: bool = False,
): ):
# Parse args # Parse args
self.model_config = model_config self.model_config = model_config
...@@ -84,8 +86,12 @@ class ModelRunner: ...@@ -84,8 +86,12 @@ class ModelRunner:
self.tp_size = tp_size self.tp_size = tp_size
self.dist_port = nccl_port self.dist_port = nccl_port
self.server_args = server_args self.server_args = server_args
self.is_draft_worker = is_draft_worker
self.is_generation = model_config.is_generation self.is_generation = model_config.is_generation
self.is_multimodal = model_config.is_multimodal self.is_multimodal = model_config.is_multimodal
self.spec_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
)
# Model-specific adjustment # Model-specific adjustment
if ( if (
...@@ -205,14 +211,18 @@ class ModelRunner: ...@@ -205,14 +211,18 @@ class ModelRunner:
else: else:
dist_init_method = f"tcp://127.0.0.1:{self.dist_port}" dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce) set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
init_distributed_environment(
backend=backend, if not self.is_draft_worker:
world_size=self.tp_size, # Only initilzie the distributed environment on the target model worker.
rank=self.tp_rank, init_distributed_environment(
local_rank=self.gpu_id, backend=backend,
distributed_init_method=dist_init_method, world_size=self.tp_size,
) rank=self.tp_rank,
initialize_model_parallel(tensor_model_parallel_size=self.tp_size) local_rank=self.gpu_id,
distributed_init_method=dist_init_method,
)
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
min_per_gpu_memory = get_available_gpu_memory( min_per_gpu_memory = get_available_gpu_memory(
self.device, self.gpu_id, distributed=self.tp_size > 1 self.device, self.gpu_id, distributed=self.tp_size > 1
) )
...@@ -407,7 +417,6 @@ class ModelRunner: ...@@ -407,7 +417,6 @@ class ModelRunner:
target_dtype = ( target_dtype = (
dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype) dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
) )
current_dtype = self.dtype if isinstance(self.dtype, str) else self.dtype
assert ( assert (
self._model_update_group is not None self._model_update_group is not None
...@@ -506,6 +515,28 @@ class ModelRunner: ...@@ -506,6 +515,28 @@ class ModelRunner:
) )
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory) self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
if max_num_reqs is None:
max_num_reqs = min(
max(
int(
self.max_total_num_tokens / self.model_config.context_len * 512
),
2048,
),
4096,
)
if not self.spec_algorithm.is_none():
if self.is_draft_worker:
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
else:
self.server_args.draft_runner_cache_size = (
self.max_total_num_tokens
+ max_num_reqs * self.server_args.speculative_num_steps
+ 100
)
if max_total_tokens is not None: if max_total_tokens is not None:
if max_total_tokens > self.max_total_num_tokens: if max_total_tokens > self.max_total_num_tokens:
logging.warning( logging.warning(
...@@ -520,17 +551,6 @@ class ModelRunner: ...@@ -520,17 +551,6 @@ class ModelRunner:
"Not enough memory. Please try to increase --mem-fraction-static." "Not enough memory. Please try to increase --mem-fraction-static."
) )
if max_num_reqs is None:
max_num_reqs = min(
max(
int(
self.max_total_num_tokens / self.model_config.context_len * 512
),
2048,
),
4096,
)
self.req_to_token_pool = ReqToTokenPool( self.req_to_token_pool = ReqToTokenPool(
size=max_num_reqs + 1, size=max_num_reqs + 1,
max_context_len=self.model_config.context_len + 4, max_context_len=self.model_config.context_len + 4,
...@@ -650,10 +670,6 @@ class ModelRunner: ...@@ -650,10 +670,6 @@ class ModelRunner:
tensor_parallel(self.model, device_mesh) tensor_parallel(self.model, device_mesh)
def forward_decode(self, forward_batch: ForwardBatch): def forward_decode(self, forward_batch: ForwardBatch):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
return self.cuda_graph_runner.replay(forward_batch)
forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)
self.attn_backend.init_forward_metadata(forward_batch) self.attn_backend.init_forward_metadata(forward_batch)
return self.model.forward( return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch forward_batch.input_ids, forward_batch.positions, forward_batch
...@@ -683,14 +699,18 @@ class ModelRunner: ...@@ -683,14 +699,18 @@ class ModelRunner:
) )
def forward_idle(self, forward_batch: ForwardBatch): def forward_idle(self, forward_batch: ForwardBatch):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
return self.cuda_graph_runner.replay(forward_batch)
return self.model.forward( return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch forward_batch.input_ids, forward_batch.positions, forward_batch
) )
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput: def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
if (
forward_batch.forward_mode.is_cuda_graph()
and self.cuda_graph_runner
and self.cuda_graph_runner.can_run(forward_batch)
):
return self.cuda_graph_runner.replay(forward_batch)
if forward_batch.forward_mode.is_decode(): if forward_batch.forward_mode.is_decode():
return self.forward_decode(forward_batch) return self.forward_decode(forward_batch)
elif forward_batch.forward_mode.is_extend(): elif forward_batch.forward_mode.is_extend():
......
...@@ -23,6 +23,7 @@ from typing import List, Optional ...@@ -23,6 +23,7 @@ from typing import List, Optional
import torch import torch
from sglang.srt.hf_transformers_utils import check_gguf_file from sglang.srt.hf_transformers_utils import check_gguf_file
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import ( from sglang.srt.utils import (
get_amdgpu_memory_capacity, get_amdgpu_memory_capacity,
get_hpu_memory_capacity, get_hpu_memory_capacity,
...@@ -247,6 +248,17 @@ class ServerArgs: ...@@ -247,6 +248,17 @@ class ServerArgs:
"Overlap scheduler is disabled." "Overlap scheduler is disabled."
) )
# Speculative Decoding
if self.speculative_algorithm == "EAGLE":
self.prefill_only_one_req = True
self.disable_cuda_graph_padding = True
self.disable_radix_cache = True
self.disable_overlap_schedule = True
self.chunked_prefill_size = -1
logger.info(
"The radix cache, chunked prefill, and overlap scheduler are disabled because of using eagle speculative decoding."
)
# GGUF # GGUF
if ( if (
self.load_format == "auto" or self.load_format == "gguf" self.load_format == "auto" or self.load_format == "gguf"
......
...@@ -2,8 +2,12 @@ from enum import IntEnum, auto ...@@ -2,8 +2,12 @@ from enum import IntEnum, auto
class SpeculativeAlgorithm(IntEnum): class SpeculativeAlgorithm(IntEnum):
NONE = auto()
EAGLE = auto() EAGLE = auto()
def is_none(self):
return self == SpeculativeAlgorithm.NONE
def is_eagle(self): def is_eagle(self):
return self == SpeculativeAlgorithm.EAGLE return self == SpeculativeAlgorithm.EAGLE
...@@ -11,6 +15,7 @@ class SpeculativeAlgorithm(IntEnum): ...@@ -11,6 +15,7 @@ class SpeculativeAlgorithm(IntEnum):
def from_string(name: str): def from_string(name: str):
name_map = { name_map = {
"EAGLE": SpeculativeAlgorithm.EAGLE, "EAGLE": SpeculativeAlgorithm.EAGLE,
None: SpeculativeAlgorithm.NONE,
} }
return name_map[name] return name_map[name]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment