Unverified Commit b0524c37 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Eagle speculative decoding part 2: Fix cuda graph + DP attention hanging (#2684)


Co-authored-by: default avataryukavio <kavioyu@gmail.com>
parent 6c42fa22
...@@ -92,7 +92,7 @@ jobs: ...@@ -92,7 +92,7 @@ jobs:
python3 test_data_parallelism.py python3 test_data_parallelism.py
- name: Evaluate MLA accuracy (TP=2) - name: Evaluate MLA accuracy (TP=2)
timeout-minutes: 20 timeout-minutes: 10
run: | run: |
cd test/srt cd test/srt
python3 test_mla.py python3 test_mla.py
......
...@@ -146,7 +146,10 @@ class LogitsProcessor(nn.Module): ...@@ -146,7 +146,10 @@ class LogitsProcessor(nn.Module):
# Compute logits # Compute logits
last_logits = self._get_logits(last_hidden, lm_head) last_logits = self._get_logits(last_hidden, lm_head)
if not logits_metadata.extend_return_logprob: if (
not logits_metadata.extend_return_logprob
or logits_metadata.capture_hidden_mode.need_capture()
):
# Decode mode or extend mode without return_logprob. # Decode mode or extend mode without return_logprob.
return LogitsProcessorOutput( return LogitsProcessorOutput(
next_token_logits=last_logits, next_token_logits=last_logits,
......
from __future__ import annotations
# Copyright 2023-2024 SGLang Team # Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -29,7 +31,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch ...@@ -29,7 +31,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
import dataclasses import dataclasses
import logging import logging
from typing import List, Optional, Set, Tuple, Union from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -47,6 +49,10 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo ...@@ -47,6 +49,10 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
if TYPE_CHECKING:
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
# Put some global args for easy access # Put some global args for easy access
...@@ -565,9 +571,13 @@ class ScheduleBatch: ...@@ -565,9 +571,13 @@ class ScheduleBatch:
# Has grammar # Has grammar
has_grammar: bool = False has_grammar: bool = False
# device # Device
device: str = "cuda" device: str = "cuda"
# Speculative decoding
spec_info: Optional[SpecInfo] = None
spec_algorithm: Optional[SpeculativeAlgorithm] = None
@classmethod @classmethod
def init_new( def init_new(
cls, cls,
...@@ -577,6 +587,7 @@ class ScheduleBatch: ...@@ -577,6 +587,7 @@ class ScheduleBatch:
tree_cache: BasePrefixCache, tree_cache: BasePrefixCache,
model_config: ModelConfig, model_config: ModelConfig,
enable_overlap: bool, enable_overlap: bool,
speculative_algorithm: Optional[SpeculativeAlgorithm] = None,
): ):
return cls( return cls(
reqs=reqs, reqs=reqs,
...@@ -589,6 +600,7 @@ class ScheduleBatch: ...@@ -589,6 +600,7 @@ class ScheduleBatch:
has_stream=any(req.stream for req in reqs), has_stream=any(req.stream for req in reqs),
has_grammar=any(req.grammar for req in reqs), has_grammar=any(req.grammar for req in reqs),
device=req_to_token_pool.device, device=req_to_token_pool.device,
spec_algorithm=speculative_algorithm,
) )
def batch_size(self): def batch_size(self):
...@@ -1103,6 +1115,9 @@ class ScheduleBatch: ...@@ -1103,6 +1115,9 @@ class ScheduleBatch:
self.has_stream |= other.has_stream self.has_stream |= other.has_stream
self.has_grammar |= other.has_grammar self.has_grammar |= other.has_grammar
if self.spec_info:
self.spec_info.merge_batch(other.spec_info)
def get_model_worker_batch(self): def get_model_worker_batch(self):
if self.forward_mode.is_decode() or self.forward_mode.is_idle(): if self.forward_mode.is_decode() or self.forward_mode.is_idle():
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
...@@ -1144,6 +1159,8 @@ class ScheduleBatch: ...@@ -1144,6 +1159,8 @@ class ScheduleBatch:
lora_paths=[req.lora_path for req in self.reqs], lora_paths=[req.lora_path for req in self.reqs],
sampling_info=self.sampling_info, sampling_info=self.sampling_info,
input_embeds=self.input_embeds, input_embeds=self.input_embeds,
spec_algorithm=self.spec_algorithm,
spec_info=self.spec_info,
) )
def copy(self): def copy(self):
...@@ -1214,6 +1231,10 @@ class ModelWorkerBatch: ...@@ -1214,6 +1231,10 @@ class ModelWorkerBatch:
# The input Embeds # The input Embeds
input_embeds: Optional[torch.tensor] = None input_embeds: Optional[torch.tensor] = None
# Speculative decoding
spec_info: Optional[SpecInfo] = None
spec_algorithm: Optional[SpeculativeAlgorithm] = None
@triton.jit @triton.jit
def write_req_to_token_pool_triton( def write_req_to_token_pool_triton(
......
...@@ -150,12 +150,18 @@ class TpModelWorker: ...@@ -150,12 +150,18 @@ class TpModelWorker:
self, self,
model_worker_batch: ModelWorkerBatch, model_worker_batch: ModelWorkerBatch,
launch_done: Optional[threading.Event] = None, launch_done: Optional[threading.Event] = None,
skip_sample: bool = False,
): ):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch) logits_output = self.model_runner.forward(forward_batch)
if launch_done: if launch_done:
launch_done.set() launch_done.set()
next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
if skip_sample:
next_token_ids = None
else:
next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
return logits_output, next_token_ids return logits_output, next_token_ids
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch): def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
......
...@@ -375,9 +375,7 @@ class CudaGraphRunner: ...@@ -375,9 +375,7 @@ class CudaGraphRunner:
def replay(self, forward_batch: ForwardBatch): def replay(self, forward_batch: ForwardBatch):
assert forward_batch.out_cache_loc is not None assert forward_batch.out_cache_loc is not None
raw_bs = forward_batch.batch_size raw_bs = forward_batch.batch_size
# In normal decoding case, raw_bs == raw_num_token raw_num_token = raw_bs * self.num_tokens_per_bs
# But in speculative decoding, raw_num_token is raw_bs * self.num_tokens_per_bs
raw_num_token = forward_batch.input_ids.numel()
# Pad # Pad
if self.enable_dp_attention: if self.enable_dp_attention:
......
...@@ -96,7 +96,11 @@ class ForwardMode(IntEnum): ...@@ -96,7 +96,11 @@ class ForwardMode(IntEnum):
return self == ForwardMode.DRAFT_EXTEND return self == ForwardMode.DRAFT_EXTEND
def is_cuda_graph(self): def is_cuda_graph(self):
return self == ForwardMode.DECODE or self == ForwardMode.TARGET_VERIFY return (
self == ForwardMode.DECODE
or self == ForwardMode.TARGET_VERIFY
or self == ForwardMode.IDLE
)
def is_dummy_first(self): def is_dummy_first(self):
return self == ForwardMode.DUMMY_FIRST return self == ForwardMode.DUMMY_FIRST
...@@ -161,15 +165,15 @@ class ForwardBatch: ...@@ -161,15 +165,15 @@ class ForwardBatch:
token_to_kv_pool: BaseTokenToKVPool = None token_to_kv_pool: BaseTokenToKVPool = None
attn_backend: AttentionBackend = None attn_backend: AttentionBackend = None
# Speculative decoding
spec_info: SpecInfo = None
spec_algorithm: SpeculativeAlgorithm = None
# For DP attention # For DP attention
global_num_tokens: Optional[List[int]] = None global_num_tokens: Optional[List[int]] = None
gathered_buffer: Optional[torch.Tensor] = None gathered_buffer: Optional[torch.Tensor] = None
can_run_dp_cuda_graph: bool = False can_run_dp_cuda_graph: bool = False
# Speculative decoding
spec_info: SpecInfo = None
spec_algorithm: SpeculativeAlgorithm = None
# For Qwen2-VL # For Qwen2-VL
mrope_positions: torch.Tensor = None mrope_positions: torch.Tensor = None
...@@ -258,6 +262,8 @@ class ForwardBatch: ...@@ -258,6 +262,8 @@ class ForwardBatch:
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph, can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
lora_paths=batch.lora_paths, lora_paths=batch.lora_paths,
sampling_info=batch.sampling_info, sampling_info=batch.sampling_info,
spec_algorithm=batch.spec_algorithm,
spec_info=batch.spec_info,
input_embeds=batch.input_embeds, input_embeds=batch.input_embeds,
) )
......
...@@ -108,14 +108,6 @@ class ServerArgs: ...@@ -108,14 +108,6 @@ class ServerArgs:
# Model override args in JSON # Model override args in JSON
json_model_override_args: str = "{}" json_model_override_args: str = "{}"
# Double Sparsity
enable_double_sparsity: bool = False
ds_channel_config_path: str = None
ds_heavy_channel_num: int = 32
ds_heavy_token_num: int = 256
ds_heavy_channel_type: str = "qk"
ds_sparse_decode_threshold: int = 4096
# LoRA # LoRA
lora_paths: Optional[List[str]] = None lora_paths: Optional[List[str]] = None
max_loras_per_batch: int = 8 max_loras_per_batch: int = 8
...@@ -125,6 +117,21 @@ class ServerArgs: ...@@ -125,6 +117,21 @@ class ServerArgs:
sampling_backend: Optional[str] = None sampling_backend: Optional[str] = None
grammar_backend: Optional[str] = "outlines" grammar_backend: Optional[str] = "outlines"
# Speculative decoding
speculative_draft_model_path: Optional[str] = None
speculative_algorithm: Optional[str] = None
speculative_num_steps: int = 5
speculative_num_draft_tokens: int = 64
speculative_eagle_topk: int = 8
# Double Sparsity
enable_double_sparsity: bool = False
ds_channel_config_path: str = None
ds_heavy_channel_num: int = 32
ds_heavy_token_num: int = 256
ds_heavy_channel_type: str = "qk"
ds_sparse_decode_threshold: int = 4096
# Optimization/debug options # Optimization/debug options
disable_radix_cache: bool = False disable_radix_cache: bool = False
disable_jump_forward: bool = False disable_jump_forward: bool = False
...@@ -602,43 +609,6 @@ class ServerArgs: ...@@ -602,43 +609,6 @@ class ServerArgs:
default=ServerArgs.json_model_override_args, default=ServerArgs.json_model_override_args,
) )
# Double Sparsity
parser.add_argument(
"--enable-double-sparsity",
action="store_true",
help="Enable double sparsity attention",
)
parser.add_argument(
"--ds-channel-config-path",
type=str,
default=ServerArgs.ds_channel_config_path,
help="The path of the double sparsity channel config",
)
parser.add_argument(
"--ds-heavy-channel-num",
type=int,
default=ServerArgs.ds_heavy_channel_num,
help="The number of heavy channels in double sparsity attention",
)
parser.add_argument(
"--ds-heavy-token-num",
type=int,
default=ServerArgs.ds_heavy_token_num,
help="The number of heavy tokens in double sparsity attention",
)
parser.add_argument(
"--ds-heavy-channel-type",
type=str,
default=ServerArgs.ds_heavy_channel_type,
help="The type of heavy channels in double sparsity attention",
)
parser.add_argument(
"--ds-sparse-decode-threshold",
type=int,
default=ServerArgs.ds_sparse_decode_threshold,
help="The type of heavy channels in double sparsity attention",
)
# LoRA # LoRA
parser.add_argument( parser.add_argument(
"--lora-paths", "--lora-paths",
...@@ -678,6 +648,75 @@ class ServerArgs: ...@@ -678,6 +648,75 @@ class ServerArgs:
help="Choose the backend for grammar-guided decoding.", help="Choose the backend for grammar-guided decoding.",
) )
# Speculative decoding
parser.add_argument(
"--speculative-algorithm",
type=str,
choices=["EAGLE"],
help="Speculative algorithm.",
)
parser.add_argument(
"--speculative-draft-model-path",
type=str,
help="The path of the draft model weights. This can be a local folder or a Hugging Face repo ID.",
)
parser.add_argument(
"--speculative-num-steps",
type=int,
help="The number of steps sampled from draft model in Speculative Decoding.",
default=ServerArgs.speculative_num_steps,
)
parser.add_argument(
"--speculative-num-draft-tokens",
type=int,
help="The number of token sampled from draft model in Speculative Decoding.",
default=ServerArgs.speculative_num_draft_tokens,
)
parser.add_argument(
"--speculative-eagle-topk",
type=int,
help="The number of token sampled from draft model in eagle2 each step.",
choices=[1, 2, 4, 8],
default=ServerArgs.speculative_eagle_topk,
)
# Double Sparsity
parser.add_argument(
"--enable-double-sparsity",
action="store_true",
help="Enable double sparsity attention",
)
parser.add_argument(
"--ds-channel-config-path",
type=str,
default=ServerArgs.ds_channel_config_path,
help="The path of the double sparsity channel config",
)
parser.add_argument(
"--ds-heavy-channel-num",
type=int,
default=ServerArgs.ds_heavy_channel_num,
help="The number of heavy channels in double sparsity attention",
)
parser.add_argument(
"--ds-heavy-token-num",
type=int,
default=ServerArgs.ds_heavy_token_num,
help="The number of heavy tokens in double sparsity attention",
)
parser.add_argument(
"--ds-heavy-channel-type",
type=str,
default=ServerArgs.ds_heavy_channel_type,
help="The type of heavy channels in double sparsity attention",
)
parser.add_argument(
"--ds-sparse-decode-threshold",
type=int,
default=ServerArgs.ds_sparse_decode_threshold,
help="The type of heavy channels in double sparsity attention",
)
# Optimization/debug options # Optimization/debug options
parser.add_argument( parser.add_argument(
"--disable-radix-cache", "--disable-radix-cache",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment