Commit 84967019 authored by Lianmin Zheng's avatar Lianmin Zheng
Browse files

[Misc] Fix metrics, weight update lock, request logging (#2543)

parent 7d672d27
This diff is collapsed.
import asyncio
class RWLock:
"""
A Read-Write Lock for asyncio:
- Multiple readers can hold the lock in parallel if no writer holds it.
- A writer has exclusive access.
"""
def __init__(self):
self._readers = 0 # How many readers currently hold the lock
self._writer_active = False
self._lock = asyncio.Lock() # Internal mutex to protect state
# Conditions associated with _lock:
self._readers_ok = asyncio.Condition(self._lock) # Notify blocked readers
self._writers_ok = asyncio.Condition(self._lock) # Notify blocked writers
# Expose two async context-manager helpers:
self.reader_lock = self._ReaderLock(self)
self.writer_lock = self._WriterLock(self)
async def _acquire_reader(self):
"""
Wait until there is no active writer.
Then increment the count of active readers.
"""
async with self._lock:
# If a writer is active, wait until it's done.
while self._writer_active:
await self._readers_ok.wait()
self._readers += 1
async def _release_reader(self):
"""
Decrement the count of active readers.
If this was the last active reader, wake up a possible waiting writer.
"""
async with self._lock:
self._readers -= 1
# If no more readers, a writer could proceed.
if self._readers == 0:
self._writers_ok.notify()
async def _acquire_writer(self):
"""
Wait until there is no active writer and no active readers.
Then mark a writer as active.
"""
async with self._lock:
while self._writer_active or self._readers > 0:
await self._writers_ok.wait()
self._writer_active = True
async def _release_writer(self):
"""
Mark the writer as done and notify readers and writers.
"""
async with self._lock:
self._writer_active = False
# Allow any waiting readers to proceed:
self._readers_ok.notify_all()
# Allow next waiting writer to proceed:
self._writers_ok.notify()
class _ReaderLock:
"""
A simple async context manager that acquires a reader lock
on entering and releases it on exit.
"""
def __init__(self, parent: "RWLock"):
self._parent = parent
async def __aenter__(self):
await self._parent._acquire_reader()
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._parent._release_reader()
class _WriterLock:
"""
A simple async context manager that acquires a writer lock
on entering and releases it on exit.
"""
def __init__(self, parent: "RWLock"):
self._parent = parent
async def __aenter__(self):
await self._parent._acquire_writer()
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._parent._release_writer()
...@@ -124,8 +124,12 @@ class ModelConfig: ...@@ -124,8 +124,12 @@ class ModelConfig:
self.num_hidden_layers = self.hf_text_config.num_hidden_layers self.num_hidden_layers = self.hf_text_config.num_hidden_layers
self.vocab_size = self.hf_text_config.vocab_size self.vocab_size = self.hf_text_config.vocab_size
# Veirfy quantization
self._verify_quantization() self._verify_quantization()
# Multimodel attrs
self.image_token_id = getattr(self.hf_config, "image_token_id", None)
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289 # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def get_total_num_kv_heads(self) -> int: def get_total_num_kv_heads(self) -> int:
"""Returns the total number of KV heads.""" """Returns the total number of KV heads."""
......
...@@ -18,11 +18,7 @@ import triton.language as tl ...@@ -18,11 +18,7 @@ import triton.language as tl
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.layers.attention import AttentionBackend from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import ( from sglang.srt.utils import is_flashinfer_available
get_bool_env_var,
is_flashinfer_available,
should_use_tensor_core,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
...@@ -731,3 +727,51 @@ def create_flashinfer_kv_indices_triton( ...@@ -731,3 +727,51 @@ def create_flashinfer_kv_indices_triton(
mask=mask, mask=mask,
) )
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask) tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
def should_use_tensor_core(
kv_cache_dtype: torch.dtype,
num_attention_heads: int,
num_kv_heads: int,
) -> bool:
"""
Determine whether to use tensor cores for attention computation.
Args:
kv_cache_dtype: Data type of the KV cache
num_attention_heads: Number of attention heads
num_kv_heads: Number of key/value heads
Returns:
bool: Whether to use tensor cores
"""
# Try to use environment variable first
env_override = os.environ.get("SGLANG_FLASHINFER_USE_TENSOR_CORE")
if env_override is not None:
return env_override.lower() == "true"
# Try to use _grouped_size_compiled_for_decode_kernels if available
# This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
try:
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
if not _grouped_size_compiled_for_decode_kernels(
num_attention_heads,
num_kv_heads,
):
return True
else:
return False
except (ImportError, AttributeError):
pass
# Calculate GQA group size
gqa_group_size = num_attention_heads // num_kv_heads
# Determine based on dtype and GQA group size
if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
return True
elif kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16):
return gqa_group_size > 4
else:
return False
...@@ -479,8 +479,22 @@ class Req: ...@@ -479,8 +479,22 @@ class Req:
return True return True
def reset_for_retract(self):
self.prefix_indices = []
self.last_node = None
self.extend_input_len = 0
self.is_retracted = True
# For incremental logprobs
# TODO: Fix the `logprob_start_len`
self.last_update_decode_tokens = 0
self.logprob_start_len = 10**9
def __repr__(self): def __repr__(self):
return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, " return (
f"rid(n={self.rid}, "
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}"
)
bid = 0 bid = 0
...@@ -894,15 +908,7 @@ class ScheduleBatch: ...@@ -894,15 +908,7 @@ class ScheduleBatch:
) )
residual_size = max(0, residual_size) residual_size = max(0, residual_size)
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free) self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
req.reset_for_retract()
req.prefix_indices = []
req.last_node = None
req.extend_input_len = 0
req.is_retracted = True
# For incremental logprobs
req.last_update_decode_tokens = 0
req.logprob_start_len = 10**9
self.filter_batch(keep_indices=sorted_indices) self.filter_batch(keep_indices=sorted_indices)
......
...@@ -22,7 +22,7 @@ import warnings ...@@ -22,7 +22,7 @@ import warnings
from collections import deque from collections import deque
from concurrent import futures from concurrent import futures
from types import SimpleNamespace from types import SimpleNamespace
from typing import List, Optional from typing import Callable, Dict, List, Optional, Tuple
import psutil import psutil
import setproctitle import setproctitle
...@@ -260,7 +260,7 @@ class Scheduler: ...@@ -260,7 +260,7 @@ class Scheduler:
self.current_stream = torch.get_device_module(self.device).current_stream() self.current_stream = torch.get_device_module(self.device).current_stream()
# Session info # Session info
self.sessions = {} self.sessions: Dict[str, Session] = {}
# Init chunked prefill # Init chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size self.chunked_prefill_size = server_args.chunked_prefill_size
......
...@@ -22,7 +22,7 @@ import signal ...@@ -22,7 +22,7 @@ import signal
import sys import sys
import time import time
import uuid import uuid
from typing import Any, Dict, List, Optional, Union from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
import fastapi import fastapi
import uvloop import uvloop
...@@ -30,6 +30,7 @@ import zmq ...@@ -30,6 +30,7 @@ import zmq
import zmq.asyncio import zmq.asyncio
from fastapi import BackgroundTasks from fastapi import BackgroundTasks
from sglang.srt.aio_rwlock import RWLock
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.image_processor import ( from sglang.srt.managers.image_processor import (
...@@ -62,7 +63,11 @@ from sglang.srt.managers.io_struct import ( ...@@ -62,7 +63,11 @@ from sglang.srt.managers.io_struct import (
from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_zmq_socket, kill_process_tree from sglang.srt.utils import (
dataclass_to_string_truncated,
get_zmq_socket,
kill_process_tree,
)
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
...@@ -82,6 +87,9 @@ class ReqState: ...@@ -82,6 +87,9 @@ class ReqState:
created_time: float created_time: float
first_token_time: Optional[float] = None first_token_time: Optional[float] = None
# For streaming output
last_output_offset: int = 0
class TokenizerManager: class TokenizerManager:
"""TokenizerManager is a process that tokenizes the text.""" """TokenizerManager is a process that tokenizes the text."""
...@@ -120,6 +128,7 @@ class TokenizerManager: ...@@ -120,6 +128,7 @@ class TokenizerManager:
self.is_generation = self.model_config.is_generation self.is_generation = self.model_config.is_generation
self.context_len = self.model_config.context_len self.context_len = self.model_config.context_len
self.image_token_id = self.model_config.image_token_id
# Create image processor placeholder # Create image processor placeholder
self.image_processor = get_dummy_image_processor() self.image_processor = get_dummy_image_processor()
...@@ -152,9 +161,12 @@ class TokenizerManager: ...@@ -152,9 +161,12 @@ class TokenizerManager:
self.to_create_loop = True self.to_create_loop = True
self.rid_to_state: Dict[str, ReqState] = {} self.rid_to_state: Dict[str, ReqState] = {}
# For update model weights # The event to notify the weight sync is finished.
self.model_update_lock = asyncio.Lock() self.model_update_lock = RWLock()
self.model_update_result = None self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
None
)
self.asyncio_tasks = set()
# For session info # For session info
self.session_futures = {} # session_id -> asyncio event self.session_futures = {} # session_id -> asyncio event
...@@ -181,9 +193,6 @@ class TokenizerManager: ...@@ -181,9 +193,6 @@ class TokenizerManager:
if self.to_create_loop: if self.to_create_loop:
self.create_handle_loop() self.create_handle_loop()
while self.model_update_lock.locked():
await asyncio.sleep(0.001)
if isinstance(obj, EmbeddingReqInput) and self.is_generation: if isinstance(obj, EmbeddingReqInput) and self.is_generation:
raise ValueError( raise ValueError(
"This model does not appear to be an embedding model by default. " "This model does not appear to be an embedding model by default. "
...@@ -191,17 +200,24 @@ class TokenizerManager: ...@@ -191,17 +200,24 @@ class TokenizerManager:
) )
obj.normalize_batch_and_arguments() obj.normalize_batch_and_arguments()
is_single = obj.is_single
if is_single: if self.server_args.log_requests:
tokenized_obj = await self._tokenize_one_request(obj) logger.info(f"Receive: obj={dataclass_to_string_truncated(obj)}")
self.send_to_scheduler.send_pyobj(tokenized_obj)
async for response in self._wait_one_response(obj, request, created_time): async with self.model_update_lock.reader_lock:
yield response is_single = obj.is_single
else: if is_single:
async for response in self._handle_batch_request( tokenized_obj = await self._tokenize_one_request(obj)
obj, request, created_time self.send_to_scheduler.send_pyobj(tokenized_obj)
): async for response in self._wait_one_response(
yield response obj, request, created_time
):
yield response
else:
async for response in self._handle_batch_request(
obj, request, created_time
):
yield response
async def _tokenize_one_request( async def _tokenize_one_request(
self, self,
...@@ -215,7 +231,7 @@ class TokenizerManager: ...@@ -215,7 +231,7 @@ class TokenizerManager:
if not self.server_args.disable_radix_cache: if not self.server_args.disable_radix_cache:
raise ValueError( raise ValueError(
"input_embeds is provided while disable_radix_cache is False. " "input_embeds is provided while disable_radix_cache is False. "
"Please add `--disable-radix-cach` when you launch the server " "Please add `--disable-radix-cache` when you launch the server "
"if you want to use input_embeds as inputs." "if you want to use input_embeds as inputs."
) )
input_embeds = obj.input_embeds input_embeds = obj.input_embeds
...@@ -301,8 +317,8 @@ class TokenizerManager: ...@@ -301,8 +317,8 @@ class TokenizerManager:
state.out_list = [] state.out_list = []
if state.finished: if state.finished:
if self.server_args.log_requests: if self.server_args.log_requests:
# Log requests msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}"
logger.info(f"in={obj}, out={out}") logger.info(msg)
del self.rid_to_state[obj.rid] del self.rid_to_state[obj.rid]
yield out yield out
break break
...@@ -423,55 +439,52 @@ class TokenizerManager: ...@@ -423,55 +439,52 @@ class TokenizerManager:
self, self,
obj: UpdateWeightFromDiskReqInput, obj: UpdateWeightFromDiskReqInput,
request: Optional[fastapi.Request] = None, request: Optional[fastapi.Request] = None,
): ) -> Tuple[bool, str]:
if self.to_create_loop: if self.to_create_loop:
self.create_handle_loop() self.create_handle_loop()
# default the load format to the server_args # default the load format to the server_args
if obj.load_format is None: if obj.load_format is None:
obj.load_format = self.server_args.load_format obj.load_format = self.server_args.load_format
logger.info("Start update_weights. Load format=%s", obj.load_format)
if not self.model_update_lock.locked(): if True:
# Hold the lock if it is not async. This means that weight sync
async with self.model_update_lock: # cannot run while requests are in progress.
# wait for the previous generation requests to finish async with self.model_update_lock.writer_lock:
for i in range(3): return await self._wait_for_model_update_from_disk(obj)
while len(self.rid_to_state) > 0:
await asyncio.sleep(0.001)
# FIXME: We add some sleep here to avoid some race conditions.
# We can use a read-write lock as a better fix.
await asyncio.sleep(0.01)
self.send_to_scheduler.send_pyobj(obj)
self.model_update_result = asyncio.Future()
if self.server_args.dp_size == 1: async def _wait_for_model_update_from_disk(
result = await self.model_update_result self, obj: UpdateWeightFromDiskReqInput
if result.success: ) -> Tuple[bool, str, int]:
self.server_args.model_path = obj.model_path self.send_to_scheduler.send_pyobj(obj)
self.server_args.load_format = obj.load_format self.model_update_result = asyncio.Future()
self.model_path = obj.model_path if self.server_args.dp_size == 1:
return result.success, result.message result = await self.model_update_result
else: # self.server_args.dp_size > 1 if result.success:
self.model_update_tmp = [] self.served_model_name = obj.model_path
result = await self.model_update_result self.server_args.model_path = obj.model_path
self.server_args.load_format = obj.load_format
all_success = all([r.success for r in result]) self.model_path = obj.model_path
if all_success is True: return result.success, result.message
self.server_args.model_path = obj.model_path else: # self.server_args.dp_size > 1
self.server_args.load_format = obj.load_format self.model_update_tmp = []
self.model_path = obj.model_path result = await self.model_update_result
all_message = [r.message for r in result]
all_message = " | ".join(all_message) all_success = all([r.success for r in result])
return all_success, all_message if all_success is True:
self.server_args.model_path = obj.model_path
else: self.server_args.load_format = obj.load_format
return False, "Another update is in progress. Please try again later." self.model_path = obj.model_path
all_message = [r.message for r in result]
all_message = " | ".join(all_message)
return all_success, all_message
async def init_weights_update_group( async def init_weights_update_group(
self, self,
obj: InitWeightsUpdateGroupReqInput, obj: InitWeightsUpdateGroupReqInput,
request: Optional[fastapi.Request] = None, request: Optional[fastapi.Request] = None,
) -> bool: ) -> Tuple[bool, str]:
if self.to_create_loop: if self.to_create_loop:
self.create_handle_loop() self.create_handle_loop()
self.send_to_scheduler.send_pyobj(obj) self.send_to_scheduler.send_pyobj(obj)
...@@ -487,25 +500,22 @@ class TokenizerManager: ...@@ -487,25 +500,22 @@ class TokenizerManager:
self, self,
obj: UpdateWeightsFromDistributedReqInput, obj: UpdateWeightsFromDistributedReqInput,
request: Optional[fastapi.Request] = None, request: Optional[fastapi.Request] = None,
): ) -> Tuple[bool, str]:
if self.to_create_loop: if self.to_create_loop:
self.create_handle_loop() self.create_handle_loop()
if not self.model_update_lock.locked(): # This means that weight sync
async with self.model_update_lock: # cannot run while requests are in progress.
self.send_to_scheduler.send_pyobj(obj) async with self.model_update_lock.writer_lock:
self.parameter_update_result = asyncio.Future() self.send_to_scheduler.send_pyobj(obj)
assert ( self.parameter_update_result: Awaitable[
self.server_args.dp_size == 1 UpdateWeightsFromDistributedReqOutput
), "dp_size must be for update weights from distributed" ] = asyncio.Future()
result = await self.parameter_update_result assert (
return result.success, result.message self.server_args.dp_size == 1
else: ), "dp_size must be for update weights from distributed"
logger.error("Another parameter update is in progress in tokenizer manager") result = await self.parameter_update_result
return ( return result.success, result.message
False,
"Another parameter update is in progress. Please try again later.",
)
async def get_weights_by_name( async def get_weights_by_name(
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
...@@ -564,11 +574,11 @@ class TokenizerManager: ...@@ -564,11 +574,11 @@ class TokenizerManager:
self.to_create_loop = False self.to_create_loop = False
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.create_task(self.handle_loop()) self.asyncio_tasks.add(loop.create_task(self.handle_loop()))
signal_handler = SignalHandler(self) signal_handler = SignalHandler(self)
loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler) loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
loop.create_task(self.sigterm_watchdog()) self.asyncio_tasks.add(loop.create_task(self.sigterm_watchdog()))
async def sigterm_watchdog(self): async def sigterm_watchdog(self):
while not self.gracefully_exit: while not self.gracefully_exit:
......
...@@ -184,26 +184,35 @@ class MHATokenToKVPool(BaseTokenToKVPool): ...@@ -184,26 +184,35 @@ class MHATokenToKVPool(BaseTokenToKVPool):
device: str, device: str,
): ):
super().__init__(size, dtype, device) super().__init__(size, dtype, device)
self.head_num = head_num
self.head_dim = head_dim
self.layer_num = layer_num
self._create_buffers()
def _create_buffers(self):
# [size, head_num, head_dim] for each layer # [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens. # The padded slot 0 is used for writing dummy outputs from padded tokens.
self.k_buffer = [ self.k_buffer = [
torch.empty( torch.empty(
(size + 1, head_num, head_dim), (self.size + 1, self.head_num, self.head_dim),
dtype=self.store_dtype, dtype=self.store_dtype,
device=device, device=self.device,
) )
for _ in range(layer_num) for _ in range(self.layer_num)
] ]
self.v_buffer = [ self.v_buffer = [
torch.empty( torch.empty(
(size + 1, head_num, head_dim), (self.size + 1, self.head_num, self.head_dim),
dtype=self.store_dtype, dtype=self.store_dtype,
device=device, device=self.device,
) )
for _ in range(layer_num) for _ in range(self.layer_num)
] ]
def _clear_buffers(self):
del self.k_buffer
del self.v_buffer
def get_key_buffer(self, layer_id: int): def get_key_buffer(self, layer_id: int):
if self.store_dtype != self.dtype: if self.store_dtype != self.dtype:
return self.k_buffer[layer_id].view(self.dtype) return self.k_buffer[layer_id].view(self.dtype)
...@@ -245,7 +254,6 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype): ...@@ -245,7 +254,6 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
class MLATokenToKVPool(BaseTokenToKVPool): class MLATokenToKVPool(BaseTokenToKVPool):
def __init__( def __init__(
self, self,
size: int, size: int,
...@@ -298,7 +306,6 @@ class MLATokenToKVPool(BaseTokenToKVPool): ...@@ -298,7 +306,6 @@ class MLATokenToKVPool(BaseTokenToKVPool):
class DoubleSparseTokenToKVPool(BaseTokenToKVPool): class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
def __init__( def __init__(
self, self,
size: int, size: int,
......
...@@ -311,6 +311,7 @@ async def generate_request(obj: GenerateReqInput, request: Request): ...@@ -311,6 +311,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
ret = await tokenizer_manager.generate_request(obj, request).__anext__() ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret return ret
except ValueError as e: except ValueError as e:
logger.error(f"Error: {e}")
return ORJSONResponse( return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
) )
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
"""Common utilities.""" """Common utilities."""
import base64 import base64
import dataclasses
import ipaddress import ipaddress
import itertools import itertools
import json import json
...@@ -1238,49 +1239,37 @@ def cuda_device_count_stateless() -> int: ...@@ -1238,49 +1239,37 @@ def cuda_device_count_stateless() -> int:
return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None)) return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None))
def should_use_tensor_core( def dataclass_to_string_truncated(data, max_length=2048):
kv_cache_dtype: torch.dtype, if isinstance(data, str):
num_attention_heads: int, if len(data) > max_length:
num_kv_heads: int, half_length = max_length // 2
) -> bool: return f'"{data[:half_length]} ... {data[-half_length:]}"'
"""
Determine whether to use tensor cores for attention computation.
Args:
kv_cache_dtype: Data type of the KV cache
num_attention_heads: Number of attention heads
num_kv_heads: Number of key/value heads
Returns:
bool: Whether to use tensor cores
"""
# Try to use environment variable first
env_override = os.environ.get("SGLANG_FLASHINFER_USE_TENSOR_CORE")
if env_override is not None:
return env_override.lower() == "true"
# Try to use _grouped_size_compiled_for_decode_kernels if available
# This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
try:
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
if not _grouped_size_compiled_for_decode_kernels(
num_attention_heads,
num_kv_heads,
):
return True
else: else:
return False return f'"{data}"'
except (ImportError, AttributeError): elif isinstance(data, (list, tuple)):
pass if len(data) > max_length:
half_length = max_length // 2
# Calculate GQA group size return str(data[:half_length]) + " ... " + str(data[-half_length:])
gqa_group_size = num_attention_heads // num_kv_heads else:
return str(data)
# Determine based on dtype and GQA group size elif isinstance(data, dict):
if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): return (
return True "{"
elif kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16): + ", ".join(
return gqa_group_size > 4 f"{k}: {dataclass_to_string_truncated(v, max_length)}"
for k, v in data.items()
)
+ "}"
)
elif dataclasses.is_dataclass(data):
fields = dataclasses.fields(data)
return (
f"{data.__class__.__name__}("
+ ", ".join(
f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}"
for f in fields
)
+ ")"
)
else: else:
return False return str(data)
...@@ -51,8 +51,10 @@ class TestEnableMetrics(unittest.TestCase): ...@@ -51,8 +51,10 @@ class TestEnableMetrics(unittest.TestCase):
# Verify essential metrics are present # Verify essential metrics are present
essential_metrics = [ essential_metrics = [
"sglang:num_running_reqs", "sglang:num_running_reqs",
"sglang:num_used_tokens",
"sglang:token_usage", "sglang:token_usage",
"sglang:gen_throughput", "sglang:gen_throughput",
"sglang:num_queue_reqs",
"sglang:cache_hit_rate", "sglang:cache_hit_rate",
"sglang:func_latency_seconds", "sglang:func_latency_seconds",
"sglang:prompt_tokens_total", "sglang:prompt_tokens_total",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment