Commit ac238727 authored by Lianmin Zheng's avatar Lianmin Zheng
Browse files

Support penalty in overlap mode; return logprob with chunked prefill; improve...


Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)
Co-authored-by: default avatarSangBin Cho <rkooo567@gmail.com>
Co-authored-by: default avatardhou-xai <dhou@x.ai>
Co-authored-by: default avatarHanming Lu <hanming_lu@berkeley.edu>
parent 0194948f
...@@ -35,12 +35,12 @@ class SessionReqNode: ...@@ -35,12 +35,12 @@ class SessionReqNode:
for req_node in self.childs: for req_node in self.childs:
req_node.clear(req_dict) req_node.clear(req_dict)
if self.req.finished_reason == None: if self.req.finished_reason is None:
self.req.to_abort = True self.req.to_abort = True
del req_dict[self.req.rid] del req_dict[self.req.rid]
def abort(self): def abort(self):
if self.req.finished_reason == None: if self.req.finished_reason is None:
self.req.to_abort = True self.req.to_abort = True
def __str__(self): def __str__(self):
...@@ -132,6 +132,10 @@ class Session: ...@@ -132,6 +132,10 @@ class Session:
lora_path=req.lora_path, lora_path=req.lora_path,
session_id=self.session_id, session_id=self.session_id,
custom_logit_processor=req.custom_logit_processor, custom_logit_processor=req.custom_logit_processor,
stream=req.stream,
return_logprob=req.return_logprob,
top_logprobs_num=req.top_logprobs_num,
token_ids_logprob=req.token_ids_logprob,
) )
if last_req is not None: if last_req is not None:
new_req.image_inputs = last_req.image_inputs new_req.image_inputs = last_req.image_inputs
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import asyncio import asyncio
import copy import copy
import dataclasses import dataclasses
import json
import logging import logging
import os import os
import pickle import pickle
...@@ -24,9 +25,21 @@ import sys ...@@ -24,9 +25,21 @@ import sys
import threading import threading
import time import time
import uuid import uuid
from collections import deque
from datetime import datetime from datetime import datetime
from http import HTTPStatus from http import HTTPStatus
from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union from typing import (
Any,
Awaitable,
Deque,
Dict,
Generic,
List,
Optional,
Tuple,
TypeVar,
Union,
)
import fastapi import fastapi
import uvloop import uvloop
...@@ -44,6 +57,7 @@ from sglang.srt.managers.image_processor import ( ...@@ -44,6 +57,7 @@ from sglang.srt.managers.image_processor import (
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
BatchEmbeddingOut, BatchEmbeddingOut,
BatchMultimodalOut,
BatchStrOut, BatchStrOut,
BatchTokenIDOut, BatchTokenIDOut,
CloseSessionReqInput, CloseSessionReqInput,
...@@ -51,18 +65,25 @@ from sglang.srt.managers.io_struct import ( ...@@ -51,18 +65,25 @@ from sglang.srt.managers.io_struct import (
EmbeddingReqInput, EmbeddingReqInput,
FlushCacheReq, FlushCacheReq,
GenerateReqInput, GenerateReqInput,
GetInternalStateReq,
GetInternalStateReqOutput,
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
GetWeightsByNameReqOutput, GetWeightsByNameReqOutput,
HealthCheckOutput,
InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput, InitWeightsUpdateGroupReqOutput,
OpenSessionReqInput, OpenSessionReqInput,
OpenSessionReqOutput, OpenSessionReqOutput,
ProfileReq, ProfileReq,
ProfileReqOutput,
ProfileReqType,
ReleaseMemoryOccupationReqInput, ReleaseMemoryOccupationReqInput,
ReleaseMemoryOccupationReqOutput, ReleaseMemoryOccupationReqOutput,
ResumeMemoryOccupationReqInput, ResumeMemoryOccupationReqInput,
ResumeMemoryOccupationReqOutput, ResumeMemoryOccupationReqOutput,
SessionParams, SessionParams,
SetInternalStateReq,
SetInternalStateReqOutput,
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
...@@ -98,7 +119,10 @@ class ReqState: ...@@ -98,7 +119,10 @@ class ReqState:
# For metrics # For metrics
created_time: float created_time: float
first_token_time: Optional[float] = None finished_time: float = 0.0
first_token_time: float = 0.0
last_time: float = 0.0
last_completion_tokens: int = 1
# For streaming output # For streaming output
last_output_offset: int = 0 last_output_offset: int = 0
...@@ -113,11 +137,10 @@ class TokenizerManager: ...@@ -113,11 +137,10 @@ class TokenizerManager:
port_args: PortArgs, port_args: PortArgs,
): ):
# Parse args # Parse args
self.server_args = server_args self.server_args = server_args
self.enable_metrics = server_args.enable_metrics self.enable_metrics = server_args.enable_metrics
self.log_requests = server_args.log_requests self.log_requests = server_args.log_requests
self.log_requests_level = 0 self.log_requests_level = server_args.log_requests_level
# Init inter-process communication # Init inter-process communication
context = zmq.asyncio.Context(2) context = zmq.asyncio.Context(2)
...@@ -143,6 +166,7 @@ class TokenizerManager: ...@@ -143,6 +166,7 @@ class TokenizerManager:
) )
self.is_generation = self.model_config.is_generation self.is_generation = self.model_config.is_generation
self.is_image_gen = self.model_config.is_image_gen
self.context_len = self.model_config.context_len self.context_len = self.model_config.context_len
self.image_token_id = self.model_config.image_token_id self.image_token_id = self.model_config.image_token_id
...@@ -178,9 +202,12 @@ class TokenizerManager: ...@@ -178,9 +202,12 @@ class TokenizerManager:
# Store states # Store states
self.no_create_loop = False self.no_create_loop = False
self.rid_to_state: Dict[str, ReqState] = {} self.rid_to_state: Dict[str, ReqState] = {}
self.gracefully_exit = False
self.last_receive_tstamp = 0
self.dump_requests_folder = "" # By default do not dump self.dump_requests_folder = "" # By default do not dump
self.dump_requests_threshold = 1000 self.dump_requests_threshold = 1000
self.dump_request_list: List[Tuple] = [] self.dump_request_list: List[Tuple] = []
self.log_request_metadata = self.get_log_request_metadata()
# The event to notify the weight sync is finished. # The event to notify the weight sync is finished.
self.model_update_lock = RWLock() self.model_update_lock = RWLock()
...@@ -192,8 +219,19 @@ class TokenizerManager: ...@@ -192,8 +219,19 @@ class TokenizerManager:
# For session info # For session info
self.session_futures = {} # session_id -> asyncio event self.session_futures = {} # session_id -> asyncio event
# Others # Set after scheduler is initialized
self.gracefully_exit = False self.max_req_input_len = None
# Metrics
if self.enable_metrics:
self.metrics_collector = TokenizerMetricsCollector(
labels={
"model_name": self.server_args.served_model_name,
# TODO: Add lora name/path in the future,
},
)
# Communicators
self.init_weights_update_group_communicator = _Communicator( self.init_weights_update_group_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size self.send_to_scheduler, server_args.dp_size
) )
...@@ -212,22 +250,26 @@ class TokenizerManager: ...@@ -212,22 +250,26 @@ class TokenizerManager:
self.resume_memory_occupation_communicator = _Communicator( self.resume_memory_occupation_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size self.send_to_scheduler, server_args.dp_size
) )
# Set after scheduler is initialized self.start_profile_communicator = _Communicator(
self.max_req_input_len = None self.send_to_scheduler, server_args.dp_size
)
# Metrics self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
if self.enable_metrics: self.get_internal_state_communicator = _Communicator(
self.metrics_collector = TokenizerMetricsCollector( self.send_to_scheduler, server_args.dp_size
labels={ )
"model_name": self.server_args.served_model_name, self.set_internal_state_communicator = _Communicator(
# TODO: Add lora name/path in the future, self.send_to_scheduler, server_args.dp_size
}, )
)
self._result_dispatcher = TypeBasedDispatcher( self._result_dispatcher = TypeBasedDispatcher(
[ [
( (
(BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut), (
BatchStrOut,
BatchEmbeddingOut,
BatchTokenIDOut,
BatchMultimodalOut,
),
self._handle_batch_output, self._handle_batch_output,
), ),
(OpenSessionReqOutput, self._handle_open_session_req_output), (OpenSessionReqOutput, self._handle_open_session_req_output),
...@@ -259,6 +301,19 @@ class TokenizerManager: ...@@ -259,6 +301,19 @@ class TokenizerManager:
ResumeMemoryOccupationReqOutput, ResumeMemoryOccupationReqOutput,
self.resume_memory_occupation_communicator.handle_recv, self.resume_memory_occupation_communicator.handle_recv,
), ),
(
ProfileReqOutput,
self.start_profile_communicator.handle_recv,
),
(
GetInternalStateReqOutput,
self.get_internal_state_communicator.handle_recv,
),
(
SetInternalStateReqOutput,
self.set_internal_state_communicator.handle_recv,
),
(HealthCheckOutput, lambda x: None),
] ]
) )
...@@ -280,9 +335,9 @@ class TokenizerManager: ...@@ -280,9 +335,9 @@ class TokenizerManager:
obj.normalize_batch_and_arguments() obj.normalize_batch_and_arguments()
if self.log_requests: if self.log_requests:
max_length = 2048 if self.log_requests_level == 0 else 1 << 30 max_length, skip_names, _ = self.log_request_metadata
logger.info( logger.info(
f"Receive: obj={dataclass_to_string_truncated(obj, max_length)}" f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
) )
async with self.model_update_lock.reader_lock: async with self.model_update_lock.reader_lock:
...@@ -336,6 +391,7 @@ class TokenizerManager: ...@@ -336,6 +391,7 @@ class TokenizerManager:
return_logprob = obj.return_logprob return_logprob = obj.return_logprob
logprob_start_len = obj.logprob_start_len logprob_start_len = obj.logprob_start_len
top_logprobs_num = obj.top_logprobs_num top_logprobs_num = obj.top_logprobs_num
token_ids_logprob = obj.token_ids_logprob
session_params = ( session_params = (
SessionParams(**obj.session_params) if obj.session_params else None SessionParams(**obj.session_params) if obj.session_params else None
) )
...@@ -378,6 +434,7 @@ class TokenizerManager: ...@@ -378,6 +434,7 @@ class TokenizerManager:
return_logprob, return_logprob,
logprob_start_len, logprob_start_len,
top_logprobs_num, top_logprobs_num,
token_ids_logprob,
obj.stream, obj.stream,
lora_path=obj.lora_path, lora_path=obj.lora_path,
input_embeds=input_embeds, input_embeds=input_embeds,
...@@ -401,8 +458,7 @@ class TokenizerManager: ...@@ -401,8 +458,7 @@ class TokenizerManager:
tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput], tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
created_time: Optional[float] = None, created_time: Optional[float] = None,
): ):
event = asyncio.Event() state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
state = ReqState([], False, event, obj, created_time=created_time)
self.rid_to_state[obj.rid] = state self.rid_to_state[obj.rid] = state
self.send_to_scheduler.send_pyobj(tokenized_obj) self.send_to_scheduler.send_pyobj(tokenized_obj)
...@@ -420,7 +476,10 @@ class TokenizerManager: ...@@ -420,7 +476,10 @@ class TokenizerManager:
except asyncio.TimeoutError: except asyncio.TimeoutError:
if request is not None and await request.is_disconnected(): if request is not None and await request.is_disconnected():
self.abort_request(obj.rid) self.abort_request(obj.rid)
raise ValueError(f"Abort request {obj.rid}") raise ValueError(
"Request is disconnected from the client side. "
f"Abort request {obj.rid}"
)
continue continue
out = state.out_list[-1] out = state.out_list[-1]
...@@ -428,8 +487,11 @@ class TokenizerManager: ...@@ -428,8 +487,11 @@ class TokenizerManager:
state.out_list = [] state.out_list = []
if state.finished: if state.finished:
if self.log_requests: if self.log_requests:
max_length = 2048 if self.log_requests_level == 0 else 1 << 30 max_length, skip_names, out_skip_names = self.log_request_metadata
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length)}, out={dataclass_to_string_truncated(out, max_length)}" if self.model_config.is_multimodal_gen:
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
else:
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
logger.info(msg) logger.info(msg)
del self.rid_to_state[obj.rid] del self.rid_to_state[obj.rid]
...@@ -452,7 +514,10 @@ class TokenizerManager: ...@@ -452,7 +514,10 @@ class TokenizerManager:
else: else:
if request is not None and await request.is_disconnected(): if request is not None and await request.is_disconnected():
self.abort_request(obj.rid) self.abort_request(obj.rid)
raise ValueError(f"Abort request {obj.rid}") raise ValueError(
"Request is disconnected from the client side. "
f"Abort request {obj.rid}"
)
async def _handle_batch_request( async def _handle_batch_request(
self, self,
...@@ -543,12 +608,25 @@ class TokenizerManager: ...@@ -543,12 +608,25 @@ class TokenizerManager:
req = AbortReq(rid) req = AbortReq(rid)
self.send_to_scheduler.send_pyobj(req) self.send_to_scheduler.send_pyobj(req)
def start_profile(self): async def start_profile(
req = ProfileReq.START_PROFILE self,
self.send_to_scheduler.send_pyobj(req) output_dir: Optional[str] = None,
num_steps: Optional[int] = None,
activities: Optional[List[str]] = None,
):
req = ProfileReq(
type=ProfileReqType.START_PROFILE,
output_dir=output_dir,
num_steps=num_steps,
activities=activities,
)
result = (await self.start_profile_communicator(req))[0]
if not result.success:
raise RuntimeError(result.message)
return result
def stop_profile(self): def stop_profile(self):
req = ProfileReq.STOP_PROFILE req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
self.send_to_scheduler.send_pyobj(req) self.send_to_scheduler.send_pyobj(req)
async def update_weights_from_disk( async def update_weights_from_disk(
...@@ -581,7 +659,7 @@ class TokenizerManager: ...@@ -581,7 +659,7 @@ class TokenizerManager:
self.server_args.model_path = obj.model_path self.server_args.model_path = obj.model_path
self.server_args.load_format = obj.load_format self.server_args.load_format = obj.load_format
self.model_path = obj.model_path self.model_path = obj.model_path
return result.success, result.message return result.success, result.message, result.num_paused_requests
else: # self.server_args.dp_size > 1 else: # self.server_args.dp_size > 1
self.model_update_tmp = [] self.model_update_tmp = []
result = await self.model_update_result result = await self.model_update_result
...@@ -593,7 +671,8 @@ class TokenizerManager: ...@@ -593,7 +671,8 @@ class TokenizerManager:
self.model_path = obj.model_path self.model_path = obj.model_path
all_message = [r.message for r in result] all_message = [r.message for r in result]
all_message = " | ".join(all_message) all_message = " | ".join(all_message)
return all_success, all_message all_paused_requests = [r.num_paused_requests for r in result]
return all_success, all_message, all_paused_requests
async def init_weights_update_group( async def init_weights_update_group(
self, self,
...@@ -688,6 +767,54 @@ class TokenizerManager: ...@@ -688,6 +767,54 @@ class TokenizerManager:
): ):
await self.send_to_scheduler.send_pyobj(obj) await self.send_to_scheduler.send_pyobj(obj)
async def get_internal_state(self) -> Dict[Any, Any]:
req = GetInternalStateReq()
res: List[GetInternalStateReqOutput] = (
await self.get_internal_state_communicator(req)
)
return res[0].internal_state
async def set_internal_state(
self, obj: SetInternalStateReq
) -> SetInternalStateReqOutput:
res: List[SetInternalStateReqOutput] = (
await self.set_internal_state_communicator(obj)
)
return res[0]
def get_log_request_metadata(self):
max_length = None
skip_names = None
out_skip_names = None
if self.log_requests:
if self.log_requests_level == 0:
max_length = 1 << 30
skip_names = set(
[
"text",
"input_ids",
"input_embeds",
"image_data",
"audio_data",
"lora_path",
]
)
out_skip_names = set(
[
"text",
"output_ids",
]
)
elif self.log_requests_level == 1:
max_length = 2048
elif self.log_requests_level == 2:
max_length = 1 << 30
else:
raise ValueError(
f"Invalid --log-requests-level: {self.log_requests_level=}"
)
return max_length, skip_names, out_skip_names
def configure_logging(self, obj: ConfigureLoggingReq): def configure_logging(self, obj: ConfigureLoggingReq):
if obj.log_requests is not None: if obj.log_requests is not None:
self.log_requests = obj.log_requests self.log_requests = obj.log_requests
...@@ -698,6 +825,7 @@ class TokenizerManager: ...@@ -698,6 +825,7 @@ class TokenizerManager:
if obj.dump_requests_threshold is not None: if obj.dump_requests_threshold is not None:
self.dump_requests_threshold = obj.dump_requests_threshold self.dump_requests_threshold = obj.dump_requests_threshold
logging.info(f"Config logging: {obj=}") logging.info(f"Config logging: {obj=}")
self.log_request_metadata = self.get_log_request_metadata()
def create_abort_task(self, obj: GenerateReqInput): def create_abort_task(self, obj: GenerateReqInput):
# Abort the request if the client is disconnected. # Abort the request if the client is disconnected.
...@@ -762,15 +890,20 @@ class TokenizerManager: ...@@ -762,15 +890,20 @@ class TokenizerManager:
while True: while True:
recv_obj = await self.recv_from_detokenizer.recv_pyobj() recv_obj = await self.recv_from_detokenizer.recv_pyobj()
self._result_dispatcher(recv_obj) self._result_dispatcher(recv_obj)
self.last_receive_tstamp = time.time()
def _handle_batch_output( def _handle_batch_output(
self, recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut] self,
recv_obj: Union[
BatchStrOut, BatchEmbeddingOut, BatchMultimodalOut, BatchTokenIDOut
],
): ):
for i, rid in enumerate(recv_obj.rids): for i, rid in enumerate(recv_obj.rids):
state = self.rid_to_state.get(rid, None) state = self.rid_to_state.get(rid, None)
if state is None: if state is None:
continue continue
# Build meta_info and return value
meta_info = { meta_info = {
"id": rid, "id": rid,
"finish_reason": recv_obj.finished_reasons[i], "finish_reason": recv_obj.finished_reasons[i],
...@@ -781,14 +914,12 @@ class TokenizerManager: ...@@ -781,14 +914,12 @@ class TokenizerManager:
self.convert_logprob_style( self.convert_logprob_style(
meta_info, meta_info,
state.obj.top_logprobs_num, state.obj.top_logprobs_num,
state.obj.token_ids_logprob,
state.obj.return_text_in_logprobs, state.obj.return_text_in_logprobs,
recv_obj, recv_obj,
i, i,
) )
if self.server_args.speculative_algorithm:
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
if not isinstance(recv_obj, BatchEmbeddingOut): if not isinstance(recv_obj, BatchEmbeddingOut):
meta_info.update( meta_info.update(
{ {
...@@ -806,10 +937,20 @@ class TokenizerManager: ...@@ -806,10 +937,20 @@ class TokenizerManager:
"meta_info": meta_info, "meta_info": meta_info,
} }
elif isinstance(recv_obj, BatchTokenIDOut): elif isinstance(recv_obj, BatchTokenIDOut):
if self.server_args.stream_output and state.obj.stream:
output_token_ids = recv_obj.output_ids[i][
state.last_output_offset :
]
state.last_output_offset = len(recv_obj.output_ids[i])
else:
output_token_ids = recv_obj.output_ids[i]
out_dict = { out_dict = {
"token_ids": recv_obj.output_ids[i], "output_ids": output_token_ids,
"meta_info": meta_info, "meta_info": meta_info,
} }
elif isinstance(recv_obj, BatchMultimodalOut):
raise NotImplementedError()
else: else:
assert isinstance(recv_obj, BatchEmbeddingOut) assert isinstance(recv_obj, BatchEmbeddingOut)
out_dict = { out_dict = {
...@@ -817,10 +958,17 @@ class TokenizerManager: ...@@ -817,10 +958,17 @@ class TokenizerManager:
"meta_info": meta_info, "meta_info": meta_info,
} }
state.out_list.append(out_dict)
state.finished = recv_obj.finished_reasons[i] is not None state.finished = recv_obj.finished_reasons[i] is not None
if state.finished:
if self.server_args.speculative_algorithm:
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
state.finished_time = time.time()
meta_info["e2e_latency"] = state.finished_time - state.created_time
state.out_list.append(out_dict)
state.event.set() state.event.set()
# Log metrics and dump
if self.enable_metrics and state.obj.log_metrics: if self.enable_metrics and state.obj.log_metrics:
self.collect_metrics(state, recv_obj, i) self.collect_metrics(state, recv_obj, i)
if self.dump_requests_folder and state.finished and state.obj.log_metrics: if self.dump_requests_folder and state.finished and state.obj.log_metrics:
...@@ -830,6 +978,7 @@ class TokenizerManager: ...@@ -830,6 +978,7 @@ class TokenizerManager:
self, self,
meta_info: dict, meta_info: dict,
top_logprobs_num: int, top_logprobs_num: int,
token_ids_logprob: List[int],
return_text_in_logprobs: bool, return_text_in_logprobs: bool,
recv_obj: BatchStrOut, recv_obj: BatchStrOut,
recv_obj_index: int, recv_obj_index: int,
...@@ -857,6 +1006,20 @@ class TokenizerManager: ...@@ -857,6 +1006,20 @@ class TokenizerManager:
return_text_in_logprobs, return_text_in_logprobs,
) )
if token_ids_logprob is not None:
meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
recv_obj.input_token_ids_logprobs_val[recv_obj_index],
recv_obj.input_token_ids_logprobs_idx[recv_obj_index],
return_text_in_logprobs,
)
meta_info["output_token_ids_logprobs"] = (
self.detokenize_top_logprobs_tokens(
recv_obj.output_token_ids_logprobs_val[recv_obj_index],
recv_obj.output_token_ids_logprobs_idx[recv_obj_index],
return_text_in_logprobs,
)
)
def detokenize_logprob_tokens( def detokenize_logprob_tokens(
self, self,
token_logprobs_val: List[float], token_logprobs_val: List[float],
...@@ -900,34 +1063,30 @@ class TokenizerManager: ...@@ -900,34 +1063,30 @@ class TokenizerManager:
else 0 else 0
) )
if state.first_token_time is None: if state.first_token_time == 0.0:
state.first_token_time = time.time() state.first_token_time = state.last_time = time.time()
state.last_completion_tokens = completion_tokens
self.metrics_collector.observe_time_to_first_token( self.metrics_collector.observe_time_to_first_token(
state.first_token_time - state.created_time state.first_token_time - state.created_time
) )
else: else:
if completion_tokens >= 2: num_new_tokens = completion_tokens - state.last_completion_tokens
# Compute time_per_output_token for the streaming case if num_new_tokens:
self.metrics_collector.observe_time_per_output_token( new_time = time.time()
(time.time() - state.first_token_time) / (completion_tokens - 1) interval = new_time - state.last_time
self.metrics_collector.observe_inter_token_latency(
interval,
num_new_tokens,
) )
state.last_time = new_time
state.last_completion_tokens = completion_tokens
if state.finished: if state.finished:
self.metrics_collector.observe_one_finished_request( self.metrics_collector.observe_one_finished_request(
recv_obj.prompt_tokens[i], completion_tokens recv_obj.prompt_tokens[i],
) completion_tokens,
self.metrics_collector.observe_e2e_request_latency( state.finished_time - state.created_time,
time.time() - state.created_time
) )
# Compute time_per_output_token for the non-streaming case
if (
hasattr(state.obj, "stream")
and not state.obj.stream
and completion_tokens >= 1
):
self.metrics_collector.observe_time_per_output_token(
(time.time() - state.created_time) / completion_tokens
)
def dump_requests(self, state: ReqState, out_dict: dict): def dump_requests(self, state: ReqState, out_dict: dict):
self.dump_request_list.append( self.dump_request_list.append(
...@@ -996,22 +1155,38 @@ T = TypeVar("T") ...@@ -996,22 +1155,38 @@ T = TypeVar("T")
class _Communicator(Generic[T]): class _Communicator(Generic[T]):
"""Note: The communicator now only run up to 1 in-flight request at any time."""
def __init__(self, sender, fan_out: int): def __init__(self, sender, fan_out: int):
self._sender = sender self._sender = sender
self._fan_out = fan_out self._fan_out = fan_out
self._result_future: Optional[asyncio.Future] = None self._result_event: Optional[asyncio.Event] = None
self._result_values: Optional[List[T]] = None self._result_values: Optional[List[T]] = None
self._ready_queue: Deque[asyncio.Future] = deque()
async def __call__(self, obj): async def __call__(self, obj):
self._sender.send_pyobj(obj) ready_event = asyncio.Event()
self._result_future = asyncio.Future() if self._result_event is not None or len(self._ready_queue) > 0:
self._ready_queue.append(ready_event)
await ready_event.wait()
assert self._result_event is None
assert self._result_values is None
if obj:
self._sender.send_pyobj(obj)
self._result_event = asyncio.Event()
self._result_values = [] self._result_values = []
await self._result_future await self._result_event.wait()
result_values = self._result_values result_values = self._result_values
self._result_future = self._result_values = None self._result_event = self._result_values = None
if len(self._ready_queue) > 0:
self._ready_queue.popleft().set()
return result_values return result_values
def handle_recv(self, recv_obj: T): def handle_recv(self, recv_obj: T):
self._result_values.append(recv_obj) self._result_values.append(recv_obj)
if len(self._result_values) == self._fan_out: if len(self._result_values) == self._fan_out:
self._result_future.set_result(None) self._result_event.set()
...@@ -15,10 +15,13 @@ ...@@ -15,10 +15,13 @@
import logging import logging
import threading import threading
from typing import Optional from typing import Optional, Tuple
import torch
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqInput,
...@@ -159,7 +162,7 @@ class TpModelWorker: ...@@ -159,7 +162,7 @@ class TpModelWorker:
model_worker_batch: ModelWorkerBatch, model_worker_batch: ModelWorkerBatch,
launch_done: Optional[threading.Event] = None, launch_done: Optional[threading.Event] = None,
skip_sample: bool = False, skip_sample: bool = False,
): ) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]:
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch) logits_output = self.model_runner.forward(forward_batch)
if launch_done: if launch_done:
......
...@@ -175,7 +175,7 @@ class TpModelWorkerClient: ...@@ -175,7 +175,7 @@ class TpModelWorkerClient:
logits_output.next_token_logprobs.tolist() logits_output.next_token_logprobs.tolist()
) )
if logits_output.input_token_logprobs is not None: if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = ( logits_output.input_token_logprobs = tuple(
logits_output.input_token_logprobs.tolist() logits_output.input_token_logprobs.tolist()
) )
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
...@@ -188,8 +188,7 @@ class TpModelWorkerClient: ...@@ -188,8 +188,7 @@ class TpModelWorkerClient:
model_worker_batch.sampling_info = self.cur_sampling_info = dataclasses.replace( model_worker_batch.sampling_info = self.cur_sampling_info = dataclasses.replace(
sampling_info, sampling_info,
sampling_info_done=threading.Event(), sampling_info_done=threading.Event(),
scaling_penalties=sampling_info.scaling_penalties, penalizer_orchestrator=None,
linear_penalties=sampling_info.linear_penalties,
) )
# A cuda stream sync here to avoid the cuda illegal memory access error. # A cuda stream sync here to avoid the cuda illegal memory access error.
......
...@@ -2,7 +2,9 @@ from __future__ import annotations ...@@ -2,7 +2,9 @@ from __future__ import annotations
"""Cache for chunked prefill, used when RadixCache is disabled.""" """Cache for chunked prefill, used when RadixCache is disabled."""
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
import torch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
...@@ -12,7 +14,7 @@ if TYPE_CHECKING: ...@@ -12,7 +14,7 @@ if TYPE_CHECKING:
class ChunkCacheEntry: class ChunkCacheEntry:
def __init__(self, rid, value): def __init__(self, rid: str, value: torch.Tensor):
self.rid = rid self.rid = rid
self.value = value self.value = value
...@@ -24,6 +26,7 @@ class ChunkCache(BasePrefixCache): ...@@ -24,6 +26,7 @@ class ChunkCache(BasePrefixCache):
self.disable = True self.disable = True
self.req_to_token_pool = req_to_token_pool self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool self.token_to_kv_pool = token_to_kv_pool
self.entries: Dict[str, ChunkCacheEntry] = {}
self.reset() self.reset()
...@@ -53,11 +56,8 @@ class ChunkCache(BasePrefixCache): ...@@ -53,11 +56,8 @@ class ChunkCache(BasePrefixCache):
if req.rid in self.entries: if req.rid in self.entries:
del self.entries[req.rid] del self.entries[req.rid]
def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None): def cache_unfinished_req(self, req: Req):
if token_ids is None: token_id_len = len(req.fill_ids)
token_id_len = len(req.fill_ids)
else:
token_id_len = len(token_ids)
kv_indices = self.req_to_token_pool.req_to_token[ kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, :token_id_len req.req_pool_idx, :token_id_len
...@@ -86,5 +86,8 @@ class ChunkCache(BasePrefixCache): ...@@ -86,5 +86,8 @@ class ChunkCache(BasePrefixCache):
def evictable_size(self): def evictable_size(self):
return 0 return 0
def pretty_print(self):
return ""
def protected_size(self): def protected_size(self):
return 0 return 0
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# ============================================================================== # ==============================================================================
"""Utilities for Prometheus Metrics Collection.""" """Utilities for Prometheus Metrics Collection."""
import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Union from typing import Dict, Union
...@@ -35,19 +36,20 @@ class SchedulerMetricsCollector: ...@@ -35,19 +36,20 @@ class SchedulerMetricsCollector:
from prometheus_client import Gauge from prometheus_client import Gauge
self.labels = labels self.labels = labels
self.last_log_time = time.time()
self.num_running_reqs = Gauge( self.num_running_reqs = Gauge(
name="sglang:num_running_reqs", name="sglang:num_running_reqs",
documentation="The number of running requests.", documentation="The number of running requests.",
labelnames=labels.keys(), labelnames=labels.keys(),
multiprocess_mode="sum", multiprocess_mode="mostrecent",
) )
self.num_used_tokens = Gauge( self.num_used_tokens = Gauge(
name="sglang:num_used_tokens", name="sglang:num_used_tokens",
documentation="The number of used tokens.", documentation="The number of used tokens.",
labelnames=labels.keys(), labelnames=labels.keys(),
multiprocess_mode="sum", multiprocess_mode="mostrecent",
) )
self.token_usage = Gauge( self.token_usage = Gauge(
...@@ -61,14 +63,14 @@ class SchedulerMetricsCollector: ...@@ -61,14 +63,14 @@ class SchedulerMetricsCollector:
name="sglang:gen_throughput", name="sglang:gen_throughput",
documentation="The generation throughput (token/s).", documentation="The generation throughput (token/s).",
labelnames=labels.keys(), labelnames=labels.keys(),
multiprocess_mode="sum", multiprocess_mode="mostrecent",
) )
self.num_queue_reqs = Gauge( self.num_queue_reqs = Gauge(
name="sglang:num_queue_reqs", name="sglang:num_queue_reqs",
documentation="The number of requests in the waiting queue.", documentation="The number of requests in the waiting queue.",
labelnames=labels.keys(), labelnames=labels.keys(),
multiprocess_mode="sum", multiprocess_mode="mostrecent",
) )
self.cache_hit_rate = Gauge( self.cache_hit_rate = Gauge(
...@@ -97,6 +99,7 @@ class SchedulerMetricsCollector: ...@@ -97,6 +99,7 @@ class SchedulerMetricsCollector:
self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs) self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs)
self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate) self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
self._log_gauge(self.spec_accept_length, stats.spec_accept_length) self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
self.last_log_time = time.time()
class TokenizerMetricsCollector: class TokenizerMetricsCollector:
...@@ -130,12 +133,15 @@ class TokenizerMetricsCollector: ...@@ -130,12 +133,15 @@ class TokenizerMetricsCollector:
labelnames=labels.keys(), labelnames=labels.keys(),
buckets=[ buckets=[
0.1, 0.1,
0.25, 0.3,
0.5, 0.5,
0.75, 0.7,
0.9,
1, 1,
2, 2,
5, 4,
6,
8,
10, 10,
20, 20,
40, 40,
...@@ -151,24 +157,56 @@ class TokenizerMetricsCollector: ...@@ -151,24 +157,56 @@ class TokenizerMetricsCollector:
documentation="Histogram of time per output token in seconds.", documentation="Histogram of time per output token in seconds.",
labelnames=labels.keys(), labelnames=labels.keys(),
buckets=[ buckets=[
0.002,
0.005, 0.005,
0.01, 0.010,
0.020,
0.030,
0.040,
0.050,
0.060,
0.070,
0.080,
0.090,
0.100,
0.150,
0.200,
0.300,
0.400,
0.600,
0.800,
1.000,
2.000,
],
)
self.histogram_inter_token_latency_seconds = Histogram(
name="sglang:inter_token_latency_seconds",
documentation="Histogram of inter-token latency in seconds.",
labelnames=labels.keys(),
buckets=[
0.002,
0.004,
0.006,
0.008,
0.010,
0.015, 0.015,
0.02, 0.020,
0.025, 0.025,
0.03, 0.030,
0.04, 0.035,
0.05, 0.040,
0.050,
0.075, 0.075,
0.1, 0.100,
0.15, 0.150,
0.2, 0.200,
0.3, 0.300,
0.4, 0.400,
0.5, 0.500,
0.75, 0.750,
1.0, 1.000,
2.5, 2.000,
], ],
) )
...@@ -178,8 +216,9 @@ class TokenizerMetricsCollector: ...@@ -178,8 +216,9 @@ class TokenizerMetricsCollector:
labelnames=labels.keys(), labelnames=labels.keys(),
buckets=[ buckets=[
0.1, 0.1,
0.25, 0.2,
0.5, 0.4,
0.8,
1, 1,
2, 2,
5, 5,
...@@ -188,28 +227,161 @@ class TokenizerMetricsCollector: ...@@ -188,28 +227,161 @@ class TokenizerMetricsCollector:
40, 40,
60, 60,
80, 80,
100,
150,
200,
250,
300,
350,
500,
1000,
],
)
self.histogram_prefill_prealloc_duration = Histogram(
name="sglang:prefill_prealloc_duration_seconds",
documentation="Histogram of prefill prealloc duration in seconds.",
labelnames=labels.keys(),
buckets=[
0.1,
0.3,
0.5,
0.7,
0.9,
1,
2,
4,
6,
8,
10,
20,
40,
60,
80,
120, 120,
160, 160,
], ],
) )
self.histogram_prefill_queue_duration = Histogram(
name="sglang:prefill_queue_duration_seconds",
documentation="Histogram of prefill queue duration in seconds.",
labelnames=labels.keys(),
buckets=[
0.1,
0.3,
0.5,
0.7,
0.9,
2,
4,
8,
16,
64,
],
)
self.histogram_prefill_forward_duration = Histogram(
name="sglang:prefill_forward_duration_seconds",
documentation="Histogram of prefill forward duration in seconds.",
labelnames=labels.keys(),
buckets=[
0.1,
0.3,
0.5,
0.7,
0.9,
2,
4,
8,
16,
64,
],
)
self.histogram_prefill_transfer_duration = Histogram(
name="sglang:prefill_transfer_duration_seconds",
documentation="Histogram of prefill transfer duration in seconds.",
labelnames=labels.keys(),
buckets=[
0.050,
0.100,
0.150,
0.200,
0.300,
0.400,
0.500,
1.000,
2.000,
],
)
self.histogram_decode_prealloc_duration = Histogram(
name="sglang:decode_prealloc_duration_seconds",
documentation="Histogram of decode prealloc duration in seconds.",
labelnames=labels.keys(),
buckets=[
0.1,
0.3,
0.5,
0.7,
0.9,
2,
4,
8,
16,
64,
],
)
self.histogram_decode_queue_duration = Histogram(
name="sglang:decode_queue_duration_seconds",
documentation="Histogram of decode queue duration in seconds.",
labelnames=labels.keys(),
buckets=[
0.1,
0.3,
0.5,
0.7,
0.9,
2,
4,
8,
16,
64,
],
)
def _log_histogram(self, histogram, data: Union[int, float]) -> None: def _log_histogram(self, histogram, data: Union[int, float]) -> None:
histogram.labels(**self.labels).observe(data) histogram.labels(**self.labels).observe(data)
def _log_counter(self, counter, data: Union[int, float]) -> None: def observe_one_finished_request(
# Convenience function for logging to counter. self,
counter.labels(**self.labels).inc(data) prompt_tokens: int,
generation_tokens: int,
def observe_one_finished_request(self, prompt_tokens: int, generation_tokens: int): e2e_latency: float,
):
self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens) self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
self.generation_tokens_total.labels(**self.labels).inc(generation_tokens) self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
self.num_requests_total.labels(**self.labels).inc(1) self.num_requests_total.labels(**self.labels).inc(1)
self._log_histogram(self.histogram_e2e_request_latency, e2e_latency)
if generation_tokens >= 1:
self.histogram_time_per_output_token.labels(**self.labels).observe(
e2e_latency / generation_tokens
)
def observe_time_to_first_token(self, value: float):
self.histogram_time_to_first_token.labels(**self.labels).observe(value)
def observe_time_to_first_token(self, value: Union[float, int]): def observe_inter_token_latency(self, internval: float, num_new_tokens: int):
self._log_histogram(self.histogram_time_to_first_token, value) adjusted_interval = internval / num_new_tokens
def observe_time_per_output_token(self, value: Union[float, int]): # A faster version of the Histogram::observe which observes multiple values at the same time.
self._log_histogram(self.histogram_time_per_output_token, value) # reference: https://github.com/prometheus/client_python/blob/v0.21.1/prometheus_client/metrics.py#L639
his = self.histogram_inter_token_latency_seconds.labels(**self.labels)
his._sum.inc(internval)
def observe_e2e_request_latency(self, value: Union[float, int]): for i, bound in enumerate(his._upper_bounds):
self._log_histogram(self.histogram_e2e_request_latency, value) if adjusted_interval <= bound:
his._buckets[i].inc(num_new_tokens)
break
...@@ -109,11 +109,15 @@ def set_torch_compile_config(): ...@@ -109,11 +109,15 @@ def set_torch_compile_config():
def get_batch_sizes_to_capture(model_runner: ModelRunner): def get_batch_sizes_to_capture(model_runner: ModelRunner):
server_args = model_runner.server_args server_args = model_runner.server_args
capture_bs = server_args.cuda_graph_bs capture_bs = server_args.cuda_graph_bs
if capture_bs is None: if capture_bs is None:
if server_args.disable_cuda_graph_padding: if server_args.speculative_algorithm is None:
capture_bs = list(range(1, 33)) + [64, 128] if server_args.disable_cuda_graph_padding:
capture_bs = list(range(1, 33)) + [64, 96, 128, 160]
else:
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
else: else:
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] capture_bs = list(range(1, 33))
if is_hip_: if is_hip_:
capture_bs += [i * 8 for i in range(21, 33)] capture_bs += [i * 8 for i in range(21, 33)]
...@@ -130,6 +134,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): ...@@ -130,6 +134,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
) )
) )
) )
capture_bs = [ capture_bs = [
bs bs
for bs in capture_bs for bs in capture_bs
...@@ -385,9 +390,6 @@ class CudaGraphRunner: ...@@ -385,9 +390,6 @@ class CudaGraphRunner:
run_once() run_once()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
torch.cuda.synchronize() torch.cuda.synchronize()
self.model_runner.tp_group.barrier() self.model_runner.tp_group.barrier()
...@@ -401,12 +403,11 @@ class CudaGraphRunner: ...@@ -401,12 +403,11 @@ class CudaGraphRunner:
global_graph_memory_pool = graph.pool() global_graph_memory_pool = graph.pool()
return graph, out return graph, out
def replay(self, forward_batch: ForwardBatch): def recapture_if_needed(self, forward_batch: ForwardBatch):
assert forward_batch.out_cache_loc is not None # If the capture_hidden_mode changes, we need to recapture the graph
hidden_mode_from_spec_info = getattr( hidden_mode_from_spec_info = getattr(
forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
) )
# If the capture_hidden_mode changes, we need to recapture the graph
if ( if (
forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL
and self.capture_hidden_mode != CaptureHiddenMode.FULL and self.capture_hidden_mode != CaptureHiddenMode.FULL
...@@ -420,6 +421,9 @@ class CudaGraphRunner: ...@@ -420,6 +421,9 @@ class CudaGraphRunner:
self.capture_hidden_mode = hidden_mode_from_spec_info self.capture_hidden_mode = hidden_mode_from_spec_info
self.capture() self.capture()
def replay(self, forward_batch: ForwardBatch):
self.recapture_if_needed(forward_batch)
raw_bs = forward_batch.batch_size raw_bs = forward_batch.batch_size
raw_num_token = raw_bs * self.num_tokens_per_bs raw_num_token = raw_bs * self.num_tokens_per_bs
......
...@@ -31,7 +31,7 @@ from __future__ import annotations ...@@ -31,7 +31,7 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional, Union
import torch import torch
import triton import triton
...@@ -46,7 +46,8 @@ if TYPE_CHECKING: ...@@ -46,7 +46,8 @@ if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
class ForwardMode(IntEnum): class ForwardMode(IntEnum):
...@@ -112,7 +113,9 @@ class ForwardMode(IntEnum): ...@@ -112,7 +113,9 @@ class ForwardMode(IntEnum):
class CaptureHiddenMode(IntEnum): class CaptureHiddenMode(IntEnum):
NULL = auto() NULL = auto()
# Capture hidden states of all tokens.
FULL = auto() FULL = auto()
# Capture a hidden state of the last token.
LAST = auto() LAST = auto()
def need_capture(self): def need_capture(self):
...@@ -148,6 +151,7 @@ class ForwardBatch: ...@@ -148,6 +151,7 @@ class ForwardBatch:
# For logprob # For logprob
return_logprob: bool = False return_logprob: bool = False
top_logprobs_nums: Optional[List[int]] = None top_logprobs_nums: Optional[List[int]] = None
token_ids_logprobs: Optional[List[List[int]]] = None
# Position information # Position information
positions: torch.Tensor = None positions: torch.Tensor = None
...@@ -160,6 +164,7 @@ class ForwardBatch: ...@@ -160,6 +164,7 @@ class ForwardBatch:
extend_prefix_lens_cpu: Optional[List[int]] = None extend_prefix_lens_cpu: Optional[List[int]] = None
extend_seq_lens_cpu: Optional[List[int]] = None extend_seq_lens_cpu: Optional[List[int]] = None
extend_logprob_start_lens_cpu: Optional[List[int]] = None extend_logprob_start_lens_cpu: Optional[List[int]] = None
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
# For multimodal # For multimodal
image_inputs: Optional[List[ImageInputs]] = None image_inputs: Optional[List[ImageInputs]] = None
...@@ -190,10 +195,13 @@ class ForwardBatch: ...@@ -190,10 +195,13 @@ class ForwardBatch:
can_run_dp_cuda_graph: bool = False can_run_dp_cuda_graph: bool = False
# Speculative decoding # Speculative decoding
spec_info: SpecInfo = None spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
spec_algorithm: SpeculativeAlgorithm = None spec_algorithm: SpeculativeAlgorithm = None
capture_hidden_mode: CaptureHiddenMode = None capture_hidden_mode: CaptureHiddenMode = None
# For padding
padded_static_len: int = -1 # -1 if not padded
# For Qwen2-VL # For Qwen2-VL
mrope_positions: torch.Tensor = None mrope_positions: torch.Tensor = None
...@@ -203,8 +211,13 @@ class ForwardBatch: ...@@ -203,8 +211,13 @@ class ForwardBatch:
batch: ModelWorkerBatch, batch: ModelWorkerBatch,
model_runner: ModelRunner, model_runner: ModelRunner,
): ):
device = model_runner.device device = model_runner.device
extend_input_logprob_token_ids_gpu = None
if batch.extend_input_logprob_token_ids is not None:
extend_input_logprob_token_ids_gpu = (
batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
)
ret = cls( ret = cls(
forward_mode=batch.forward_mode, forward_mode=batch.forward_mode,
batch_size=len(batch.seq_lens), batch_size=len(batch.seq_lens),
...@@ -220,6 +233,7 @@ class ForwardBatch: ...@@ -220,6 +233,7 @@ class ForwardBatch:
seq_lens_sum=batch.seq_lens_sum, seq_lens_sum=batch.seq_lens_sum,
return_logprob=batch.return_logprob, return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums, top_logprobs_nums=batch.top_logprobs_nums,
token_ids_logprobs=batch.token_ids_logprobs,
global_num_tokens=batch.global_num_tokens, global_num_tokens=batch.global_num_tokens,
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph, can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
lora_paths=batch.lora_paths, lora_paths=batch.lora_paths,
...@@ -231,6 +245,7 @@ class ForwardBatch: ...@@ -231,6 +245,7 @@ class ForwardBatch:
spec_info=batch.spec_info, spec_info=batch.spec_info,
capture_hidden_mode=batch.capture_hidden_mode, capture_hidden_mode=batch.capture_hidden_mode,
input_embeds=batch.input_embeds, input_embeds=batch.input_embeds,
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
) )
if ret.global_num_tokens is not None: if ret.global_num_tokens is not None:
...@@ -341,6 +356,7 @@ class ForwardBatch: ...@@ -341,6 +356,7 @@ class ForwardBatch:
) )
batch.image_inputs[i].mrope_position_delta = mrope_position_delta batch.image_inputs[i].mrope_position_delta = mrope_position_delta
mrope_positions_list[i] = mrope_positions mrope_positions_list[i] = mrope_positions
self.mrope_positions = torch.concat( self.mrope_positions = torch.concat(
[torch.tensor(pos, device=device) for pos in mrope_positions_list], [torch.tensor(pos, device=device) for pos in mrope_positions_list],
axis=1, axis=1,
...@@ -379,7 +395,7 @@ def compute_position_kernel( ...@@ -379,7 +395,7 @@ def compute_position_kernel(
extend_seq_lens, extend_seq_lens,
): ):
BLOCK_SIZE: tl.constexpr = 512 BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(0) pid = tl.program_id(0).to(tl.int64)
prefix_len = tl.load(extend_prefix_lens + pid) prefix_len = tl.load(extend_prefix_lens + pid)
seq_len = tl.load(extend_seq_lens + pid) seq_len = tl.load(extend_seq_lens + pid)
......
...@@ -13,9 +13,12 @@ ...@@ -13,9 +13,12 @@
# ============================================================================== # ==============================================================================
"""ModelRunner runs the forward passes of the models.""" """ModelRunner runs the forward passes of the models."""
import collections
import datetime
import gc import gc
import json import json
import logging import logging
import os
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
...@@ -58,6 +61,7 @@ from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner ...@@ -58,6 +61,7 @@ from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader import get_model from sglang.srt.model_loader import get_model
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
...@@ -73,10 +77,15 @@ from sglang.srt.utils import ( ...@@ -73,10 +77,15 @@ from sglang.srt.utils import (
set_cpu_offload_max_bytes, set_cpu_offload_max_bytes,
set_cuda_arch, set_cuda_arch,
) )
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
class ModelRunner: class ModelRunner:
"""ModelRunner runs the forward passes of the models.""" """ModelRunner runs the forward passes of the models."""
...@@ -180,9 +189,13 @@ class ModelRunner: ...@@ -180,9 +189,13 @@ class ModelRunner:
"enable_dp_attention": server_args.enable_dp_attention, "enable_dp_attention": server_args.enable_dp_attention,
"enable_ep_moe": server_args.enable_ep_moe, "enable_ep_moe": server_args.enable_ep_moe,
"device": server_args.device, "device": server_args.device,
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
"enable_flashinfer_mla": server_args.enable_flashinfer_mla, "enable_flashinfer_mla": server_args.enable_flashinfer_mla,
"disable_radix_cache": server_args.disable_radix_cache, "disable_radix_cache": server_args.disable_radix_cache,
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged, "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
} }
) )
...@@ -199,6 +212,18 @@ class ModelRunner: ...@@ -199,6 +212,18 @@ class ModelRunner:
self.sampler = Sampler() self.sampler = Sampler()
self.load_model() self.load_model()
# Handle the case where some of models don't finish loading.
try:
dist.monitored_barrier(
group=get_tp_group().cpu_group,
timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
wait_all_ranks=True,
)
except RuntimeError:
raise ValueError(
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
) from None
# Apply torchao quantization # Apply torchao quantization
torchao_applied = getattr(self.model, "torchao_applied", False) torchao_applied = getattr(self.model, "torchao_applied", False)
# In layered loading, torchao may have been applied # In layered loading, torchao may have been applied
...@@ -625,6 +650,9 @@ class ModelRunner: ...@@ -625,6 +650,9 @@ class ModelRunner:
4096, 4096,
) )
if SGLANG_CI_SMALL_KV_SIZE:
self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
if not self.spec_algorithm.is_none(): if not self.spec_algorithm.is_none():
if self.is_draft_worker: if self.is_draft_worker:
self.max_total_num_tokens = self.server_args.draft_runner_cache_size self.max_total_num_tokens = self.server_args.draft_runner_cache_size
...@@ -655,6 +683,7 @@ class ModelRunner: ...@@ -655,6 +683,7 @@ class ModelRunner:
device=self.device, device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver, enable_memory_saver=self.server_args.enable_memory_saver,
) )
if ( if (
self.model_config.attention_arch == AttentionArch.MLA self.model_config.attention_arch == AttentionArch.MLA
and not self.server_args.disable_mla and not self.server_args.disable_mla
...@@ -758,9 +787,13 @@ class ModelRunner: ...@@ -758,9 +787,13 @@ class ModelRunner:
return return
tic = time.time() tic = time.time()
logger.info("Capture cuda graph begin. This can take up to several minutes.") logger.info(
f"Capture cuda graph begin. This can take up to several minutes. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
self.cuda_graph_runner = CudaGraphRunner(self) self.cuda_graph_runner = CudaGraphRunner(self)
logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s") logger.info(
f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
def apply_torch_tp(self): def apply_torch_tp(self):
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.") logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
...@@ -820,11 +853,10 @@ class ModelRunner: ...@@ -820,11 +853,10 @@ class ModelRunner:
else: else:
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}") raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
def sample( def _preprocess_logits(
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
) -> torch.Tensor: ):
# Apply logit bias # Apply logit bias
sampling_info = forward_batch.sampling_info
if sampling_info.sampling_info_done: if sampling_info.sampling_info_done:
# Overlap mode: the function update_regex_vocab_mask was executed # Overlap mode: the function update_regex_vocab_mask was executed
# in process_batch_result of the last batch. # in process_batch_result of the last batch.
...@@ -833,15 +865,77 @@ class ModelRunner: ...@@ -833,15 +865,77 @@ class ModelRunner:
else: else:
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass. # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
sampling_info.update_regex_vocab_mask() sampling_info.update_regex_vocab_mask()
sampling_info.update_penalties()
sampling_info.apply_logits_bias(logits_output.next_token_logits) sampling_info.apply_logits_bias(logits_output.next_token_logits)
def update_output_logprobs(
self,
logits_output: LogitsProcessorOutput,
sampling_info: SamplingBatchInfo,
top_logprobs_nums: List[int],
token_ids_logprobs: List[int],
next_token_ids: torch.Tensor,
*,
num_tokens_per_req: List[int],
):
"""Update the logits_output's output logprob based on next_token_ids
Args:
logits_output: The logits output from the model forward
sampling_info: Sampling info for logprob calculation
top_logprobs_nums: Number of logprobs per request.
next_token_ids: Next token ids.
num_tokens_per_req: The number of tokens per request.
Returns:
A list of next_token_ids
"""
self._preprocess_logits(logits_output, sampling_info)
# We should repeat top_logprobs_nums to match num_tokens_per_req.
top_logprobs_nums_repeat_interleaved = []
token_ids_logprobs_repeat_interleaved = []
for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
self.sampler(
logits_output,
sampling_info,
True,
top_logprobs_nums_repeat_interleaved,
token_ids_logprobs_repeat_interleaved,
batch_next_token_ids=next_token_ids,
)
def sample(
self,
logits_output: LogitsProcessorOutput,
forward_batch: ForwardBatch,
) -> torch.Tensor:
"""Sample and compute logprobs and update logits_output.
Args:
logits_output: The logits output from the model forward
forward_batch: The forward batch that generates logits_output
Returns:
A list of next_token_ids
"""
# For duplex models with multiple output streams.
if isinstance(logits_output, tuple):
return torch.stack(
[self.sample(values, forward_batch) for values in logits_output],
axis=-1,
)
self._preprocess_logits(logits_output, forward_batch.sampling_info)
# Sample the next tokens # Sample the next tokens
next_token_ids = self.sampler( next_token_ids = self.sampler(
logits_output, logits_output,
sampling_info, forward_batch.sampling_info,
forward_batch.return_logprob, forward_batch.return_logprob,
forward_batch.top_logprobs_nums, forward_batch.top_logprobs_nums,
forward_batch.token_ids_logprobs,
) )
return next_token_ids return next_token_ids
......
...@@ -25,10 +25,10 @@ import filelock ...@@ -25,10 +25,10 @@ import filelock
import gguf import gguf
import huggingface_hub.constants import huggingface_hub.constants
import numpy as np import numpy as np
import safetensors.torch
import torch import torch
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm from tqdm.auto import tqdm
from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.load_config import LoadConfig
...@@ -62,7 +62,6 @@ enable_hf_transfer() ...@@ -62,7 +62,6 @@ enable_hf_transfer()
class DisabledTqdm(tqdm): class DisabledTqdm(tqdm):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, disable=True) super().__init__(*args, **kwargs, disable=True)
...@@ -121,7 +120,7 @@ def convert_bin_to_safetensor_file( ...@@ -121,7 +120,7 @@ def convert_bin_to_safetensor_file(
) )
# check if the tensors are the same # check if the tensors are the same
reloaded = load_file(sf_filename) reloaded = safetensors.torch.load_file(sf_filename)
for k in loaded: for k in loaded:
pt_tensor = loaded[k] pt_tensor = loaded[k]
sf_tensor = reloaded[k] sf_tensor = reloaded[k]
...@@ -133,7 +132,6 @@ def convert_bin_to_safetensor_file( ...@@ -133,7 +132,6 @@ def convert_bin_to_safetensor_file(
def get_quant_config( def get_quant_config(
model_config: ModelConfig, load_config: LoadConfig model_config: ModelConfig, load_config: LoadConfig
) -> QuantizationConfig: ) -> QuantizationConfig:
quant_cls = get_quantization_config(model_config.quantization) quant_cls = get_quantization_config(model_config.quantization)
# GGUF doesn't have config file # GGUF doesn't have config file
...@@ -402,15 +400,34 @@ def np_cache_weights_iterator( ...@@ -402,15 +400,34 @@ def np_cache_weights_iterator(
yield name, torch.from_numpy(param) yield name, torch.from_numpy(param)
def decrypt(fn, key):
raise NotImplementedError()
def safetensors_encrypted_weights_iterator(
hf_weights_files: List[str],
is_all_weights_sharded: bool = False,
decryption_key: Optional[str] = None,
):
raise NotImplementedError()
def safetensors_weights_iterator( def safetensors_weights_iterator(
hf_weights_files: List[str], hf_weights_files: List[str],
is_all_weights_sharded: bool = False, is_all_weights_sharded: bool = False,
decryption_key: Optional[str] = None,
) -> Generator[Tuple[str, torch.Tensor], None, None]: ) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files. """Iterate over the weights in the model safetensor files.
If is_all_weights_sharded is True, it uses more optimize read by reading an If is_all_weights_sharded is True, it uses more optimize read by reading an
entire file instead of reading each tensor one by one. entire file instead of reading each tensor one by one.
""" """
if decryption_key:
yield from safetensors_encrypted_weights_iterator(
hf_weights_files, is_all_weights_sharded, decryption_key
)
return
enable_tqdm = ( enable_tqdm = (
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
) )
...@@ -420,15 +437,9 @@ def safetensors_weights_iterator( ...@@ -420,15 +437,9 @@ def safetensors_weights_iterator(
disable=not enable_tqdm, disable=not enable_tqdm,
bar_format=_BAR_FORMAT, bar_format=_BAR_FORMAT,
): ):
if not is_all_weights_sharded: result = safetensors.torch.load_file(st_file, device="cpu")
with safe_open(st_file, framework="pt") as f: for name, param in result.items():
for name in f.keys(): # noqa: SIM118 yield name, param
param = f.get_tensor(name)
yield name, param
else:
result = load_file(st_file, device="cpu")
for name, param in result.items():
yield name, param
def pt_weights_iterator( def pt_weights_iterator(
......
from .orchestrator import BatchedPenalizerOrchestrator from sglang.srt.sampling.penaltylib.frequency_penalty import BatchedFrequencyPenalizer
from .penalizers.frequency_penalty import BatchedFrequencyPenalizer from sglang.srt.sampling.penaltylib.min_new_tokens import BatchedMinNewTokensPenalizer
from .penalizers.min_new_tokens import BatchedMinNewTokensPenalizer from sglang.srt.sampling.penaltylib.orchestrator import BatchedPenalizerOrchestrator
from .penalizers.presence_penalty import BatchedPresencePenalizer from sglang.srt.sampling.penaltylib.presence_penalty import BatchedPresencePenalizer
from .penalizers.repetition_penalty import BatchedRepetitionPenalizer
__all__ = [ __all__ = [
"BatchedFrequencyPenalizer", "BatchedFrequencyPenalizer",
"BatchedMinNewTokensPenalizer", "BatchedMinNewTokensPenalizer",
"BatchedPresencePenalizer", "BatchedPresencePenalizer",
"BatchedRepetitionPenalizer",
"BatchedPenalizerOrchestrator", "BatchedPenalizerOrchestrator",
] ]
from typing import List
import torch import torch
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs from sglang.srt.sampling.penaltylib.orchestrator import (
BatchedPenalizerOrchestrator,
_BatchedPenalizer,
)
class BatchedFrequencyPenalizer(_BatchedPenalizer): class BatchedFrequencyPenalizer(_BatchedPenalizer):
...@@ -10,8 +11,9 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer): ...@@ -10,8 +11,9 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
Frequency penalizer penalizes tokens based on their frequency in the output. Frequency penalizer penalizes tokens based on their frequency in the output.
""" """
frequency_penalties: torch.Tensor = None def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
cumulated_frequency_penalties: torch.Tensor = None self.orchestrator = orchestrator
self._is_prepared = False
def _is_required(self) -> bool: def _is_required(self) -> bool:
return any( return any(
...@@ -20,14 +22,10 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer): ...@@ -20,14 +22,10 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
) )
def _prepare(self): def _prepare(self):
self.cumulated_frequency_penalties = ( self.cumulated_frequency_penalties = torch.zeros(
torch.tensor( (len(self.orchestrator.reqs()), self.orchestrator.vocab_size),
data=[0.0 for _ in self.orchestrator.reqs()], dtype=torch.float32,
dtype=torch.float32, device=self.orchestrator.device,
device=self.orchestrator.device,
)
.unsqueeze_(1)
.repeat(1, self.orchestrator.vocab_size)
) )
self.frequency_penalties = ( self.frequency_penalties = (
...@@ -39,33 +37,26 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer): ...@@ -39,33 +37,26 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
dtype=torch.float32, dtype=torch.float32,
device=self.orchestrator.device, device=self.orchestrator.device,
) )
.unsqueeze_(1) ).unsqueeze_(1)
.expand_as(self.cumulated_frequency_penalties)
)
def _teardown(self):
self.frequency_penalties = None
self.cumulated_frequency_penalties = None
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
pass
def _cumulate_output_tokens(self, output_ids: _TokenIDs): def _cumulate_output_tokens(self, output_ids: torch.Tensor):
self.cumulated_frequency_penalties += ( self.cumulated_frequency_penalties.scatter_add_(
self.frequency_penalties * output_ids.occurrence_count() dim=1,
index=output_ids.unsqueeze(1),
src=self.frequency_penalties,
) )
def _apply(self, logits: torch.Tensor) -> torch.Tensor: def _apply(self, logits: torch.Tensor) -> torch.Tensor:
logits -= self.cumulated_frequency_penalties logits.sub_(self.cumulated_frequency_penalties)
return logits
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor): def _filter(self, keep_indices: torch.Tensor):
self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep] self.frequency_penalties = self.frequency_penalties[keep_indices]
self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[ self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[
indices_tensor_to_keep keep_indices
] ]
def _merge(self, their: "BatchedFrequencyPenalizer"): def _merge(self, their: "BatchedFrequencyPenalizer"):
print(f"{self.frequency_penalties.shape=}, {their.frequency_penalties.shape=}")
self.frequency_penalties = torch.cat( self.frequency_penalties = torch.cat(
[self.frequency_penalties, their.frequency_penalties], dim=0 [self.frequency_penalties, their.frequency_penalties], dim=0
) )
......
from typing import List
import torch import torch
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs from sglang.srt.sampling.penaltylib.orchestrator import (
BatchedPenalizerOrchestrator,
_BatchedPenalizer,
)
class BatchedMinNewTokensPenalizer(_BatchedPenalizer): class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
...@@ -10,9 +11,9 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer): ...@@ -10,9 +11,9 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
Min new tokens penalizer penalizes tokens based on the length of the output. Min new tokens penalizer penalizes tokens based on the length of the output.
""" """
min_new_tokens: torch.Tensor = None def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
stop_token_penalties: torch.Tensor = None self.orchestrator = orchestrator
len_output_tokens: torch.Tensor = None self._is_prepared = False
def _is_required(self) -> bool: def _is_required(self) -> bool:
return any( return any(
...@@ -47,7 +48,7 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer): ...@@ -47,7 +48,7 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
padding_value=self.orchestrator.vocab_size, padding_value=self.orchestrator.vocab_size,
) )
self.stop_token_penalties = torch.zeros( self.stop_token_penalties = torch.zeros(
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1), size=(len(self.orchestrator.reqs()), self.orchestrator.vocab_size + 1),
dtype=torch.float32, dtype=torch.float32,
device=self.orchestrator.device, device=self.orchestrator.device,
).scatter_add_( ).scatter_add_(
...@@ -64,31 +65,22 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer): ...@@ -64,31 +65,22 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
] ]
self.len_output_tokens = torch.zeros( self.len_output_tokens = torch.zeros(
size=(self.orchestrator.batch_size(), 1), size=(len(self.orchestrator.reqs()), 1),
dtype=torch.int32, dtype=torch.int32,
device=self.orchestrator.device, device=self.orchestrator.device,
) )
def _teardown(self): def _cumulate_output_tokens(self, output_ids: torch.Tensor):
self.min_new_tokens = None
self.stop_token_penalties = None
self.len_output_tokens = None
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
pass
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
self.len_output_tokens += 1 self.len_output_tokens += 1
def _apply(self, logits: torch.Tensor) -> torch.Tensor: def _apply(self, logits: torch.Tensor):
mask = (self.len_output_tokens < self.min_new_tokens).expand_as(logits) mask = (self.len_output_tokens < self.min_new_tokens).expand_as(logits)
logits[mask] += self.stop_token_penalties[mask] logits[mask] += self.stop_token_penalties[mask]
return logits
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor): def _filter(self, keep_indices: torch.Tensor):
self.min_new_tokens = self.min_new_tokens[indices_tensor_to_keep] self.min_new_tokens = self.min_new_tokens[keep_indices]
self.stop_token_penalties = self.stop_token_penalties[indices_tensor_to_keep] self.stop_token_penalties = self.stop_token_penalties[keep_indices]
self.len_output_tokens = self.len_output_tokens[indices_tensor_to_keep] self.len_output_tokens = self.len_output_tokens[keep_indices]
def _merge(self, their: "BatchedMinNewTokensPenalizer"): def _merge(self, their: "BatchedMinNewTokensPenalizer"):
self.min_new_tokens = torch.cat( self.min_new_tokens = torch.cat(
......
from __future__ import annotations
import abc import abc
import dataclasses from typing import TYPE_CHECKING, Set, Type
from typing import List, Set, Type, Union
import torch import torch
if TYPE_CHECKING:
@dataclasses.dataclass from sglang.srt.managers.schedule_batch import ScheduleBatch
class _ReqLike:
origin_input_ids: List[int]
@dataclasses.dataclass
class _BatchLike:
reqs: List[_ReqLike]
def batch_size(self):
return len(self.reqs)
class BatchedPenalizerOrchestrator: class BatchedPenalizerOrchestrator:
def __init__( def __init__(
self, self,
vocab_size: int, vocab_size: int,
batch: _BatchLike, batch: ScheduleBatch,
device: str, penalizers: Set[Type["_BatchedPenalizer"]],
Penalizers: Set[Type["_BatchedPenalizer"]],
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.batch = batch self.batch = batch
self.device = device self.device = batch.device
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers} self.penalizers = {Penalizer: Penalizer(self) for Penalizer in penalizers}
is_required = False is_required = False
for penalizer in self.penalizers.values(): for penalizer in self.penalizers.values():
...@@ -37,31 +27,9 @@ class BatchedPenalizerOrchestrator: ...@@ -37,31 +27,9 @@ class BatchedPenalizerOrchestrator:
is_required |= pen_is_required is_required |= pen_is_required
self.is_required = is_required self.is_required = is_required
input_ids = [
torch.tensor(req.origin_input_ids, dtype=torch.int64, device=self.device)
for req in self.reqs()
]
if self.is_required:
self.cumulate_input_tokens(input_ids=input_ids)
def reqs(self): def reqs(self):
return self.batch.reqs return self.batch.reqs
def batch_size(self):
return self.batch.batch_size()
def cumulate_input_tokens(self, input_ids: List[torch.Tensor]):
"""
Feed the input tokens to the penalizers.
Args:
input_ids (List[torch.Tensor]): The input tokens.
"""
token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids)
for penalizer in self.penalizers.values():
penalizer.cumulate_input_tokens(input_ids=token_ids)
def cumulate_output_tokens(self, output_ids: torch.Tensor): def cumulate_output_tokens(self, output_ids: torch.Tensor):
""" """
Feed the output tokens to the penalizers. Feed the output tokens to the penalizers.
...@@ -69,13 +37,8 @@ class BatchedPenalizerOrchestrator: ...@@ -69,13 +37,8 @@ class BatchedPenalizerOrchestrator:
Args: Args:
output_ids (torch.Tensor): The output tokens. output_ids (torch.Tensor): The output tokens.
""" """
if not self.is_required:
return
token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids)
for penalizer in self.penalizers.values(): for penalizer in self.penalizers.values():
penalizer.cumulate_output_tokens(output_ids=token_ids) penalizer.cumulate_output_tokens(output_ids=output_ids)
def apply(self, logits: torch.Tensor) -> torch.Tensor: def apply(self, logits: torch.Tensor) -> torch.Tensor:
""" """
...@@ -88,48 +51,33 @@ class BatchedPenalizerOrchestrator: ...@@ -88,48 +51,33 @@ class BatchedPenalizerOrchestrator:
Returns: Returns:
torch.Tensor: The logits after applying the penalizers. torch.Tensor: The logits after applying the penalizers.
""" """
if not self.is_required:
return
for penalizer in self.penalizers.values(): for penalizer in self.penalizers.values():
logits = penalizer.apply(logits) penalizer.apply(logits)
return logits
def filter( def filter(self, keep_indices: torch.Tensor):
self,
indices_to_keep: List[int],
indices_tensor_to_keep: torch.Tensor = None,
):
""" """
Filter the penalizers based on the indices to keep in the batch. Filter the penalizers based on the indices to keep in the batch.
Args: Args:
indices_to_keep (List[int]): List of indices to keep in the batch. keep_indices (torch.Tensor): Tensor of indices to keep in the batch.
indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
""" """
if not self.is_required: if not self.is_required:
return return
empty_indices = len(indices_to_keep) == 0 if len(keep_indices) == 0:
self.is_required = False
for penalizer in self.penalizers.values():
penalizer.teardown()
return
is_required = False is_required = False
for penalizer in self.penalizers.values(): for penalizer in self.penalizers.values():
tmp_is_required = penalizer.is_required() tmp_is_required = penalizer.is_required()
is_required = is_required or tmp_is_required is_required |= tmp_is_required
if not tmp_is_required or empty_indices: if tmp_is_required:
penalizer.teardown() penalizer.filter(keep_indices=keep_indices)
else: else:
# create tensor index only when it's needed penalizer.teardown()
if indices_tensor_to_keep is None:
indices_tensor_to_keep = torch.tensor(
indices_to_keep, dtype=torch.int32, device=self.device
)
penalizer.filter(
indices_to_keep=indices_to_keep,
indices_tensor_to_keep=indices_tensor_to_keep,
)
self.is_required = is_required self.is_required = is_required
def merge(self, their: "BatchedPenalizerOrchestrator"): def merge(self, their: "BatchedPenalizerOrchestrator"):
...@@ -146,75 +94,9 @@ class BatchedPenalizerOrchestrator: ...@@ -146,75 +94,9 @@ class BatchedPenalizerOrchestrator:
if not self.is_required and not their.is_required: if not self.is_required and not their.is_required:
return return
self.is_required |= their.is_required self.is_required = True
for Penalizer, their_penalizer in their.penalizers.items(): for penalizer, their_penalizer in their.penalizers.items():
if Penalizer not in self.penalizers: self.penalizers[penalizer].merge(their_penalizer)
raise ValueError(f"Penalizer {Penalizer} not found in self.penalizers")
self.penalizers[Penalizer].merge(their_penalizer)
class _TokenIDs:
"""
A class that wraps token IDs to provide additional utility functions to penalizers.
Attributes:
orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to.
token_ids (Union[torch.Tensor, List[torch.Tensor]]): The token IDs.
cached_counts (torch.Tensor): The cached occurrence count tensor.
"""
def __init__(
self,
orchestrator: BatchedPenalizerOrchestrator,
token_ids: Union[torch.Tensor, List[torch.Tensor]],
):
self.orchestrator = orchestrator
self.token_ids = token_ids
self.cached_counts = None
def occurrence_count(self) -> torch.Tensor:
"""
Returns a tensor of shape (batch_size, vocab_size) where each element is the number of times the corresponding token appears in the batch.
Returns:
torch.Tensor: The occurrence count tensor.
"""
if self.cached_counts is not None:
return self.cached_counts
token_ids = self.token_ids
if isinstance(token_ids, list):
# TODO: optimize this part
padded_token_ids = torch.nn.utils.rnn.pad_sequence(
sequences=token_ids,
batch_first=True,
padding_value=self.orchestrator.vocab_size,
)
self.cached_counts = torch.zeros(
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
dtype=torch.int64,
device=self.orchestrator.device,
).scatter_add_(
dim=1,
index=padded_token_ids,
src=torch.ones_like(padded_token_ids),
)[
:, : self.orchestrator.vocab_size
]
else:
# TODO: optimize this part. We do not need to create this big tensor every time.
# We can directly apply the results on the logits.
self.cached_counts = torch.zeros(
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size),
device=self.orchestrator.device,
)
self.cached_counts[
torch.arange(len(token_ids), device=self.orchestrator.device), token_ids
] = 1
return self.cached_counts
class _BatchedPenalizer(abc.ABC): class _BatchedPenalizer(abc.ABC):
...@@ -222,10 +104,6 @@ class _BatchedPenalizer(abc.ABC): ...@@ -222,10 +104,6 @@ class _BatchedPenalizer(abc.ABC):
An abstract class for a batched penalizer. An abstract class for a batched penalizer.
""" """
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
self.orchestrator = orchestrator
self._is_prepared = False
def is_prepared(self) -> bool: def is_prepared(self) -> bool:
return self._is_prepared return self._is_prepared
...@@ -233,51 +111,40 @@ class _BatchedPenalizer(abc.ABC): ...@@ -233,51 +111,40 @@ class _BatchedPenalizer(abc.ABC):
return self._is_required() return self._is_required()
def prepare(self): def prepare(self):
if not self.is_prepared(): if not self._is_prepared:
self._prepare() self._prepare()
self._is_prepared = True self._is_prepared = True
def prepare_if_required(self): def prepare_if_required(self):
if self.is_required(): if self._is_required():
self.prepare() self.prepare()
return True return True
else: else:
return False return False
def teardown(self): def teardown(self):
if self.is_prepared(): self._is_prepared = False
self._teardown()
self._is_prepared = False
def cumulate_input_tokens(self, input_ids: _TokenIDs):
if not self.is_prepared():
return
self._cumulate_input_tokens(input_ids=input_ids)
def cumulate_output_tokens(self, output_ids: _TokenIDs): def cumulate_output_tokens(self, output_ids: torch.Tensor):
if not self.is_prepared(): if not self._is_prepared:
return return
self._cumulate_output_tokens(output_ids=output_ids) self._cumulate_output_tokens(output_ids=output_ids)
def apply(self, logits: torch.Tensor) -> torch.Tensor: def apply(self, logits: torch.Tensor) -> torch.Tensor:
if not self.is_prepared(): if not self._is_prepared:
return logits return
return self._apply(logits=logits) self._apply(logits=logits)
def filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor): def filter(self, keep_indices: torch.Tensor):
if not self.is_prepared(): if not self._is_prepared:
return return
self._filter( self._filter(keep_indices=keep_indices)
indices_to_keep=indices_to_keep,
indices_tensor_to_keep=indices_tensor_to_keep,
)
def merge(self, their: "_BatchedPenalizer"): def merge(self, their: "_BatchedPenalizer"):
if not self.is_prepared() and not their.is_prepared(): if not self._is_prepared and not their._is_prepared:
return return
self.prepare() self.prepare()
...@@ -300,23 +167,7 @@ class _BatchedPenalizer(abc.ABC): ...@@ -300,23 +167,7 @@ class _BatchedPenalizer(abc.ABC):
pass pass
@abc.abstractmethod @abc.abstractmethod
def _teardown(self): def _cumulate_output_tokens(self, output_ids: torch.Tensor):
"""
Tear down the penalizer.
Usually, this is where the penalizer frees its tensors.
"""
pass
@abc.abstractmethod
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
"""
Cumulate the input tokens.
Orchestrator will call this function to feed the input tokens to the penalizer.
"""
pass
@abc.abstractmethod
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
""" """
Cumulate the output tokens. Cumulate the output tokens.
Orchestrator will call this function to feed the output tokens to the penalizer. Orchestrator will call this function to feed the output tokens to the penalizer.
...@@ -332,7 +183,7 @@ class _BatchedPenalizer(abc.ABC): ...@@ -332,7 +183,7 @@ class _BatchedPenalizer(abc.ABC):
pass pass
@abc.abstractmethod @abc.abstractmethod
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor): def _filter(self, keep_indices: torch.Tensor):
""" """
Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch. Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch.
""" """
......
from typing import List
import torch
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
from sglang.srt.utils import get_compiler_backend
@torch.compile(dynamic=True, backend=get_compiler_backend())
def apply_scaling_penalties(logits, scaling_penalties):
logits[:] = torch.where(
logits > 0,
logits / scaling_penalties,
logits * scaling_penalties,
)
class BatchedRepetitionPenalizer(_BatchedPenalizer):
"""
Repetition penalizer penalizes tokens based on their repetition in the input and output.
"""
repetition_penalties: torch.Tensor = None
cumulated_repetition_penalties: torch.Tensor = None
def _is_required(self) -> bool:
return any(
req.sampling_params.repetition_penalty != 1.0
for req in self.orchestrator.reqs()
)
def _prepare(self):
self.cumulated_repetition_penalties = (
torch.tensor(
data=[1.0 for _ in self.orchestrator.reqs()],
dtype=torch.float32,
device=self.orchestrator.device,
)
.unsqueeze_(1)
.repeat(1, self.orchestrator.vocab_size)
)
self.repetition_penalties = (
torch.tensor(
data=[
req.sampling_params.repetition_penalty
for req in self.orchestrator.reqs()
],
dtype=torch.float32,
device=self.orchestrator.device,
)
.unsqueeze_(1)
.expand_as(self.cumulated_repetition_penalties)
)
def _teardown(self):
self.repetition_penalties = None
self.cumulated_repetition_penalties = None
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
mask = input_ids.occurrence_count() > 0
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
mask = output_ids.occurrence_count() > 0
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
apply_scaling_penalties(logits, self.cumulated_repetition_penalties)
return logits
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[
indices_tensor_to_keep
]
def _merge(self, their: "BatchedRepetitionPenalizer"):
self.repetition_penalties = torch.cat(
[self.repetition_penalties, their.repetition_penalties], dim=0
)
self.cumulated_repetition_penalties = torch.cat(
[self.cumulated_repetition_penalties, their.cumulated_repetition_penalties],
dim=0,
)
from typing import List
import torch import torch
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs from sglang.srt.sampling.penaltylib.orchestrator import (
BatchedPenalizerOrchestrator,
_BatchedPenalizer,
)
class BatchedPresencePenalizer(_BatchedPenalizer): class BatchedPresencePenalizer(_BatchedPenalizer):
...@@ -10,8 +11,9 @@ class BatchedPresencePenalizer(_BatchedPenalizer): ...@@ -10,8 +11,9 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
Presence penalizer penalizes tokens based on their presence in the output. Presence penalizer penalizes tokens based on their presence in the output.
""" """
presence_penalties: torch.Tensor = None def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
cumulated_presence_penalties: torch.Tensor = None self.orchestrator = orchestrator
self._is_prepared = False
def _is_required(self) -> bool: def _is_required(self) -> bool:
return any( return any(
...@@ -20,14 +22,10 @@ class BatchedPresencePenalizer(_BatchedPenalizer): ...@@ -20,14 +22,10 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
) )
def _prepare(self): def _prepare(self):
self.cumulated_presence_penalties = ( self.cumulated_presence_penalties = torch.zeros(
torch.tensor( (len(self.orchestrator.reqs()), self.orchestrator.vocab_size),
data=[0.0 for _ in self.orchestrator.reqs()], dtype=torch.float32,
dtype=torch.float32, device=self.orchestrator.device,
device=self.orchestrator.device,
)
.unsqueeze_(1)
.repeat(1, self.orchestrator.vocab_size)
) )
self.presence_penalties = ( self.presence_penalties = (
...@@ -39,32 +37,26 @@ class BatchedPresencePenalizer(_BatchedPenalizer): ...@@ -39,32 +37,26 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
dtype=torch.float32, dtype=torch.float32,
device=self.orchestrator.device, device=self.orchestrator.device,
) )
.unsqueeze_(1) ).unsqueeze_(1)
.expand_as(self.cumulated_presence_penalties)
)
def _teardown(self):
self.presence_penalties = None
self.cumulated_presence_penalties = None
def _cumulate_input_tokens(self, input_ids: _TokenIDs): def _cumulate_output_tokens(self, output_ids: torch.Tensor):
pass self.cumulated_presence_penalties.scatter_(
dim=1,
def _cumulate_output_tokens(self, output_ids: _TokenIDs): index=output_ids.unsqueeze(1),
mask = output_ids.occurrence_count() > 0 src=self.presence_penalties,
self.cumulated_presence_penalties[mask] = self.presence_penalties[mask] )
def _apply(self, logits: torch.Tensor) -> torch.Tensor: def _apply(self, logits: torch.Tensor) -> torch.Tensor:
logits -= self.cumulated_presence_penalties logits.sub_(self.cumulated_presence_penalties)
return logits
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor): def _filter(self, keep_indices: torch.Tensor):
self.presence_penalties = self.presence_penalties[indices_tensor_to_keep] self.presence_penalties = self.presence_penalties[keep_indices]
self.cumulated_presence_penalties = self.cumulated_presence_penalties[ self.cumulated_presence_penalties = self.cumulated_presence_penalties[
indices_tensor_to_keep keep_indices
] ]
def _merge(self, their: "BatchedPresencePenalizer"): def _merge(self, their: "BatchedPresencePenalizer"):
print(f"{self.presence_penalties.shape=}, {their.presence_penalties.shape=}")
self.presence_penalties = torch.cat( self.presence_penalties = torch.cat(
[self.presence_penalties, their.presence_penalties], dim=0 [self.presence_penalties, their.presence_penalties], dim=0
) )
......
...@@ -9,9 +9,6 @@ import torch ...@@ -9,9 +9,6 @@ import torch
import sglang.srt.sampling.penaltylib as penaltylib import sglang.srt.sampling.penaltylib as penaltylib
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.sampling.penaltylib.penalizers.repetition_penalty import (
apply_scaling_penalties,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -22,49 +19,45 @@ if TYPE_CHECKING: ...@@ -22,49 +19,45 @@ if TYPE_CHECKING:
@dataclasses.dataclass @dataclasses.dataclass
class SamplingBatchInfo: class SamplingBatchInfo:
# Batched sampling params # Basic batched sampling params
temperatures: torch.Tensor temperatures: torch.Tensor
top_ps: torch.Tensor top_ps: torch.Tensor
top_ks: torch.Tensor top_ks: torch.Tensor
min_ps: torch.Tensor min_ps: torch.Tensor
# All requests use greedy sampling # Whether all requests use greedy sampling
is_all_greedy: bool is_all_greedy: bool
# Dispatch in CUDA graph # Whether any request needs min_p sampling
need_min_p_sampling: bool need_min_p_sampling: bool
# Whether any request has custom logit processor # Masking tensors for grammar-guided structured outputs
has_custom_logit_processor: bool
# Bias Tensors
vocab_size: int vocab_size: int
grammars: Optional[List] = None grammars: Optional[List] = None
sampling_info_done: Optional[threading.Event] = None
logit_bias: torch.Tensor = None
vocab_mask: Optional[torch.Tensor] = None vocab_mask: Optional[torch.Tensor] = None
apply_mask: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
# An event used for overlap schedule
sampling_info_done: Optional[threading.Event] = None
# Penalizer # Penalizer
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
linear_penalties: Optional[torch.Tensor] = None linear_penalty: torch.Tensor = None
scaling_penalties: Optional[torch.Tensor] = None
# Device # Whether any request has custom logit processor
device: str = "cuda" has_custom_logit_processor: bool = False
# Custom parameters
# Custom Parameters
custom_params: Optional[List[Optional[Dict[str, Any]]]] = None custom_params: Optional[List[Optional[Dict[str, Any]]]] = None
# Custom logit processor
# Custom Logit Processor
custom_logit_processor: Optional[ custom_logit_processor: Optional[
Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]] Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]
] = None ] = None
# Device
device: str = "cuda"
@classmethod @classmethod
def from_schedule_batch( def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
cls, batch: ScheduleBatch, vocab_size: int, enable_overlap_schedule: bool
):
reqs = batch.reqs reqs = batch.reqs
device = batch.device device = batch.device
temperatures = ( temperatures = (
...@@ -118,106 +111,60 @@ class SamplingBatchInfo: ...@@ -118,106 +111,60 @@ class SamplingBatchInfo:
merged_custom_logit_processor = None merged_custom_logit_processor = None
custom_params = None custom_params = None
ret = cls(
temperatures=temperatures,
top_ps=top_ps,
top_ks=top_ks,
min_ps=min_ps,
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
has_custom_logit_processor=has_custom_logit_processor,
vocab_size=vocab_size,
device=device,
custom_params=custom_params,
custom_logit_processor=merged_custom_logit_processor,
)
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
if enable_overlap_schedule:
# TODO (lianmin): Some penalizers such as frequency and presence depend on model outputs,
# so it is kind of tricky to make it work with overlap scheduler.
# It requires correcly updating the penalty logits before the sampling and syncing the events.
# We will support them later.
penalizers = {
penaltylib.BatchedMinNewTokensPenalizer,
}
if (
any(req.sampling_params.frequency_penalty != 0.0 for req in reqs)
or any(req.sampling_params.presence_penalty != 0.0 for req in reqs)
or any(req.sampling_params.repetition_penalty != 1.0 for req in reqs)
):
logger.warning(
"frequency_penalty, presence_penalty, and repetition_penalty are not supported "
"when using the default overlap scheduler. They will be ignored. "
"Please add `--disable-overlap` when launching the server if you need these features. "
"The speed will be slower in that case."
)
else:
penalizers = {
penaltylib.BatchedFrequencyPenalizer,
penaltylib.BatchedMinNewTokensPenalizer,
penaltylib.BatchedPresencePenalizer,
penaltylib.BatchedRepetitionPenalizer,
}
# Each penalizers will do nothing if they evaluate themselves as not required by looking at # Each penalizers will do nothing if they evaluate themselves as not required by looking at
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this # the sampling_params of the requests (See {_is_required()} of each penalizers). So this
# should not add hefty computation overhead other than simple checks. # should not add hefty computation overhead other than simple checks.
# #
# While we choose not to even create the class instances if they are not required, this # While we can choose not to even create the class instances if they are not required, this
# could add additional complexity to the {ScheduleBatch} class, especially we need to # could add additional complexity to the {ScheduleBatch} class, especially we need to
# handle {filter_batch()} and {merge_batch()} cases as well. # handle {filter_batch()} and {merge_batch()} cases as well.
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator( penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
vocab_size=vocab_size, vocab_size=vocab_size,
batch=batch, batch=batch,
device=batch.device, penalizers={
Penalizers=penalizers, penaltylib.BatchedFrequencyPenalizer,
penaltylib.BatchedMinNewTokensPenalizer,
penaltylib.BatchedPresencePenalizer,
},
) )
# Handle logit bias but only allocate when needed ret = cls(
ret.logit_bias = None temperatures=temperatures,
top_ps=top_ps,
top_ks=top_ks,
min_ps=min_ps,
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
vocab_size=vocab_size,
penalizer_orchestrator=penalizer_orchestrator,
has_custom_logit_processor=has_custom_logit_processor,
custom_params=custom_params,
custom_logit_processor=merged_custom_logit_processor,
device=device,
)
return ret return ret
def __len__(self): def __len__(self):
return len(self.temperatures) return len(self.temperatures)
def update_penalties(self):
self.scaling_penalties = None
self.linear_penalties = None
for penalizer in self.penalizer_orchestrator.penalizers.values():
if not penalizer.is_prepared():
continue
if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer):
self.scaling_penalties = penalizer.cumulated_repetition_penalties
else:
if self.linear_penalties is None:
bs = self.penalizer_orchestrator.batch.batch_size()
self.linear_penalties = torch.zeros(
(bs, self.vocab_size),
dtype=torch.float32,
device=self.device,
)
self.linear_penalties = penalizer.apply(self.linear_penalties)
def update_regex_vocab_mask(self): def update_regex_vocab_mask(self):
if not self.grammars: if not self.grammars:
self.vocab_mask = None self.vocab_mask = None
self.apply_mask = None self.apply_mask_func = None
return return
# find a grammar from the list # Find a grammar from the list
first_grammar = next(grammar for grammar in self.grammars if grammar) first_grammar = next(grammar for grammar in self.grammars if grammar)
# maybe we can reuse the existing mask? # TODO(lianmin): Maybe we can reuse the existing mask?
self.vocab_mask = first_grammar.allocate_vocab_mask( self.vocab_mask = first_grammar.allocate_vocab_mask(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
batch_size=len(self.temperatures), batch_size=len(self.temperatures),
device=self.device, device=self.device,
) )
self.apply_mask = first_grammar.apply_vocab_mask # force to use static method self.apply_mask_func = (
first_grammar.apply_vocab_mask
) # force to use static method
# Apply the mask # Apply the mask
for i, grammar in enumerate(self.grammars): for i, grammar in enumerate(self.grammars):
...@@ -227,35 +174,56 @@ class SamplingBatchInfo: ...@@ -227,35 +174,56 @@ class SamplingBatchInfo:
# Move the mask to the device if needed # Move the mask to the device if needed
self.vocab_mask = first_grammar.move_vocab_mask(self.vocab_mask, self.device) self.vocab_mask = first_grammar.move_vocab_mask(self.vocab_mask, self.device)
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor): def update_penalties(self):
self.penalizer_orchestrator.filter(unfinished_indices, new_indices) if self.penalizer_orchestrator.is_required:
self.linear_penalty = torch.zeros(
(len(self.temperatures), self.vocab_size),
dtype=torch.float32,
device=self.temperatures.device,
)
self.penalizer_orchestrator.apply(self.linear_penalty)
else:
self.linear_penalty = None
def apply_logits_bias(self, logits: torch.Tensor):
if self.linear_penalty is not None:
# Used in the overlap mode
logits.add_(self.linear_penalty)
if self.penalizer_orchestrator and self.penalizer_orchestrator.is_required:
# Used in the non-overlap mode
self.penalizer_orchestrator.apply(logits)
if self.vocab_mask is not None:
self.apply_mask_func(logits=logits, vocab_mask=self.vocab_mask)
def filter_batch(self, keep_indices: List[int], keep_indices_device: torch.Tensor):
self.penalizer_orchestrator.filter(keep_indices_device)
if self.has_custom_logit_processor: if self.has_custom_logit_processor:
self._filter_batch_custom_logit_processor(unfinished_indices, new_indices) self._filter_batch_custom_logit_processor(keep_indices, keep_indices_device)
for item in [ for item in [
"temperatures", "temperatures",
"top_ps", "top_ps",
"top_ks", "top_ks",
"min_ps", "min_ps",
"logit_bias",
]: ]:
value = getattr(self, item, None) value = getattr(self, item, None)
if value is not None: # logit_bias can be None setattr(self, item, value[keep_indices_device])
setattr(self, item, value[new_indices])
def _filter_batch_custom_logit_processor( def _filter_batch_custom_logit_processor(
self, unfinished_indices: List[int], new_indices: torch.Tensor self, keep_indices: List[int], keep_indices_device: torch.Tensor
): ):
"""Filter the custom logit processor and custom params""" """Filter the custom logit processor and custom params"""
self.custom_logit_processor = { self.custom_logit_processor = {
k: (p, mask[new_indices]) k: (p, mask[keep_indices_device])
for k, (p, mask) in self.custom_logit_processor.items() for k, (p, mask) in self.custom_logit_processor.items()
if any( if torch.any(
mask[new_indices] mask[keep_indices_device]
) # ignore the custom logit processor whose mask is all False ) # ignore the custom logit processor whose mask is all False
} }
self.custom_params = [self.custom_params[i] for i in unfinished_indices] self.custom_params = [self.custom_params[i] for i in keep_indices]
# If the custom logit processor is an empty dict, set the flag to False, # If the custom logit processor is an empty dict, set the flag to False,
# and set the custom logit processor and custom params to None. # and set the custom logit processor and custom params to None.
...@@ -264,31 +232,6 @@ class SamplingBatchInfo: ...@@ -264,31 +232,6 @@ class SamplingBatchInfo:
self.custom_params = None self.custom_params = None
self.has_custom_logit_processor = False self.has_custom_logit_processor = False
@staticmethod
def merge_bias_tensor(
lhs: torch.Tensor,
rhs: torch.Tensor,
bs1: int,
bs2: int,
device: str,
default: int = 0,
):
# bias tensor can be None
if lhs is not None or rhs is not None:
shape, dtype = None, None
if lhs is not None:
shape, dtype = lhs.shape[1:], lhs.dtype
else:
shape, dtype = rhs.shape[1:], rhs.dtype
with torch.dtype(dtype):
if lhs is None:
lhs = torch.empty((bs1, *shape), device=device).fill_(default)
if rhs is None:
rhs = torch.empty((bs2, *shape), device=device).fill_(default)
return torch.cat([lhs, rhs])
return None
@staticmethod @staticmethod
def merge_custom_logit_processor( def merge_custom_logit_processor(
lhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]], lhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
...@@ -332,11 +275,6 @@ class SamplingBatchInfo: ...@@ -332,11 +275,6 @@ class SamplingBatchInfo:
def merge_batch(self, other: "SamplingBatchInfo"): def merge_batch(self, other: "SamplingBatchInfo"):
self.penalizer_orchestrator.merge(other.penalizer_orchestrator) self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
# Merge the logit bias tensor
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
self.logit_bias, other.logit_bias, len(self), len(other), self.device
)
# Merge the custom logit processors and custom params lists # Merge the custom logit processors and custom params lists
if self.has_custom_logit_processor or other.has_custom_logit_processor: if self.has_custom_logit_processor or other.has_custom_logit_processor:
# Merge the custom logit processors # Merge the custom logit processors
...@@ -370,22 +308,5 @@ class SamplingBatchInfo: ...@@ -370,22 +308,5 @@ class SamplingBatchInfo:
other_val = getattr(other, item, None) other_val = getattr(other, item, None)
setattr(self, item, torch.concat([self_val, other_val])) setattr(self, item, torch.concat([self_val, other_val]))
self.is_all_greedy = self.is_all_greedy and other.is_all_greedy self.is_all_greedy |= other.is_all_greedy
self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling self.need_min_p_sampling |= other.need_min_p_sampling
def apply_logits_bias(self, logits: torch.Tensor):
# Apply logit_bias
if self.logit_bias is not None:
logits.add_(self.logit_bias)
# min-token, presence, frequency
if self.linear_penalties is not None:
logits.add_(self.linear_penalties)
# repetition
if self.scaling_penalties is not None:
apply_scaling_penalties(logits, self.scaling_penalties)
# Apply regex vocab_mask
if self.vocab_mask is not None:
self.apply_mask(logits=logits, vocab_mask=self.vocab_mask)
...@@ -15,15 +15,21 @@ ...@@ -15,15 +15,21 @@
import argparse import argparse
import dataclasses import dataclasses
import json
import logging import logging
import os
import random import random
import subprocess
import tempfile import tempfile
import uuid
from pathlib import Path
from typing import List, Optional from typing import List, Optional
import torch import torch
from sglang.srt.hf_transformers_utils import check_gguf_file from sglang.srt.hf_transformers_utils import check_gguf_file
from sglang.srt.utils import ( from sglang.srt.utils import (
create_checksum,
get_amdgpu_memory_capacity, get_amdgpu_memory_capacity,
get_hpu_memory_capacity, get_hpu_memory_capacity,
get_nvgpu_memory_capacity, get_nvgpu_memory_capacity,
...@@ -43,12 +49,13 @@ class ServerArgs: ...@@ -43,12 +49,13 @@ class ServerArgs:
model_path: str model_path: str
tokenizer_path: Optional[str] = None tokenizer_path: Optional[str] = None
tokenizer_mode: str = "auto" tokenizer_mode: str = "auto"
skip_tokenizer_init: bool = False
load_format: str = "auto" load_format: str = "auto"
trust_remote_code: bool = True trust_remote_code: bool = False
dtype: str = "auto" dtype: str = "auto"
kv_cache_dtype: str = "auto" kv_cache_dtype: str = "auto"
quantization_param_path: nullable_str = None
quantization: Optional[str] = None quantization: Optional[str] = None
quantization_param_path: nullable_str = None
context_length: Optional[int] = None context_length: Optional[int] = None
device: str = "cuda" device: str = "cuda"
served_model_name: Optional[str] = None served_model_name: Optional[str] = None
...@@ -67,7 +74,7 @@ class ServerArgs: ...@@ -67,7 +74,7 @@ class ServerArgs:
max_total_tokens: Optional[int] = None max_total_tokens: Optional[int] = None
chunked_prefill_size: Optional[int] = None chunked_prefill_size: Optional[int] = None
max_prefill_tokens: int = 16384 max_prefill_tokens: int = 16384
schedule_policy: str = "lpm" schedule_policy: str = "fcfs"
schedule_conservativeness: float = 1.0 schedule_conservativeness: float = 1.0
cpu_offload_gb: int = 0 cpu_offload_gb: int = 0
prefill_only_one_req: bool = False prefill_only_one_req: bool = False
...@@ -88,6 +95,7 @@ class ServerArgs: ...@@ -88,6 +95,7 @@ class ServerArgs:
log_level: str = "info" log_level: str = "info"
log_level_http: Optional[str] = None log_level_http: Optional[str] = None
log_requests: bool = False log_requests: bool = False
log_requests_level: int = 0
show_time_cost: bool = False show_time_cost: bool = False
enable_metrics: bool = False enable_metrics: bool = False
decode_log_interval: int = 40 decode_log_interval: int = 40
...@@ -123,11 +131,13 @@ class ServerArgs: ...@@ -123,11 +131,13 @@ class ServerArgs:
grammar_backend: Optional[str] = "outlines" grammar_backend: Optional[str] = "outlines"
# Speculative decoding # Speculative decoding
speculative_draft_model_path: Optional[str] = None
speculative_algorithm: Optional[str] = None speculative_algorithm: Optional[str] = None
speculative_draft_model_path: Optional[str] = None
speculative_num_steps: int = 5 speculative_num_steps: int = 5
speculative_eagle_topk: int = 8 speculative_eagle_topk: int = 4
speculative_num_draft_tokens: int = 64 speculative_num_draft_tokens: int = 8
speculative_accept_threshold_single: float = 1.0
speculative_accept_threshold_acc: float = 1.0
speculative_token_map: Optional[str] = None speculative_token_map: Optional[str] = None
# Double Sparsity # Double Sparsity
...@@ -169,6 +179,12 @@ class ServerArgs: ...@@ -169,6 +179,12 @@ class ServerArgs:
enable_hierarchical_cache: bool = False enable_hierarchical_cache: bool = False
enable_flashinfer_mla: bool = False enable_flashinfer_mla: bool = False
flashinfer_mla_disable_ragged: bool = False flashinfer_mla_disable_ragged: bool = False
warmups: Optional[str] = None
# Debug tensor dumps
debug_tensor_dump_output_folder: Optional[str] = None
debug_tensor_dump_input_file: Optional[str] = None
debug_tensor_dump_inject: bool = False
def __post_init__(self): def __post_init__(self):
# Set missing default values # Set missing default values
...@@ -266,10 +282,10 @@ class ServerArgs: ...@@ -266,10 +282,10 @@ class ServerArgs:
self.speculative_algorithm == "EAGLE" self.speculative_algorithm == "EAGLE"
or self.speculative_algorithm == "NEXTN" or self.speculative_algorithm == "NEXTN"
): ):
self.disable_overlap_schedule = True
self.prefill_only_one_req = True self.prefill_only_one_req = True
self.disable_cuda_graph_padding = True self.disable_cuda_graph_padding = True
self.disable_radix_cache = True self.disable_radix_cache = True
self.disable_overlap_schedule = True
self.chunked_prefill_size = -1 self.chunked_prefill_size = -1
logger.info( logger.info(
f"The radix cache, chunked prefill, and overlap scheduler are disabled because of using {self.speculative_algorithm} speculative decoding." f"The radix cache, chunked prefill, and overlap scheduler are disabled because of using {self.speculative_algorithm} speculative decoding."
...@@ -377,15 +393,6 @@ class ServerArgs: ...@@ -377,15 +393,6 @@ class ServerArgs:
choices=["auto", "fp8_e5m2", "fp8_e4m3"], choices=["auto", "fp8_e5m2", "fp8_e4m3"],
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.', help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
) )
parser.add_argument(
"--quantization-param-path",
type=nullable_str,
default=None,
help="Path to the JSON file containing the KV cache "
"scaling factors. This should generally be supplied, when "
"KV cache dtype is FP8. Otherwise, KV cache scaling factors "
"default to 1.0, which may cause accuracy issues. ",
)
parser.add_argument( parser.add_argument(
"--quantization", "--quantization",
type=str, type=str,
...@@ -404,6 +411,15 @@ class ServerArgs: ...@@ -404,6 +411,15 @@ class ServerArgs:
], ],
help="The quantization method.", help="The quantization method.",
) )
parser.add_argument(
"--quantization-param-path",
type=nullable_str,
default=None,
help="Path to the JSON file containing the KV cache "
"scaling factors. This should generally be supplied, when "
"KV cache dtype is FP8. Otherwise, KV cache scaling factors "
"default to 1.0, which may cause accuracy issues. ",
)
parser.add_argument( parser.add_argument(
"--context-length", "--context-length",
type=int, type=int,
...@@ -578,7 +594,14 @@ class ServerArgs: ...@@ -578,7 +594,14 @@ class ServerArgs:
parser.add_argument( parser.add_argument(
"--log-requests", "--log-requests",
action="store_true", action="store_true",
help="Log the inputs and outputs of all requests.", help="Log metadata, inputs, outputs of all requests. The verbosity is decided by --log-requests-level",
)
parser.add_argument(
"--log-requests-level",
type=int,
default=0,
help="0: Log metadata. 1. Log metadata and partial input/output. 2. Log every input/output.",
choices=[0, 1, 2],
) )
parser.add_argument( parser.add_argument(
"--show-time-cost", "--show-time-cost",
...@@ -742,16 +765,28 @@ class ServerArgs: ...@@ -742,16 +765,28 @@ class ServerArgs:
parser.add_argument( parser.add_argument(
"--speculative-eagle-topk", "--speculative-eagle-topk",
type=int, type=int,
help="The number of token sampled from draft model in eagle2 each step.", help="The number of tokens sampled from the draft model in eagle2 each step.",
choices=[1, 2, 4, 8], choices=[1, 2, 4, 8],
default=ServerArgs.speculative_eagle_topk, default=ServerArgs.speculative_eagle_topk,
) )
parser.add_argument( parser.add_argument(
"--speculative-num-draft-tokens", "--speculative-num-draft-tokens",
type=int, type=int,
help="The number of token sampled from draft model in Speculative Decoding.", help="The number of tokens sampled from the draft model in Speculative Decoding.",
default=ServerArgs.speculative_num_draft_tokens, default=ServerArgs.speculative_num_draft_tokens,
) )
parser.add_argument(
"--speculative-accept-threshold-single",
type=float,
help="Accept a draft token if its probability in the target model is greater than this threshold.",
default=ServerArgs.speculative_accept_threshold_single,
)
parser.add_argument(
"--speculative-accept-threshold-acc",
type=float,
help="The accept probability of a draft token is raised from its target probability p to min(1, p / threshold_acc).",
default=ServerArgs.speculative_accept_threshold_acc,
)
parser.add_argument( parser.add_argument(
"--speculative-token-map", "--speculative-token-map",
type=str, type=str,
...@@ -949,6 +984,35 @@ class ServerArgs: ...@@ -949,6 +984,35 @@ class ServerArgs:
help="Enable hierarchical cache", help="Enable hierarchical cache",
) )
# Server warmups
parser.add_argument(
"--warmups",
type=str,
required=False,
help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 "
"will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests",
)
# Debug tensor dumps
parser.add_argument(
"--debug-tensor-dump-output-folder",
type=str,
default=ServerArgs.debug_tensor_dump_output_folder,
help="The output folder for dumping tensors.",
)
parser.add_argument(
"--debug-tensor-dump-input-file",
type=str,
default=ServerArgs.debug_tensor_dump_input_file,
help="The input filename for dumping tensors",
)
parser.add_argument(
"--debug-tensor-dump-inject",
type=str,
default=ServerArgs.debug_tensor_dump_inject,
help="Inject the outputs from jax as the input of every layer.",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size args.tp_size = args.tensor_parallel_size
......
...@@ -32,13 +32,15 @@ import socket ...@@ -32,13 +32,15 @@ import socket
import subprocess import subprocess
import sys import sys
import tempfile import tempfile
import threading
import time import time
import warnings import warnings
from functools import lru_cache from functools import lru_cache
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version
from io import BytesIO from io import BytesIO
from multiprocessing import Pool
from multiprocessing.reduction import ForkingPickler from multiprocessing.reduction import ForkingPickler
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union
import numpy as np import numpy as np
import psutil import psutil
...@@ -480,6 +482,10 @@ def assert_pkg_version(pkg: str, min_version: str, message: str): ...@@ -480,6 +482,10 @@ def assert_pkg_version(pkg: str, min_version: str, message: str):
def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None): def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None):
"""Kill the process and all its child processes.""" """Kill the process and all its child processes."""
# Remove sigchld handler to avoid spammy logs.
if threading.current_thread() is threading.main_thread():
signal.signal(signal.SIGCHLD, signal.SIG_DFL)
if parent_pid is None: if parent_pid is None:
parent_pid = os.getpid() parent_pid = os.getpid()
include_parent = False include_parent = False
...@@ -499,17 +505,14 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N ...@@ -499,17 +505,14 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N
pass pass
if include_parent: if include_parent:
if parent_pid == os.getpid(): try:
sys.exit(0) itself.kill()
else:
try:
itself.kill()
# Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes), # Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes),
# so we send an additional signal to kill them. # so we send an additional signal to kill them.
itself.send_signal(signal.SIGQUIT) itself.send_signal(signal.SIGQUIT)
except psutil.NoSuchProcess: except psutil.NoSuchProcess:
pass pass
def monkey_patch_p2p_access_check(): def monkey_patch_p2p_access_check():
...@@ -1215,7 +1218,11 @@ def cuda_device_count_stateless() -> int: ...@@ -1215,7 +1218,11 @@ def cuda_device_count_stateless() -> int:
return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None)) return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None))
def dataclass_to_string_truncated(data, max_length=2048): def dataclass_to_string_truncated(
data, max_length=2048, skip_names: Optional[Set[str]] = None
):
if skip_names is None:
skip_names = set()
if isinstance(data, str): if isinstance(data, str):
if len(data) > max_length: if len(data) > max_length:
half_length = max_length // 2 half_length = max_length // 2
...@@ -1234,6 +1241,7 @@ def dataclass_to_string_truncated(data, max_length=2048): ...@@ -1234,6 +1241,7 @@ def dataclass_to_string_truncated(data, max_length=2048):
+ ", ".join( + ", ".join(
f"'{k}': {dataclass_to_string_truncated(v, max_length)}" f"'{k}': {dataclass_to_string_truncated(v, max_length)}"
for k, v in data.items() for k, v in data.items()
if k not in skip_names
) )
+ "}" + "}"
) )
...@@ -1244,6 +1252,7 @@ def dataclass_to_string_truncated(data, max_length=2048): ...@@ -1244,6 +1252,7 @@ def dataclass_to_string_truncated(data, max_length=2048):
+ ", ".join( + ", ".join(
f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}" f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}"
for f in fields for f in fields
if f.name not in skip_names
) )
+ ")" + ")"
) )
...@@ -1322,9 +1331,9 @@ def pyspy_dump_schedulers(): ...@@ -1322,9 +1331,9 @@ def pyspy_dump_schedulers():
result = subprocess.run( result = subprocess.run(
cmd, shell=True, capture_output=True, text=True, check=True cmd, shell=True, capture_output=True, text=True, check=True
) )
logger.info(f"Profile for PID {pid}:\n{result.stdout}") logger.error(f"Pyspy dump for PID {pid}:\n{result.stdout}")
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
logger.info(f"Failed to profile PID {pid}. Error: {e.stderr}") logger.error(f"Pyspy failed to dump PID {pid}. Error: {e.stderr}")
def kill_itself_when_parent_died(): def kill_itself_when_parent_died():
...@@ -1448,6 +1457,10 @@ def launch_dummy_health_check_server(host, port): ...@@ -1448,6 +1457,10 @@ def launch_dummy_health_check_server(host, port):
) )
def create_checksum(directory: str):
raise NotImplementedError()
def set_cuda_arch(): def set_cuda_arch():
if is_flashinfer_available(): if is_flashinfer_available():
capability = torch.cuda.get_device_capability() capability = torch.cuda.get_device_capability()
......
import logging
from typing import List
import numpy as np
import tqdm
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.tokenizer_manager import TokenizerManager
logger = logging.getLogger(__file__)
_warmup_registry = {}
def warmup(name: str) -> callable:
def decorator(fn: callable):
_warmup_registry[name] = fn
return fn
return decorator
async def execute_warmups(warmup_names: List[str], tokenizer_manager: TokenizerManager):
for warmup_name in warmup_names:
if warmup_name not in _warmup_registry:
logger.warning(f"Could not find custom warmup {warmup_name}")
continue
logger.info(f"Running warmup {warmup_name}")
await _warmup_registry[warmup_name](tokenizer_manager)
@warmup("voice_chat")
async def voice_chat(tokenizer_manager: TokenizerManager):
# this warms up the fused_moe triton kernels and caches them
# if we don't do this we break real time inference for voice chat
for i in tqdm.trange(1, 512):
size = i * 4
generate_req_input = GenerateReqInput(
input_ids=(np.random.randint(2**16, size=[size])).tolist(),
sampling_params={
"max_new_tokens": 30,
"temperature": 0.8,
"stop_token_ids": [1],
"min_p": 0.0,
},
)
await tokenizer_manager.generate_request(generate_req_input, None).__anext__()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment