Unverified Commit f3cbd245 authored by zixuanzhang226's avatar zixuanzhang226 Committed by GitHub
Browse files

feat: send kvmetrics from sglang scheduler (#6721)

parent 506a2d59
...@@ -115,13 +115,13 @@ class Engine(EngineBase): ...@@ -115,13 +115,13 @@ class Engine(EngineBase):
atexit.register(self.shutdown) atexit.register(self.shutdown)
# Allocate ports for inter-process communications # Allocate ports for inter-process communications
port_args = PortArgs.init_new(server_args) self.port_args = PortArgs.init_new(server_args)
logger.info(f"{server_args=}") logger.info(f"{server_args=}")
# Launch subprocesses # Launch subprocesses
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses( tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
server_args=server_args, server_args=server_args,
port_args=port_args, port_args=self.port_args,
) )
self.server_args = server_args self.server_args = server_args
self.tokenizer_manager = tokenizer_manager self.tokenizer_manager = tokenizer_manager
...@@ -130,7 +130,7 @@ class Engine(EngineBase): ...@@ -130,7 +130,7 @@ class Engine(EngineBase):
context = zmq.Context(2) context = zmq.Context(2)
self.send_to_rpc = get_zmq_socket( self.send_to_rpc = get_zmq_socket(
context, zmq.DEALER, port_args.rpc_ipc_name, True context, zmq.DEALER, self.port_args.rpc_ipc_name, True
) )
def generate( def generate(
......
...@@ -182,6 +182,18 @@ class EmbeddingBatchResult: ...@@ -182,6 +182,18 @@ class EmbeddingBatchResult:
bid: int bid: int
class KvMetrics:
def __init__(self):
self.request_active_slots = None
self.request_total_slots = None
self.kv_active_blocks = None
self.kv_total_blocks = None
self.num_requests_waiting = None
self.gpu_cache_usage_perc = None
self.gpu_prefix_cache_hit_rate = None
self.data_parallel_rank = None
class IdleSleeper: class IdleSleeper:
""" """
In setups which have long inactivity periods it is desirable to reduce In setups which have long inactivity periods it is desirable to reduce
...@@ -223,6 +235,7 @@ class Scheduler( ...@@ -223,6 +235,7 @@ class Scheduler(
self.server_args = server_args self.server_args = server_args
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.pp_rank = pp_rank self.pp_rank = pp_rank
self.dp_rank = dp_rank
self.tp_size = server_args.tp_size self.tp_size = server_args.tp_size
self.pp_size = server_args.pp_size self.pp_size = server_args.pp_size
self.dp_size = server_args.dp_size self.dp_size = server_args.dp_size
...@@ -261,6 +274,9 @@ class Scheduler( ...@@ -261,6 +274,9 @@ class Scheduler(
self.send_to_tokenizer = get_zmq_socket( self.send_to_tokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_ipc_name, False context, zmq.PUSH, port_args.tokenizer_ipc_name, False
) )
self.send_metrics_from_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.metrics_ipc_name, False
)
if server_args.skip_tokenizer_init: if server_args.skip_tokenizer_init:
# Directly send to the TokenizerManager # Directly send to the TokenizerManager
...@@ -286,6 +302,7 @@ class Scheduler( ...@@ -286,6 +302,7 @@ class Scheduler(
else: else:
self.recv_from_tokenizer = None self.recv_from_tokenizer = None
self.recv_from_rpc = None self.recv_from_rpc = None
self.send_metrics_from_scheduler = None
self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None) self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None) self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
...@@ -1239,6 +1256,22 @@ class Scheduler( ...@@ -1239,6 +1256,22 @@ class Scheduler(
req.logprob_start_len = len(req.origin_input_ids) - 1 req.logprob_start_len = len(req.origin_input_ids) - 1
self._add_request_to_queue(req) self._add_request_to_queue(req)
def _emit_kv_metrics(self):
kv_metrics = KvMetrics()
kv_metrics.request_active_slots = self.stats.num_running_reqs
kv_metrics.request_total_slots = self.max_running_requests
kv_metrics.kv_active_blocks = int(
self.stats.token_usage * self.max_total_num_tokens
)
kv_metrics.kv_total_blocks = self.max_total_num_tokens
kv_metrics.num_requests_waiting = self.stats.num_queue_reqs
kv_metrics.gpu_cache_usage_perc = self.stats.token_usage
kv_metrics.gpu_prefix_cache_hit_rate = self.stats.cache_hit_rate
kv_metrics.data_parallel_rank = self.dp_rank if self.dp_rank is not None else 0
if not self.send_metrics_from_scheduler.closed:
self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
def log_prefill_stats( def log_prefill_stats(
self, self,
adder: PrefillAdder, adder: PrefillAdder,
...@@ -1291,6 +1324,7 @@ class Scheduler( ...@@ -1291,6 +1324,7 @@ class Scheduler(
self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
self.metrics_collector.log_stats(self.stats) self.metrics_collector.log_stats(self.stats)
self._emit_kv_metrics()
self._publish_kv_events() self._publish_kv_events()
def log_decode_stats( def log_decode_stats(
...@@ -1352,6 +1386,7 @@ class Scheduler( ...@@ -1352,6 +1386,7 @@ class Scheduler(
self.stats.num_grammar_queue_reqs = len(self.grammar_queue) self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
self.stats.spec_accept_length = spec_accept_length self.stats.spec_accept_length = spec_accept_length
self.metrics_collector.log_stats(self.stats) self.metrics_collector.log_stats(self.stats)
self._emit_kv_metrics()
self._publish_kv_events() self._publish_kv_events()
def check_memory(self): def check_memory(self):
......
...@@ -1701,6 +1701,9 @@ class PortArgs: ...@@ -1701,6 +1701,9 @@ class PortArgs:
# The ipc filename for rpc call between Engine and Scheduler # The ipc filename for rpc call between Engine and Scheduler
rpc_ipc_name: str rpc_ipc_name: str
# The ipc filename for Scheduler to send metrics
metrics_ipc_name: str
@staticmethod @staticmethod
def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs": def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
port = server_args.port + random.randint(100, 1000) port = server_args.port + random.randint(100, 1000)
...@@ -1720,6 +1723,7 @@ class PortArgs: ...@@ -1720,6 +1723,7 @@ class PortArgs:
detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
nccl_port=port, nccl_port=port,
rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
metrics_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
) )
else: else:
# DP attention. Use TCP + port to handle both single-node and multi-node. # DP attention. Use TCP + port to handle both single-node and multi-node.
...@@ -1739,9 +1743,9 @@ class PortArgs: ...@@ -1739,9 +1743,9 @@ class PortArgs:
port_base = int(dist_init_port) + 1 port_base = int(dist_init_port) + 1
if dp_rank is None: if dp_rank is None:
# TokenizerManager to DataParallelController # TokenizerManager to DataParallelController
scheduler_input_port = port_base + 3 scheduler_input_port = port_base + 4
else: else:
scheduler_input_port = port_base + 3 + 1 + dp_rank scheduler_input_port = port_base + 4 + 1 + dp_rank
return PortArgs( return PortArgs(
tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}", tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
...@@ -1749,6 +1753,7 @@ class PortArgs: ...@@ -1749,6 +1753,7 @@ class PortArgs:
detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}", detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
nccl_port=port, nccl_port=port,
rpc_ipc_name=f"tcp://{dist_init_host}:{port_base + 2}", rpc_ipc_name=f"tcp://{dist_init_host}:{port_base + 2}",
metrics_ipc_name=f"tcp://{dist_init_host}:{port_base + 3}",
) )
......
...@@ -76,7 +76,7 @@ class TestPortArgs(unittest.TestCase): ...@@ -76,7 +76,7 @@ class TestPortArgs(unittest.TestCase):
port_args = PortArgs.init_new(server_args, dp_rank=2) port_args = PortArgs.init_new(server_args, dp_rank=2)
print(f"{port_args=}") print(f"{port_args=}")
self.assertTrue(port_args.scheduler_input_ipc_name.endswith(":25007")) self.assertTrue(port_args.scheduler_input_ipc_name.endswith(":25008"))
self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://192.168.1.1:")) self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://192.168.1.1:"))
self.assertTrue(port_args.detokenizer_ipc_name.startswith("tcp://192.168.1.1:")) self.assertTrue(port_args.detokenizer_ipc_name.startswith("tcp://192.168.1.1:"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment