Commit ac238727 authored by Lianmin Zheng's avatar Lianmin Zheng
Browse files

Support penalty in overlap mode; return logprob with chunked prefill; improve...


Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)
Co-authored-by: default avatarSangBin Cho <rkooo567@gmail.com>
Co-authored-by: default avatardhou-xai <dhou@x.ai>
Co-authored-by: default avatarHanming Lu <hanming_lu@berkeley.edu>
parent 0194948f
...@@ -35,12 +35,12 @@ class SessionReqNode: ...@@ -35,12 +35,12 @@ class SessionReqNode:
for req_node in self.childs: for req_node in self.childs:
req_node.clear(req_dict) req_node.clear(req_dict)
if self.req.finished_reason == None: if self.req.finished_reason is None:
self.req.to_abort = True self.req.to_abort = True
del req_dict[self.req.rid] del req_dict[self.req.rid]
def abort(self): def abort(self):
if self.req.finished_reason == None: if self.req.finished_reason is None:
self.req.to_abort = True self.req.to_abort = True
def __str__(self): def __str__(self):
...@@ -132,6 +132,10 @@ class Session: ...@@ -132,6 +132,10 @@ class Session:
lora_path=req.lora_path, lora_path=req.lora_path,
session_id=self.session_id, session_id=self.session_id,
custom_logit_processor=req.custom_logit_processor, custom_logit_processor=req.custom_logit_processor,
stream=req.stream,
return_logprob=req.return_logprob,
top_logprobs_num=req.top_logprobs_num,
token_ids_logprob=req.token_ids_logprob,
) )
if last_req is not None: if last_req is not None:
new_req.image_inputs = last_req.image_inputs new_req.image_inputs = last_req.image_inputs
......
...@@ -15,10 +15,13 @@ ...@@ -15,10 +15,13 @@
import logging import logging
import threading import threading
from typing import Optional from typing import Optional, Tuple
import torch
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqInput,
...@@ -159,7 +162,7 @@ class TpModelWorker: ...@@ -159,7 +162,7 @@ class TpModelWorker:
model_worker_batch: ModelWorkerBatch, model_worker_batch: ModelWorkerBatch,
launch_done: Optional[threading.Event] = None, launch_done: Optional[threading.Event] = None,
skip_sample: bool = False, skip_sample: bool = False,
): ) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]:
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch) logits_output = self.model_runner.forward(forward_batch)
if launch_done: if launch_done:
......
...@@ -175,7 +175,7 @@ class TpModelWorkerClient: ...@@ -175,7 +175,7 @@ class TpModelWorkerClient:
logits_output.next_token_logprobs.tolist() logits_output.next_token_logprobs.tolist()
) )
if logits_output.input_token_logprobs is not None: if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = ( logits_output.input_token_logprobs = tuple(
logits_output.input_token_logprobs.tolist() logits_output.input_token_logprobs.tolist()
) )
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
...@@ -188,8 +188,7 @@ class TpModelWorkerClient: ...@@ -188,8 +188,7 @@ class TpModelWorkerClient:
model_worker_batch.sampling_info = self.cur_sampling_info = dataclasses.replace( model_worker_batch.sampling_info = self.cur_sampling_info = dataclasses.replace(
sampling_info, sampling_info,
sampling_info_done=threading.Event(), sampling_info_done=threading.Event(),
scaling_penalties=sampling_info.scaling_penalties, penalizer_orchestrator=None,
linear_penalties=sampling_info.linear_penalties,
) )
# A cuda stream sync here to avoid the cuda illegal memory access error. # A cuda stream sync here to avoid the cuda illegal memory access error.
......
...@@ -2,7 +2,9 @@ from __future__ import annotations ...@@ -2,7 +2,9 @@ from __future__ import annotations
"""Cache for chunked prefill, used when RadixCache is disabled.""" """Cache for chunked prefill, used when RadixCache is disabled."""
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
import torch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
...@@ -12,7 +14,7 @@ if TYPE_CHECKING: ...@@ -12,7 +14,7 @@ if TYPE_CHECKING:
class ChunkCacheEntry: class ChunkCacheEntry:
def __init__(self, rid, value): def __init__(self, rid: str, value: torch.Tensor):
self.rid = rid self.rid = rid
self.value = value self.value = value
...@@ -24,6 +26,7 @@ class ChunkCache(BasePrefixCache): ...@@ -24,6 +26,7 @@ class ChunkCache(BasePrefixCache):
self.disable = True self.disable = True
self.req_to_token_pool = req_to_token_pool self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool self.token_to_kv_pool = token_to_kv_pool
self.entries: Dict[str, ChunkCacheEntry] = {}
self.reset() self.reset()
...@@ -53,11 +56,8 @@ class ChunkCache(BasePrefixCache): ...@@ -53,11 +56,8 @@ class ChunkCache(BasePrefixCache):
if req.rid in self.entries: if req.rid in self.entries:
del self.entries[req.rid] del self.entries[req.rid]
def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None): def cache_unfinished_req(self, req: Req):
if token_ids is None:
token_id_len = len(req.fill_ids) token_id_len = len(req.fill_ids)
else:
token_id_len = len(token_ids)
kv_indices = self.req_to_token_pool.req_to_token[ kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, :token_id_len req.req_pool_idx, :token_id_len
...@@ -86,5 +86,8 @@ class ChunkCache(BasePrefixCache): ...@@ -86,5 +86,8 @@ class ChunkCache(BasePrefixCache):
def evictable_size(self): def evictable_size(self):
return 0 return 0
def pretty_print(self):
return ""
def protected_size(self): def protected_size(self):
return 0 return 0
This diff is collapsed.
...@@ -109,11 +109,15 @@ def set_torch_compile_config(): ...@@ -109,11 +109,15 @@ def set_torch_compile_config():
def get_batch_sizes_to_capture(model_runner: ModelRunner): def get_batch_sizes_to_capture(model_runner: ModelRunner):
server_args = model_runner.server_args server_args = model_runner.server_args
capture_bs = server_args.cuda_graph_bs capture_bs = server_args.cuda_graph_bs
if capture_bs is None: if capture_bs is None:
if server_args.speculative_algorithm is None:
if server_args.disable_cuda_graph_padding: if server_args.disable_cuda_graph_padding:
capture_bs = list(range(1, 33)) + [64, 128] capture_bs = list(range(1, 33)) + [64, 96, 128, 160]
else: else:
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
else:
capture_bs = list(range(1, 33))
if is_hip_: if is_hip_:
capture_bs += [i * 8 for i in range(21, 33)] capture_bs += [i * 8 for i in range(21, 33)]
...@@ -130,6 +134,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): ...@@ -130,6 +134,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
) )
) )
) )
capture_bs = [ capture_bs = [
bs bs
for bs in capture_bs for bs in capture_bs
...@@ -388,9 +393,6 @@ class CudaGraphRunner: ...@@ -388,9 +393,6 @@ class CudaGraphRunner:
torch.cuda.synchronize() torch.cuda.synchronize()
self.model_runner.tp_group.barrier() self.model_runner.tp_group.barrier()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
global global_graph_memory_pool global global_graph_memory_pool
with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream): with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream):
out = run_once() out = run_once()
...@@ -401,12 +403,11 @@ class CudaGraphRunner: ...@@ -401,12 +403,11 @@ class CudaGraphRunner:
global_graph_memory_pool = graph.pool() global_graph_memory_pool = graph.pool()
return graph, out return graph, out
def replay(self, forward_batch: ForwardBatch): def recapture_if_needed(self, forward_batch: ForwardBatch):
assert forward_batch.out_cache_loc is not None # If the capture_hidden_mode changes, we need to recapture the graph
hidden_mode_from_spec_info = getattr( hidden_mode_from_spec_info = getattr(
forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
) )
# If the capture_hidden_mode changes, we need to recapture the graph
if ( if (
forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL
and self.capture_hidden_mode != CaptureHiddenMode.FULL and self.capture_hidden_mode != CaptureHiddenMode.FULL
...@@ -420,6 +421,9 @@ class CudaGraphRunner: ...@@ -420,6 +421,9 @@ class CudaGraphRunner:
self.capture_hidden_mode = hidden_mode_from_spec_info self.capture_hidden_mode = hidden_mode_from_spec_info
self.capture() self.capture()
def replay(self, forward_batch: ForwardBatch):
self.recapture_if_needed(forward_batch)
raw_bs = forward_batch.batch_size raw_bs = forward_batch.batch_size
raw_num_token = raw_bs * self.num_tokens_per_bs raw_num_token = raw_bs * self.num_tokens_per_bs
......
...@@ -31,7 +31,7 @@ from __future__ import annotations ...@@ -31,7 +31,7 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional, Union
import torch import torch
import triton import triton
...@@ -46,7 +46,8 @@ if TYPE_CHECKING: ...@@ -46,7 +46,8 @@ if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
class ForwardMode(IntEnum): class ForwardMode(IntEnum):
...@@ -112,7 +113,9 @@ class ForwardMode(IntEnum): ...@@ -112,7 +113,9 @@ class ForwardMode(IntEnum):
class CaptureHiddenMode(IntEnum): class CaptureHiddenMode(IntEnum):
NULL = auto() NULL = auto()
# Capture hidden states of all tokens.
FULL = auto() FULL = auto()
# Capture a hidden state of the last token.
LAST = auto() LAST = auto()
def need_capture(self): def need_capture(self):
...@@ -148,6 +151,7 @@ class ForwardBatch: ...@@ -148,6 +151,7 @@ class ForwardBatch:
# For logprob # For logprob
return_logprob: bool = False return_logprob: bool = False
top_logprobs_nums: Optional[List[int]] = None top_logprobs_nums: Optional[List[int]] = None
token_ids_logprobs: Optional[List[List[int]]] = None
# Position information # Position information
positions: torch.Tensor = None positions: torch.Tensor = None
...@@ -160,6 +164,7 @@ class ForwardBatch: ...@@ -160,6 +164,7 @@ class ForwardBatch:
extend_prefix_lens_cpu: Optional[List[int]] = None extend_prefix_lens_cpu: Optional[List[int]] = None
extend_seq_lens_cpu: Optional[List[int]] = None extend_seq_lens_cpu: Optional[List[int]] = None
extend_logprob_start_lens_cpu: Optional[List[int]] = None extend_logprob_start_lens_cpu: Optional[List[int]] = None
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
# For multimodal # For multimodal
image_inputs: Optional[List[ImageInputs]] = None image_inputs: Optional[List[ImageInputs]] = None
...@@ -190,10 +195,13 @@ class ForwardBatch: ...@@ -190,10 +195,13 @@ class ForwardBatch:
can_run_dp_cuda_graph: bool = False can_run_dp_cuda_graph: bool = False
# Speculative decoding # Speculative decoding
spec_info: SpecInfo = None spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
spec_algorithm: SpeculativeAlgorithm = None spec_algorithm: SpeculativeAlgorithm = None
capture_hidden_mode: CaptureHiddenMode = None capture_hidden_mode: CaptureHiddenMode = None
# For padding
padded_static_len: int = -1 # -1 if not padded
# For Qwen2-VL # For Qwen2-VL
mrope_positions: torch.Tensor = None mrope_positions: torch.Tensor = None
...@@ -203,8 +211,13 @@ class ForwardBatch: ...@@ -203,8 +211,13 @@ class ForwardBatch:
batch: ModelWorkerBatch, batch: ModelWorkerBatch,
model_runner: ModelRunner, model_runner: ModelRunner,
): ):
device = model_runner.device device = model_runner.device
extend_input_logprob_token_ids_gpu = None
if batch.extend_input_logprob_token_ids is not None:
extend_input_logprob_token_ids_gpu = (
batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
)
ret = cls( ret = cls(
forward_mode=batch.forward_mode, forward_mode=batch.forward_mode,
batch_size=len(batch.seq_lens), batch_size=len(batch.seq_lens),
...@@ -220,6 +233,7 @@ class ForwardBatch: ...@@ -220,6 +233,7 @@ class ForwardBatch:
seq_lens_sum=batch.seq_lens_sum, seq_lens_sum=batch.seq_lens_sum,
return_logprob=batch.return_logprob, return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums, top_logprobs_nums=batch.top_logprobs_nums,
token_ids_logprobs=batch.token_ids_logprobs,
global_num_tokens=batch.global_num_tokens, global_num_tokens=batch.global_num_tokens,
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph, can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
lora_paths=batch.lora_paths, lora_paths=batch.lora_paths,
...@@ -231,6 +245,7 @@ class ForwardBatch: ...@@ -231,6 +245,7 @@ class ForwardBatch:
spec_info=batch.spec_info, spec_info=batch.spec_info,
capture_hidden_mode=batch.capture_hidden_mode, capture_hidden_mode=batch.capture_hidden_mode,
input_embeds=batch.input_embeds, input_embeds=batch.input_embeds,
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
) )
if ret.global_num_tokens is not None: if ret.global_num_tokens is not None:
...@@ -341,6 +356,7 @@ class ForwardBatch: ...@@ -341,6 +356,7 @@ class ForwardBatch:
) )
batch.image_inputs[i].mrope_position_delta = mrope_position_delta batch.image_inputs[i].mrope_position_delta = mrope_position_delta
mrope_positions_list[i] = mrope_positions mrope_positions_list[i] = mrope_positions
self.mrope_positions = torch.concat( self.mrope_positions = torch.concat(
[torch.tensor(pos, device=device) for pos in mrope_positions_list], [torch.tensor(pos, device=device) for pos in mrope_positions_list],
axis=1, axis=1,
...@@ -379,7 +395,7 @@ def compute_position_kernel( ...@@ -379,7 +395,7 @@ def compute_position_kernel(
extend_seq_lens, extend_seq_lens,
): ):
BLOCK_SIZE: tl.constexpr = 512 BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(0) pid = tl.program_id(0).to(tl.int64)
prefix_len = tl.load(extend_prefix_lens + pid) prefix_len = tl.load(extend_prefix_lens + pid)
seq_len = tl.load(extend_seq_lens + pid) seq_len = tl.load(extend_seq_lens + pid)
......
...@@ -25,10 +25,10 @@ import filelock ...@@ -25,10 +25,10 @@ import filelock
import gguf import gguf
import huggingface_hub.constants import huggingface_hub.constants
import numpy as np import numpy as np
import safetensors.torch
import torch import torch
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm from tqdm.auto import tqdm
from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.load_config import LoadConfig
...@@ -62,7 +62,6 @@ enable_hf_transfer() ...@@ -62,7 +62,6 @@ enable_hf_transfer()
class DisabledTqdm(tqdm): class DisabledTqdm(tqdm):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, disable=True) super().__init__(*args, **kwargs, disable=True)
...@@ -121,7 +120,7 @@ def convert_bin_to_safetensor_file( ...@@ -121,7 +120,7 @@ def convert_bin_to_safetensor_file(
) )
# check if the tensors are the same # check if the tensors are the same
reloaded = load_file(sf_filename) reloaded = safetensors.torch.load_file(sf_filename)
for k in loaded: for k in loaded:
pt_tensor = loaded[k] pt_tensor = loaded[k]
sf_tensor = reloaded[k] sf_tensor = reloaded[k]
...@@ -133,7 +132,6 @@ def convert_bin_to_safetensor_file( ...@@ -133,7 +132,6 @@ def convert_bin_to_safetensor_file(
def get_quant_config( def get_quant_config(
model_config: ModelConfig, load_config: LoadConfig model_config: ModelConfig, load_config: LoadConfig
) -> QuantizationConfig: ) -> QuantizationConfig:
quant_cls = get_quantization_config(model_config.quantization) quant_cls = get_quantization_config(model_config.quantization)
# GGUF doesn't have config file # GGUF doesn't have config file
...@@ -402,15 +400,34 @@ def np_cache_weights_iterator( ...@@ -402,15 +400,34 @@ def np_cache_weights_iterator(
yield name, torch.from_numpy(param) yield name, torch.from_numpy(param)
def decrypt(fn, key):
raise NotImplementedError()
def safetensors_encrypted_weights_iterator(
hf_weights_files: List[str],
is_all_weights_sharded: bool = False,
decryption_key: Optional[str] = None,
):
raise NotImplementedError()
def safetensors_weights_iterator( def safetensors_weights_iterator(
hf_weights_files: List[str], hf_weights_files: List[str],
is_all_weights_sharded: bool = False, is_all_weights_sharded: bool = False,
decryption_key: Optional[str] = None,
) -> Generator[Tuple[str, torch.Tensor], None, None]: ) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files. """Iterate over the weights in the model safetensor files.
If is_all_weights_sharded is True, it uses more optimize read by reading an If is_all_weights_sharded is True, it uses more optimize read by reading an
entire file instead of reading each tensor one by one. entire file instead of reading each tensor one by one.
""" """
if decryption_key:
yield from safetensors_encrypted_weights_iterator(
hf_weights_files, is_all_weights_sharded, decryption_key
)
return
enable_tqdm = ( enable_tqdm = (
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
) )
...@@ -420,13 +437,7 @@ def safetensors_weights_iterator( ...@@ -420,13 +437,7 @@ def safetensors_weights_iterator(
disable=not enable_tqdm, disable=not enable_tqdm,
bar_format=_BAR_FORMAT, bar_format=_BAR_FORMAT,
): ):
if not is_all_weights_sharded: result = safetensors.torch.load_file(st_file, device="cpu")
with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
param = f.get_tensor(name)
yield name, param
else:
result = load_file(st_file, device="cpu")
for name, param in result.items(): for name, param in result.items():
yield name, param yield name, param
......
from .orchestrator import BatchedPenalizerOrchestrator from sglang.srt.sampling.penaltylib.frequency_penalty import BatchedFrequencyPenalizer
from .penalizers.frequency_penalty import BatchedFrequencyPenalizer from sglang.srt.sampling.penaltylib.min_new_tokens import BatchedMinNewTokensPenalizer
from .penalizers.min_new_tokens import BatchedMinNewTokensPenalizer from sglang.srt.sampling.penaltylib.orchestrator import BatchedPenalizerOrchestrator
from .penalizers.presence_penalty import BatchedPresencePenalizer from sglang.srt.sampling.penaltylib.presence_penalty import BatchedPresencePenalizer
from .penalizers.repetition_penalty import BatchedRepetitionPenalizer
__all__ = [ __all__ = [
"BatchedFrequencyPenalizer", "BatchedFrequencyPenalizer",
"BatchedMinNewTokensPenalizer", "BatchedMinNewTokensPenalizer",
"BatchedPresencePenalizer", "BatchedPresencePenalizer",
"BatchedRepetitionPenalizer",
"BatchedPenalizerOrchestrator", "BatchedPenalizerOrchestrator",
] ]
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment