Unverified Commit f84db115 authored by pansicheng's avatar pansicheng Committed by GitHub
Browse files

Add storage read/write bandwidth logs to monitor kvcache performance (#9965)


Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent efb0de2c
...@@ -33,6 +33,7 @@ from sglang.srt.distributed import ( ...@@ -33,6 +33,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
get_attention_dp_rank,
get_attention_tp_rank, get_attention_tp_rank,
get_attention_tp_size, get_attention_tp_size,
is_dp_attention_enabled, is_dp_attention_enabled,
...@@ -402,9 +403,11 @@ class HiCacheController: ...@@ -402,9 +403,11 @@ class HiCacheController:
if is_dp_attention_enabled(): if is_dp_attention_enabled():
self.tp_rank = get_attention_tp_rank() self.tp_rank = get_attention_tp_rank()
self.tp_size = get_attention_tp_size() self.tp_size = get_attention_tp_size()
self.dp_rank = get_attention_dp_rank()
else: else:
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.dp_rank = 0
# Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool. # Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool) is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
...@@ -885,7 +888,7 @@ class HiCacheController: ...@@ -885,7 +888,7 @@ class HiCacheController:
if not self.backup_skip: if not self.backup_skip:
self._page_backup(operation) self._page_backup(operation)
self.ack_backup_queue.put(operation.id) self.ack_backup_queue.put(operation)
except Empty: except Empty:
continue continue
...@@ -623,6 +623,7 @@ class Scheduler( ...@@ -623,6 +623,7 @@ class Scheduler(
hicache_write_policy=server_args.hicache_write_policy, hicache_write_policy=server_args.hicache_write_policy,
hicache_io_backend=server_args.hicache_io_backend, hicache_io_backend=server_args.hicache_io_backend,
hicache_mem_layout=server_args.hicache_mem_layout, hicache_mem_layout=server_args.hicache_mem_layout,
enable_metrics=self.enable_metrics,
hicache_storage_backend=server_args.hicache_storage_backend, hicache_storage_backend=server_args.hicache_storage_backend,
hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy, hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
model_name=server_args.served_model_name, model_name=server_args.served_model_name,
......
...@@ -128,6 +128,9 @@ class HiCacheStorage(ABC): ...@@ -128,6 +128,9 @@ class HiCacheStorage(ABC):
return i return i
return len(keys) return len(keys)
def get_stats(self):
return None
class HiCacheFile(HiCacheStorage): class HiCacheFile(HiCacheStorage):
......
...@@ -20,6 +20,7 @@ from sglang.srt.mem_cache.memory_pool_host import ( ...@@ -20,6 +20,7 @@ from sglang.srt.mem_cache.memory_pool_host import (
MLATokenToKVPoolHost, MLATokenToKVPoolHost,
) )
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
from sglang.srt.metrics.collector import StorageMetricsCollector
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -37,6 +38,7 @@ class HiRadixCache(RadixCache): ...@@ -37,6 +38,7 @@ class HiRadixCache(RadixCache):
hicache_write_policy: str, hicache_write_policy: str,
hicache_io_backend: str, hicache_io_backend: str,
hicache_mem_layout: str, hicache_mem_layout: str,
enable_metrics: bool,
hicache_storage_backend: Optional[str] = None, hicache_storage_backend: Optional[str] = None,
hicache_storage_prefetch_policy: Optional[str] = "best_effort", hicache_storage_prefetch_policy: Optional[str] = "best_effort",
model_name: Optional[str] = None, model_name: Optional[str] = None,
...@@ -73,6 +75,8 @@ class HiRadixCache(RadixCache): ...@@ -73,6 +75,8 @@ class HiRadixCache(RadixCache):
self.tp_group = tp_cache_group self.tp_group = tp_cache_group
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group) self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
self.enable_storage = hicache_storage_backend is not None self.enable_storage = hicache_storage_backend is not None
self.enable_storage_metrics = self.enable_storage and enable_metrics
# todo: customizable storage prefetch threshold and timeout # todo: customizable storage prefetch threshold and timeout
self.prefetch_threshold = 256 self.prefetch_threshold = 256
self.prefetch_timeout = 3 # seconds self.prefetch_timeout = 3 # seconds
...@@ -92,6 +96,14 @@ class HiRadixCache(RadixCache): ...@@ -92,6 +96,14 @@ class HiRadixCache(RadixCache):
model_name=model_name, model_name=model_name,
storage_backend_extra_config=storage_backend_extra_config, storage_backend_extra_config=storage_backend_extra_config,
) )
if self.enable_storage_metrics:
# TODO: support pp
labels = {
"storage_backend": hicache_storage_backend,
"tp_rank": self.cache_controller.tp_rank,
"dp_rank": self.cache_controller.dp_rank,
}
self.metrics_collector = StorageMetricsCollector(labels=labels)
# record the nodes with ongoing write through # record the nodes with ongoing write through
self.ongoing_write_through = {} self.ongoing_write_through = {}
...@@ -379,6 +391,10 @@ class HiRadixCache(RadixCache): ...@@ -379,6 +391,10 @@ class HiRadixCache(RadixCache):
self.loading_check() self.loading_check()
if self.enable_storage: if self.enable_storage:
self.drain_storage_control_queues() self.drain_storage_control_queues()
if self.enable_storage_metrics:
self.metrics_collector.log_storage_metrics(
self.cache_controller.storage_backend.get_stats()
)
def drain_storage_control_queues(self): def drain_storage_control_queues(self):
""" """
...@@ -414,10 +430,13 @@ class HiRadixCache(RadixCache): ...@@ -414,10 +430,13 @@ class HiRadixCache(RadixCache):
# process backup acks # process backup acks
for _ in range(n_backup): for _ in range(n_backup):
ack_id = cc.ack_backup_queue.get() operation = cc.ack_backup_queue.get()
ack_id = operation.id
entry = self.ongoing_backup.pop(ack_id, None) entry = self.ongoing_backup.pop(ack_id, None)
if entry is not None: if entry is not None:
entry.release_host() entry.release_host()
if self.enable_storage_metrics:
self.metrics_collector.log_backuped_tokens(operation.completed_tokens)
# release host memory # release host memory
host_indices_list = [] host_indices_list = []
...@@ -515,6 +534,11 @@ class HiRadixCache(RadixCache): ...@@ -515,6 +534,11 @@ class HiRadixCache(RadixCache):
del self.ongoing_prefetch[req_id] del self.ongoing_prefetch[req_id]
self.cache_controller.prefetch_tokens_occupied -= len(token_ids) self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
if self.enable_storage_metrics:
self.metrics_collector.log_prefetched_tokens(
min_completed_tokens - matched_length
)
return True return True
def match_prefix(self, key: List[int], **kwargs): def match_prefix(self, key: List[int], **kwargs):
......
...@@ -5,6 +5,7 @@ import logging ...@@ -5,6 +5,7 @@ import logging
import os import os
import signal import signal
import threading import threading
import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import wraps from functools import wraps
from typing import Any, List, Optional, Tuple from typing import Any, List, Optional, Tuple
...@@ -13,6 +14,7 @@ import torch ...@@ -13,6 +14,7 @@ import torch
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
from sglang.srt.metrics.collector import StorageMetrics
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -135,6 +137,7 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -135,6 +137,7 @@ class HiCacheHF3FS(HiCacheStorage):
self.file_size = file_size self.file_size = file_size
self.numjobs = numjobs self.numjobs = numjobs
self.bytes_per_page = bytes_per_page self.bytes_per_page = bytes_per_page
self.gb_per_page = bytes_per_page / (1 << 30)
self.entries = entries self.entries = entries
self.dtype = dtype self.dtype = dtype
self.metadata_client = metadata_client self.metadata_client = metadata_client
...@@ -174,6 +177,11 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -174,6 +177,11 @@ class HiCacheHF3FS(HiCacheStorage):
signal.signal(signal.SIGTERM, lambda sig, frame: self.close()) signal.signal(signal.SIGTERM, lambda sig, frame: self.close())
signal.signal(signal.SIGQUIT, lambda sig, frame: self.close()) signal.signal(signal.SIGQUIT, lambda sig, frame: self.close())
self.prefetch_pgs = []
self.backup_pgs = []
self.prefetch_bandwidth = []
self.backup_bandwidth = []
@staticmethod @staticmethod
def from_env_config( def from_env_config(
bytes_per_page: int, bytes_per_page: int,
...@@ -308,6 +316,8 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -308,6 +316,8 @@ class HiCacheHF3FS(HiCacheStorage):
for _ in range(len(batch_indices)) for _ in range(len(batch_indices))
] ]
start_time = time.perf_counter()
futures = [ futures = [
self.executor.submit( self.executor.submit(
self.clients[self.ac.next()].batch_read, self.clients[self.ac.next()].batch_read,
...@@ -318,6 +328,13 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -318,6 +328,13 @@ class HiCacheHF3FS(HiCacheStorage):
] ]
read_results = [result for future in futures for result in future.result()] read_results = [result for future in futures for result in future.result()]
end_time = time.perf_counter()
ionum = len(batch_indices)
self.prefetch_pgs.append(ionum)
self.prefetch_bandwidth.append(
ionum / (end_time - start_time) * self.gb_per_page
)
results = [None] * len(keys) results = [None] * len(keys)
for batch_index, file_result, read_result in zip( for batch_index, file_result, read_result in zip(
batch_indices, file_results, read_results batch_indices, file_results, read_results
...@@ -345,6 +362,7 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -345,6 +362,7 @@ class HiCacheHF3FS(HiCacheStorage):
[target_sizes] if target_sizes is not None else None, [target_sizes] if target_sizes is not None else None,
) )
@synchronized()
def batch_set( def batch_set(
self, self,
keys: List[str], keys: List[str],
...@@ -374,6 +392,8 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -374,6 +392,8 @@ class HiCacheHF3FS(HiCacheStorage):
assert value.is_contiguous() assert value.is_contiguous()
file_values.append(value) file_values.append(value)
start_time = time.perf_counter()
futures = [ futures = [
self.executor.submit( self.executor.submit(
self.clients[self.ac.next()].batch_write, self.clients[self.ac.next()].batch_write,
...@@ -388,6 +408,11 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -388,6 +408,11 @@ class HiCacheHF3FS(HiCacheStorage):
for result in future.result() for result in future.result()
] ]
end_time = time.perf_counter()
ionum = len(batch_indices)
self.backup_pgs.append(ionum)
self.backup_bandwidth.append(ionum / (end_time - start_time) * self.gb_per_page)
written_keys_to_confirm = [] written_keys_to_confirm = []
results = [index[0] for index in indices] results = [index[0] for index in indices]
for batch_index, write_result in zip(batch_indices, write_results): for batch_index, write_result in zip(batch_indices, write_results):
...@@ -439,3 +464,16 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -439,3 +464,16 @@ class HiCacheHF3FS(HiCacheStorage):
except Exception as e: except Exception as e:
logger.error(f"close HiCacheHF3FS: {e}") logger.error(f"close HiCacheHF3FS: {e}")
logger.info("close HiCacheHF3FS") logger.info("close HiCacheHF3FS")
@synchronized()
def get_stats(self):
storage_metrics = StorageMetrics()
storage_metrics.prefetch_pgs.extend(self.prefetch_pgs)
storage_metrics.backup_pgs.extend(self.backup_pgs)
storage_metrics.prefetch_bandwidth.extend(self.prefetch_bandwidth)
storage_metrics.backup_bandwidth.extend(self.backup_bandwidth)
self.prefetch_pgs.clear()
self.backup_pgs.clear()
self.prefetch_bandwidth.clear()
self.backup_bandwidth.clear()
return storage_metrics
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
"""Utilities for Prometheus Metrics Collection.""" """Utilities for Prometheus Metrics Collection."""
import time import time
from dataclasses import dataclass from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
...@@ -559,3 +559,105 @@ class TokenizerMetricsCollector: ...@@ -559,3 +559,105 @@ class TokenizerMetricsCollector:
def observe_one_aborted_request(self): def observe_one_aborted_request(self):
self.num_aborted_requests_total.labels(**self.labels).inc(1) self.num_aborted_requests_total.labels(**self.labels).inc(1)
@dataclass
class StorageMetrics:
prefetch_pgs: List[int] = field(default_factory=list)
backup_pgs: List[int] = field(default_factory=list)
prefetch_bandwidth: List[float] = field(default_factory=list)
backup_bandwidth: List[float] = field(default_factory=list)
class StorageMetricsCollector:
def __init__(
self,
labels: Dict[str, str],
):
from prometheus_client import Counter, Histogram
self.labels = labels
self.prefetched_tokens_total = Counter(
name="sglang:prefetched_tokens_total",
documentation="Number of prefetched prompt tokens.",
labelnames=labels.keys(),
)
self.backuped_tokens_total = Counter(
name="sglang:backuped_tokens_total",
documentation="Number of backuped tokens.",
labelnames=labels.keys(),
)
bucket_io = [
1,
5,
10,
50,
100,
]
bucket_bandwidth = [
0.1,
0.5,
1,
5,
10,
50,
100,
]
self.histogram_prefetch_pgs = Histogram(
name="sglang:prefetch_pgs",
documentation="Histogram of prefetch pages of batches.",
labelnames=labels.keys(),
buckets=bucket_io,
)
self.histogram_backup_pgs = Histogram(
name="sglang:backup_pgs",
documentation="Histogram of backup pages of batches.",
labelnames=labels.keys(),
buckets=bucket_io,
)
self.histogram_prefetch_bandwidth = Histogram(
name="sglang:prefetch_bandwidth",
documentation="Histogram of prefetch bandwidth in GB/s.",
labelnames=labels.keys(),
buckets=bucket_bandwidth,
)
self.histogram_backup_bandwidth = Histogram(
name="sglang:backup_bandwidth",
documentation="Histogram of backup bandwidth in GB/s.",
labelnames=labels.keys(),
buckets=bucket_bandwidth,
)
def log_prefetched_tokens(self, prefetched_tokens: int):
if prefetched_tokens > 0:
self.prefetched_tokens_total.labels(**self.labels).inc(prefetched_tokens)
def log_backuped_tokens(self, backuped_tokens: int):
if backuped_tokens > 0:
self.backuped_tokens_total.labels(**self.labels).inc(backuped_tokens)
def _log_histogram(self, histogram, data: Union[int, float]):
histogram.labels(**self.labels).observe(data)
def log_storage_metrics(self, storage_metrics: Optional[StorageMetrics] = None):
if storage_metrics is None:
return
assert isinstance(storage_metrics, StorageMetrics)
for v in storage_metrics.prefetch_pgs:
self._log_histogram(self.histogram_prefetch_pgs, v)
for v in storage_metrics.backup_pgs:
self._log_histogram(self.histogram_backup_pgs, v)
for v in storage_metrics.prefetch_bandwidth:
self._log_histogram(self.histogram_prefetch_bandwidth, v)
for v in storage_metrics.backup_bandwidth:
self._log_histogram(self.histogram_backup_bandwidth, v)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment