Unverified Commit e2fd2b9c authored by pansicheng's avatar pansicheng Committed by GitHub
Browse files

Simple prefetch policy (#8692)

parent 7490e3f6
......@@ -20,6 +20,8 @@ from sglang.bench_serving import (
sample_random_requests,
)
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=20 * 60 * 60)
def parse_args():
parser = argparse.ArgumentParser(
......@@ -139,7 +141,7 @@ async def async_request_sglang_generate(
"""
Sends a streaming request to the server. Gathers text token-by-token.
"""
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers = {}
generated_text = ""
ttft = 0.0
......@@ -150,6 +152,8 @@ async def async_request_sglang_generate(
try:
async with session.post(url=url, json=payload, headers=headers) as response:
if response.status == 200:
prompt_tokens = 0
cached_tokens = 0
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
......@@ -168,6 +172,12 @@ async def async_request_sglang_generate(
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
prompt_tokens = (data.get("meta_info") or {}).get(
"prompt_tokens", 0
)
cached_tokens = (data.get("meta_info") or {}).get(
"cached_tokens", 0
)
# Decoding phase
else:
......@@ -179,6 +189,8 @@ async def async_request_sglang_generate(
output.generated_text = generated_text
output.success = True
output.latency = latency
output.prompt_len = prompt_tokens
output.cached_tokens = cached_tokens
else:
output.error = response.reason or ""
output.success = False
......@@ -201,6 +213,7 @@ def gen_payload(prompt, output_len):
"ignore_eos": True,
},
"stream": True,
"stream_options": {"include_usage": True},
"lora_path": "",
"return_logprob": False,
"logprob_start_len": -1,
......@@ -303,7 +316,12 @@ class WorkloadGenerator:
self.response_queue = queue.Queue()
self.pbar = tqdm(total=args.num_clients * args.num_rounds)
self.performance_metrics = {"ttft": [], "latency": []}
self.performance_metrics = {
"ttft": [],
"latency": [],
"prompt_len": [],
"cached_tokens": [],
}
async def handle_request(self, item):
try:
......@@ -360,6 +378,8 @@ class WorkloadGenerator:
self.client_records[client_id]["round"] += 1
self.performance_metrics["ttft"].append(response.ttft)
self.performance_metrics["latency"].append(response.latency)
self.performance_metrics["prompt_len"].append(response.prompt_len)
self.performance_metrics["cached_tokens"].append(response.cached_tokens)
self.completed_requests += 1
if self.client_records[client_id]["round"] < args.num_rounds:
......@@ -416,6 +436,12 @@ class WorkloadGenerator:
len(self.performance_metrics["latency"]) // 2
],
"throughput": self.pbar.total / (self.finished_time - self.start_time),
"cache_hit_rate": (
0
if sum(self.performance_metrics["prompt_len"]) == 0
else sum(self.performance_metrics["cached_tokens"])
/ sum(self.performance_metrics["prompt_len"])
),
},
}
print("All requests completed")
......@@ -434,6 +460,7 @@ class WorkloadGenerator:
print(
f" Throughput: {performance_data['summary']['throughput']:.2f} requests per second"
)
print(f" Cache Hit Rate: {performance_data['summary']['cache_hit_rate']:.6f}")
log_to_jsonl_file(performance_data, args.log_file, tag=args.tag)
......
......@@ -16,6 +16,7 @@ limitations under the License.
import logging
import math
import threading
import time
from queue import Empty, Full, PriorityQueue, Queue
from typing import TYPE_CHECKING, List, Optional
......@@ -195,6 +196,8 @@ class PrefetchOperation(StorageOperation):
self._done_flag = False
self._lock = threading.Lock()
self.start_time = time.monotonic()
super().__init__(host_indices, token_ids, last_hash)
def increment(self, num_tokens: int):
......@@ -278,6 +281,12 @@ class HiCacheController:
self.enable_storage = True
# todo: threshold policy for prefetching
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
self.prefetch_capacity_limit = int(
0.8 * (self.mem_pool_host.size - self.mem_pool_device.size)
)
# tracking the number of tokens locked in prefetching, updated by the main scheduler thread
self.prefetch_tokens_occupied = 0
# create a new communication group for synchronizing storage operations across TP workers
self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
if self.tp_world_size > 1:
......@@ -525,7 +534,7 @@ class HiCacheController:
host_indices: torch.Tensor,
new_input_tokens: List[int],
last_hash: Optional[str] = None,
) -> int:
) -> PrefetchOperation:
"""
Prefetch KV caches from storage backend to host memory.
"""
......@@ -586,11 +595,23 @@ class HiCacheController:
operation = self.prefetch_buffer.get(block=True, timeout=1)
if self.is_mooncake_backend():
self.mooncake_page_transfer(operation)
elif self.storage_backend_type == "hf3fs":
self.generic_page_transfer(operation, batch_size=128)
else:
self.generic_page_transfer(operation)
except Empty:
continue
def prefetch_rate_limit_check(self) -> bool:
"""
Rate limit the prefetching operations to avoid overwhelming the storage backend.
"""
# cancel prefetch if too much memory is occupied
if self.prefetch_tokens_occupied >= self.prefetch_capacity_limit:
return False
# todo: more sophisticated rate limiting based on storage backend performance
return True
def prefetch_thread_func(self):
"""
Manage prefetching operations from storage backend to host memory.
......@@ -604,34 +625,36 @@ class HiCacheController:
if operation is None:
continue
last_hash = operation.last_hash
tokens_to_fetch = operation.token_ids
storage_hit_count = 0
remaining_tokens = len(tokens_to_fetch)
hash_value = []
while remaining_tokens >= self.page_size:
last_hash = self.get_hash_str(
tokens_to_fetch[
storage_hit_count : storage_hit_count + self.page_size
],
last_hash,
)
# todo, more unified interface
if not self.is_mooncake_backend():
if not self.storage_backend.exists(last_hash):
break
hash_value.append(last_hash)
storage_hit_count += self.page_size
remaining_tokens -= self.page_size
if self.is_mooncake_backend():
# deferring to batch exists for mooncake store
exist_result = self.storage_backend.exists(hash_value)
storage_hit_count = (
sum(1 for v in exist_result.values() if v != 0) * self.page_size
)
if self.prefetch_rate_limit_check():
last_hash = operation.last_hash
tokens_to_fetch = operation.token_ids
remaining_tokens = len(tokens_to_fetch)
hash_value = []
while remaining_tokens >= self.page_size:
last_hash = self.get_hash_str(
tokens_to_fetch[
storage_hit_count : storage_hit_count + self.page_size
],
last_hash,
)
# todo, more unified interface
if not self.is_mooncake_backend():
if not self.storage_backend.exists(last_hash):
break
hash_value.append(last_hash)
storage_hit_count += self.page_size
remaining_tokens -= self.page_size
if self.is_mooncake_backend():
# deferring to batch exists for mooncake store
exist_result = self.storage_backend.exists(hash_value)
storage_hit_count = (
sum(1 for v in exist_result.values() if v != 0)
* self.page_size
)
if self.tp_world_size > 1:
storage_hit_count_tensor = torch.tensor(
......@@ -750,6 +773,8 @@ class HiCacheController:
if self.is_mooncake_backend():
self.mooncake_page_backup(operation)
elif self.storage_backend_type == "hf3fs":
self.generic_page_backup(operation, batch_size=128)
else:
self.generic_page_backup(operation)
......
......@@ -619,6 +619,7 @@ class Scheduler(
),
hicache_mem_layout=server_args.hicache_mem_layout,
hicache_storage_backend=server_args.hicache_storage_backend,
hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
)
self.tp_worker.register_hicache_layer_transfer_counter(
self.tree_cache.cache_controller.layer_done_counter
......@@ -1572,7 +1573,10 @@ class Scheduler(
break
if self.enable_hicache_storage:
self.tree_cache.check_prefetch_progress(req.rid)
prefetch_done = self.tree_cache.check_prefetch_progress(req.rid)
if not prefetch_done:
# skip staging requests that are ongoing prefetch
continue
req.init_next_round_input(self.tree_cache)
res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
......
......@@ -2,11 +2,12 @@ import heapq
import logging
import threading
import time
from queue import Queue
from typing import List, Optional
import torch
from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.managers.cache_controller import HiCacheController, PrefetchOperation
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
from sglang.srt.mem_cache.memory_pool import (
......@@ -37,6 +38,7 @@ class HiRadixCache(RadixCache):
hicache_io_backend: str,
hicache_mem_layout: str,
hicache_storage_backend: Optional[str] = None,
hicache_storage_prefetch_policy: Optional[str] = "best_effort",
):
if hicache_io_backend == "direct":
......@@ -85,6 +87,13 @@ class HiRadixCache(RadixCache):
prefetch_threshold=self.prefetch_threshold,
)
self.prefetch_stop_policy = hicache_storage_prefetch_policy
# todo: customizable storage prefetch timeout
self.prefetch_timeout = 3 # seconds
logger.info(
f"HiCache storage prefetch policy: {hicache_storage_prefetch_policy}"
)
# record the nodes with ongoing write through
self.ongoing_write_through = {}
# record the node segments with ongoing load back
......@@ -385,9 +394,10 @@ class HiRadixCache(RadixCache):
for _ in range(queue_size.item()):
req_id = self.cache_controller.prefetch_revoke_queue.get()
if req_id in self.ongoing_prefetch:
last_host_node, _, _, _ = self.ongoing_prefetch[req_id]
last_host_node, token_ids, _, _ = self.ongoing_prefetch[req_id]
last_host_node.release_host()
del self.ongoing_prefetch[req_id]
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
else:
# the revoked operation already got terminated
pass
......@@ -419,10 +429,41 @@ class HiRadixCache(RadixCache):
host_node.release_host()
del self.ongoing_backup[ack_id]
def check_prefetch_progress(self, req_id: str):
def can_terminate_prefetch(self, operation: PrefetchOperation):
can_terminate = True
if self.prefetch_stop_policy == "best_effort":
return can_terminate
completed = (
operation.completed_tokens == len(operation.hash_value) * self.page_size
)
if self.prefetch_stop_policy == "wait_complete":
can_terminate = completed
elif self.prefetch_stop_policy == "timeout":
can_terminate = completed or (
time.monotonic() - operation.start_time > self.prefetch_timeout
)
else:
# unknown prefetch stop policy, just return True
return True
if self.tp_world_size > 1:
can_terminate = torch.tensor(can_terminate, dtype=torch.int)
torch.distributed.all_reduce(
can_terminate,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
)
can_terminate = bool(can_terminate.item())
return can_terminate
def check_prefetch_progress(self, req_id: str) -> bool:
if req_id not in self.ongoing_prefetch:
# there is no ongoing prefetch for this request or it has been revoked
return
return True
# todo: more policies for prefetch progress such as timeout
# the current policy is to prefetch with best effort and terminate when queuing is over
......@@ -430,13 +471,16 @@ class HiRadixCache(RadixCache):
req_id
]
if not self.can_terminate_prefetch(operation):
return False
completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
operation
)
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
min_completed_tokens = completed_tokens
if self.tp_world_size > 1:
if self.tp_world_size > 1 and self.prefetch_stop_policy != "wait_complete":
# synchrnoize TP workers to make the same update to hiradix cache
completed_tokens_tensor = torch.tensor(
min_completed_tokens, dtype=torch.int
......@@ -464,6 +508,9 @@ class HiRadixCache(RadixCache):
)
last_host_node.release_host()
del self.ongoing_prefetch[req_id]
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
return True
def match_prefix(self, key: List[int], **kwargs):
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
......@@ -531,6 +578,7 @@ class HiRadixCache(RadixCache):
host_indices,
operation,
)
self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens)
def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
node.last_access_time = time.monotonic()
......
......@@ -96,6 +96,8 @@ class Hf3fsClient:
)
self.iov_r = make_iovec(self.shm_r, self.hf3fs_mount_point)
self.iov_w = make_iovec(self.shm_w, self.hf3fs_mount_point)
self.shm_r.unlink()
self.shm_w.unlink()
self.rlock = threading.RLock()
self.wlock = threading.RLock()
......@@ -176,8 +178,6 @@ class Hf3fsClient:
del self.iov_w
self.shm_r.close()
self.shm_w.close()
self.shm_r.unlink()
self.shm_w.unlink()
def flush(self) -> None:
os.fsync(self.file)
......@@ -203,6 +203,7 @@ class ServerArgs:
hicache_io_backend: str = "kernel"
hicache_mem_layout: str = "layer_first"
hicache_storage_backend: Optional[str] = None
hicache_storage_prefetch_policy: str = "best_effort"
# Double Sparsity
enable_double_sparsity: bool = False
......@@ -1626,6 +1627,13 @@ class ServerArgs:
default=ServerArgs.hicache_storage_backend,
help="The storage backend for hierarchical KV cache.",
)
parser.add_argument(
"--hicache-storage-prefetch-policy",
type=str,
choices=["best_effort", "wait_complete", "timeout"],
default=ServerArgs.hicache_storage_prefetch_policy,
help="Control when prefetching from the storage backend should stop.",
)
# Double Sparsity
parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment