Unverified Commit b0524c37 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Eagle speculative decoding part 2: Fix cuda graph + DP attention hanging (#2684)


Co-authored-by: default avataryukavio <kavioyu@gmail.com>
parent 6c42fa22
......@@ -92,7 +92,7 @@ jobs:
python3 test_data_parallelism.py
- name: Evaluate MLA accuracy (TP=2)
timeout-minutes: 20
timeout-minutes: 10
run: |
cd test/srt
python3 test_mla.py
......
......@@ -146,7 +146,10 @@ class LogitsProcessor(nn.Module):
# Compute logits
last_logits = self._get_logits(last_hidden, lm_head)
if not logits_metadata.extend_return_logprob:
if (
not logits_metadata.extend_return_logprob
or logits_metadata.capture_hidden_mode.need_capture()
):
# Decode mode or extend mode without return_logprob.
return LogitsProcessorOutput(
next_token_logits=last_logits,
......
from __future__ import annotations
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -29,7 +31,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
import dataclasses
import logging
from typing import List, Optional, Set, Tuple, Union
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
import numpy as np
import torch
......@@ -47,6 +49,10 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
if TYPE_CHECKING:
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
# Put some global args for easy access
......@@ -565,9 +571,13 @@ class ScheduleBatch:
# Has grammar
has_grammar: bool = False
# device
# Device
device: str = "cuda"
# Speculative decoding
spec_info: Optional[SpecInfo] = None
spec_algorithm: Optional[SpeculativeAlgorithm] = None
@classmethod
def init_new(
cls,
......@@ -577,6 +587,7 @@ class ScheduleBatch:
tree_cache: BasePrefixCache,
model_config: ModelConfig,
enable_overlap: bool,
speculative_algorithm: Optional[SpeculativeAlgorithm] = None,
):
return cls(
reqs=reqs,
......@@ -589,6 +600,7 @@ class ScheduleBatch:
has_stream=any(req.stream for req in reqs),
has_grammar=any(req.grammar for req in reqs),
device=req_to_token_pool.device,
spec_algorithm=speculative_algorithm,
)
def batch_size(self):
......@@ -1103,6 +1115,9 @@ class ScheduleBatch:
self.has_stream |= other.has_stream
self.has_grammar |= other.has_grammar
if self.spec_info:
self.spec_info.merge_batch(other.spec_info)
def get_model_worker_batch(self):
if self.forward_mode.is_decode() or self.forward_mode.is_idle():
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
......@@ -1144,6 +1159,8 @@ class ScheduleBatch:
lora_paths=[req.lora_path for req in self.reqs],
sampling_info=self.sampling_info,
input_embeds=self.input_embeds,
spec_algorithm=self.spec_algorithm,
spec_info=self.spec_info,
)
def copy(self):
......@@ -1214,6 +1231,10 @@ class ModelWorkerBatch:
# The input Embeds
input_embeds: Optional[torch.tensor] = None
# Speculative decoding
spec_info: Optional[SpecInfo] = None
spec_algorithm: Optional[SpeculativeAlgorithm] = None
@triton.jit
def write_req_to_token_pool_triton(
......
......@@ -150,12 +150,18 @@ class TpModelWorker:
self,
model_worker_batch: ModelWorkerBatch,
launch_done: Optional[threading.Event] = None,
skip_sample: bool = False,
):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch)
if launch_done:
launch_done.set()
next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
if skip_sample:
next_token_ids = None
else:
next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
return logits_output, next_token_ids
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
......
......@@ -375,9 +375,7 @@ class CudaGraphRunner:
def replay(self, forward_batch: ForwardBatch):
assert forward_batch.out_cache_loc is not None
raw_bs = forward_batch.batch_size
# In normal decoding case, raw_bs == raw_num_token
# But in speculative decoding, raw_num_token is raw_bs * self.num_tokens_per_bs
raw_num_token = forward_batch.input_ids.numel()
raw_num_token = raw_bs * self.num_tokens_per_bs
# Pad
if self.enable_dp_attention:
......
......@@ -96,7 +96,11 @@ class ForwardMode(IntEnum):
return self == ForwardMode.DRAFT_EXTEND
def is_cuda_graph(self):
return self == ForwardMode.DECODE or self == ForwardMode.TARGET_VERIFY
return (
self == ForwardMode.DECODE
or self == ForwardMode.TARGET_VERIFY
or self == ForwardMode.IDLE
)
def is_dummy_first(self):
return self == ForwardMode.DUMMY_FIRST
......@@ -161,15 +165,15 @@ class ForwardBatch:
token_to_kv_pool: BaseTokenToKVPool = None
attn_backend: AttentionBackend = None
# Speculative decoding
spec_info: SpecInfo = None
spec_algorithm: SpeculativeAlgorithm = None
# For DP attention
global_num_tokens: Optional[List[int]] = None
gathered_buffer: Optional[torch.Tensor] = None
can_run_dp_cuda_graph: bool = False
# Speculative decoding
spec_info: SpecInfo = None
spec_algorithm: SpeculativeAlgorithm = None
# For Qwen2-VL
mrope_positions: torch.Tensor = None
......@@ -258,6 +262,8 @@ class ForwardBatch:
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
lora_paths=batch.lora_paths,
sampling_info=batch.sampling_info,
spec_algorithm=batch.spec_algorithm,
spec_info=batch.spec_info,
input_embeds=batch.input_embeds,
)
......
......@@ -108,14 +108,6 @@ class ServerArgs:
# Model override args in JSON
json_model_override_args: str = "{}"
# Double Sparsity
enable_double_sparsity: bool = False
ds_channel_config_path: str = None
ds_heavy_channel_num: int = 32
ds_heavy_token_num: int = 256
ds_heavy_channel_type: str = "qk"
ds_sparse_decode_threshold: int = 4096
# LoRA
lora_paths: Optional[List[str]] = None
max_loras_per_batch: int = 8
......@@ -125,6 +117,21 @@ class ServerArgs:
sampling_backend: Optional[str] = None
grammar_backend: Optional[str] = "outlines"
# Speculative decoding
speculative_draft_model_path: Optional[str] = None
speculative_algorithm: Optional[str] = None
speculative_num_steps: int = 5
speculative_num_draft_tokens: int = 64
speculative_eagle_topk: int = 8
# Double Sparsity
enable_double_sparsity: bool = False
ds_channel_config_path: str = None
ds_heavy_channel_num: int = 32
ds_heavy_token_num: int = 256
ds_heavy_channel_type: str = "qk"
ds_sparse_decode_threshold: int = 4096
# Optimization/debug options
disable_radix_cache: bool = False
disable_jump_forward: bool = False
......@@ -602,43 +609,6 @@ class ServerArgs:
default=ServerArgs.json_model_override_args,
)
# Double Sparsity
parser.add_argument(
"--enable-double-sparsity",
action="store_true",
help="Enable double sparsity attention",
)
parser.add_argument(
"--ds-channel-config-path",
type=str,
default=ServerArgs.ds_channel_config_path,
help="The path of the double sparsity channel config",
)
parser.add_argument(
"--ds-heavy-channel-num",
type=int,
default=ServerArgs.ds_heavy_channel_num,
help="The number of heavy channels in double sparsity attention",
)
parser.add_argument(
"--ds-heavy-token-num",
type=int,
default=ServerArgs.ds_heavy_token_num,
help="The number of heavy tokens in double sparsity attention",
)
parser.add_argument(
"--ds-heavy-channel-type",
type=str,
default=ServerArgs.ds_heavy_channel_type,
help="The type of heavy channels in double sparsity attention",
)
parser.add_argument(
"--ds-sparse-decode-threshold",
type=int,
default=ServerArgs.ds_sparse_decode_threshold,
help="The type of heavy channels in double sparsity attention",
)
# LoRA
parser.add_argument(
"--lora-paths",
......@@ -678,6 +648,75 @@ class ServerArgs:
help="Choose the backend for grammar-guided decoding.",
)
# Speculative decoding
parser.add_argument(
"--speculative-algorithm",
type=str,
choices=["EAGLE"],
help="Speculative algorithm.",
)
parser.add_argument(
"--speculative-draft-model-path",
type=str,
help="The path of the draft model weights. This can be a local folder or a Hugging Face repo ID.",
)
parser.add_argument(
"--speculative-num-steps",
type=int,
help="The number of steps sampled from draft model in Speculative Decoding.",
default=ServerArgs.speculative_num_steps,
)
parser.add_argument(
"--speculative-num-draft-tokens",
type=int,
help="The number of token sampled from draft model in Speculative Decoding.",
default=ServerArgs.speculative_num_draft_tokens,
)
parser.add_argument(
"--speculative-eagle-topk",
type=int,
help="The number of token sampled from draft model in eagle2 each step.",
choices=[1, 2, 4, 8],
default=ServerArgs.speculative_eagle_topk,
)
# Double Sparsity
parser.add_argument(
"--enable-double-sparsity",
action="store_true",
help="Enable double sparsity attention",
)
parser.add_argument(
"--ds-channel-config-path",
type=str,
default=ServerArgs.ds_channel_config_path,
help="The path of the double sparsity channel config",
)
parser.add_argument(
"--ds-heavy-channel-num",
type=int,
default=ServerArgs.ds_heavy_channel_num,
help="The number of heavy channels in double sparsity attention",
)
parser.add_argument(
"--ds-heavy-token-num",
type=int,
default=ServerArgs.ds_heavy_token_num,
help="The number of heavy tokens in double sparsity attention",
)
parser.add_argument(
"--ds-heavy-channel-type",
type=str,
default=ServerArgs.ds_heavy_channel_type,
help="The type of heavy channels in double sparsity attention",
)
parser.add_argument(
"--ds-sparse-decode-threshold",
type=int,
default=ServerArgs.ds_sparse_decode_threshold,
help="The type of heavy channels in double sparsity attention",
)
# Optimization/debug options
parser.add_argument(
"--disable-radix-cache",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment