"git@developer.sourcefind.cn:change/sglang.git" did not exist on "7282ab741a4d07dfb775cbba7fd442b68fddfeeb"
Unverified Commit 3815b23c authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up wrapper in flashinfer backend (#2638)

parent fd34f2da
...@@ -331,6 +331,7 @@ def throughput_test( ...@@ -331,6 +331,7 @@ def throughput_test(
extra_request_body=extra_request_body, extra_request_body=extra_request_body,
profile=bench_args.profile, profile=bench_args.profile,
) )
backend.shutdown()
if bench_args.result_filename: if bench_args.result_filename:
with open(bench_args.result_filename, "a") as fout: with open(bench_args.result_filename, "a") as fout:
......
...@@ -131,10 +131,8 @@ class ModelConfig: ...@@ -131,10 +131,8 @@ class ModelConfig:
# Veirfy quantization # Veirfy quantization
self._verify_quantization() self._verify_quantization()
# Text attrs # Cache attributes
self.hf_eos_token_id = self.get_hf_eos_token_id() self.hf_eos_token_id = self.get_hf_eos_token_id()
# Multimodel attrs
self.image_token_id = getattr(self.hf_config, "image_token_id", None) self.image_token_id = getattr(self.hf_config, "image_token_id", None)
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289 # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
......
...@@ -2,7 +2,6 @@ from abc import ABC, abstractmethod ...@@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
from typing import Optional from typing import Optional
import torch import torch
from torch import nn
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
......
...@@ -8,8 +8,9 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an ...@@ -8,8 +8,9 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an
""" """
import os import os
from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, List, Union
import torch import torch
import triton import triton
...@@ -38,12 +39,25 @@ class WrapperDispatch(Enum): ...@@ -38,12 +39,25 @@ class WrapperDispatch(Enum):
CROSS_ATTENTION = auto() CROSS_ATTENTION = auto()
@dataclass
class DecodeMetadata:
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
@dataclass
class PrefillMetadata:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
use_ragged: bool
extend_no_prefix: bool
class FlashInferAttnBackend(AttentionBackend): class FlashInferAttnBackend(AttentionBackend):
"""Flashinfer attention kernels.""" """Flashinfer attention kernels."""
def __init__(self, model_runner: ModelRunner): def __init__(self, model_runner: ModelRunner):
super().__init__() super().__init__()
# Parse constants
self.decode_use_tensor_cores = should_use_tensor_core( self.decode_use_tensor_cores = should_use_tensor_core(
kv_cache_dtype=model_runner.kv_cache_dtype, kv_cache_dtype=model_runner.kv_cache_dtype,
num_attention_heads=model_runner.model_config.num_attention_heads num_attention_heads=model_runner.model_config.num_attention_heads
...@@ -52,7 +66,6 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -52,7 +66,6 @@ class FlashInferAttnBackend(AttentionBackend):
model_runner.tp_size model_runner.tp_size
), ),
) )
self.max_context_len = model_runner.model_config.context_len self.max_context_len = model_runner.model_config.context_len
assert not ( assert not (
...@@ -120,8 +133,8 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -120,8 +133,8 @@ class FlashInferAttnBackend(AttentionBackend):
) )
# Other metadata # Other metadata
self.forward_metadata = None self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
self.cuda_graph_metadata = {} self.decode_cuda_graph_metadata = {}
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode(): if forward_batch.forward_mode.is_decode():
...@@ -129,10 +142,10 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -129,10 +142,10 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch.req_pool_indices, forward_batch.req_pool_indices,
forward_batch.seq_lens, forward_batch.seq_lens,
forward_batch.seq_lens_sum, forward_batch.seq_lens_sum,
decode_wrappers=None, decode_wrappers=self.decode_wrappers,
encoder_lens=forward_batch.encoder_lens, encoder_lens=forward_batch.encoder_lens,
) )
self.forward_metadata = (self.decode_wrappers,) self.forward_metadata = DecodeMetadata(self.decode_wrappers)
else: else:
prefix_lens = forward_batch.extend_prefix_lens prefix_lens = forward_batch.extend_prefix_lens
...@@ -149,11 +162,13 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -149,11 +162,13 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch.seq_lens, forward_batch.seq_lens,
forward_batch.seq_lens_sum, forward_batch.seq_lens_sum,
prefix_lens, prefix_lens,
prefill_wrappers=self.prefill_wrappers_paged,
use_ragged=use_ragged, use_ragged=use_ragged,
encoder_lens=forward_batch.encoder_lens, encoder_lens=forward_batch.encoder_lens,
) )
self.forward_metadata = PrefillMetadata(
self.forward_metadata = (use_ragged, extend_no_prefix) self.prefill_wrappers_paged, use_ragged, extend_no_prefix
)
def init_cuda_graph_state(self, max_bs: int): def init_cuda_graph_state(self, max_bs: int):
cuda_graph_kv_indices = torch.zeros( cuda_graph_kv_indices = torch.zeros(
...@@ -194,8 +209,8 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -194,8 +209,8 @@ class FlashInferAttnBackend(AttentionBackend):
decode_wrappers=decode_wrappers, decode_wrappers=decode_wrappers,
encoder_lens=encoder_lens, encoder_lens=encoder_lens,
) )
self.cuda_graph_metadata[bs] = decode_wrappers self.decode_cuda_graph_metadata[bs] = decode_wrappers
self.forward_metadata = (decode_wrappers,) self.forward_metadata = DecodeMetadata(decode_wrappers)
def init_forward_metadata_replay_cuda_graph( def init_forward_metadata_replay_cuda_graph(
self, self,
...@@ -209,7 +224,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -209,7 +224,7 @@ class FlashInferAttnBackend(AttentionBackend):
req_pool_indices[:bs], req_pool_indices[:bs],
seq_lens[:bs], seq_lens[:bs],
seq_lens_sum, seq_lens_sum,
decode_wrappers=self.cuda_graph_metadata[bs], decode_wrappers=self.decode_cuda_graph_metadata[bs],
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None, encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
) )
...@@ -225,18 +240,16 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -225,18 +240,16 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache=True, save_kv_cache=True,
): ):
prefill_wrapper_paged = self.prefill_wrappers_paged[ prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
self._get_wrapper_idx(layer) self._get_wrapper_idx(layer)
] ]
use_ragged, extend_no_prefix = self.forward_metadata
cache_loc = ( cache_loc = (
forward_batch.out_cache_loc forward_batch.out_cache_loc
if not layer.is_cross_attention if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc else forward_batch.encoder_out_cache_loc
) )
if not use_ragged: if not self.forward_metadata.use_ragged:
if k is not None: if k is not None:
assert v is not None assert v is not None
if save_kv_cache: if save_kv_cache:
...@@ -260,7 +273,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -260,7 +273,7 @@ class FlashInferAttnBackend(AttentionBackend):
logits_soft_cap=layer.logit_cap, logits_soft_cap=layer.logit_cap,
) )
if extend_no_prefix: if self.forward_metadata.extend_no_prefix:
o = o1 o = o1
else: else:
o2, s2 = prefill_wrapper_paged.forward_return_lse( o2, s2 = prefill_wrapper_paged.forward_return_lse(
...@@ -287,7 +300,9 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -287,7 +300,9 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache=True, save_kv_cache=True,
): ):
decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)] decode_wrapper = self.forward_metadata.decode_wrappers[
self._get_wrapper_idx(layer)
]
cache_loc = ( cache_loc = (
forward_batch.out_cache_loc forward_batch.out_cache_loc
if not layer.is_cross_attention if not layer.is_cross_attention
...@@ -322,7 +337,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -322,7 +337,7 @@ class FlashInferAttnBackend(AttentionBackend):
class FlashInferIndicesUpdaterDecode: class FlashInferIndicesUpdaterDecode:
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
# Constants # Parse Constants
self.num_qo_heads = ( self.num_qo_heads = (
model_runner.model_config.num_attention_heads // model_runner.tp_size model_runner.model_config.num_attention_heads // model_runner.tp_size
) )
...@@ -340,9 +355,8 @@ class FlashInferIndicesUpdaterDecode: ...@@ -340,9 +355,8 @@ class FlashInferIndicesUpdaterDecode:
self.kv_indptr = attn_backend.kv_indptr self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len self.kv_last_page_len = attn_backend.kv_last_page_len
self.req_to_token = model_runner.req_to_token_pool.req_to_token self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.decode_wrappers = attn_backend.decode_wrappers
# Dispatch # Dispatch the update function
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW: if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
self.update = self.update_sliding_window self.update = self.update_sliding_window
elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION: elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
...@@ -356,7 +370,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -356,7 +370,7 @@ class FlashInferIndicesUpdaterDecode:
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List, decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: torch.Tensor, encoder_lens: torch.Tensor,
): ):
# Keep the signature for type checking. It will be assigned during runtime. # Keep the signature for type checking. It will be assigned during runtime.
...@@ -367,7 +381,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -367,7 +381,7 @@ class FlashInferIndicesUpdaterDecode:
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List, decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: torch.Tensor, encoder_lens: torch.Tensor,
): ):
decode_wrappers = decode_wrappers or self.decode_wrappers decode_wrappers = decode_wrappers or self.decode_wrappers
...@@ -385,11 +399,9 @@ class FlashInferIndicesUpdaterDecode: ...@@ -385,11 +399,9 @@ class FlashInferIndicesUpdaterDecode:
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List, decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: torch.Tensor, encoder_lens: torch.Tensor,
): ):
decode_wrappers = decode_wrappers or self.decode_wrappers
for wrapper_id in range(2): for wrapper_id in range(2):
if wrapper_id == 0: if wrapper_id == 0:
# Sliding window attention # Sliding window attention
...@@ -419,11 +431,9 @@ class FlashInferIndicesUpdaterDecode: ...@@ -419,11 +431,9 @@ class FlashInferIndicesUpdaterDecode:
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List, decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: torch.Tensor, encoder_lens: torch.Tensor,
): ):
decode_wrappers = decode_wrappers or self.decode_wrappers
for wrapper_id in range(2): for wrapper_id in range(2):
if wrapper_id == 0: if wrapper_id == 0:
# Normal attention # Normal attention
...@@ -446,7 +456,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -446,7 +456,7 @@ class FlashInferIndicesUpdaterDecode:
def call_begin_forward( def call_begin_forward(
self, self,
wrapper, wrapper: BatchDecodeWithPagedKVCacheWrapper,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor, paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int, paged_kernel_lens_sum: int,
...@@ -486,7 +496,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -486,7 +496,7 @@ class FlashInferIndicesUpdaterDecode:
class FlashInferIndicesUpdaterPrefill: class FlashInferIndicesUpdaterPrefill:
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
# Constants # Parse Constants
self.num_qo_heads = ( self.num_qo_heads = (
model_runner.model_config.num_attention_heads // model_runner.tp_size model_runner.model_config.num_attention_heads // model_runner.tp_size
) )
...@@ -505,10 +515,9 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -505,10 +515,9 @@ class FlashInferIndicesUpdaterPrefill:
self.kv_last_page_len = attn_backend.kv_last_page_len self.kv_last_page_len = attn_backend.kv_last_page_len
self.qo_indptr = attn_backend.qo_indptr self.qo_indptr = attn_backend.qo_indptr
self.req_to_token = model_runner.req_to_token_pool.req_to_token self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.wrapper_ragged = attn_backend.prefill_wrapper_ragged self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
self.wrappers_paged = attn_backend.prefill_wrappers_paged
# Dispatch # Dispatch the update function
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW: if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
self.update = self.update_sliding_window self.update = self.update_sliding_window
elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION: elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
...@@ -523,6 +532,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -523,6 +532,7 @@ class FlashInferIndicesUpdaterPrefill:
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
prefix_lens: torch.Tensor, prefix_lens: torch.Tensor,
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool, use_ragged: bool,
encoder_lens: torch.Tensor, encoder_lens: torch.Tensor,
): ):
...@@ -535,6 +545,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -535,6 +545,7 @@ class FlashInferIndicesUpdaterPrefill:
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
prefix_lens: torch.Tensor, prefix_lens: torch.Tensor,
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool, use_ragged: bool,
encoder_lens: torch.Tensor, encoder_lens: torch.Tensor,
): ):
...@@ -546,8 +557,8 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -546,8 +557,8 @@ class FlashInferIndicesUpdaterPrefill:
paged_kernel_lens_sum = seq_lens_sum paged_kernel_lens_sum = seq_lens_sum
self.call_begin_forward( self.call_begin_forward(
self.wrapper_ragged, self.prefill_wrapper_ragged,
self.wrappers_paged[0], prefill_wrappers[0],
req_pool_indices, req_pool_indices,
paged_kernel_lens, paged_kernel_lens,
paged_kernel_lens_sum, paged_kernel_lens_sum,
...@@ -565,6 +576,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -565,6 +576,7 @@ class FlashInferIndicesUpdaterPrefill:
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
prefix_lens: torch.Tensor, prefix_lens: torch.Tensor,
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool, use_ragged: bool,
encoder_lens: torch.Tensor, encoder_lens: torch.Tensor,
): ):
...@@ -584,8 +596,8 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -584,8 +596,8 @@ class FlashInferIndicesUpdaterPrefill:
kv_start_idx = seq_lens - paged_kernel_lens kv_start_idx = seq_lens - paged_kernel_lens
self.call_begin_forward( self.call_begin_forward(
self.wrapper_ragged, self.prefill_wrapper_ragged,
self.wrappers_paged[wrapper_id], prefill_wrappers[wrapper_id],
req_pool_indices, req_pool_indices,
paged_kernel_lens, paged_kernel_lens,
paged_kernel_lens_sum, paged_kernel_lens_sum,
...@@ -603,6 +615,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -603,6 +615,7 @@ class FlashInferIndicesUpdaterPrefill:
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
prefix_lens: torch.Tensor, prefix_lens: torch.Tensor,
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool, use_ragged: bool,
encoder_lens: torch.Tensor, encoder_lens: torch.Tensor,
): ):
...@@ -619,8 +632,8 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -619,8 +632,8 @@ class FlashInferIndicesUpdaterPrefill:
paged_kernel_lens_sum = paged_kernel_lens.sum().item() paged_kernel_lens_sum = paged_kernel_lens.sum().item()
self.call_begin_forward( self.call_begin_forward(
self.wrapper_ragged, self.prefill_wrapper_ragged,
self.wrappers_paged[wrapper_id], prefill_wrappers[wrapper_id],
req_pool_indices, req_pool_indices,
paged_kernel_lens, paged_kernel_lens,
paged_kernel_lens_sum, paged_kernel_lens_sum,
...@@ -634,8 +647,8 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -634,8 +647,8 @@ class FlashInferIndicesUpdaterPrefill:
def call_begin_forward( def call_begin_forward(
self, self,
wrapper_ragged, wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,
wrapper_paged, wrapper_paged: BatchPrefillWithPagedKVCacheWrapper,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor, paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int, paged_kernel_lens_sum: int,
......
...@@ -24,7 +24,11 @@ from vllm.distributed import ( ...@@ -24,7 +24,11 @@ from vllm.distributed import (
) )
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardMode,
)
@dataclasses.dataclass @dataclasses.dataclass
...@@ -46,6 +50,10 @@ class LogitsProcessorOutput: ...@@ -46,6 +50,10 @@ class LogitsProcessorOutput:
output_top_logprobs_val: List = None output_top_logprobs_val: List = None
output_top_logprobs_idx: List = None output_top_logprobs_idx: List = None
# Used by speculative decoding (EAGLE)
# The output of transformer layers
hidden_states: Optional[torch.Tensor] = None
@dataclasses.dataclass @dataclasses.dataclass
class LogitsMetadata: class LogitsMetadata:
...@@ -61,6 +69,8 @@ class LogitsMetadata: ...@@ -61,6 +69,8 @@ class LogitsMetadata:
extend_logprob_start_lens_cpu: Optional[List[int]] = None extend_logprob_start_lens_cpu: Optional[List[int]] = None
extend_logprob_pruned_lens_cpu: Optional[List[int]] = None extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL
@classmethod @classmethod
def from_forward_batch(cls, forward_batch: ForwardBatch): def from_forward_batch(cls, forward_batch: ForwardBatch):
extend_logprob_pruned_lens_cpu = None extend_logprob_pruned_lens_cpu = None
...@@ -78,6 +88,11 @@ class LogitsMetadata: ...@@ -78,6 +88,11 @@ class LogitsMetadata:
else: else:
return_top_logprob = False return_top_logprob = False
if forward_batch.spec_info:
capture_hidden_mode = forward_batch.spec_info.capture_hidden_mode
else:
capture_hidden_mode = CaptureHiddenMode.NULL
return cls( return cls(
forward_mode=forward_batch.forward_mode, forward_mode=forward_batch.forward_mode,
top_logprobs_nums=forward_batch.top_logprobs_nums, top_logprobs_nums=forward_batch.top_logprobs_nums,
...@@ -87,6 +102,7 @@ class LogitsMetadata: ...@@ -87,6 +102,7 @@ class LogitsMetadata:
extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu, extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu, extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu, extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
capture_hidden_mode=capture_hidden_mode,
) )
...@@ -116,7 +132,10 @@ class LogitsProcessor(nn.Module): ...@@ -116,7 +132,10 @@ class LogitsProcessor(nn.Module):
assert isinstance(logits_metadata, LogitsMetadata) assert isinstance(logits_metadata, LogitsMetadata)
# Get the last hidden states and last logits for the next token prediction # Get the last hidden states and last logits for the next token prediction
if logits_metadata.forward_mode.is_decode(): if (
logits_metadata.forward_mode.is_decode()
or logits_metadata.forward_mode.is_target_verify()
):
last_index = None last_index = None
last_hidden = hidden_states last_hidden = hidden_states
else: else:
...@@ -137,6 +156,15 @@ class LogitsProcessor(nn.Module): ...@@ -137,6 +156,15 @@ class LogitsProcessor(nn.Module):
if not logits_metadata.return_logprob: if not logits_metadata.return_logprob:
return LogitsProcessorOutput( return LogitsProcessorOutput(
next_token_logits=last_logits, next_token_logits=last_logits,
hidden_states=(
hidden_states
if logits_metadata.capture_hidden_mode.is_full()
else (
last_hidden
if logits_metadata.capture_hidden_mode.is_last()
else None
)
),
) )
else: else:
last_logprobs = self.compute_temp_top_p_normalized_logprobs( last_logprobs = self.compute_temp_top_p_normalized_logprobs(
......
...@@ -843,8 +843,8 @@ class ScheduleBatch: ...@@ -843,8 +843,8 @@ class ScheduleBatch:
# TODO (lianmin): Revisit this. It should be seq_len - 1 # TODO (lianmin): Revisit this. It should be seq_len - 1
self.extend_logprob_start_lens.extend([0] * running_bs) self.extend_logprob_start_lens.extend([0] * running_bs)
def check_decode_mem(self): def check_decode_mem(self, buf_multiplier=1):
bs = len(self.reqs) bs = len(self.reqs) * buf_multiplier
if self.token_to_kv_pool.available_size() >= bs: if self.token_to_kv_pool.available_size() >= bs:
return True return True
......
...@@ -90,7 +90,7 @@ from sglang.utils import get_exception_traceback ...@@ -90,7 +90,7 @@ from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Test retract decode # Test retract decode for debugging purposes
test_retract = get_bool_env_var("SGLANG_TEST_RETRACT") test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
...@@ -129,12 +129,12 @@ class Scheduler: ...@@ -129,12 +129,12 @@ class Scheduler:
) )
if server_args.skip_tokenizer_init: if server_args.skip_tokenizer_init:
# Directly send to the tokenizer/api # Directly send to the TokenizerManager
self.send_to_detokenizer = get_zmq_socket( self.send_to_detokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_ipc_name context, zmq.PUSH, port_args.tokenizer_ipc_name
) )
else: else:
# Send to the detokenizer # Send to the DetokenizerManager
self.send_to_detokenizer = get_zmq_socket( self.send_to_detokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.detokenizer_ipc_name context, zmq.PUSH, port_args.detokenizer_ipc_name
) )
...@@ -385,7 +385,8 @@ class Scheduler: ...@@ -385,7 +385,8 @@ class Scheduler:
self.process_input_requests(recv_reqs) self.process_input_requests(recv_reqs)
batch = self.get_next_batch_to_run() batch = self.get_next_batch_to_run()
if self.server_args.enable_dp_attention:
if self.server_args.enable_dp_attention: # TODO: simplify this
batch = self.prepare_dp_attn_batch(batch) batch = self.prepare_dp_attn_batch(batch)
self.cur_batch = batch self.cur_batch = batch
...@@ -394,7 +395,7 @@ class Scheduler: ...@@ -394,7 +395,7 @@ class Scheduler:
result = self.run_batch(batch) result = self.run_batch(batch)
self.process_batch_result(batch, result) self.process_batch_result(batch, result)
else: else:
# Self-check and re-init some states when the server is idle # When the server is idle, so self-check and re-init some states
self.check_memory() self.check_memory()
self.new_token_ratio = self.init_new_token_ratio self.new_token_ratio = self.init_new_token_ratio
...@@ -411,12 +412,13 @@ class Scheduler: ...@@ -411,12 +412,13 @@ class Scheduler:
batch = self.get_next_batch_to_run() batch = self.get_next_batch_to_run()
self.cur_batch = batch self.cur_batch = batch
if batch: if batch:
result = self.run_batch(batch) result = self.run_batch(batch)
result_queue.append((batch.copy(), result)) result_queue.append((batch.copy(), result))
if self.last_batch is None: if self.last_batch is None:
# A dummy first batch to start the pipeline for overlap scheduler. # Create a dummy first batch to start the pipeline for overlap scheduler.
# It is now used for triggering the sampling_info_done event. # It is now used for triggering the sampling_info_done event.
tmp_batch = ScheduleBatch( tmp_batch = ScheduleBatch(
reqs=None, reqs=None,
...@@ -426,19 +428,21 @@ class Scheduler: ...@@ -426,19 +428,21 @@ class Scheduler:
self.process_batch_result(tmp_batch, None) self.process_batch_result(tmp_batch, None)
if self.last_batch: if self.last_batch:
# Process the results of the last batch
tmp_batch, tmp_result = result_queue.popleft() tmp_batch, tmp_result = result_queue.popleft()
tmp_batch.next_batch_sampling_info = ( tmp_batch.next_batch_sampling_info = (
self.tp_worker.cur_sampling_info if batch else None self.tp_worker.cur_sampling_info if batch else None
) )
self.process_batch_result(tmp_batch, tmp_result) self.process_batch_result(tmp_batch, tmp_result)
elif batch is None: elif batch is None:
# Self-check and re-init some states when the server is idle # When the server is idle, so self-check and re-init some states
self.check_memory() self.check_memory()
self.new_token_ratio = self.init_new_token_ratio self.new_token_ratio = self.init_new_token_ratio
self.last_batch = batch self.last_batch = batch
def recv_requests(self): def recv_requests(self) -> List[Req]:
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
if self.tp_rank == 0 or self.server_args.enable_dp_attention: if self.tp_rank == 0 or self.server_args.enable_dp_attention:
recv_reqs = [] recv_reqs = []
...@@ -812,6 +816,8 @@ class Scheduler: ...@@ -812,6 +816,8 @@ class Scheduler:
if res == AddReqResult.NO_TOKEN: if res == AddReqResult.NO_TOKEN:
self.batch_is_full = True self.batch_is_full = True
break break
if self.server_args.prefill_only_one_req:
break
# Update waiting queue # Update waiting queue
can_run_list = adder.can_run_list can_run_list = adder.can_run_list
...@@ -1528,18 +1534,20 @@ def run_scheduler_process( ...@@ -1528,18 +1534,20 @@ def run_scheduler_process(
if dp_rank is None and "SGLANG_DP_RANK" in os.environ: if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
dp_rank = int(os.environ["SGLANG_DP_RANK"]) dp_rank = int(os.environ["SGLANG_DP_RANK"])
# Configue the logger
if dp_rank is None: if dp_rank is None:
configure_logger(server_args, prefix=f" TP{tp_rank}") configure_logger(server_args, prefix=f" TP{tp_rank}")
else: else:
configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}") configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
suppress_other_loggers()
# set cpu affinity to this gpu process # Set cpu affinity to this gpu process
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"): if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id) set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
suppress_other_loggers()
parent_process = psutil.Process().parent() parent_process = psutil.Process().parent()
# Create a scheduler and run the event loop
try: try:
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank) scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
pipe_writer.send( pipe_writer.send(
......
...@@ -45,6 +45,7 @@ if TYPE_CHECKING: ...@@ -45,6 +45,7 @@ if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
class ForwardMode(IntEnum): class ForwardMode(IntEnum):
...@@ -59,6 +60,11 @@ class ForwardMode(IntEnum): ...@@ -59,6 +60,11 @@ class ForwardMode(IntEnum):
# No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence are allocated. # No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence are allocated.
IDLE = auto() IDLE = auto()
# Used in speculative decoding: verify a batch in the target model.
TARGET_VERIFY = auto()
# Used in speculative decoding: extend a batch in the draft model.
DRAFT_EXTEND = auto()
# A dummy first batch to start the pipeline for overlap scheduler. # A dummy first batch to start the pipeline for overlap scheduler.
# It is now used for triggering the sampling_info_done event for the first prefill batch. # It is now used for triggering the sampling_info_done event for the first prefill batch.
DUMMY_FIRST = auto() DUMMY_FIRST = auto()
...@@ -67,7 +73,12 @@ class ForwardMode(IntEnum): ...@@ -67,7 +73,12 @@ class ForwardMode(IntEnum):
return self == ForwardMode.PREFILL return self == ForwardMode.PREFILL
def is_extend(self): def is_extend(self):
return self == ForwardMode.EXTEND or self == ForwardMode.MIXED return (
self == ForwardMode.EXTEND
or self == ForwardMode.MIXED
or self == ForwardMode.DRAFT_EXTEND
or self == self.TARGET_VERIFY
)
def is_decode(self): def is_decode(self):
return self == ForwardMode.DECODE return self == ForwardMode.DECODE
...@@ -78,6 +89,15 @@ class ForwardMode(IntEnum): ...@@ -78,6 +89,15 @@ class ForwardMode(IntEnum):
def is_idle(self): def is_idle(self):
return self == ForwardMode.IDLE return self == ForwardMode.IDLE
def is_target_verify(self):
return self == ForwardMode.TARGET_VERIFY
def is_draft_extend(self):
return self == ForwardMode.DRAFT_EXTEND
def is_cuda_graph(self):
return self in (ForwardMode.DECODE, ForwardMode.TARGET_VERIFY)
def is_dummy_first(self): def is_dummy_first(self):
return self == ForwardMode.DUMMY_FIRST return self == ForwardMode.DUMMY_FIRST
...@@ -141,14 +161,18 @@ class ForwardBatch: ...@@ -141,14 +161,18 @@ class ForwardBatch:
token_to_kv_pool: BaseTokenToKVPool = None token_to_kv_pool: BaseTokenToKVPool = None
attn_backend: AttentionBackend = None attn_backend: AttentionBackend = None
# For Qwen2-VL # Speculative decoding
mrope_positions: torch.Tensor = None spec_info: SpecInfo = None
spec_algorithm: SpeculativeAlgorithm = None
# For DP attention # For DP attention
global_num_tokens: Optional[List[int]] = None global_num_tokens: Optional[List[int]] = None
gathered_buffer: Optional[torch.Tensor] = None gathered_buffer: Optional[torch.Tensor] = None
can_run_dp_cuda_graph: bool = False can_run_dp_cuda_graph: bool = False
# For Qwen2-VL
mrope_positions: torch.Tensor = None
def compute_mrope_positions( def compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch self, model_runner: ModelRunner, batch: ModelWorkerBatch
): ):
...@@ -351,3 +375,18 @@ def compute_position_torch( ...@@ -351,3 +375,18 @@ def compute_position_torch(
extend_start_loc = torch.zeros_like(extend_seq_lens) extend_start_loc = torch.zeros_like(extend_seq_lens)
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0) extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
return positions.to(torch.int64), extend_start_loc return positions.to(torch.int64), extend_start_loc
class CaptureHiddenMode(IntEnum):
NULL = auto()
FULL = auto()
LAST = auto()
def need_capture(self):
return self != CaptureHiddenMode.NULL
def is_full(self):
return self == CaptureHiddenMode.FULL
def is_last(self):
return self == CaptureHiddenMode.LAST
...@@ -516,6 +516,17 @@ class LlamaForCausalLM(nn.Module): ...@@ -516,6 +516,17 @@ class LlamaForCausalLM(nn.Module):
) )
return None return None
def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight
def set_embed_and_head(self, embed, head):
del self.model.embed_tokens.weight
del self.lm_head.weight
self.model.embed_tokens.weight = embed
self.lm_head.weight = head
torch.cuda.empty_cache()
torch.cuda.synchronize()
class Phi3ForCausalLM(LlamaForCausalLM): class Phi3ForCausalLM(LlamaForCausalLM):
pass pass
......
...@@ -503,7 +503,7 @@ def launch_engine( ...@@ -503,7 +503,7 @@ def launch_engine(
) )
scheduler_infos.append(data) scheduler_infos.append(data)
# Assume all schedulers have same max_total_num_tokens # Assume all schedulers have same scheduler_info
scheduler_info = scheduler_infos[0] scheduler_info = scheduler_infos[0]
...@@ -890,7 +890,7 @@ class Runtime: ...@@ -890,7 +890,7 @@ class Runtime:
using the commond line interface. using the commond line interface.
It is mainly used for the frontend language. It is mainly used for the frontend language.
You should use the Engine class if you want to do normal offline processing. You should use the Engine class above if you want to do normal offline processing.
""" """
def __init__( def __init__(
......
...@@ -55,7 +55,7 @@ class ServerArgs: ...@@ -55,7 +55,7 @@ class ServerArgs:
is_embedding: bool = False is_embedding: bool = False
revision: Optional[str] = None revision: Optional[str] = None
# Port # Port for the HTTP server
host: str = "127.0.0.1" host: str = "127.0.0.1"
port: int = 30000 port: int = 30000
...@@ -68,6 +68,7 @@ class ServerArgs: ...@@ -68,6 +68,7 @@ class ServerArgs:
schedule_policy: str = "lpm" schedule_policy: str = "lpm"
schedule_conservativeness: float = 1.0 schedule_conservativeness: float = 1.0
cpu_offload_gb: int = 0 cpu_offload_gb: int = 0
prefill_only_one_req: bool = False
# Other runtime options # Other runtime options
tp_size: int = 1 tp_size: int = 1
...@@ -94,6 +95,7 @@ class ServerArgs: ...@@ -94,6 +95,7 @@ class ServerArgs:
# Data parallelism # Data parallelism
dp_size: int = 1 dp_size: int = 1
load_balance_method: str = "round_robin" load_balance_method: str = "round_robin"
# Expert parallelism # Expert parallelism
ep_size: int = 1 ep_size: int = 1
...@@ -217,6 +219,13 @@ class ServerArgs: ...@@ -217,6 +219,13 @@ class ServerArgs:
) )
self.disable_cuda_graph = True self.disable_cuda_graph = True
# Expert parallelism
if self.enable_ep_moe:
self.ep_size = self.tp_size
logger.info(
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
)
# Others # Others
if self.enable_dp_attention: if self.enable_dp_attention:
self.dp_size = self.tp_size self.dp_size = self.tp_size
...@@ -229,12 +238,6 @@ class ServerArgs: ...@@ -229,12 +238,6 @@ class ServerArgs:
"Data parallel size is adjusted to be the same as tensor parallel size. " "Data parallel size is adjusted to be the same as tensor parallel size. "
"Overlap scheduler is disabled." "Overlap scheduler is disabled."
) )
# Expert parallelism
if self.enable_ep_moe:
self.ep_size = self.tp_size
logger.info(
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
)
# GGUF # GGUF
if ( if (
...@@ -430,13 +433,18 @@ class ServerArgs: ...@@ -430,13 +433,18 @@ class ServerArgs:
default=ServerArgs.schedule_conservativeness, default=ServerArgs.schedule_conservativeness,
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.", help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
) )
parser.add_argument( parser.add_argument(
"--cpu-offload-gb", "--cpu-offload-gb",
type=int, type=int,
default=ServerArgs.cpu_offload_gb, default=ServerArgs.cpu_offload_gb,
help="How many GBs of RAM to reserve for CPU offloading", help="How many GBs of RAM to reserve for CPU offloading",
) )
parser.add_argument(
"--prefill-only-one-req",
type=bool,
help="If true, we only prefill one request at one prefill batch",
default=ServerArgs.prefill_only_one_req,
)
# Other runtime options # Other runtime options
parser.add_argument( parser.add_argument(
...@@ -555,6 +563,7 @@ class ServerArgs: ...@@ -555,6 +563,7 @@ class ServerArgs:
"shortest_queue", "shortest_queue",
], ],
) )
# Expert parallelism # Expert parallelism
parser.add_argument( parser.add_argument(
"--expert-parallel-size", "--expert-parallel-size",
...@@ -777,28 +786,6 @@ class ServerArgs: ...@@ -777,28 +786,6 @@ class ServerArgs:
help="Delete the model checkpoint after loading the model.", help="Delete the model checkpoint after loading the model.",
) )
# Deprecated arguments
parser.add_argument(
"--enable-overlap-schedule",
action=DeprecatedAction,
help="'--enable-overlap-schedule' is deprecated. It is enabled by default now. Please drop this argument.",
)
parser.add_argument(
"--disable-flashinfer",
action=DeprecatedAction,
help="'--disable-flashinfer' is deprecated. Please use '--attention-backend triton' instead.",
)
parser.add_argument(
"--disable-flashinfer-sampling",
action=DeprecatedAction,
help="'--disable-flashinfer-sampling' is deprecated. Please use '--sampling-backend pytroch' instead.",
)
parser.add_argument(
"--disable-disk-cache",
action=DeprecatedAction,
help="'--disable-disk-cache' is deprecated. Please use '--disable-outlines-disk-cache' instead.",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size args.tp_size = args.tensor_parallel_size
......
from enum import IntEnum, auto
class SpeculativeAlgorithm(IntEnum):
EAGLE = auto()
def is_eagle(self):
return self == SpeculativeAlgorithm.EAGLE
@staticmethod
def from_string(name: str):
name_map = {
"EAGLE": SpeculativeAlgorithm.EAGLE,
}
return name_map[name]
class SpecInfo:
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment