Unverified Commit 6c7a152c authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

Hierarchical Caching for SGLang (#2693)


Co-authored-by: default avatarWenxuan Tan <wenxuan.tan@wisc.edu>
Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
parent 4d2a88bd
## Run synthetic multi-turn benchmark
```
# SGLang server with radix cache disabled
python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000 --disable-radix-cache
# SGLang server with radix cache on and first-come-first-serve policy
python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000 --schedule-policy fcfs
# The default SGLang server with radix cache on and long-prefix-match policy
python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000
# SGLang server with hierarchical radix cache enabled
python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000 --enable-hierarchical-cache
```
```
python bench_multiturn.py --model-path Qwen/Qwen2.5-14B-Instruct
```
Note: The performance gain of hierarchical caching depends on the ratio of reusable tokens to GPU memory capacity. The more tokens to be reused, the larger the model, and the more constrained the GPU memory size, the greater the benefit one can expect from hierarchical caching.
## More benchmarks to be added
...@@ -5,6 +5,7 @@ import queue ...@@ -5,6 +5,7 @@ import queue
import random import random
import threading import threading
import time import time
from datetime import datetime
from typing import Optional from typing import Optional
import aiohttp import aiohttp
...@@ -26,9 +27,15 @@ def parse_args(): ...@@ -26,9 +27,15 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--num-clients", "--num-clients",
type=int, type=int,
default=200, default=256,
help="Number of concurrent clients", help="Number of concurrent clients",
) )
parser.add_argument(
"--max-parallel",
type=int,
default=128,
help="Maximum number of parallel requests",
)
parser.add_argument( parser.add_argument(
"--request-length", "--request-length",
type=int, type=int,
...@@ -73,11 +80,17 @@ def parse_args(): ...@@ -73,11 +80,17 @@ def parse_args():
help="Server port (default: 30000)", help="Server port (default: 30000)",
) )
parser.add_argument( parser.add_argument(
"--model", "--model-path",
type=str, type=str,
default="meta-llama/Llama-3.1-8B-Instruct", default="meta-llama/Llama-3.1-8B-Instruct",
help="model path compatible with Hugging Face Transformers", help="model path compatible with Hugging Face Transformers",
) )
parser.add_argument(
"--log-file",
type=str,
default="performance_metrics.jsonl",
help="File to log performance metrics",
)
return parser.parse_args() return parser.parse_args()
...@@ -158,6 +171,18 @@ def gen_payload(prompt, output_len): ...@@ -158,6 +171,18 @@ def gen_payload(prompt, output_len):
return payload return payload
def log_to_jsonl_file(data, file_path="performance_metrics.jsonl"):
"""Append the data with a timestamp to the specified JSONL file."""
timestamped_data = {"timestamp": datetime.now().isoformat(), **data}
try:
with open(file_path, "a") as file:
file.write(
json.dumps(timestamped_data) + "\n"
) # Write as a single line in JSONL format
except IOError as e:
print(f"Error writing to JSONL file: {e}")
class ReadyQueue: class ReadyQueue:
""" """
Thread-safe queue that can pop requests in different orders based on given policy. Thread-safe queue that can pop requests in different orders based on given policy.
...@@ -191,12 +216,15 @@ class WorkloadGenerator: ...@@ -191,12 +216,15 @@ class WorkloadGenerator:
# Construct the base URL for requests # Construct the base URL for requests
self.url = f"http://{args.host}:{args.port}/generate" self.url = f"http://{args.host}:{args.port}/generate"
self.tokenizer = get_tokenizer(args.model) self.tokenizer = get_tokenizer(args.model_path)
self.distribution = args.distribution self.distribution = args.distribution
self.request_rate = args.request_rate self.request_rate = args.request_rate
self.start_time = None self.start_time = None
self.finished_time = None self.finished_time = None
self.sent_requests = 0
self.completed_requests = 0
self.candidate_inputs = sample_random_requests( self.candidate_inputs = sample_random_requests(
input_len=args.request_length, input_len=args.request_length,
output_len=args.output_length, output_len=args.output_length,
...@@ -235,6 +263,18 @@ class WorkloadGenerator: ...@@ -235,6 +263,18 @@ class WorkloadGenerator:
def request_sender(self): def request_sender(self):
async def request_loop(): async def request_loop():
while True: while True:
if self.sent_requests - self.completed_requests < args.max_parallel:
new_request = self.ready_queue.pop()
if new_request:
asyncio.create_task(self.handle_request(new_request))
self.sent_requests += 1
else:
await asyncio.sleep(0.05)
continue
if self.pbar.n == self.pbar.total:
break
# Calculate Poisson-distributed wait time # Calculate Poisson-distributed wait time
if self.distribution == "poisson": if self.distribution == "poisson":
sleep_time = random.expovariate(self.request_rate) sleep_time = random.expovariate(self.request_rate)
...@@ -247,14 +287,6 @@ class WorkloadGenerator: ...@@ -247,14 +287,6 @@ class WorkloadGenerator:
raise ValueError("Invalid distribution type") raise ValueError("Invalid distribution type")
await asyncio.sleep(sleep_time) # Wait before sending the next request await asyncio.sleep(sleep_time) # Wait before sending the next request
new_request = self.ready_queue.pop()
# Submit async request
if new_request:
asyncio.create_task(self.handle_request(new_request))
else:
if self.pbar.n == self.pbar.total:
break
# Create and run the event loop for asynchronous requests # Create and run the event loop for asynchronous requests
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
...@@ -273,6 +305,7 @@ class WorkloadGenerator: ...@@ -273,6 +305,7 @@ class WorkloadGenerator:
self.client_records[client_id]["round"] += 1 self.client_records[client_id]["round"] += 1
self.performance_metrics["ttft"].append(response.ttft) self.performance_metrics["ttft"].append(response.ttft)
self.performance_metrics["latency"].append(response.latency) self.performance_metrics["latency"].append(response.latency)
self.completed_requests += 1
if self.client_records[client_id]["round"] < args.num_rounds: if self.client_records[client_id]["round"] < args.num_rounds:
self.client_records[client_id][ self.client_records[client_id][
...@@ -301,34 +334,56 @@ class WorkloadGenerator: ...@@ -301,34 +334,56 @@ class WorkloadGenerator:
request_thread.join() request_thread.join()
response_thread.join() response_thread.join()
self.pbar.close() self.pbar.close()
print("All requests completed.")
performance_data = {
"summary": {
"total_requests": len(self.performance_metrics["ttft"]),
"request_rate": self.request_rate,
"average_ttft": sum(self.performance_metrics["ttft"])
/ len(self.performance_metrics["ttft"]),
"p90_ttft": sorted(self.performance_metrics["ttft"])[
int(0.9 * len(self.performance_metrics["ttft"]))
],
"median_ttft": sorted(self.performance_metrics["ttft"])[
len(self.performance_metrics["ttft"]) // 2
],
"average_latency": sum(self.performance_metrics["latency"])
/ len(self.performance_metrics["latency"]),
"p90_latency": sorted(self.performance_metrics["latency"])[
int(0.9 * len(self.performance_metrics["latency"]))
],
"median_latency": sorted(self.performance_metrics["latency"])[
len(self.performance_metrics["latency"]) // 2
],
"throughput": self.pbar.total / (self.finished_time - self.start_time),
},
}
print("All requests completed")
print("Performance metrics summary:") print("Performance metrics summary:")
print( print(
f" Total requests: {len(self.performance_metrics['ttft'])} at {self.request_rate} requests per second" f" Total requests: {performance_data['summary']['total_requests']} at {performance_data['summary']['request_rate']} requests per second"
)
print(
f" Average TTFT: {sum(self.performance_metrics['ttft']) / len(self.performance_metrics['ttft']):.2f}"
)
print(
f" Median TTFT: {sorted(self.performance_metrics['ttft'])[len(self.performance_metrics['ttft']) // 2]:.2f}"
) )
print(f" Average TTFT: {performance_data['summary']['average_ttft']:.2f}")
print(f" P90 TTFT: {performance_data['summary']['p90_ttft']:.2f}")
print(f" Median TTFT: {performance_data['summary']['median_ttft']:.2f}")
print( print(
f" Average latency: {sum(self.performance_metrics['latency']) / len(self.performance_metrics['latency']):.2f}" f" Average latency: {performance_data['summary']['average_latency']:.2f}"
) )
print(f" P90 latency: {performance_data['summary']['p90_latency']:.2f}")
print(f" Median latency: {performance_data['summary']['median_latency']:.2f}")
print( print(
f" Median latency: {sorted(self.performance_metrics['latency'])[len(self.performance_metrics['latency']) // 2]:.2f}" f" Throughput: {performance_data['summary']['throughput']:.2f} requests per second"
) )
throughput = self.pbar.total / (self.finished_time - self.start_time) log_to_jsonl_file(performance_data, args.log_file)
print(f"Throughput: {throughput:.2f} requests per second")
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
flush_cache_url = f"http://{args.host}:{args.port}/flush_cache" flush_cache_url = f"http://{args.host}:{args.port}/flush_cache"
for request_rate in range(1, 41, 2): for request_rate in [16, 14, 12, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]:
args.request_rate = request_rate args.request_rate = request_rate
requests.post(flush_cache_url) requests.post(flush_cache_url)
time.sleep(1)
WorkloadGenerator(args).run() WorkloadGenerator(args).run()
...@@ -5,9 +5,7 @@ Copyright 2023-2025 SGLang Team ...@@ -5,9 +5,7 @@ Copyright 2023-2025 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -15,10 +13,12 @@ See the License for the specific language governing permissions and ...@@ -15,10 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import concurrent.futures
import logging import logging
import math
import threading import threading
from queue import PriorityQueue, Queue from queue import Empty, Full, PriorityQueue, Queue
from typing import Optional from typing import List, Optional
import torch import torch
...@@ -55,6 +55,27 @@ class CacheOperation: ...@@ -55,6 +55,27 @@ class CacheOperation:
self.priority = min(self.priority, other.priority) self.priority = min(self.priority, other.priority)
self.node_ids.extend(other.node_ids) self.node_ids.extend(other.node_ids)
def split(self, factor) -> List["CacheOperation"]:
# split an operation into smaller operations to reduce the size of intermediate buffers
if factor <= 1:
return [self]
chunk_size = math.ceil(len(self.host_indices) / factor)
split_ops = []
for i in range(0, len(self.host_indices), chunk_size):
split_ops.append(
CacheOperation(
host_indices=self.host_indices[i : i + chunk_size],
device_indices=self.device_indices[i : i + chunk_size],
node_id=0,
)
)
# Inherit the node_ids on the final chunk
if split_ops:
split_ops[-1].node_ids = self.node_ids
return split_ops
def __lt__(self, other: "CacheOperation"): def __lt__(self, other: "CacheOperation"):
return self.priority < other.priority return self.priority < other.priority
...@@ -64,7 +85,10 @@ class TransferBuffer: ...@@ -64,7 +85,10 @@ class TransferBuffer:
Overlapping buffer preparation and transfer operations to improve throughput. Overlapping buffer preparation and transfer operations to improve throughput.
""" """
def __init__(self, buffer_count: int = 3, max_buffer_size: int = 1000) -> None: def __init__(
self, stop_event, buffer_count: int = 3, max_buffer_size: int = 1000
) -> None:
self.stop_event = stop_event
self.buffers = Queue(maxsize=buffer_count) self.buffers = Queue(maxsize=buffer_count)
# todo: adjust the buffer size based on throughput profile of the system # todo: adjust the buffer size based on throughput profile of the system
self.max_buffer_size = max_buffer_size self.max_buffer_size = max_buffer_size
...@@ -75,15 +99,29 @@ class TransferBuffer: ...@@ -75,15 +99,29 @@ class TransferBuffer:
def empty(self) -> bool: def empty(self) -> bool:
return self.buffers.empty() return self.buffers.empty()
def put(self, item, block=True) -> None: def put(self, item, block=True, timeout=1) -> None:
self.buffers.put(item, block=block) while not self.stop_event.is_set():
try:
self.buffers.put(item, block=block, timeout=timeout)
break
except Full:
if not block:
break
continue
except Exception as e:
logger.error(e)
def get(self, block=True) -> Optional[CacheOperation]: def get(self, block=True, timeout=1) -> Optional[CacheOperation]:
try: try:
return self.buffers.get(block=block) return self.buffers.get(block=block, timeout=timeout)
except Empty:
return None
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
def clear(self):
self.buffers.queue.clear()
class HiCacheController: class HiCacheController:
...@@ -111,8 +149,11 @@ class HiCacheController: ...@@ -111,8 +149,11 @@ class HiCacheController:
self.ack_write_queue = Queue() self.ack_write_queue = Queue()
self.ack_load_queue = Queue() self.ack_load_queue = Queue()
self.write_buffer = TransferBuffer() self.stop_event = threading.Event()
self.load_buffer = TransferBuffer() self.write_buffer = TransferBuffer(self.stop_event)
self.load_buffer = TransferBuffer(
self.stop_event, buffer_count=10, max_buffer_size=100
)
self.write_stream = torch.cuda.Stream() self.write_stream = torch.cuda.Stream()
self.load_stream = torch.cuda.Stream() self.load_stream = torch.cuda.Stream()
...@@ -126,6 +167,28 @@ class HiCacheController: ...@@ -126,6 +167,28 @@ class HiCacheController:
self.write_thread.start() self.write_thread.start()
self.load_thread.start() self.load_thread.start()
def reset(self):
self.stop_event.set()
self.write_thread.join()
self.load_thread.join()
self.write_queue.queue.clear()
self.load_queue.queue.clear()
self.write_buffer.clear()
self.load_buffer.clear()
self.ack_write_queue.queue.clear()
self.ack_load_queue.queue.clear()
self.write_thread = threading.Thread(
target=self.write_thread_func_buffer, daemon=True
)
self.load_thread = threading.Thread(
target=self.load_thread_func_buffer, daemon=True
)
self.stop_event.clear()
self.write_thread.start()
self.load_thread.start()
def write( def write(
self, self,
device_indices: torch.Tensor, device_indices: torch.Tensor,
...@@ -138,10 +201,10 @@ class HiCacheController: ...@@ -138,10 +201,10 @@ class HiCacheController:
host_indices = self.mem_pool_host.alloc(len(device_indices)) host_indices = self.mem_pool_host.alloc(len(device_indices))
if host_indices is None: if host_indices is None:
return None return None
self.mem_pool_host.protect_write(host_indices)
self.write_queue.put( self.write_queue.put(
CacheOperation(host_indices, device_indices, node_id, priority) CacheOperation(host_indices, device_indices, node_id, priority)
) )
self.mem_pool_host.protect_write(host_indices)
return host_indices return host_indices
def load( def load(
...@@ -156,10 +219,10 @@ class HiCacheController: ...@@ -156,10 +219,10 @@ class HiCacheController:
device_indices = self.mem_pool_device.alloc(len(host_indices)) device_indices = self.mem_pool_device.alloc(len(host_indices))
if device_indices is None: if device_indices is None:
return None return None
self.mem_pool_host.protect_load(host_indices)
self.load_queue.put( self.load_queue.put(
CacheOperation(host_indices, device_indices, node_id, priority) CacheOperation(host_indices, device_indices, node_id, priority)
) )
self.mem_pool_host.protect_load(host_indices)
return device_indices return device_indices
def write_thread_func_direct(self): def write_thread_func_direct(self):
...@@ -167,16 +230,19 @@ class HiCacheController: ...@@ -167,16 +230,19 @@ class HiCacheController:
Directly write through KV caches to host memory without buffering. Directly write through KV caches to host memory without buffering.
""" """
with torch.cuda.stream(self.write_stream): with torch.cuda.stream(self.write_stream):
while True: while not self.stop_event.is_set():
try: try:
operation = self.write_queue.get(block=True) operation = self.write_queue.get(block=True, timeout=1)
operation.data = self.mem_pool_device.get_flat_data( operation.data = self.mem_pool_device.get_flat_data(
operation.device_indices operation.device_indices
) )
self.mem_pool_host.transfer(operation.host_indices, operation.data) self.mem_pool_host.transfer(operation.host_indices, operation.data)
self.mem_pool_host.complete_io(operation.host_indices) self.mem_pool_host.complete_io(operation.host_indices)
for node_id in operation.node_ids: for node_id in operation.node_ids:
self.ack_write_queue.put(node_id) if node_id != 0:
self.ack_write_queue.put(node_id)
except Empty:
continue
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
...@@ -185,9 +251,10 @@ class HiCacheController: ...@@ -185,9 +251,10 @@ class HiCacheController:
Directly load KV caches from host memory to device memory without buffering. Directly load KV caches from host memory to device memory without buffering.
""" """
with torch.cuda.stream(self.load_stream): with torch.cuda.stream(self.load_stream):
while True: while not self.stop_event.is_set():
try: try:
operation = self.load_queue.get(block=True) operation = self.load_queue.get(block=True, timeout=1)
# time.sleep(18e-6 * len(operation.host_indices))
operation.data = self.mem_pool_host.get_flat_data( operation.data = self.mem_pool_host.get_flat_data(
operation.host_indices operation.host_indices
) )
...@@ -196,7 +263,10 @@ class HiCacheController: ...@@ -196,7 +263,10 @@ class HiCacheController:
) )
self.mem_pool_host.complete_io(operation.host_indices) self.mem_pool_host.complete_io(operation.host_indices)
for node_id in operation.node_ids: for node_id in operation.node_ids:
self.ack_load_queue.put(node_id) if node_id != 0:
self.ack_load_queue.put(node_id)
except Empty:
continue
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
...@@ -204,39 +274,98 @@ class HiCacheController: ...@@ -204,39 +274,98 @@ class HiCacheController:
""" """
Auxiliary function to prepare the buffer for write operations. Auxiliary function to prepare the buffer for write operations.
""" """
def _to_op(op_):
assert op_.device_indices.is_cuda, "Device indices should be on GPU"
op_.data = self.mem_pool_device.get_flat_data(op_.device_indices).to(
self.mem_pool_host.device
)
self.write_buffer.put(op_)
return op_
buffer = None buffer = None
while True: with torch.cuda.stream(self.write_stream):
try: while not self.stop_event.is_set():
operation = self.write_queue.get(block=True) try:
if buffer is None: operation = self.write_queue.get(block=True, timeout=1)
buffer = operation factor = (
else: len(operation.device_indices)
buffer.merge(operation) // self.write_buffer.max_buffer_size
if ( )
no_wait
or len(buffer.host_indices) >= self.write_buffer.max_buffer_size if factor >= 1:
or self.write_queue.empty() if buffer is not None:
or self.write_buffer.empty() _to_op(buffer)
): buffer = None
assert (
buffer.device_indices.is_cuda if factor < 2:
), "Device indices should be on GPU" _to_op(operation)
buffer.data = self.mem_pool_device.get_flat_data( else:
buffer.device_indices split_ops = operation.split(factor)
).contiguous() for op_ in split_ops:
self.write_buffer.put(buffer, block=True) _to_op(op_)
buffer = None continue
except Exception as e:
logger.error(e) if buffer is None:
buffer = operation
else:
buffer.merge(operation)
if (
no_wait
or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
or self.write_queue.empty()
or self.write_buffer.empty()
):
_to_op(buffer)
buffer = None
except Empty:
continue
except Exception as e:
logger.error(e)
def load_aux_func(self): def load_aux_func(self):
""" """
Auxiliary function to prepare the buffer for load operations. Auxiliary function to prepare the buffer for load operations.
""" """
def _pin_op(op_, put=True):
op_.data = (
self.mem_pool_host.get_flat_data(op_.host_indices)
.contiguous()
.pin_memory()
)
if put:
self.load_buffer.put(op_)
return op_
buffer = None buffer = None
while True: while not self.stop_event.is_set():
try: try:
operation = self.load_queue.get(block=True) operation = self.load_queue.get(block=True, timeout=1)
factor = len(operation.host_indices) // self.load_buffer.max_buffer_size
if factor >= 1:
if buffer is not None:
_pin_op(buffer)
buffer = None
if factor < 2:
_pin_op(operation)
else:
split_ops = operation.split(factor)
split_args = [(op_, True) for op_ in split_ops[:-1]]
split_args.append((split_ops[-1], False))
# Spawn threads to pin each op concurrently
with concurrent.futures.ThreadPoolExecutor() as executor:
pinned_ops = list(
executor.map(
lambda x: _pin_op(x[0], put=x[1]), split_args
)
)
# preserve the order of last op to ensure correct ack
self.load_buffer.put(pinned_ops[-1])
continue
if buffer is None: if buffer is None:
buffer = operation buffer = operation
else: else:
...@@ -246,41 +375,43 @@ class HiCacheController: ...@@ -246,41 +375,43 @@ class HiCacheController:
or self.load_queue.empty() or self.load_queue.empty()
or self.load_buffer.empty() or self.load_buffer.empty()
): ):
buffer.data = ( _pin_op(buffer)
self.mem_pool_host.get_flat_data(buffer.host_indices)
.contiguous()
.pin_memory()
)
self.load_buffer.put(buffer, block=True)
buffer = None buffer = None
except Empty:
continue
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
def write_thread_func_buffer(self): def write_thread_func_buffer(self):
aux_thread = threading.Thread(target=self.write_aux_func, daemon=True) aux_thread = threading.Thread(target=self.write_aux_func, daemon=True)
aux_thread.start() aux_thread.start()
with torch.cuda.stream(self.write_stream):
while True: while not self.stop_event.is_set():
operation = self.write_buffer.get() operation = self.write_buffer.get()
if operation is None: if operation is None:
continue continue
self.mem_pool_host.transfer(operation.host_indices, operation.data) self.mem_pool_host.assign_flat_data(operation.host_indices, operation.data)
self.mem_pool_host.complete_io(operation.host_indices) self.mem_pool_host.complete_io(operation.host_indices)
for node_id in operation.node_ids: for node_id in operation.node_ids:
if node_id != 0:
self.ack_write_queue.put(node_id) self.ack_write_queue.put(node_id)
aux_thread.join()
def load_thread_func_buffer(self): def load_thread_func_buffer(self):
aux_thread = threading.Thread(target=self.load_aux_func, daemon=True) aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
aux_thread.start() aux_thread.start()
with torch.cuda.stream(self.load_stream): with torch.cuda.stream(self.load_stream):
while True: while not self.stop_event.is_set():
operation = self.load_buffer.get() operation = self.load_buffer.get()
if operation is None: if operation is None:
continue continue
self.mem_pool_device.transfer(operation.device_indices, operation.data) self.mem_pool_device.transfer(operation.device_indices, operation.data)
self.mem_pool_host.complete_io(operation.host_indices) self.mem_pool_host.complete_io(operation.host_indices)
for node_id in operation.node_ids: for node_id in operation.node_ids:
self.ack_load_queue.put(node_id) if node_id != 0:
self.ack_load_queue.put(node_id)
aux_thread.join()
def evict_device( def evict_device(
self, device_indices: torch.Tensor, host_indices: torch.Tensor self, device_indices: torch.Tensor, host_indices: torch.Tensor
......
...@@ -82,6 +82,7 @@ from sglang.srt.managers.tp_worker import TpModelWorker ...@@ -82,6 +82,7 @@ from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.managers.utils import validate_input_length from sglang.srt.managers.utils import validate_input_length
from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
...@@ -300,16 +301,24 @@ class Scheduler: ...@@ -300,16 +301,24 @@ class Scheduler:
token_to_kv_pool=self.token_to_kv_pool, token_to_kv_pool=self.token_to_kv_pool,
) )
else: else:
self.tree_cache = RadixCache( self.tree_cache = (
req_to_token_pool=self.req_to_token_pool, HiRadixCache(
token_to_kv_pool=self.token_to_kv_pool, req_to_token_pool=self.req_to_token_pool,
disable=server_args.disable_radix_cache, token_to_kv_pool=self.token_to_kv_pool,
)
if self.enable_hierarchical_cache
else RadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool=self.token_to_kv_pool,
disable=server_args.disable_radix_cache,
)
) )
self.tree_cache_metrics = {"total": 0, "hit": 0} self.tree_cache_metrics = {"total": 0, "hit": 0}
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache) self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
# Init running status # Init running status
self.waiting_queue: List[Req] = [] self.waiting_queue: List[Req] = []
self.staging_reqs = {}
# The running decoding batch for continuous batching # The running decoding batch for continuous batching
self.running_batch: Optional[ScheduleBatch] = None self.running_batch: Optional[ScheduleBatch] = None
# The current forward batch # The current forward batch
...@@ -953,6 +962,30 @@ class Scheduler: ...@@ -953,6 +962,30 @@ class Scheduler:
break break
req.init_next_round_input(None if prefix_computed else self.tree_cache) req.init_next_round_input(None if prefix_computed else self.tree_cache)
if self.enable_hierarchical_cache and req.last_node is not None:
if req.last_node.evicted:
# loading KV cache for the request
req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
req.last_node,
req.prefix_indices,
adder.rem_total_tokens,
)
if req.last_node.loading:
# to prevent frequent cache invalidation
if req.rid in self.staging_reqs:
self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
self.tree_cache.inc_lock_ref(req.last_node)
self.staging_reqs[req.rid] = req.last_node
continue
elif req.last_node.loading:
if not self.tree_cache.loading_complete(req.last_node):
continue
if req.rid in self.staging_reqs:
self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
del self.staging_reqs[req.rid]
res = adder.add_one_req(req) res = adder.add_one_req(req)
if res != AddReqResult.CONTINUE: if res != AddReqResult.CONTINUE:
if res == AddReqResult.NO_TOKEN: if res == AddReqResult.NO_TOKEN:
......
import heapq
import logging
import time
from typing import List, Optional
import torch
from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.mem_cache.memory_pool import (
BaseTokenToKVPool,
MLATokenToKVPoolHost,
ReqToTokenPool,
)
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match
logger = logging.getLogger(__name__)
class HiRadixCache(RadixCache):
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool: BaseTokenToKVPool,
):
self.token_to_kv_pool_host = MLATokenToKVPoolHost(token_to_kv_pool)
self.cache_controller = HiCacheController(
token_to_kv_pool, self.token_to_kv_pool_host
)
# record the nodes with ongoing write through
self.ongoing_write_through = {}
# record the node segments with ongoing load back
self.ongoing_load_back = {}
# todo: dynamically adjust the threshold
self.write_through_threshold = 1
self.load_back_threshold = 10
super().__init__(req_to_token_pool, token_to_kv_pool, disable=False)
def reset(self):
TreeNode.counter = 0
self.cache_controller.reset()
self.token_to_kv_pool_host.clear()
super().reset()
def get_height(self, node: TreeNode):
height = 0
while node != self.root_node:
node = node.parent
height += 1
return height
def write_backup(self, node: TreeNode):
host_indices = self.cache_controller.write(
device_indices=node.value,
priority=-self.get_height(node),
node_id=node.id,
)
if host_indices is None:
self.evict_host(len(node.value))
host_indices = self.cache_controller.write(
device_indices=node.value,
priority=-self.get_height(node),
node_id=node.id,
)
if host_indices is not None:
node.host_value = host_indices
self.ongoing_write_through[node.id] = node
self.inc_lock_ref(node)
else:
return None
return len(host_indices)
def inc_hit_count(self, node: TreeNode):
if self.cache_controller.write_policy != "write_through_selective":
return
node.hit_count += 1
if node.host_value is None and node.hit_count > self.write_through_threshold:
self.write_backup(node)
node.hit_count = 0
def writing_check(self):
while not self.cache_controller.ack_write_queue.empty():
try:
ack_id = self.cache_controller.ack_write_queue.get_nowait()
self.dec_lock_ref(self.ongoing_write_through[ack_id])
# clear the reference
del self.ongoing_write_through[ack_id]
except Exception:
break
def loading_check(self):
while not self.cache_controller.ack_load_queue.empty():
try:
ack_id = self.cache_controller.ack_load_queue.get_nowait()
start_node, end_node = self.ongoing_load_back[ack_id]
self.dec_lock_ref(end_node)
while end_node != start_node:
assert end_node.loading
end_node.loading = False
end_node = end_node.parent
# clear the reference
del self.ongoing_load_back[ack_id]
except Exception:
break
def evictable_size(self):
self.writing_check()
self.loading_check()
return self.evictable_size_
def evict(self, num_tokens: int, evict_callback=None):
leaves = self._collect_leaves_device()
heapq.heapify(leaves)
num_evicted = 0
pending_nodes = []
while num_evicted < num_tokens and len(leaves):
x = heapq.heappop(leaves)
if x.lock_ref > 0:
continue
if x.host_value is None:
if self.cache_controller.write_policy == "write_back":
num_evicted += self.write_backup(x)
elif self.cache_controller.write_policy == "write_through_selective":
num_evicted += self._evict_write_through_selective(x)
else:
assert (
self.cache_controller.write_policy != "write_through"
), "write_through should be inclusive"
raise NotImplementedError
else:
num_evicted += self._evict_write_through(x)
for child in x.parent.children.values():
if child in pending_nodes:
continue
if not child.evicted:
break
else:
# all children are evicted or no children
heapq.heappush(leaves, x.parent)
if self.cache_controller.write_policy == "write_back":
# blocking till all write back complete
while len(self.ongoing_write_through) > 0:
self.writing_check()
time.sleep(0.1)
def _evict_write_through(self, node: TreeNode):
# evict a node already written to host
num_evicted = self.cache_controller.evict_device(node.value, node.host_value)
assert num_evicted > 0
self.evictable_size_ -= num_evicted
node.value = None
return num_evicted
def _evict_write_through_selective(self, node: TreeNode):
# evict a node not initiated write to host
self.cache_controller.mem_pool_device.free(node.value)
num_evicted = len(node.value)
self._delete_leaf(node)
return num_evicted
def evict_host(self, num_tokens: int):
leaves = self._collect_leaves()
heapq.heapify(leaves)
num_evicted = 0
while num_evicted < num_tokens and len(leaves):
x = heapq.heappop(leaves)
if x == self.root_node:
break
# only evict the host value of evicted nodes
if not x.evicted:
continue
assert x.lock_ref == 0 and x.host_value is not None
assert self.cache_controller.evict_host(x.host_value) > 0
for k, v in x.parent.children.items():
if v == x:
break
del x.parent.children[k]
if len(x.parent.children) == 0 and x.parent.evicted:
heapq.heappush(leaves, x.parent)
def load_back(
self, node: TreeNode, mem_quota: Optional[int] = None
) -> Optional[torch.Tensor]:
# todo: more loading policies
last_hit_node = node
nodes_to_load = []
while node.evicted:
assert (
node.backuped
), "No backup available on evicted nodes, should not happen"
nodes_to_load.insert(0, node)
node = node.parent
else:
ancester_node = node
# protect the ancestor nodes from eviction
delta = self.inc_lock_ref(ancester_node)
# load it all or not at all
host_indices = torch.cat([n.host_value for n in nodes_to_load])
if len(host_indices) < self.load_back_threshold or (
len(host_indices) > mem_quota + delta if mem_quota is not None else False
):
# skip loading back if the total size is too small or exceeding the memory quota
self.dec_lock_ref(ancester_node)
return None
device_indices = self.cache_controller.load(
host_indices=host_indices, node_id=last_hit_node.id
)
if device_indices is None:
self.evict(len(host_indices))
device_indices = self.cache_controller.load(
host_indices=host_indices, node_id=last_hit_node.id
)
self.dec_lock_ref(ancester_node)
if device_indices is None:
# no sufficient GPU memory to load back KV caches
return None
self.ongoing_load_back[last_hit_node.id] = (ancester_node, last_hit_node)
offset = 0
for node in nodes_to_load:
node.value = device_indices[offset : offset + len(node.host_value)]
offset += len(node.host_value)
node.loading = True
self.evictable_size_ += len(device_indices)
self.inc_lock_ref(last_hit_node)
return device_indices
def loading_complete(self, node: TreeNode):
self.loading_check()
return node.loading == False
def init_load_back(
self,
last_node: TreeNode,
prefix_indices: torch.Tensor,
mem_quota: Optional[int] = None,
):
assert (
len(prefix_indices) == 0 or prefix_indices.is_cuda
), "indices of device kV caches should be on GPU"
if last_node.evicted:
loading_values = self.load_back(last_node, mem_quota)
if loading_values is not None:
prefix_indices = (
loading_values
if len(prefix_indices) == 0
else torch.cat([prefix_indices, loading_values])
)
logger.debug(
f"loading back {len(loading_values)} tokens for node {last_node.id}"
)
while last_node.evicted:
last_node = last_node.parent
return last_node, prefix_indices
def _match_prefix_helper(
self, node: TreeNode, key: List, value, last_node: TreeNode
):
node.last_access_time = time.time()
if len(key) == 0:
return
if key[0] in node.children.keys():
child = node.children[key[0]]
prefix_len = _key_match(child.key, key)
if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len)
self.inc_hit_count(new_node)
if not new_node.evicted:
value.append(new_node.value)
last_node[0] = new_node
else:
self.inc_hit_count(child)
if not child.evicted:
value.append(child.value)
last_node[0] = child
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
def _split_node(self, key, child: TreeNode, split_len: int):
# child node split into new_node -> child
new_node = TreeNode()
new_node.children = {key[split_len]: child}
new_node.parent = child.parent
new_node.lock_ref = child.lock_ref
new_node.key = child.key[:split_len]
new_node.loading = child.loading
# split value and host value if exists
if child.evicted:
new_node.value = None
else:
new_node.value = child.value[:split_len]
child.value = child.value[split_len:]
if child.host_value is not None:
new_node.host_value = child.host_value[:split_len]
child.host_value = child.host_value[split_len:]
child.parent = new_node
child.key = child.key[split_len:]
new_node.parent.children[key[0]] = new_node
return new_node
def _insert_helper(self, node: TreeNode, key: List, value):
node.last_access_time = time.time()
if len(key) == 0:
return 0
if key[0] in node.children.keys():
child = node.children[key[0]]
prefix_len = _key_match(child.key, key)
if prefix_len == len(child.key):
if child.evicted:
# change the reference if the node is evicted
# this often happens in the case of KV cache recomputation
child.value = value[:prefix_len]
self.token_to_kv_pool_host.update_synced(child.host_value)
self.evictable_size_ += len(value[:prefix_len])
return self._insert_helper(
child, key[prefix_len:], value[prefix_len:]
)
else:
self.inc_hit_count(child)
return prefix_len + self._insert_helper(
child, key[prefix_len:], value[prefix_len:]
)
# partial match, split the node
new_node = self._split_node(child.key, child, prefix_len)
if new_node.evicted:
new_node.value = value[:prefix_len]
self.token_to_kv_pool_host.update_synced(new_node.host_value)
self.evictable_size_ += len(new_node.value)
return self._insert_helper(
new_node, key[prefix_len:], value[prefix_len:]
)
else:
self.inc_hit_count(new_node)
return prefix_len + self._insert_helper(
new_node, key[prefix_len:], value[prefix_len:]
)
if len(key):
new_node = TreeNode()
new_node.parent = node
new_node.key = key
new_node.value = value
node.children[key[0]] = new_node
self.evictable_size_ += len(value)
if self.cache_controller.write_policy == "write_through":
self.write_backup(new_node)
return 0
def _collect_leaves_device(self):
def is_leaf(node):
if node.evicted:
return False
if node == self.root_node:
return False
if len(node.children) == 0:
return True
for child in node.children.values():
if not child.evicted:
return False
return True
ret_list = []
stack = [self.root_node]
while stack:
cur_node = stack.pop()
if is_leaf(cur_node):
ret_list.append(cur_node)
else:
for cur_child in cur_node.children.values():
if not cur_child.evicted:
stack.append(cur_child)
return ret_list
...@@ -442,7 +442,7 @@ class MLATokenToKVPoolHost: ...@@ -442,7 +442,7 @@ class MLATokenToKVPoolHost:
def __init__( def __init__(
self, self,
device_pool: MHATokenToKVPool, device_pool: MHATokenToKVPool,
host_to_device_ratio: float = 2.0, host_to_device_ratio: float = 4.0,
pin_memory: bool = False, # no need to use pin memory with the double buffering pin_memory: bool = False, # no need to use pin memory with the double buffering
device: str = "cpu", device: str = "cpu",
): ):
...@@ -502,6 +502,9 @@ class MLATokenToKVPoolHost: ...@@ -502,6 +502,9 @@ class MLATokenToKVPoolHost:
def get_flat_data(self, indices): def get_flat_data(self, indices):
return self.kv_buffer[:, :, indices] return self.kv_buffer[:, :, indices]
def assign_flat_data(self, indices, flat_data):
self.kv_buffer[:, :, indices] = flat_data
@debug_timing @debug_timing
def transfer(self, indices, flat_data): def transfer(self, indices, flat_data):
# backup prepared data from device to host # backup prepared data from device to host
......
...@@ -1289,7 +1289,7 @@ def debug_timing(func): ...@@ -1289,7 +1289,7 @@ def debug_timing(func):
tic.record() tic.record()
result = func(*args, **kwargs) result = func(*args, **kwargs)
toc.record() toc.record()
torch.cuda.synchronize() # Ensure all CUDA operations are complete toc.synchronize() # Wait for the function to complete without synchronizing all ops on the GPU
elapsed = tic.elapsed_time(toc) elapsed = tic.elapsed_time(toc)
indices = kwargs.get("indices", args[1] if len(args) > 1 else None) indices = kwargs.get("indices", args[1] if len(args) > 1 else None)
num_tokens = len(indices) if indices is not None else 0 num_tokens = len(indices) if indices is not None else 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment