Commit 84967019 authored by Lianmin Zheng's avatar Lianmin Zheng
Browse files

[Misc] Fix metrics, weight update lock, request logging (#2543)

parent 7d672d27
This diff is collapsed.
import asyncio
class RWLock:
"""
A Read-Write Lock for asyncio:
- Multiple readers can hold the lock in parallel if no writer holds it.
- A writer has exclusive access.
"""
def __init__(self):
self._readers = 0 # How many readers currently hold the lock
self._writer_active = False
self._lock = asyncio.Lock() # Internal mutex to protect state
# Conditions associated with _lock:
self._readers_ok = asyncio.Condition(self._lock) # Notify blocked readers
self._writers_ok = asyncio.Condition(self._lock) # Notify blocked writers
# Expose two async context-manager helpers:
self.reader_lock = self._ReaderLock(self)
self.writer_lock = self._WriterLock(self)
async def _acquire_reader(self):
"""
Wait until there is no active writer.
Then increment the count of active readers.
"""
async with self._lock:
# If a writer is active, wait until it's done.
while self._writer_active:
await self._readers_ok.wait()
self._readers += 1
async def _release_reader(self):
"""
Decrement the count of active readers.
If this was the last active reader, wake up a possible waiting writer.
"""
async with self._lock:
self._readers -= 1
# If no more readers, a writer could proceed.
if self._readers == 0:
self._writers_ok.notify()
async def _acquire_writer(self):
"""
Wait until there is no active writer and no active readers.
Then mark a writer as active.
"""
async with self._lock:
while self._writer_active or self._readers > 0:
await self._writers_ok.wait()
self._writer_active = True
async def _release_writer(self):
"""
Mark the writer as done and notify readers and writers.
"""
async with self._lock:
self._writer_active = False
# Allow any waiting readers to proceed:
self._readers_ok.notify_all()
# Allow next waiting writer to proceed:
self._writers_ok.notify()
class _ReaderLock:
"""
A simple async context manager that acquires a reader lock
on entering and releases it on exit.
"""
def __init__(self, parent: "RWLock"):
self._parent = parent
async def __aenter__(self):
await self._parent._acquire_reader()
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._parent._release_reader()
class _WriterLock:
"""
A simple async context manager that acquires a writer lock
on entering and releases it on exit.
"""
def __init__(self, parent: "RWLock"):
self._parent = parent
async def __aenter__(self):
await self._parent._acquire_writer()
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._parent._release_writer()
......@@ -124,8 +124,12 @@ class ModelConfig:
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
self.vocab_size = self.hf_text_config.vocab_size
# Veirfy quantization
self._verify_quantization()
# Multimodel attrs
self.image_token_id = getattr(self.hf_config, "image_token_id", None)
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def get_total_num_kv_heads(self) -> int:
"""Returns the total number of KV heads."""
......
......@@ -18,11 +18,7 @@ import triton.language as tl
from sglang.global_config import global_config
from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import (
get_bool_env_var,
is_flashinfer_available,
should_use_tensor_core,
)
from sglang.srt.utils import is_flashinfer_available
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
......@@ -731,3 +727,51 @@ def create_flashinfer_kv_indices_triton(
mask=mask,
)
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
def should_use_tensor_core(
kv_cache_dtype: torch.dtype,
num_attention_heads: int,
num_kv_heads: int,
) -> bool:
"""
Determine whether to use tensor cores for attention computation.
Args:
kv_cache_dtype: Data type of the KV cache
num_attention_heads: Number of attention heads
num_kv_heads: Number of key/value heads
Returns:
bool: Whether to use tensor cores
"""
# Try to use environment variable first
env_override = os.environ.get("SGLANG_FLASHINFER_USE_TENSOR_CORE")
if env_override is not None:
return env_override.lower() == "true"
# Try to use _grouped_size_compiled_for_decode_kernels if available
# This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
try:
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
if not _grouped_size_compiled_for_decode_kernels(
num_attention_heads,
num_kv_heads,
):
return True
else:
return False
except (ImportError, AttributeError):
pass
# Calculate GQA group size
gqa_group_size = num_attention_heads // num_kv_heads
# Determine based on dtype and GQA group size
if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
return True
elif kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16):
return gqa_group_size > 4
else:
return False
......@@ -479,8 +479,22 @@ class Req:
return True
def reset_for_retract(self):
self.prefix_indices = []
self.last_node = None
self.extend_input_len = 0
self.is_retracted = True
# For incremental logprobs
# TODO: Fix the `logprob_start_len`
self.last_update_decode_tokens = 0
self.logprob_start_len = 10**9
def __repr__(self):
return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
return (
f"rid(n={self.rid}, "
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}"
)
bid = 0
......@@ -894,15 +908,7 @@ class ScheduleBatch:
)
residual_size = max(0, residual_size)
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
req.prefix_indices = []
req.last_node = None
req.extend_input_len = 0
req.is_retracted = True
# For incremental logprobs
req.last_update_decode_tokens = 0
req.logprob_start_len = 10**9
req.reset_for_retract()
self.filter_batch(keep_indices=sorted_indices)
......
......@@ -22,7 +22,7 @@ import warnings
from collections import deque
from concurrent import futures
from types import SimpleNamespace
from typing import List, Optional
from typing import Callable, Dict, List, Optional, Tuple
import psutil
import setproctitle
......@@ -260,7 +260,7 @@ class Scheduler:
self.current_stream = torch.get_device_module(self.device).current_stream()
# Session info
self.sessions = {}
self.sessions: Dict[str, Session] = {}
# Init chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size
......
......@@ -22,7 +22,7 @@ import signal
import sys
import time
import uuid
from typing import Any, Dict, List, Optional, Union
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
import fastapi
import uvloop
......@@ -30,6 +30,7 @@ import zmq
import zmq.asyncio
from fastapi import BackgroundTasks
from sglang.srt.aio_rwlock import RWLock
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.image_processor import (
......@@ -62,7 +63,11 @@ from sglang.srt.managers.io_struct import (
from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_zmq_socket, kill_process_tree
from sglang.srt.utils import (
dataclass_to_string_truncated,
get_zmq_socket,
kill_process_tree,
)
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
......@@ -82,6 +87,9 @@ class ReqState:
created_time: float
first_token_time: Optional[float] = None
# For streaming output
last_output_offset: int = 0
class TokenizerManager:
"""TokenizerManager is a process that tokenizes the text."""
......@@ -120,6 +128,7 @@ class TokenizerManager:
self.is_generation = self.model_config.is_generation
self.context_len = self.model_config.context_len
self.image_token_id = self.model_config.image_token_id
# Create image processor placeholder
self.image_processor = get_dummy_image_processor()
......@@ -152,9 +161,12 @@ class TokenizerManager:
self.to_create_loop = True
self.rid_to_state: Dict[str, ReqState] = {}
# For update model weights
self.model_update_lock = asyncio.Lock()
self.model_update_result = None
# The event to notify the weight sync is finished.
self.model_update_lock = RWLock()
self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
None
)
self.asyncio_tasks = set()
# For session info
self.session_futures = {} # session_id -> asyncio event
......@@ -181,9 +193,6 @@ class TokenizerManager:
if self.to_create_loop:
self.create_handle_loop()
while self.model_update_lock.locked():
await asyncio.sleep(0.001)
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
raise ValueError(
"This model does not appear to be an embedding model by default. "
......@@ -191,17 +200,24 @@ class TokenizerManager:
)
obj.normalize_batch_and_arguments()
is_single = obj.is_single
if is_single:
tokenized_obj = await self._tokenize_one_request(obj)
self.send_to_scheduler.send_pyobj(tokenized_obj)
async for response in self._wait_one_response(obj, request, created_time):
yield response
else:
async for response in self._handle_batch_request(
obj, request, created_time
):
yield response
if self.server_args.log_requests:
logger.info(f"Receive: obj={dataclass_to_string_truncated(obj)}")
async with self.model_update_lock.reader_lock:
is_single = obj.is_single
if is_single:
tokenized_obj = await self._tokenize_one_request(obj)
self.send_to_scheduler.send_pyobj(tokenized_obj)
async for response in self._wait_one_response(
obj, request, created_time
):
yield response
else:
async for response in self._handle_batch_request(
obj, request, created_time
):
yield response
async def _tokenize_one_request(
self,
......@@ -215,7 +231,7 @@ class TokenizerManager:
if not self.server_args.disable_radix_cache:
raise ValueError(
"input_embeds is provided while disable_radix_cache is False. "
"Please add `--disable-radix-cach` when you launch the server "
"Please add `--disable-radix-cache` when you launch the server "
"if you want to use input_embeds as inputs."
)
input_embeds = obj.input_embeds
......@@ -301,8 +317,8 @@ class TokenizerManager:
state.out_list = []
if state.finished:
if self.server_args.log_requests:
# Log requests
logger.info(f"in={obj}, out={out}")
msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}"
logger.info(msg)
del self.rid_to_state[obj.rid]
yield out
break
......@@ -423,55 +439,52 @@ class TokenizerManager:
self,
obj: UpdateWeightFromDiskReqInput,
request: Optional[fastapi.Request] = None,
):
) -> Tuple[bool, str]:
if self.to_create_loop:
self.create_handle_loop()
# default the load format to the server_args
if obj.load_format is None:
obj.load_format = self.server_args.load_format
logger.info("Start update_weights. Load format=%s", obj.load_format)
if not self.model_update_lock.locked():
async with self.model_update_lock:
# wait for the previous generation requests to finish
for i in range(3):
while len(self.rid_to_state) > 0:
await asyncio.sleep(0.001)
# FIXME: We add some sleep here to avoid some race conditions.
# We can use a read-write lock as a better fix.
await asyncio.sleep(0.01)
self.send_to_scheduler.send_pyobj(obj)
self.model_update_result = asyncio.Future()
if True:
# Hold the lock if it is not async. This means that weight sync
# cannot run while requests are in progress.
async with self.model_update_lock.writer_lock:
return await self._wait_for_model_update_from_disk(obj)
if self.server_args.dp_size == 1:
result = await self.model_update_result
if result.success:
self.server_args.model_path = obj.model_path
self.server_args.load_format = obj.load_format
self.model_path = obj.model_path
return result.success, result.message
else: # self.server_args.dp_size > 1
self.model_update_tmp = []
result = await self.model_update_result
all_success = all([r.success for r in result])
if all_success is True:
self.server_args.model_path = obj.model_path
self.server_args.load_format = obj.load_format
self.model_path = obj.model_path
all_message = [r.message for r in result]
all_message = " | ".join(all_message)
return all_success, all_message
else:
return False, "Another update is in progress. Please try again later."
async def _wait_for_model_update_from_disk(
self, obj: UpdateWeightFromDiskReqInput
) -> Tuple[bool, str, int]:
self.send_to_scheduler.send_pyobj(obj)
self.model_update_result = asyncio.Future()
if self.server_args.dp_size == 1:
result = await self.model_update_result
if result.success:
self.served_model_name = obj.model_path
self.server_args.model_path = obj.model_path
self.server_args.load_format = obj.load_format
self.model_path = obj.model_path
return result.success, result.message
else: # self.server_args.dp_size > 1
self.model_update_tmp = []
result = await self.model_update_result
all_success = all([r.success for r in result])
if all_success is True:
self.server_args.model_path = obj.model_path
self.server_args.load_format = obj.load_format
self.model_path = obj.model_path
all_message = [r.message for r in result]
all_message = " | ".join(all_message)
return all_success, all_message
async def init_weights_update_group(
self,
obj: InitWeightsUpdateGroupReqInput,
request: Optional[fastapi.Request] = None,
) -> bool:
) -> Tuple[bool, str]:
if self.to_create_loop:
self.create_handle_loop()
self.send_to_scheduler.send_pyobj(obj)
......@@ -487,25 +500,22 @@ class TokenizerManager:
self,
obj: UpdateWeightsFromDistributedReqInput,
request: Optional[fastapi.Request] = None,
):
) -> Tuple[bool, str]:
if self.to_create_loop:
self.create_handle_loop()
if not self.model_update_lock.locked():
async with self.model_update_lock:
self.send_to_scheduler.send_pyobj(obj)
self.parameter_update_result = asyncio.Future()
assert (
self.server_args.dp_size == 1
), "dp_size must be for update weights from distributed"
result = await self.parameter_update_result
return result.success, result.message
else:
logger.error("Another parameter update is in progress in tokenizer manager")
return (
False,
"Another parameter update is in progress. Please try again later.",
)
# This means that weight sync
# cannot run while requests are in progress.
async with self.model_update_lock.writer_lock:
self.send_to_scheduler.send_pyobj(obj)
self.parameter_update_result: Awaitable[
UpdateWeightsFromDistributedReqOutput
] = asyncio.Future()
assert (
self.server_args.dp_size == 1
), "dp_size must be for update weights from distributed"
result = await self.parameter_update_result
return result.success, result.message
async def get_weights_by_name(
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
......@@ -564,11 +574,11 @@ class TokenizerManager:
self.to_create_loop = False
loop = asyncio.get_event_loop()
loop.create_task(self.handle_loop())
self.asyncio_tasks.add(loop.create_task(self.handle_loop()))
signal_handler = SignalHandler(self)
loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
loop.create_task(self.sigterm_watchdog())
self.asyncio_tasks.add(loop.create_task(self.sigterm_watchdog()))
async def sigterm_watchdog(self):
while not self.gracefully_exit:
......
......@@ -184,26 +184,35 @@ class MHATokenToKVPool(BaseTokenToKVPool):
device: str,
):
super().__init__(size, dtype, device)
self.head_num = head_num
self.head_dim = head_dim
self.layer_num = layer_num
self._create_buffers()
def _create_buffers(self):
# [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.k_buffer = [
torch.empty(
(size + 1, head_num, head_dim),
(self.size + 1, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=device,
device=self.device,
)
for _ in range(layer_num)
for _ in range(self.layer_num)
]
self.v_buffer = [
torch.empty(
(size + 1, head_num, head_dim),
(self.size + 1, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=device,
device=self.device,
)
for _ in range(layer_num)
for _ in range(self.layer_num)
]
def _clear_buffers(self):
del self.k_buffer
del self.v_buffer
def get_key_buffer(self, layer_id: int):
if self.store_dtype != self.dtype:
return self.k_buffer[layer_id].view(self.dtype)
......@@ -245,7 +254,6 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
class MLATokenToKVPool(BaseTokenToKVPool):
def __init__(
self,
size: int,
......@@ -298,7 +306,6 @@ class MLATokenToKVPool(BaseTokenToKVPool):
class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
def __init__(
self,
size: int,
......
......@@ -311,6 +311,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
logger.error(f"Error: {e}")
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
......
......@@ -14,6 +14,7 @@
"""Common utilities."""
import base64
import dataclasses
import ipaddress
import itertools
import json
......@@ -1238,49 +1239,37 @@ def cuda_device_count_stateless() -> int:
return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None))
def should_use_tensor_core(
kv_cache_dtype: torch.dtype,
num_attention_heads: int,
num_kv_heads: int,
) -> bool:
"""
Determine whether to use tensor cores for attention computation.
Args:
kv_cache_dtype: Data type of the KV cache
num_attention_heads: Number of attention heads
num_kv_heads: Number of key/value heads
Returns:
bool: Whether to use tensor cores
"""
# Try to use environment variable first
env_override = os.environ.get("SGLANG_FLASHINFER_USE_TENSOR_CORE")
if env_override is not None:
return env_override.lower() == "true"
# Try to use _grouped_size_compiled_for_decode_kernels if available
# This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
try:
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
if not _grouped_size_compiled_for_decode_kernels(
num_attention_heads,
num_kv_heads,
):
return True
def dataclass_to_string_truncated(data, max_length=2048):
if isinstance(data, str):
if len(data) > max_length:
half_length = max_length // 2
return f'"{data[:half_length]} ... {data[-half_length:]}"'
else:
return False
except (ImportError, AttributeError):
pass
# Calculate GQA group size
gqa_group_size = num_attention_heads // num_kv_heads
# Determine based on dtype and GQA group size
if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
return True
elif kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16):
return gqa_group_size > 4
return f'"{data}"'
elif isinstance(data, (list, tuple)):
if len(data) > max_length:
half_length = max_length // 2
return str(data[:half_length]) + " ... " + str(data[-half_length:])
else:
return str(data)
elif isinstance(data, dict):
return (
"{"
+ ", ".join(
f"{k}: {dataclass_to_string_truncated(v, max_length)}"
for k, v in data.items()
)
+ "}"
)
elif dataclasses.is_dataclass(data):
fields = dataclasses.fields(data)
return (
f"{data.__class__.__name__}("
+ ", ".join(
f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}"
for f in fields
)
+ ")"
)
else:
return False
return str(data)
......@@ -51,8 +51,10 @@ class TestEnableMetrics(unittest.TestCase):
# Verify essential metrics are present
essential_metrics = [
"sglang:num_running_reqs",
"sglang:num_used_tokens",
"sglang:token_usage",
"sglang:gen_throughput",
"sglang:num_queue_reqs",
"sglang:cache_hit_rate",
"sglang:func_latency_seconds",
"sglang:prompt_tokens_total",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment