Commit ac238727 authored by Lianmin Zheng's avatar Lianmin Zheng
Browse files

Support penalty in overlap mode; return logprob with chunked prefill; improve...


Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)
Co-authored-by: default avatarSangBin Cho <rkooo567@gmail.com>
Co-authored-by: default avatardhou-xai <dhou@x.ai>
Co-authored-by: default avatarHanming Lu <hanming_lu@berkeley.edu>
parent 0194948f
...@@ -35,12 +35,12 @@ class SessionReqNode: ...@@ -35,12 +35,12 @@ class SessionReqNode:
for req_node in self.childs: for req_node in self.childs:
req_node.clear(req_dict) req_node.clear(req_dict)
if self.req.finished_reason == None: if self.req.finished_reason is None:
self.req.to_abort = True self.req.to_abort = True
del req_dict[self.req.rid] del req_dict[self.req.rid]
def abort(self): def abort(self):
if self.req.finished_reason == None: if self.req.finished_reason is None:
self.req.to_abort = True self.req.to_abort = True
def __str__(self): def __str__(self):
...@@ -132,6 +132,10 @@ class Session: ...@@ -132,6 +132,10 @@ class Session:
lora_path=req.lora_path, lora_path=req.lora_path,
session_id=self.session_id, session_id=self.session_id,
custom_logit_processor=req.custom_logit_processor, custom_logit_processor=req.custom_logit_processor,
stream=req.stream,
return_logprob=req.return_logprob,
top_logprobs_num=req.top_logprobs_num,
token_ids_logprob=req.token_ids_logprob,
) )
if last_req is not None: if last_req is not None:
new_req.image_inputs = last_req.image_inputs new_req.image_inputs = last_req.image_inputs
......
...@@ -15,10 +15,13 @@ ...@@ -15,10 +15,13 @@
import logging import logging
import threading import threading
from typing import Optional from typing import Optional, Tuple
import torch
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqInput,
...@@ -159,7 +162,7 @@ class TpModelWorker: ...@@ -159,7 +162,7 @@ class TpModelWorker:
model_worker_batch: ModelWorkerBatch, model_worker_batch: ModelWorkerBatch,
launch_done: Optional[threading.Event] = None, launch_done: Optional[threading.Event] = None,
skip_sample: bool = False, skip_sample: bool = False,
): ) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]:
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch) logits_output = self.model_runner.forward(forward_batch)
if launch_done: if launch_done:
......
...@@ -175,7 +175,7 @@ class TpModelWorkerClient: ...@@ -175,7 +175,7 @@ class TpModelWorkerClient:
logits_output.next_token_logprobs.tolist() logits_output.next_token_logprobs.tolist()
) )
if logits_output.input_token_logprobs is not None: if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = ( logits_output.input_token_logprobs = tuple(
logits_output.input_token_logprobs.tolist() logits_output.input_token_logprobs.tolist()
) )
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
...@@ -188,8 +188,7 @@ class TpModelWorkerClient: ...@@ -188,8 +188,7 @@ class TpModelWorkerClient:
model_worker_batch.sampling_info = self.cur_sampling_info = dataclasses.replace( model_worker_batch.sampling_info = self.cur_sampling_info = dataclasses.replace(
sampling_info, sampling_info,
sampling_info_done=threading.Event(), sampling_info_done=threading.Event(),
scaling_penalties=sampling_info.scaling_penalties, penalizer_orchestrator=None,
linear_penalties=sampling_info.linear_penalties,
) )
# A cuda stream sync here to avoid the cuda illegal memory access error. # A cuda stream sync here to avoid the cuda illegal memory access error.
......
...@@ -2,7 +2,9 @@ from __future__ import annotations ...@@ -2,7 +2,9 @@ from __future__ import annotations
"""Cache for chunked prefill, used when RadixCache is disabled.""" """Cache for chunked prefill, used when RadixCache is disabled."""
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
import torch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
...@@ -12,7 +14,7 @@ if TYPE_CHECKING: ...@@ -12,7 +14,7 @@ if TYPE_CHECKING:
class ChunkCacheEntry: class ChunkCacheEntry:
def __init__(self, rid, value): def __init__(self, rid: str, value: torch.Tensor):
self.rid = rid self.rid = rid
self.value = value self.value = value
...@@ -24,6 +26,7 @@ class ChunkCache(BasePrefixCache): ...@@ -24,6 +26,7 @@ class ChunkCache(BasePrefixCache):
self.disable = True self.disable = True
self.req_to_token_pool = req_to_token_pool self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool self.token_to_kv_pool = token_to_kv_pool
self.entries: Dict[str, ChunkCacheEntry] = {}
self.reset() self.reset()
...@@ -53,11 +56,8 @@ class ChunkCache(BasePrefixCache): ...@@ -53,11 +56,8 @@ class ChunkCache(BasePrefixCache):
if req.rid in self.entries: if req.rid in self.entries:
del self.entries[req.rid] del self.entries[req.rid]
def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None): def cache_unfinished_req(self, req: Req):
if token_ids is None: token_id_len = len(req.fill_ids)
token_id_len = len(req.fill_ids)
else:
token_id_len = len(token_ids)
kv_indices = self.req_to_token_pool.req_to_token[ kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, :token_id_len req.req_pool_idx, :token_id_len
...@@ -86,5 +86,8 @@ class ChunkCache(BasePrefixCache): ...@@ -86,5 +86,8 @@ class ChunkCache(BasePrefixCache):
def evictable_size(self): def evictable_size(self):
return 0 return 0
def pretty_print(self):
return ""
def protected_size(self): def protected_size(self):
return 0 return 0
This diff is collapsed.
...@@ -109,11 +109,15 @@ def set_torch_compile_config(): ...@@ -109,11 +109,15 @@ def set_torch_compile_config():
def get_batch_sizes_to_capture(model_runner: ModelRunner): def get_batch_sizes_to_capture(model_runner: ModelRunner):
server_args = model_runner.server_args server_args = model_runner.server_args
capture_bs = server_args.cuda_graph_bs capture_bs = server_args.cuda_graph_bs
if capture_bs is None: if capture_bs is None:
if server_args.disable_cuda_graph_padding: if server_args.speculative_algorithm is None:
capture_bs = list(range(1, 33)) + [64, 128] if server_args.disable_cuda_graph_padding:
capture_bs = list(range(1, 33)) + [64, 96, 128, 160]
else:
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
else: else:
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] capture_bs = list(range(1, 33))
if is_hip_: if is_hip_:
capture_bs += [i * 8 for i in range(21, 33)] capture_bs += [i * 8 for i in range(21, 33)]
...@@ -130,6 +134,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): ...@@ -130,6 +134,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
) )
) )
) )
capture_bs = [ capture_bs = [
bs bs
for bs in capture_bs for bs in capture_bs
...@@ -385,9 +390,6 @@ class CudaGraphRunner: ...@@ -385,9 +390,6 @@ class CudaGraphRunner:
run_once() run_once()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
torch.cuda.synchronize() torch.cuda.synchronize()
self.model_runner.tp_group.barrier() self.model_runner.tp_group.barrier()
...@@ -401,12 +403,11 @@ class CudaGraphRunner: ...@@ -401,12 +403,11 @@ class CudaGraphRunner:
global_graph_memory_pool = graph.pool() global_graph_memory_pool = graph.pool()
return graph, out return graph, out
def replay(self, forward_batch: ForwardBatch): def recapture_if_needed(self, forward_batch: ForwardBatch):
assert forward_batch.out_cache_loc is not None # If the capture_hidden_mode changes, we need to recapture the graph
hidden_mode_from_spec_info = getattr( hidden_mode_from_spec_info = getattr(
forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
) )
# If the capture_hidden_mode changes, we need to recapture the graph
if ( if (
forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL
and self.capture_hidden_mode != CaptureHiddenMode.FULL and self.capture_hidden_mode != CaptureHiddenMode.FULL
...@@ -420,6 +421,9 @@ class CudaGraphRunner: ...@@ -420,6 +421,9 @@ class CudaGraphRunner:
self.capture_hidden_mode = hidden_mode_from_spec_info self.capture_hidden_mode = hidden_mode_from_spec_info
self.capture() self.capture()
def replay(self, forward_batch: ForwardBatch):
self.recapture_if_needed(forward_batch)
raw_bs = forward_batch.batch_size raw_bs = forward_batch.batch_size
raw_num_token = raw_bs * self.num_tokens_per_bs raw_num_token = raw_bs * self.num_tokens_per_bs
......
...@@ -31,7 +31,7 @@ from __future__ import annotations ...@@ -31,7 +31,7 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional, Union
import torch import torch
import triton import triton
...@@ -46,7 +46,8 @@ if TYPE_CHECKING: ...@@ -46,7 +46,8 @@ if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
class ForwardMode(IntEnum): class ForwardMode(IntEnum):
...@@ -112,7 +113,9 @@ class ForwardMode(IntEnum): ...@@ -112,7 +113,9 @@ class ForwardMode(IntEnum):
class CaptureHiddenMode(IntEnum): class CaptureHiddenMode(IntEnum):
NULL = auto() NULL = auto()
# Capture hidden states of all tokens.
FULL = auto() FULL = auto()
# Capture a hidden state of the last token.
LAST = auto() LAST = auto()
def need_capture(self): def need_capture(self):
...@@ -148,6 +151,7 @@ class ForwardBatch: ...@@ -148,6 +151,7 @@ class ForwardBatch:
# For logprob # For logprob
return_logprob: bool = False return_logprob: bool = False
top_logprobs_nums: Optional[List[int]] = None top_logprobs_nums: Optional[List[int]] = None
token_ids_logprobs: Optional[List[List[int]]] = None
# Position information # Position information
positions: torch.Tensor = None positions: torch.Tensor = None
...@@ -160,6 +164,7 @@ class ForwardBatch: ...@@ -160,6 +164,7 @@ class ForwardBatch:
extend_prefix_lens_cpu: Optional[List[int]] = None extend_prefix_lens_cpu: Optional[List[int]] = None
extend_seq_lens_cpu: Optional[List[int]] = None extend_seq_lens_cpu: Optional[List[int]] = None
extend_logprob_start_lens_cpu: Optional[List[int]] = None extend_logprob_start_lens_cpu: Optional[List[int]] = None
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
# For multimodal # For multimodal
image_inputs: Optional[List[ImageInputs]] = None image_inputs: Optional[List[ImageInputs]] = None
...@@ -190,10 +195,13 @@ class ForwardBatch: ...@@ -190,10 +195,13 @@ class ForwardBatch:
can_run_dp_cuda_graph: bool = False can_run_dp_cuda_graph: bool = False
# Speculative decoding # Speculative decoding
spec_info: SpecInfo = None spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
spec_algorithm: SpeculativeAlgorithm = None spec_algorithm: SpeculativeAlgorithm = None
capture_hidden_mode: CaptureHiddenMode = None capture_hidden_mode: CaptureHiddenMode = None
# For padding
padded_static_len: int = -1 # -1 if not padded
# For Qwen2-VL # For Qwen2-VL
mrope_positions: torch.Tensor = None mrope_positions: torch.Tensor = None
...@@ -203,8 +211,13 @@ class ForwardBatch: ...@@ -203,8 +211,13 @@ class ForwardBatch:
batch: ModelWorkerBatch, batch: ModelWorkerBatch,
model_runner: ModelRunner, model_runner: ModelRunner,
): ):
device = model_runner.device device = model_runner.device
extend_input_logprob_token_ids_gpu = None
if batch.extend_input_logprob_token_ids is not None:
extend_input_logprob_token_ids_gpu = (
batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
)
ret = cls( ret = cls(
forward_mode=batch.forward_mode, forward_mode=batch.forward_mode,
batch_size=len(batch.seq_lens), batch_size=len(batch.seq_lens),
...@@ -220,6 +233,7 @@ class ForwardBatch: ...@@ -220,6 +233,7 @@ class ForwardBatch:
seq_lens_sum=batch.seq_lens_sum, seq_lens_sum=batch.seq_lens_sum,
return_logprob=batch.return_logprob, return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums, top_logprobs_nums=batch.top_logprobs_nums,
token_ids_logprobs=batch.token_ids_logprobs,
global_num_tokens=batch.global_num_tokens, global_num_tokens=batch.global_num_tokens,
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph, can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
lora_paths=batch.lora_paths, lora_paths=batch.lora_paths,
...@@ -231,6 +245,7 @@ class ForwardBatch: ...@@ -231,6 +245,7 @@ class ForwardBatch:
spec_info=batch.spec_info, spec_info=batch.spec_info,
capture_hidden_mode=batch.capture_hidden_mode, capture_hidden_mode=batch.capture_hidden_mode,
input_embeds=batch.input_embeds, input_embeds=batch.input_embeds,
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
) )
if ret.global_num_tokens is not None: if ret.global_num_tokens is not None:
...@@ -341,6 +356,7 @@ class ForwardBatch: ...@@ -341,6 +356,7 @@ class ForwardBatch:
) )
batch.image_inputs[i].mrope_position_delta = mrope_position_delta batch.image_inputs[i].mrope_position_delta = mrope_position_delta
mrope_positions_list[i] = mrope_positions mrope_positions_list[i] = mrope_positions
self.mrope_positions = torch.concat( self.mrope_positions = torch.concat(
[torch.tensor(pos, device=device) for pos in mrope_positions_list], [torch.tensor(pos, device=device) for pos in mrope_positions_list],
axis=1, axis=1,
...@@ -379,7 +395,7 @@ def compute_position_kernel( ...@@ -379,7 +395,7 @@ def compute_position_kernel(
extend_seq_lens, extend_seq_lens,
): ):
BLOCK_SIZE: tl.constexpr = 512 BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(0) pid = tl.program_id(0).to(tl.int64)
prefix_len = tl.load(extend_prefix_lens + pid) prefix_len = tl.load(extend_prefix_lens + pid)
seq_len = tl.load(extend_seq_lens + pid) seq_len = tl.load(extend_seq_lens + pid)
......
from .orchestrator import BatchedPenalizerOrchestrator from sglang.srt.sampling.penaltylib.frequency_penalty import BatchedFrequencyPenalizer
from .penalizers.frequency_penalty import BatchedFrequencyPenalizer from sglang.srt.sampling.penaltylib.min_new_tokens import BatchedMinNewTokensPenalizer
from .penalizers.min_new_tokens import BatchedMinNewTokensPenalizer from sglang.srt.sampling.penaltylib.orchestrator import BatchedPenalizerOrchestrator
from .penalizers.presence_penalty import BatchedPresencePenalizer from sglang.srt.sampling.penaltylib.presence_penalty import BatchedPresencePenalizer
from .penalizers.repetition_penalty import BatchedRepetitionPenalizer
__all__ = [ __all__ = [
"BatchedFrequencyPenalizer", "BatchedFrequencyPenalizer",
"BatchedMinNewTokensPenalizer", "BatchedMinNewTokensPenalizer",
"BatchedPresencePenalizer", "BatchedPresencePenalizer",
"BatchedRepetitionPenalizer",
"BatchedPenalizerOrchestrator", "BatchedPenalizerOrchestrator",
] ]
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment