Commit ac238727 authored by Lianmin Zheng's avatar Lianmin Zheng
Browse files

Support penalty in overlap mode; return logprob with chunked prefill; improve...


Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)
Co-authored-by: default avatarSangBin Cho <rkooo567@gmail.com>
Co-authored-by: default avatardhou-xai <dhou@x.ai>
Co-authored-by: default avatarHanming Lu <hanming_lu@berkeley.edu>
parent 0194948f
......@@ -35,12 +35,12 @@ class SessionReqNode:
for req_node in self.childs:
req_node.clear(req_dict)
if self.req.finished_reason == None:
if self.req.finished_reason is None:
self.req.to_abort = True
del req_dict[self.req.rid]
def abort(self):
if self.req.finished_reason == None:
if self.req.finished_reason is None:
self.req.to_abort = True
def __str__(self):
......@@ -132,6 +132,10 @@ class Session:
lora_path=req.lora_path,
session_id=self.session_id,
custom_logit_processor=req.custom_logit_processor,
stream=req.stream,
return_logprob=req.return_logprob,
top_logprobs_num=req.top_logprobs_num,
token_ids_logprob=req.token_ids_logprob,
)
if last_req is not None:
new_req.image_inputs = last_req.image_inputs
......
......@@ -15,10 +15,13 @@
import logging
import threading
from typing import Optional
from typing import Optional, Tuple
import torch
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput,
InitWeightsUpdateGroupReqInput,
......@@ -159,7 +162,7 @@ class TpModelWorker:
model_worker_batch: ModelWorkerBatch,
launch_done: Optional[threading.Event] = None,
skip_sample: bool = False,
):
) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]:
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch)
if launch_done:
......
......@@ -175,7 +175,7 @@ class TpModelWorkerClient:
logits_output.next_token_logprobs.tolist()
)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = (
logits_output.input_token_logprobs = tuple(
logits_output.input_token_logprobs.tolist()
)
next_token_ids = next_token_ids.tolist()
......@@ -188,8 +188,7 @@ class TpModelWorkerClient:
model_worker_batch.sampling_info = self.cur_sampling_info = dataclasses.replace(
sampling_info,
sampling_info_done=threading.Event(),
scaling_penalties=sampling_info.scaling_penalties,
linear_penalties=sampling_info.linear_penalties,
penalizer_orchestrator=None,
)
# A cuda stream sync here to avoid the cuda illegal memory access error.
......
......@@ -2,7 +2,9 @@ from __future__ import annotations
"""Cache for chunked prefill, used when RadixCache is disabled."""
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
import torch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
......@@ -12,7 +14,7 @@ if TYPE_CHECKING:
class ChunkCacheEntry:
def __init__(self, rid, value):
def __init__(self, rid: str, value: torch.Tensor):
self.rid = rid
self.value = value
......@@ -24,6 +26,7 @@ class ChunkCache(BasePrefixCache):
self.disable = True
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool
self.entries: Dict[str, ChunkCacheEntry] = {}
self.reset()
......@@ -53,11 +56,8 @@ class ChunkCache(BasePrefixCache):
if req.rid in self.entries:
del self.entries[req.rid]
def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None):
if token_ids is None:
def cache_unfinished_req(self, req: Req):
token_id_len = len(req.fill_ids)
else:
token_id_len = len(token_ids)
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, :token_id_len
......@@ -86,5 +86,8 @@ class ChunkCache(BasePrefixCache):
def evictable_size(self):
return 0
def pretty_print(self):
return ""
def protected_size(self):
return 0
This diff is collapsed.
......@@ -109,11 +109,15 @@ def set_torch_compile_config():
def get_batch_sizes_to_capture(model_runner: ModelRunner):
server_args = model_runner.server_args
capture_bs = server_args.cuda_graph_bs
if capture_bs is None:
if server_args.speculative_algorithm is None:
if server_args.disable_cuda_graph_padding:
capture_bs = list(range(1, 33)) + [64, 128]
capture_bs = list(range(1, 33)) + [64, 96, 128, 160]
else:
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
else:
capture_bs = list(range(1, 33))
if is_hip_:
capture_bs += [i * 8 for i in range(21, 33)]
......@@ -130,6 +134,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
)
)
)
capture_bs = [
bs
for bs in capture_bs
......@@ -388,9 +393,6 @@ class CudaGraphRunner:
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
global global_graph_memory_pool
with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream):
out = run_once()
......@@ -401,12 +403,11 @@ class CudaGraphRunner:
global_graph_memory_pool = graph.pool()
return graph, out
def replay(self, forward_batch: ForwardBatch):
assert forward_batch.out_cache_loc is not None
def recapture_if_needed(self, forward_batch: ForwardBatch):
# If the capture_hidden_mode changes, we need to recapture the graph
hidden_mode_from_spec_info = getattr(
forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
)
# If the capture_hidden_mode changes, we need to recapture the graph
if (
forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL
and self.capture_hidden_mode != CaptureHiddenMode.FULL
......@@ -420,6 +421,9 @@ class CudaGraphRunner:
self.capture_hidden_mode = hidden_mode_from_spec_info
self.capture()
def replay(self, forward_batch: ForwardBatch):
self.recapture_if_needed(forward_batch)
raw_bs = forward_batch.batch_size
raw_num_token = raw_bs * self.num_tokens_per_bs
......
......@@ -31,7 +31,7 @@ from __future__ import annotations
from dataclasses import dataclass
from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, List, Optional, Union
import torch
import triton
......@@ -46,7 +46,8 @@ if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
class ForwardMode(IntEnum):
......@@ -112,7 +113,9 @@ class ForwardMode(IntEnum):
class CaptureHiddenMode(IntEnum):
NULL = auto()
# Capture hidden states of all tokens.
FULL = auto()
# Capture a hidden state of the last token.
LAST = auto()
def need_capture(self):
......@@ -148,6 +151,7 @@ class ForwardBatch:
# For logprob
return_logprob: bool = False
top_logprobs_nums: Optional[List[int]] = None
token_ids_logprobs: Optional[List[List[int]]] = None
# Position information
positions: torch.Tensor = None
......@@ -160,6 +164,7 @@ class ForwardBatch:
extend_prefix_lens_cpu: Optional[List[int]] = None
extend_seq_lens_cpu: Optional[List[int]] = None
extend_logprob_start_lens_cpu: Optional[List[int]] = None
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
# For multimodal
image_inputs: Optional[List[ImageInputs]] = None
......@@ -190,10 +195,13 @@ class ForwardBatch:
can_run_dp_cuda_graph: bool = False
# Speculative decoding
spec_info: SpecInfo = None
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
spec_algorithm: SpeculativeAlgorithm = None
capture_hidden_mode: CaptureHiddenMode = None
# For padding
padded_static_len: int = -1 # -1 if not padded
# For Qwen2-VL
mrope_positions: torch.Tensor = None
......@@ -203,8 +211,13 @@ class ForwardBatch:
batch: ModelWorkerBatch,
model_runner: ModelRunner,
):
device = model_runner.device
extend_input_logprob_token_ids_gpu = None
if batch.extend_input_logprob_token_ids is not None:
extend_input_logprob_token_ids_gpu = (
batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
)
ret = cls(
forward_mode=batch.forward_mode,
batch_size=len(batch.seq_lens),
......@@ -220,6 +233,7 @@ class ForwardBatch:
seq_lens_sum=batch.seq_lens_sum,
return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums,
token_ids_logprobs=batch.token_ids_logprobs,
global_num_tokens=batch.global_num_tokens,
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
lora_paths=batch.lora_paths,
......@@ -231,6 +245,7 @@ class ForwardBatch:
spec_info=batch.spec_info,
capture_hidden_mode=batch.capture_hidden_mode,
input_embeds=batch.input_embeds,
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
)
if ret.global_num_tokens is not None:
......@@ -341,6 +356,7 @@ class ForwardBatch:
)
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
mrope_positions_list[i] = mrope_positions
self.mrope_positions = torch.concat(
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
axis=1,
......@@ -379,7 +395,7 @@ def compute_position_kernel(
extend_seq_lens,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(0)
pid = tl.program_id(0).to(tl.int64)
prefix_len = tl.load(extend_prefix_lens + pid)
seq_len = tl.load(extend_seq_lens + pid)
......
......@@ -25,10 +25,10 @@ import filelock
import gguf
import huggingface_hub.constants
import numpy as np
import safetensors.torch
import torch
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm
from sglang.srt.configs.load_config import LoadConfig
......@@ -62,7 +62,6 @@ enable_hf_transfer()
class DisabledTqdm(tqdm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, disable=True)
......@@ -121,7 +120,7 @@ def convert_bin_to_safetensor_file(
)
# check if the tensors are the same
reloaded = load_file(sf_filename)
reloaded = safetensors.torch.load_file(sf_filename)
for k in loaded:
pt_tensor = loaded[k]
sf_tensor = reloaded[k]
......@@ -133,7 +132,6 @@ def convert_bin_to_safetensor_file(
def get_quant_config(
model_config: ModelConfig, load_config: LoadConfig
) -> QuantizationConfig:
quant_cls = get_quantization_config(model_config.quantization)
# GGUF doesn't have config file
......@@ -402,15 +400,34 @@ def np_cache_weights_iterator(
yield name, torch.from_numpy(param)
def decrypt(fn, key):
raise NotImplementedError()
def safetensors_encrypted_weights_iterator(
hf_weights_files: List[str],
is_all_weights_sharded: bool = False,
decryption_key: Optional[str] = None,
):
raise NotImplementedError()
def safetensors_weights_iterator(
hf_weights_files: List[str],
is_all_weights_sharded: bool = False,
decryption_key: Optional[str] = None,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files.
If is_all_weights_sharded is True, it uses more optimize read by reading an
entire file instead of reading each tensor one by one.
"""
if decryption_key:
yield from safetensors_encrypted_weights_iterator(
hf_weights_files, is_all_weights_sharded, decryption_key
)
return
enable_tqdm = (
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
)
......@@ -420,13 +437,7 @@ def safetensors_weights_iterator(
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
):
if not is_all_weights_sharded:
with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
param = f.get_tensor(name)
yield name, param
else:
result = load_file(st_file, device="cpu")
result = safetensors.torch.load_file(st_file, device="cpu")
for name, param in result.items():
yield name, param
......
from .orchestrator import BatchedPenalizerOrchestrator
from .penalizers.frequency_penalty import BatchedFrequencyPenalizer
from .penalizers.min_new_tokens import BatchedMinNewTokensPenalizer
from .penalizers.presence_penalty import BatchedPresencePenalizer
from .penalizers.repetition_penalty import BatchedRepetitionPenalizer
from sglang.srt.sampling.penaltylib.frequency_penalty import BatchedFrequencyPenalizer
from sglang.srt.sampling.penaltylib.min_new_tokens import BatchedMinNewTokensPenalizer
from sglang.srt.sampling.penaltylib.orchestrator import BatchedPenalizerOrchestrator
from sglang.srt.sampling.penaltylib.presence_penalty import BatchedPresencePenalizer
__all__ = [
"BatchedFrequencyPenalizer",
"BatchedMinNewTokensPenalizer",
"BatchedPresencePenalizer",
"BatchedRepetitionPenalizer",
"BatchedPenalizerOrchestrator",
]
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment