Unverified Commit e2fd2b9c authored by pansicheng's avatar pansicheng Committed by GitHub
Browse files

Simple prefetch policy (#8692)

parent 7490e3f6
...@@ -20,6 +20,8 @@ from sglang.bench_serving import ( ...@@ -20,6 +20,8 @@ from sglang.bench_serving import (
sample_random_requests, sample_random_requests,
) )
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=20 * 60 * 60)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
...@@ -139,7 +141,7 @@ async def async_request_sglang_generate( ...@@ -139,7 +141,7 @@ async def async_request_sglang_generate(
""" """
Sends a streaming request to the server. Gathers text token-by-token. Sends a streaming request to the server. Gathers text token-by-token.
""" """
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers = {} headers = {}
generated_text = "" generated_text = ""
ttft = 0.0 ttft = 0.0
...@@ -150,6 +152,8 @@ async def async_request_sglang_generate( ...@@ -150,6 +152,8 @@ async def async_request_sglang_generate(
try: try:
async with session.post(url=url, json=payload, headers=headers) as response: async with session.post(url=url, json=payload, headers=headers) as response:
if response.status == 200: if response.status == 200:
prompt_tokens = 0
cached_tokens = 0
async for chunk_bytes in response.content: async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip() chunk_bytes = chunk_bytes.strip()
if not chunk_bytes: if not chunk_bytes:
...@@ -168,6 +172,12 @@ async def async_request_sglang_generate( ...@@ -168,6 +172,12 @@ async def async_request_sglang_generate(
if ttft == 0.0: if ttft == 0.0:
ttft = time.perf_counter() - st ttft = time.perf_counter() - st
output.ttft = ttft output.ttft = ttft
prompt_tokens = (data.get("meta_info") or {}).get(
"prompt_tokens", 0
)
cached_tokens = (data.get("meta_info") or {}).get(
"cached_tokens", 0
)
# Decoding phase # Decoding phase
else: else:
...@@ -179,6 +189,8 @@ async def async_request_sglang_generate( ...@@ -179,6 +189,8 @@ async def async_request_sglang_generate(
output.generated_text = generated_text output.generated_text = generated_text
output.success = True output.success = True
output.latency = latency output.latency = latency
output.prompt_len = prompt_tokens
output.cached_tokens = cached_tokens
else: else:
output.error = response.reason or "" output.error = response.reason or ""
output.success = False output.success = False
...@@ -201,6 +213,7 @@ def gen_payload(prompt, output_len): ...@@ -201,6 +213,7 @@ def gen_payload(prompt, output_len):
"ignore_eos": True, "ignore_eos": True,
}, },
"stream": True, "stream": True,
"stream_options": {"include_usage": True},
"lora_path": "", "lora_path": "",
"return_logprob": False, "return_logprob": False,
"logprob_start_len": -1, "logprob_start_len": -1,
...@@ -303,7 +316,12 @@ class WorkloadGenerator: ...@@ -303,7 +316,12 @@ class WorkloadGenerator:
self.response_queue = queue.Queue() self.response_queue = queue.Queue()
self.pbar = tqdm(total=args.num_clients * args.num_rounds) self.pbar = tqdm(total=args.num_clients * args.num_rounds)
self.performance_metrics = {"ttft": [], "latency": []} self.performance_metrics = {
"ttft": [],
"latency": [],
"prompt_len": [],
"cached_tokens": [],
}
async def handle_request(self, item): async def handle_request(self, item):
try: try:
...@@ -360,6 +378,8 @@ class WorkloadGenerator: ...@@ -360,6 +378,8 @@ class WorkloadGenerator:
self.client_records[client_id]["round"] += 1 self.client_records[client_id]["round"] += 1
self.performance_metrics["ttft"].append(response.ttft) self.performance_metrics["ttft"].append(response.ttft)
self.performance_metrics["latency"].append(response.latency) self.performance_metrics["latency"].append(response.latency)
self.performance_metrics["prompt_len"].append(response.prompt_len)
self.performance_metrics["cached_tokens"].append(response.cached_tokens)
self.completed_requests += 1 self.completed_requests += 1
if self.client_records[client_id]["round"] < args.num_rounds: if self.client_records[client_id]["round"] < args.num_rounds:
...@@ -416,6 +436,12 @@ class WorkloadGenerator: ...@@ -416,6 +436,12 @@ class WorkloadGenerator:
len(self.performance_metrics["latency"]) // 2 len(self.performance_metrics["latency"]) // 2
], ],
"throughput": self.pbar.total / (self.finished_time - self.start_time), "throughput": self.pbar.total / (self.finished_time - self.start_time),
"cache_hit_rate": (
0
if sum(self.performance_metrics["prompt_len"]) == 0
else sum(self.performance_metrics["cached_tokens"])
/ sum(self.performance_metrics["prompt_len"])
),
}, },
} }
print("All requests completed") print("All requests completed")
...@@ -434,6 +460,7 @@ class WorkloadGenerator: ...@@ -434,6 +460,7 @@ class WorkloadGenerator:
print( print(
f" Throughput: {performance_data['summary']['throughput']:.2f} requests per second" f" Throughput: {performance_data['summary']['throughput']:.2f} requests per second"
) )
print(f" Cache Hit Rate: {performance_data['summary']['cache_hit_rate']:.6f}")
log_to_jsonl_file(performance_data, args.log_file, tag=args.tag) log_to_jsonl_file(performance_data, args.log_file, tag=args.tag)
......
...@@ -16,6 +16,7 @@ limitations under the License. ...@@ -16,6 +16,7 @@ limitations under the License.
import logging import logging
import math import math
import threading import threading
import time
from queue import Empty, Full, PriorityQueue, Queue from queue import Empty, Full, PriorityQueue, Queue
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
...@@ -195,6 +196,8 @@ class PrefetchOperation(StorageOperation): ...@@ -195,6 +196,8 @@ class PrefetchOperation(StorageOperation):
self._done_flag = False self._done_flag = False
self._lock = threading.Lock() self._lock = threading.Lock()
self.start_time = time.monotonic()
super().__init__(host_indices, token_ids, last_hash) super().__init__(host_indices, token_ids, last_hash)
def increment(self, num_tokens: int): def increment(self, num_tokens: int):
...@@ -278,6 +281,12 @@ class HiCacheController: ...@@ -278,6 +281,12 @@ class HiCacheController:
self.enable_storage = True self.enable_storage = True
# todo: threshold policy for prefetching # todo: threshold policy for prefetching
self.prefetch_threshold = max(prefetch_threshold, self.page_size) self.prefetch_threshold = max(prefetch_threshold, self.page_size)
self.prefetch_capacity_limit = int(
0.8 * (self.mem_pool_host.size - self.mem_pool_device.size)
)
# tracking the number of tokens locked in prefetching, updated by the main scheduler thread
self.prefetch_tokens_occupied = 0
# create a new communication group for synchronizing storage operations across TP workers # create a new communication group for synchronizing storage operations across TP workers
self.tp_world_size = torch.distributed.get_world_size(group=tp_group) self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
if self.tp_world_size > 1: if self.tp_world_size > 1:
...@@ -525,7 +534,7 @@ class HiCacheController: ...@@ -525,7 +534,7 @@ class HiCacheController:
host_indices: torch.Tensor, host_indices: torch.Tensor,
new_input_tokens: List[int], new_input_tokens: List[int],
last_hash: Optional[str] = None, last_hash: Optional[str] = None,
) -> int: ) -> PrefetchOperation:
""" """
Prefetch KV caches from storage backend to host memory. Prefetch KV caches from storage backend to host memory.
""" """
...@@ -586,11 +595,23 @@ class HiCacheController: ...@@ -586,11 +595,23 @@ class HiCacheController:
operation = self.prefetch_buffer.get(block=True, timeout=1) operation = self.prefetch_buffer.get(block=True, timeout=1)
if self.is_mooncake_backend(): if self.is_mooncake_backend():
self.mooncake_page_transfer(operation) self.mooncake_page_transfer(operation)
elif self.storage_backend_type == "hf3fs":
self.generic_page_transfer(operation, batch_size=128)
else: else:
self.generic_page_transfer(operation) self.generic_page_transfer(operation)
except Empty: except Empty:
continue continue
def prefetch_rate_limit_check(self) -> bool:
"""
Rate limit the prefetching operations to avoid overwhelming the storage backend.
"""
# cancel prefetch if too much memory is occupied
if self.prefetch_tokens_occupied >= self.prefetch_capacity_limit:
return False
# todo: more sophisticated rate limiting based on storage backend performance
return True
def prefetch_thread_func(self): def prefetch_thread_func(self):
""" """
Manage prefetching operations from storage backend to host memory. Manage prefetching operations from storage backend to host memory.
...@@ -604,34 +625,36 @@ class HiCacheController: ...@@ -604,34 +625,36 @@ class HiCacheController:
if operation is None: if operation is None:
continue continue
last_hash = operation.last_hash
tokens_to_fetch = operation.token_ids
storage_hit_count = 0 storage_hit_count = 0
remaining_tokens = len(tokens_to_fetch) if self.prefetch_rate_limit_check():
hash_value = [] last_hash = operation.last_hash
while remaining_tokens >= self.page_size: tokens_to_fetch = operation.token_ids
last_hash = self.get_hash_str(
tokens_to_fetch[ remaining_tokens = len(tokens_to_fetch)
storage_hit_count : storage_hit_count + self.page_size hash_value = []
], while remaining_tokens >= self.page_size:
last_hash, last_hash = self.get_hash_str(
) tokens_to_fetch[
storage_hit_count : storage_hit_count + self.page_size
# todo, more unified interface ],
if not self.is_mooncake_backend(): last_hash,
if not self.storage_backend.exists(last_hash): )
break
hash_value.append(last_hash) # todo, more unified interface
storage_hit_count += self.page_size if not self.is_mooncake_backend():
remaining_tokens -= self.page_size if not self.storage_backend.exists(last_hash):
break
if self.is_mooncake_backend(): hash_value.append(last_hash)
# deferring to batch exists for mooncake store storage_hit_count += self.page_size
exist_result = self.storage_backend.exists(hash_value) remaining_tokens -= self.page_size
storage_hit_count = (
sum(1 for v in exist_result.values() if v != 0) * self.page_size if self.is_mooncake_backend():
) # deferring to batch exists for mooncake store
exist_result = self.storage_backend.exists(hash_value)
storage_hit_count = (
sum(1 for v in exist_result.values() if v != 0)
* self.page_size
)
if self.tp_world_size > 1: if self.tp_world_size > 1:
storage_hit_count_tensor = torch.tensor( storage_hit_count_tensor = torch.tensor(
...@@ -750,6 +773,8 @@ class HiCacheController: ...@@ -750,6 +773,8 @@ class HiCacheController:
if self.is_mooncake_backend(): if self.is_mooncake_backend():
self.mooncake_page_backup(operation) self.mooncake_page_backup(operation)
elif self.storage_backend_type == "hf3fs":
self.generic_page_backup(operation, batch_size=128)
else: else:
self.generic_page_backup(operation) self.generic_page_backup(operation)
......
...@@ -619,6 +619,7 @@ class Scheduler( ...@@ -619,6 +619,7 @@ class Scheduler(
), ),
hicache_mem_layout=server_args.hicache_mem_layout, hicache_mem_layout=server_args.hicache_mem_layout,
hicache_storage_backend=server_args.hicache_storage_backend, hicache_storage_backend=server_args.hicache_storage_backend,
hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
) )
self.tp_worker.register_hicache_layer_transfer_counter( self.tp_worker.register_hicache_layer_transfer_counter(
self.tree_cache.cache_controller.layer_done_counter self.tree_cache.cache_controller.layer_done_counter
...@@ -1572,7 +1573,10 @@ class Scheduler( ...@@ -1572,7 +1573,10 @@ class Scheduler(
break break
if self.enable_hicache_storage: if self.enable_hicache_storage:
self.tree_cache.check_prefetch_progress(req.rid) prefetch_done = self.tree_cache.check_prefetch_progress(req.rid)
if not prefetch_done:
# skip staging requests that are ongoing prefetch
continue
req.init_next_round_input(self.tree_cache) req.init_next_round_input(self.tree_cache)
res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None)) res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
......
...@@ -2,11 +2,12 @@ import heapq ...@@ -2,11 +2,12 @@ import heapq
import logging import logging
import threading import threading
import time import time
from queue import Queue
from typing import List, Optional from typing import List, Optional
import torch import torch
from sglang.srt.managers.cache_controller import HiCacheController from sglang.srt.managers.cache_controller import HiCacheController, PrefetchOperation
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import MatchResult from sglang.srt.mem_cache.base_prefix_cache import MatchResult
from sglang.srt.mem_cache.memory_pool import ( from sglang.srt.mem_cache.memory_pool import (
...@@ -37,6 +38,7 @@ class HiRadixCache(RadixCache): ...@@ -37,6 +38,7 @@ class HiRadixCache(RadixCache):
hicache_io_backend: str, hicache_io_backend: str,
hicache_mem_layout: str, hicache_mem_layout: str,
hicache_storage_backend: Optional[str] = None, hicache_storage_backend: Optional[str] = None,
hicache_storage_prefetch_policy: Optional[str] = "best_effort",
): ):
if hicache_io_backend == "direct": if hicache_io_backend == "direct":
...@@ -85,6 +87,13 @@ class HiRadixCache(RadixCache): ...@@ -85,6 +87,13 @@ class HiRadixCache(RadixCache):
prefetch_threshold=self.prefetch_threshold, prefetch_threshold=self.prefetch_threshold,
) )
self.prefetch_stop_policy = hicache_storage_prefetch_policy
# todo: customizable storage prefetch timeout
self.prefetch_timeout = 3 # seconds
logger.info(
f"HiCache storage prefetch policy: {hicache_storage_prefetch_policy}"
)
# record the nodes with ongoing write through # record the nodes with ongoing write through
self.ongoing_write_through = {} self.ongoing_write_through = {}
# record the node segments with ongoing load back # record the node segments with ongoing load back
...@@ -385,9 +394,10 @@ class HiRadixCache(RadixCache): ...@@ -385,9 +394,10 @@ class HiRadixCache(RadixCache):
for _ in range(queue_size.item()): for _ in range(queue_size.item()):
req_id = self.cache_controller.prefetch_revoke_queue.get() req_id = self.cache_controller.prefetch_revoke_queue.get()
if req_id in self.ongoing_prefetch: if req_id in self.ongoing_prefetch:
last_host_node, _, _, _ = self.ongoing_prefetch[req_id] last_host_node, token_ids, _, _ = self.ongoing_prefetch[req_id]
last_host_node.release_host() last_host_node.release_host()
del self.ongoing_prefetch[req_id] del self.ongoing_prefetch[req_id]
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
else: else:
# the revoked operation already got terminated # the revoked operation already got terminated
pass pass
...@@ -419,10 +429,41 @@ class HiRadixCache(RadixCache): ...@@ -419,10 +429,41 @@ class HiRadixCache(RadixCache):
host_node.release_host() host_node.release_host()
del self.ongoing_backup[ack_id] del self.ongoing_backup[ack_id]
def check_prefetch_progress(self, req_id: str): def can_terminate_prefetch(self, operation: PrefetchOperation):
can_terminate = True
if self.prefetch_stop_policy == "best_effort":
return can_terminate
completed = (
operation.completed_tokens == len(operation.hash_value) * self.page_size
)
if self.prefetch_stop_policy == "wait_complete":
can_terminate = completed
elif self.prefetch_stop_policy == "timeout":
can_terminate = completed or (
time.monotonic() - operation.start_time > self.prefetch_timeout
)
else:
# unknown prefetch stop policy, just return True
return True
if self.tp_world_size > 1:
can_terminate = torch.tensor(can_terminate, dtype=torch.int)
torch.distributed.all_reduce(
can_terminate,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
)
can_terminate = bool(can_terminate.item())
return can_terminate
def check_prefetch_progress(self, req_id: str) -> bool:
if req_id not in self.ongoing_prefetch: if req_id not in self.ongoing_prefetch:
# there is no ongoing prefetch for this request or it has been revoked # there is no ongoing prefetch for this request or it has been revoked
return return True
# todo: more policies for prefetch progress such as timeout # todo: more policies for prefetch progress such as timeout
# the current policy is to prefetch with best effort and terminate when queuing is over # the current policy is to prefetch with best effort and terminate when queuing is over
...@@ -430,13 +471,16 @@ class HiRadixCache(RadixCache): ...@@ -430,13 +471,16 @@ class HiRadixCache(RadixCache):
req_id req_id
] ]
if not self.can_terminate_prefetch(operation):
return False
completed_tokens, hash_value = self.cache_controller.terminate_prefetch( completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
operation operation
) )
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens") logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
min_completed_tokens = completed_tokens min_completed_tokens = completed_tokens
if self.tp_world_size > 1: if self.tp_world_size > 1 and self.prefetch_stop_policy != "wait_complete":
# synchrnoize TP workers to make the same update to hiradix cache # synchrnoize TP workers to make the same update to hiradix cache
completed_tokens_tensor = torch.tensor( completed_tokens_tensor = torch.tensor(
min_completed_tokens, dtype=torch.int min_completed_tokens, dtype=torch.int
...@@ -464,6 +508,9 @@ class HiRadixCache(RadixCache): ...@@ -464,6 +508,9 @@ class HiRadixCache(RadixCache):
) )
last_host_node.release_host() last_host_node.release_host()
del self.ongoing_prefetch[req_id] del self.ongoing_prefetch[req_id]
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
return True
def match_prefix(self, key: List[int], **kwargs): def match_prefix(self, key: List[int], **kwargs):
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device) empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
...@@ -531,6 +578,7 @@ class HiRadixCache(RadixCache): ...@@ -531,6 +578,7 @@ class HiRadixCache(RadixCache):
host_indices, host_indices,
operation, operation,
) )
self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens)
def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value): def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
node.last_access_time = time.monotonic() node.last_access_time = time.monotonic()
......
...@@ -96,6 +96,8 @@ class Hf3fsClient: ...@@ -96,6 +96,8 @@ class Hf3fsClient:
) )
self.iov_r = make_iovec(self.shm_r, self.hf3fs_mount_point) self.iov_r = make_iovec(self.shm_r, self.hf3fs_mount_point)
self.iov_w = make_iovec(self.shm_w, self.hf3fs_mount_point) self.iov_w = make_iovec(self.shm_w, self.hf3fs_mount_point)
self.shm_r.unlink()
self.shm_w.unlink()
self.rlock = threading.RLock() self.rlock = threading.RLock()
self.wlock = threading.RLock() self.wlock = threading.RLock()
...@@ -176,8 +178,6 @@ class Hf3fsClient: ...@@ -176,8 +178,6 @@ class Hf3fsClient:
del self.iov_w del self.iov_w
self.shm_r.close() self.shm_r.close()
self.shm_w.close() self.shm_w.close()
self.shm_r.unlink()
self.shm_w.unlink()
def flush(self) -> None: def flush(self) -> None:
os.fsync(self.file) os.fsync(self.file)
...@@ -203,6 +203,7 @@ class ServerArgs: ...@@ -203,6 +203,7 @@ class ServerArgs:
hicache_io_backend: str = "kernel" hicache_io_backend: str = "kernel"
hicache_mem_layout: str = "layer_first" hicache_mem_layout: str = "layer_first"
hicache_storage_backend: Optional[str] = None hicache_storage_backend: Optional[str] = None
hicache_storage_prefetch_policy: str = "best_effort"
# Double Sparsity # Double Sparsity
enable_double_sparsity: bool = False enable_double_sparsity: bool = False
...@@ -1626,6 +1627,13 @@ class ServerArgs: ...@@ -1626,6 +1627,13 @@ class ServerArgs:
default=ServerArgs.hicache_storage_backend, default=ServerArgs.hicache_storage_backend,
help="The storage backend for hierarchical KV cache.", help="The storage backend for hierarchical KV cache.",
) )
parser.add_argument(
"--hicache-storage-prefetch-policy",
type=str,
choices=["best_effort", "wait_complete", "timeout"],
default=ServerArgs.hicache_storage_prefetch_policy,
help="Control when prefetching from the storage backend should stop.",
)
# Double Sparsity # Double Sparsity
parser.add_argument( parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment