Unverified Commit e273aa6d authored by DarkSharpness's avatar DarkSharpness Committed by GitHub
Browse files

[Feature] Radix Tree in C++ (#7369)

parent 828a4fe9
...@@ -569,7 +569,23 @@ class Scheduler( ...@@ -569,7 +569,23 @@ class Scheduler(
page_size=self.page_size, page_size=self.page_size,
) )
else: else:
if self.enable_hierarchical_cache: if os.environ.get("SGLANG_EXPERIMENTAL_CPP_RADIX_TREE") == "1":
# lazy import to avoid JIT overhead
from sglang.srt.mem_cache.radix_cache_cpp import RadixCacheCpp
self.tree_cache = RadixCacheCpp(
disable=False,
use_hicache=self.enable_hierarchical_cache,
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool=self.token_to_kv_pool_allocator,
tp_cache_group=self.tp_cpu_group,
page_size=self.page_size,
hicache_ratio=server_args.hicache_ratio,
hicache_size=server_args.hicache_size,
hicache_write_policy=server_args.hicache_write_policy,
enable_kv_cache_events=self.enable_kv_cache_events,
)
elif self.enable_hierarchical_cache:
self.tree_cache = HiRadixCache( self.tree_cache = HiRadixCache(
req_to_token_pool=self.req_to_token_pool, req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
......
../../../../../sgl-kernel/.clang-format
\ No newline at end of file
#pragma once
#include <cstddef>
#include <cstdint>
#include <source_location>
#include <span>
#include <stdexcept>
#include <string>
#include <vector>
namespace radix_tree_v2 {
using token_t = std::int32_t;
using token_vec_t = std::vector<token_t>;
using token_slice = std::span<const token_t>;
using NodeHandle = std::size_t;
using IOTicket = std::uint32_t;
inline void _assert(
bool condition,
const char* message = "Assertion failed",
std::source_location loc = std::source_location::current()) {
if (!condition) [[unlikely]] {
std::string msg = message;
msg = msg + " at " + loc.file_name() + ":" + std::to_string(loc.line()) + " in " + loc.function_name();
throw std::runtime_error(msg);
}
}
} // namespace radix_tree_v2
from __future__ import annotations
import os
from typing import TYPE_CHECKING, List, Optional, Tuple
import torch
from torch.utils.cpp_extension import load
_abs_path = os.path.dirname(os.path.abspath(__file__))
radix_tree_cpp = load(
name="radix_tree_cpp",
sources=[
f"{_abs_path}/tree_v2_binding.cpp",
f"{_abs_path}/tree_v2_debug.cpp",
f"{_abs_path}/tree_v2.cpp",
],
extra_cflags=["-O3", "-std=c++20"],
)
if TYPE_CHECKING:
class TreeNodeCpp:
"""
A placeholder for the TreeNode class. Cannot be constructed elsewhere.
"""
class IOHandle:
"""
A placeholder for the IOHandle class. Cannot be constructed elsewhere.
"""
class RadixTreeCpp:
def __init__(
self,
disabled: bool,
host_size: Optional[int],
page_size: int,
write_through_threshold: int,
):
"""
Initializes the RadixTreeCpp instance.
Args:
disabled (bool): If True, the radix tree is disabled.
host_size (Optional[int]): Size of the radix tree on the CPU. None means no CPU tree.
page_size (int): Size of the page for the radix tree.
write_through_threshold (int): Threshold for writing through from GPU to CPU.
"""
self.tree = radix_tree_cpp.RadixTree( # type: ignore
disabled, host_size, page_size, write_through_threshold
)
def match_prefix(
self, prefix: List[int]
) -> Tuple[List[torch.Tensor], int, TreeNodeCpp, TreeNodeCpp]:
"""
Matches a prefix in the radix tree.
Args:
prefix (List[int]): The prefix to match.
Returns:
Tuple[List[torch.Tensor], TreeNodeCpp, TreeNodeCpp]:
0. A list of indices that is matched by the prefix on the GPU.
1. Sum length of the indices matched on the CPU.
2. The last node of the prefix matched on the GPU.
3. The last node of the prefix matched on the CPU.
"""
return self.tree.match_prefix(prefix)
def evict(self, num_tokens: int) -> List[torch.Tensor]:
"""
Evicts a number of tokens from the radix tree.
Args:
num_tokens (int): The number of tokens to evict.
Returns:
List[torch.Tensor]: A list of indices that were evicted.
"""
return self.tree.evict(num_tokens)
def lock_ref(self, handle: TreeNodeCpp, lock: bool) -> None:
"""
Locks or unlocks a reference to a tree node.
After locking, the node will not be evicted from the radix tree.
Args:
handle (TreeNodeCpp): The tree node to lock or unlock.
lock (bool): If True, locks the node; if False, unlocks it.
"""
return self.tree.lock_ref(handle, lock)
def writing_through(
self, key: List[int], indices: torch.Tensor
) -> Tuple[List[Tuple[IOHandle, torch.Tensor, torch.Tensor]], int]:
"""
Inserts a key-value pair into the radix tree and perform write-through check.
Args:
key (List[int]): The key to insert.
indices (torch.Tensor): The value associated with the key.
Returns:
Tuple[List[Tuple[IOHandle, torch.Tensor, torch.Tensor]], int]:
0. A list of (IOHandle, device indices, host indices) tuples.
These IOhandles require write-through to the CPU in python side.
1. The number of indices that are matched on device.
"""
return self.tree.writing_through(key, indices)
def loading_onboard(
self,
host_node: TreeNodeCpp,
new_device_indices: torch.Tensor,
) -> Tuple[IOHandle, List[torch.Tensor]]:
"""
Updates the device indices of tree nodes within a range on the tree.
Args:
host_node (TreeNodeCpp): The tree node on the host, must be descendant of device_node.
new_device_indices (torch.Tensor): The new device indices to set.
The length of this tensor must be exactly host indices length.
Returns:
Tuple[IOHandle, List[torch.Tensor]]:
0. An IOHandle that requires loading to the CPU in python side.
1. A list of host indices corresponding to the new device indices.
"""
return self.tree.loading_onboard(host_node, new_device_indices)
def commit_writing_through(self, handle: IOHandle, success: bool) -> None:
"""
Commits the write-through process for a tree node.
Args:
handle (IOHandle): The IOHandle to commit.
success (bool): If True, commits the write-through; if False, just indicates failure.
"""
return self.tree.commit_writing_through(handle, success)
def commit_loading_onboard(self, handle: IOHandle, success: bool) -> None:
"""
Commits the load onboard process for tree nodes within a range on the tree.
Args:
handle (IOHandle): The IOHandle to commit.
success (bool): If True, commits the load-onboard; if False, just indicates failure.
"""
return self.tree.commit_loading_onboard(handle, success)
def evictable_size(self) -> int:
"""
Returns the size of the evictable part of the radix tree.
This is the size of the part that can be evicted from the GPU (ref_count = 0).
Returns:
int: The size of the evictable part.
"""
return self.tree.evictable_size()
def protected_size(self) -> int:
"""
Returns the size of the protected part of the radix tree.
This is the size of the part that cannot be evicted from the GPU (ref_count > 0).
Returns:
int: The size of the protected part.
"""
return self.tree.protected_size()
def total_size(self) -> int:
"""
Returns the total size of the radix tree (including CPU nodes).
Returns:
int: The total size of the radix tree.
"""
return self.tree.total_size()
def reset(self) -> None:
"""
Resets the radix tree, clearing all nodes and indices.
"""
return self.tree.reset()
def debug_print(self) -> None:
"""
Prints the internal state of the radix tree for debugging purposes.
"""
return self.tree.debug_print()
else:
# Real implementation of the classes for runtime
RadixTreeCpp = radix_tree_cpp.RadixTree
TreeNodeCpp = object
IOHandle = object
#include "tree_v2.h"
#include <ATen/core/TensorBody.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/tensor.h>
#include <ATen/ops/zeros.h>
#include <c10/util/irange.h>
#include <cstddef>
#include <memory>
#include <queue>
#include <stdexcept>
#include <utility>
#include <vector>
#include "common.h"
#include "tree_v2_impl.h"
#include "tree_v2_node.h"
namespace radix_tree_v2 {
static NodeHandle node2id(TreeNode* node) {
return node->node_id;
}
// compare function for the TreeNode pointers based on their time
// we use LRU, so we want to evict the least recently used nodes
// since std::priority_queue is a max-heap, we need to reverse the comparison
static constexpr auto cmp = [](TreeNode* lhs, TreeNode* rhs) { return lhs->time() > rhs->time(); };
RadixTree::RadixTree(bool disabled, std::optional<std::size_t> host_size, std::size_t page_size, std::size_t threshold)
: m_impl(std::make_unique<Impl>(disabled, host_size.has_value(), page_size, host_size.value_or(0), threshold)) {}
RadixTree::~RadixTree() = default;
std::tuple<std::vector<at::Tensor>, std::size_t, NodeHandle, NodeHandle>
RadixTree::match_prefix(const token_vec_t& _key) {
if (m_impl->disabled) return {};
const auto key = token_slice{_key.data(), m_impl->align(_key.size())};
const auto [host_node, _] = m_impl->tree_walk(key);
// walk up to the first non-evicted node
std::size_t host_hit_length = 0;
const auto device_node = host_node;
// collect all the device indices
std::vector<at::Tensor> indices{};
walk_to_root(device_node, [&](TreeNode* n) { indices.push_back(n->device_indices()); });
std::reverse(indices.begin(), indices.end());
return {std::move(indices), host_hit_length, node2id(device_node), node2id(host_node)};
}
std::vector<at::Tensor> RadixTree::evict(std::size_t num_tokens) {
if (m_impl->disabled || num_tokens == 0) return {};
auto heap = std::priority_queue{cmp, m_impl->collect_leaves_device()};
std::vector<at::Tensor> evicted_values;
// evict nodes until we reach the desired number of tokens
std::size_t num_evict = 0;
while (num_evict < num_tokens && !heap.empty()) {
const auto node = heap.top();
heap.pop();
// when ref_count == 0, can't be writing through
_assert(node->on_gpu() && node->ref_count == 0);
if (!node->is_io_free()) continue; // skip nodes that are undergoing IO (i.e. indices protected)
evicted_values.push_back(node->device_indices());
num_evict += node->length();
const auto parent = node->parent();
m_impl->remove_device_node(node);
if (parent->is_leaf_device() && parent->ref_count == 0)
heap.push(parent); // push parent to the heap if it is now a free leaf
}
return evicted_values;
}
std::tuple<std::vector<std::tuple<IOTicket, at::Tensor, at::Tensor>>, std::size_t>
RadixTree::writing_through(const token_vec_t& _key, at::Tensor value) {
if (m_impl->disabled) return {};
_assert(_key.size() == std::size_t(value.size(0)), "Key and value must have the same size");
// just align the key to the page size, clip the unaligned tail
const auto key = token_slice{_key.data(), m_impl->align(_key.size())};
// walk the tree to find the right place to insert
const auto [host_node, host_prefix_length] = m_impl->tree_walk(key);
// insert and create a new node if the remaining part of the key is not empty
if (host_prefix_length != key.size()) {
m_impl->create_device_node(
host_node,
{key.begin() + host_prefix_length, key.end()},
value.slice(/*dim=*/0, host_prefix_length, key.size()));
}
// add the hit count for the device node
walk_to_root(host_node, [&](TreeNode* n) { n->hit_count++; });
std::vector<std::tuple<IOTicket, at::Tensor, at::Tensor>> result;
// don't write through if hicache is disabled (no host memory), fast path
if (!m_impl->use_hicache) return {std::move(result), host_prefix_length};
throw std::runtime_error("Not implemented yet");
}
std::tuple<IOTicket, std::vector<at::Tensor>> RadixTree::loading_onboard(NodeHandle, at::Tensor) {
if (m_impl->disabled) return {};
throw std::runtime_error("Not implemented yet");
}
void RadixTree::commit_writing_through(IOTicket, bool) {
if (m_impl->disabled) return;
throw std::runtime_error("Not implemented yet");
}
void RadixTree::commit_loading_onboard(IOTicket, bool) {
if (m_impl->disabled) return;
throw std::runtime_error("Not implemented yet");
}
void RadixTree::reset() {
m_impl->reset();
}
void RadixTree::lock_ref(NodeHandle node_id, bool increment) {
if (m_impl->disabled) return;
m_impl->lock_ref(node_id, increment);
}
std::size_t RadixTree::evictable_size() const {
return m_impl->evictable_size();
}
std::size_t RadixTree::protected_size() const {
return m_impl->protected_size();
}
std::size_t RadixTree::total_size() const {
return m_impl->total_size();
}
} // namespace radix_tree_v2
#pragma once
#include <ATen/core/TensorBody.h>
#include <c10/core/Device.h>
#include <cstddef>
#include <memory>
#include <optional>
#include <tuple>
#include <vector>
#include "common.h"
namespace radix_tree_v2 {
struct RadixTree {
public:
RadixTree(bool disabled, std::optional<std::size_t> host_size, std::size_t page_size, std::size_t threshold);
~RadixTree();
// Trees should not be copied or moved, as they manage their own memory and state.
RadixTree(const RadixTree&) = delete;
RadixTree(RadixTree&&) = delete;
RadixTree& operator=(const RadixTree&) = delete;
RadixTree& operator=(RadixTree&&) = delete;
/// @return (device indices that are matched, host indices length, device node, host node)
std::tuple<std::vector<at::Tensor>, std::size_t, NodeHandle, NodeHandle> match_prefix(const token_vec_t& key);
/// @return Device indices that need to be evicted (on python side).
std::vector<at::Tensor> evict(std::size_t num_tokens);
/// @brief (Un-)Lock a node.
void lock_ref(NodeHandle node_id, bool increment /* increment or decrement */);
/// @brief Update new key-value pair and try to perform write-through.
std::tuple<std::vector<std::tuple<IOTicket, at::Tensor, at::Tensor>>, std::size_t>
writing_through(const token_vec_t& key, at::Tensor value);
/// @brief Load to device from host within a range of nodes.
std::tuple<IOTicket, std::vector<at::Tensor>> loading_onboard(NodeHandle host_id, at::Tensor indices);
/// @brief Commit a transaction of write-through.
void commit_writing_through(IOTicket ticket, bool success);
/// @brief Commit a transaction of load onboard.
void commit_loading_onboard(IOTicket ticket, bool success);
/// @brief Clear and reset the tree.
void reset();
/// @return How many size are still evictable (on device + not locked).
std::size_t evictable_size() const;
/// @return How many size are protected (locked).
std::size_t protected_size() const;
/// @return How many size are used on device.
std::size_t total_size() const;
/// @brief Print debug information of the tree.
void debug_print() const;
private:
struct Impl;
std::unique_ptr<Impl> m_impl;
};
} // namespace radix_tree_v2
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
#include <cstddef>
#include <optional>
#include "tree_v2.h"
PYBIND11_MODULE(radix_tree_cpp, m) {
using namespace radix_tree_v2;
namespace py = pybind11;
py::class_<RadixTree>(m, "RadixTree")
.def(
py::init<bool, std::optional<std::size_t>, std::size_t, std::size_t>(),
py::arg("disabled"),
py::arg("host_size"),
py::arg("page_size"),
py::arg("write_through_threshold"))
.def("match_prefix", &RadixTree::match_prefix)
.def("evict", &RadixTree::evict)
.def("lock_ref", &RadixTree::lock_ref)
.def("evictable_size", &RadixTree::evictable_size)
.def("protected_size", &RadixTree::protected_size)
.def("total_size", &RadixTree::total_size)
.def("writing_through", &RadixTree::writing_through)
.def("loading_onboard", &RadixTree::loading_onboard)
.def("commit_writing_through", &RadixTree::commit_writing_through)
.def("commit_loading_onboard", &RadixTree::commit_loading_onboard)
.def("reset", &RadixTree::reset)
.def("debug_print", &RadixTree::debug_print);
}
#include <c10/core/DeviceType.h>
#include <c10/core/MemoryFormat.h>
#include <c10/core/ScalarType.h>
#include <cstddef>
#include <cstdlib>
#include <iostream>
#include <stdexcept>
#include <string>
#include "tree_v2.h"
#include "tree_v2_impl.h"
namespace radix_tree_v2 {
void RadixTree::debug_print() const {
m_impl->debug_print(std::clog);
}
static constexpr auto npos = std::size_t(-1);
void RadixTree::Impl::debug_print(std::ostream& os) const {
static constexpr auto _check = [](bool condition, auto msg, std::size_t id = npos) {
if (!condition) {
std::string suffix = id == npos ? "" : " [id = " + std::to_string(id) + "]";
throw std::runtime_error(std::string("RadixTree::debug_print failed: ") + msg + suffix);
}
};
static constexpr auto _print_node = [](TreeNode* node, std::size_t depth, std::ostream& os) {
const auto length = node->length();
os << node->node_id << " [depth = " << depth << "] [len = " << length << "]";
// placement status
if (node->on_both()) {
os << " [cpu + gpu]";
} else if (node->on_gpu()) {
os << " [gpu]";
} else if (node->on_cpu()) {
os << " [cpu]";
} else {
_check(false, "Node is not on GPU or CPU", node->node_id);
}
// IO status
if (node->is_io_free()) {
os << " [io = free]";
} else if (node->is_io_device_to_host()) {
os << " [io = gpu -> cpu]";
} else if (node->is_io_host_to_device()) {
os << " [io = cpu -> gpu]";
} else {
_check(false, "Node is in unknown IO state", node->node_id);
}
os << " [rc = " << node->ref_count << "]";
os << " [hit = " << node->hit_count << "]";
};
static constexpr auto _print_indices = [](at::Tensor indices, std::ostream& os) {
if (!indices.defined()) {
os << "[[N/A]]";
return indices;
}
indices = indices.to(c10::kCPU, c10::kLong, false, false, c10::MemoryFormat::Contiguous);
const auto length = indices.numel();
os << "[";
auto* data_ptr = indices.data_ptr<int64_t>();
for (const auto i : c10::irange(indices.size(0))) {
os << data_ptr[i];
if (i != length - 1) os << ", ";
}
os << "]";
return indices;
};
os << "Evictable size: " << evictable_size() << std::endl;
os << "Protected size: " << protected_size() << std::endl;
os << "Total size: " << const_cast<Impl*>(this)->total_size() << std::endl;
std::vector<std::tuple<TreeNode*, TreeNode*, token_slice>> stack;
auto root = const_cast<TreeNode*>(&m_root);
os << root->node_id << " [root]" << std::endl;
for (const auto& [key, child] : *root) {
stack.push_back({child.get(), root, key});
}
std::unordered_map<TreeNode*, std::size_t> depth_map;
std::string indent_buffer;
depth_map[root] = 0;
std::vector<NodeHandle> visited_id;
std::size_t evictable_size_real = 0;
while (!stack.empty()) {
const auto [node, parent, key] = stack.back();
stack.pop_back();
visited_id.push_back(node->node_id);
const auto nid = node->node_id;
_check(node != nullptr, "Node is null", nid);
_check(node->on_gpu() || node->on_cpu(), "Node is not on GPU or CPU", nid);
_check(node->parent() == parent, "Parent is not correct", nid);
_check(key.size() == page_size && node->diff_key(key, 0) == page_size, "Key is not correct", nid);
_check(depth_map.count(node) == 0, "Node is visited twice", nid);
_check(m_node_map.count(nid) == 1, "Node is not in the map", nid);
_check(m_node_map.at(nid) == node, "Node in the map is not the same as the one in the stack", nid);
_check(!node->on_gpu() || parent->is_root() || parent->on_gpu(), "Node on GPU must have a GPU/root parent", nid);
if (!node->is_io_free()) {
_check(node->ref_count > 0, "Node is in IO state but not protected", nid);
_check(node->on_both(), "Node in IO state must be on both CPU and GPU", nid);
}
if (node->on_gpu() && node->ref_count == 0) {
evictable_size_real += node->length();
}
const auto depth = (depth_map[node] = depth_map[parent] + 1);
indent_buffer.resize(depth * 2, ' ');
os << indent_buffer;
_print_node(node, depth, os);
os << std::endl;
for (const auto& [key, child] : *node) {
stack.push_back({child.get(), node, key});
}
}
_check(evictable_size_real == evictable_size(), "Evictable size is wrong");
_check(m_node_map.count(root->node_id) == 1, "Root node is not in the map");
_check(m_node_map.at(root->node_id) == root, "Root node in the map is not correct");
std::sort(visited_id.begin(), visited_id.end());
if (visited_id.size() != m_node_map.size() - 1) {
// Some error in the tree, not all nodes are visited
std::string id_list;
id_list += "(visited: ";
id_list += std::to_string(root->node_id) + " ";
for (const auto& id : visited_id) {
id_list += std::to_string(id) + " ";
}
id_list += "), (in map: ";
for (const auto& [id, _] : m_node_map) {
id_list += std::to_string(id) + " ";
}
id_list += ")";
_check(false, "Not all nodes are visited " + id_list);
}
static const auto kSGLANG_RADIX_CPP_DEBUG_LIMIT = [] {
const char* env = std::getenv("SGLANG_RADIX_CPP_DEBUG_LIMIT");
const std::size_t default_limit = 16;
if (env != nullptr) {
try {
return static_cast<std::size_t>(std::stoull(env));
} catch (const std::exception& e) {
std::cerr << "Invalid SGLANG_RADIX_CPP_DEBUG_LIMIT value: " << env //
<< ". Using default value =" << default_limit << std::endl;
}
}
return default_limit;
}();
for (const auto nid : visited_id) {
const auto node = m_node_map.at(nid);
// print key and indices
const auto& key = node->_unsafe_tokens();
if (key.size() > kSGLANG_RADIX_CPP_DEBUG_LIMIT) {
os << "Node " << nid << ": key is too long (" << key.size() << " tokens), skipping..." << std::endl;
continue;
}
os << "Node " << nid << ": key = [";
for (const auto& i : c10::irange(key.size())) {
os << key[i];
if (i != key.size() - 1) os << ", ";
}
_check(key.size() % page_size == 0, "Misaligned key", nid);
os << "] device_indices = ";
const auto device_indices = _print_indices(node->device_indices(), os);
if (device_indices.defined()) {
std::size_t length = device_indices.numel();
_check(device_indices.dim() == 1, "Device indices must be 1D tensor", nid);
_check(length == node->length(), "Wrong device indices size", nid);
}
os << " host_indices = ";
const auto host_indices = _print_indices(node->host_indices(), os);
if (host_indices.defined()) {
std::size_t length = host_indices.numel();
_check(host_indices.dim() == 1, "Host indices must be 1D tensor", nid);
_check(length == node->length(), "Wrong host indices size", nid);
}
os << std::endl;
}
}
} // namespace radix_tree_v2
#pragma once
#include <c10/util/irange.h>
#include <chrono>
#include <cstddef>
#include <iosfwd>
#include <memory>
#include <unordered_map>
#include <utility>
#include <vector>
#include "common.h"
#include "tree_v2.h"
#include "tree_v2_node.h"
namespace radix_tree_v2 {
using node_iterator_t = typename TreeNode::iterator_t;
struct RadixTree::Impl {
public:
Impl(bool disabled, bool use_hicache, std::size_t page_size, std::size_t host_size, std::size_t threshold)
: m_root(/*node_id_=*/0),
m_evictable_size(0),
m_protected_size(0),
m_cached_vec(),
m_node_map(),
m_node_counter(1), // start from 1 to avoid confusion with root node
disabled(disabled),
use_hicache(use_hicache),
page_size(page_size),
threshold(threshold) {
_assert(page_size > 0, "Page size must be greater than zero");
_assert(use_hicache == (host_size > 0), "Hierarchical cache is enabled iff host size > 0");
m_root.ref_count = 1; // root node is always protected
m_cached_vec.reserve(page_size); // to avoid repeated allocations
m_node_map[m_root.node_id] = &m_root; // add root to the map
}
TreeNode* split_node(node_iterator_t iterator, std::size_t prefix_length) {
// from `parent -> old_node` to `parent-> new_node -> old_node`
// the prefix part of the old node is moved to the new node
auto old_node_ptr = std::move(iterator->second);
auto new_node_ptr = std::make_unique<TreeNode>(m_node_counter++);
auto* old_node = old_node_ptr.get();
auto* new_node = new_node_ptr.get();
auto* parent = old_node->parent();
// set up data structures
split_prefix(new_node, old_node, prefix_length);
// set up parent-child relationship
add_child(new_node, std::move(old_node_ptr));
add_child(parent, std::move(new_node_ptr), iterator);
m_node_map[new_node->node_id] = new_node; // add to the map
return new_node;
}
// node: x -> [GPU]
TreeNode* create_device_node(TreeNode* parent, token_vec_t vec, at::Tensor indices) {
auto new_node_ptr = std::make_unique<TreeNode>(m_node_counter++);
auto new_node = new_node_ptr.get();
new_node_ptr->_unsafe_tokens() = std::move(vec);
new_node_ptr->_unsafe_device_indices() = std::move(indices);
m_evictable_size += new_node_ptr->length();
add_child(parent, std::move(new_node_ptr));
m_node_map[new_node->node_id] = new_node; // add to the map
return new_node;
}
// node: [GPU] -> x
void remove_device_node(TreeNode* node) {
_assert(node->on_gpu_only() && node->ref_count == 0);
m_evictable_size -= node->length();
node->parent()->erase_child(get_key(node));
m_node_map.erase(node->node_id); // remove from the map
}
/**
* @brief Walk the tree to find the node that matches the key.
* If the key partially matches a node, it will split that node.
* @return A pair containing the last node that matches the key and
* the total prefix length matched (on gpu and cpu) so far.
*/
std::pair<TreeNode*, std::size_t> tree_walk(token_slice key) {
_assert(key.size() % page_size == 0, "Key should be page-aligned");
std::size_t total_prefix_length = 0;
TreeNode* node = &m_root;
const auto now = std::chrono::steady_clock::now();
while (key.size() > 0) {
const auto iterator = node->find_child(get_key(key));
if (iterator == node->end()) break;
// walk to the child node
node = iterator->second.get();
// at least `page_size` tokens are matched, and there may be more tokens to match
// the return value prefix_length is no less than `page_size`
const auto prefix_length = align(node->diff_key(key, page_size) + page_size);
total_prefix_length += prefix_length;
// split the node if the prefix is not the whole token vector
if (prefix_length < node->length()) {
return {split_node(iterator, prefix_length), total_prefix_length};
}
// we have matched the whole key, continue to the next node
node->access(now);
key = key.subspan(prefix_length);
}
return {node, total_prefix_length};
}
std::vector<TreeNode*> collect_leaves() const {
std::vector<TreeNode*> leaves;
std::vector<TreeNode*> stack = {};
for (const auto& [_, child] : m_root) {
stack.push_back(child.get());
}
while (!stack.empty()) {
const auto node = stack.back();
stack.pop_back();
if (node->is_leaf()) {
if (node->ref_count == 0) {
leaves.push_back(node);
}
} else {
for (const auto& [_, child] : *node) {
stack.push_back(child.get());
}
}
}
return leaves;
}
std::vector<TreeNode*> collect_leaves_device() const {
// for non-hicache, every leaf device node is a leaf node (since no backup on host)
if (!use_hicache) return collect_leaves();
std::vector<TreeNode*> leaves;
std::vector<TreeNode*> stack = {};
for (const auto& [_, child] : m_root) {
stack.push_back(child.get());
}
while (!stack.empty()) {
const auto node = stack.back();
stack.pop_back();
if (!node->on_gpu()) continue; // skip nodes that are not on GPU
if (node->is_leaf_device()) {
if (node->ref_count == 0) {
leaves.push_back(node);
}
} else {
for (const auto& [_, child] : *node) {
stack.push_back(child.get());
}
}
}
return leaves;
}
void lock_ref(TreeNode* node, bool increment) {
if (node->is_root()) return; // skip root node
_assert(node->on_gpu(), "Cannot lock reference on an evicted node");
if (increment)
walk_to_root(node, [this](TreeNode* n) {
if (n->ref_count == 0) {
m_evictable_size -= n->length();
m_protected_size += n->length();
}
n->ref_count++;
});
else
walk_to_root(node, [this](TreeNode* n) {
_assert(n->ref_count != 0, "Cannot decrement reference count = zero");
n->ref_count--;
if (n->ref_count == 0) {
m_protected_size -= n->length();
m_evictable_size += n->length();
}
});
}
void lock_ref(NodeHandle node_ptr, bool increment) {
return lock_ref(id2node(node_ptr), increment);
}
void lock(TreeNode* node) {
return lock_ref(node, /*increment=*/true);
}
void unlock(TreeNode* node) {
return lock_ref(node, /*increment=*/false);
}
std::size_t total_size() const {
std::size_t size = 0;
std::vector<const TreeNode*> stack = {&m_root};
while (!stack.empty()) {
auto* node = stack.back();
stack.pop_back();
size += node->length();
for (const auto& [_, child] : *node)
stack.push_back(child.get());
}
return size;
}
std::size_t evictable_size() const {
return m_evictable_size;
}
std::size_t protected_size() const {
return m_protected_size;
}
std::size_t align(std::size_t size) const {
return (size / page_size) * page_size; // align to page size
}
TreeNode* id2node(NodeHandle node_id) const {
const auto iterator = m_node_map.find(node_id);
_assert(iterator != m_node_map.end(), "Node not found in the map");
return iterator->second;
}
void reset() {
_assert(m_root.ref_count == 1, "Root node must be protected during reset");
m_node_counter = 1; // reset node counter
m_root.root_reset();
m_evictable_size = 0;
m_protected_size = 0;
m_node_map.clear();
m_node_map[m_root.node_id] = &m_root; // re-add root to the map
}
void debug_print(std::ostream& os) const;
private:
// some auxiliary functions
token_vec_t& get_key(token_slice tokens) {
_assert(tokens.size() >= page_size, "Key should be at least page-sized");
tokens = tokens.subspan(0, page_size);
m_cached_vec.assign(tokens.begin(), tokens.end());
return m_cached_vec;
}
// justify for _unsafe call: we need to read the key part of the tokens
token_vec_t& get_key(TreeNode* node) {
return get_key(node->_unsafe_tokens());
}
void add_child(TreeNode* parent, std::unique_ptr<TreeNode>&& child) {
parent->add_child(get_key(child.get()), std::move(child));
}
void add_child(TreeNode* parent, std::unique_ptr<TreeNode>&& child, node_iterator_t it) {
parent->add_child(it, std::move(child));
}
TreeNode m_root; // root node of the tree
std::size_t m_evictable_size; // number of evictable tokens on GPU (lock ref = 0)
std::size_t m_protected_size; // number of protected tokens on GPU (lock ref > 0)
token_vec_t m_cached_vec; // cached vector of tokens for the current operation
std::unordered_map<std::size_t, TreeNode*> m_node_map; // map of node keys to nodes
std::size_t m_node_counter; // counter for node IDs
public:
// some public constant configurations (without m_ prefix)
const bool disabled; // whether the cache is enabled, or just a temporary cache
const bool use_hicache; // whether to use the HiCache for this tree
const std::size_t page_size; // size of each page in the cache
const std::size_t threshold; // threshold for write_through
};
} // namespace radix_tree_v2
#pragma once
#include <ATen/core/TensorBody.h>
#include <algorithm>
#include <array>
#include <chrono>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <optional>
#include <ranges>
#include <unordered_map>
#include "common.h"
namespace radix_tree_v2 {
struct std_vector_hash {
// see https://stackoverflow.com/questions/20511347/a-good-hash-function-for-a-vector
std::size_t operator()(const token_vec_t& vec) const {
std::size_t hash = 0;
for (const auto& token : vec) {
hash ^= token + 0x9e3779b9 + (hash << 6) + (hash >> 2);
}
return hash;
}
};
struct TreeNode {
public:
using childern_map_t = std::unordered_map<token_vec_t, std::unique_ptr<TreeNode>, std_vector_hash>;
using iterator_t = typename childern_map_t::iterator;
using const_iterator_t = typename childern_map_t::const_iterator;
using timestamp_t = std::chrono::steady_clock::time_point;
TreeNode(std::size_t node_id_)
: ref_count(0),
hit_count(0),
m_io_locked(std::nullopt),
m_io_status(IOStatus::None),
m_io_ticket(),
m_tokens(),
m_device_indices(),
m_host_indices(),
m_parent(),
m_children(),
m_last_access_time(std::chrono::steady_clock::now()),
node_id(node_id_) {}
void access(timestamp_t time = std::chrono::steady_clock::now()) {
m_last_access_time = time;
}
bool is_root() const {
return m_parent == nullptr;
}
timestamp_t time() const {
return m_last_access_time;
}
bool on_gpu() const {
return m_device_indices.defined();
}
bool on_cpu() const {
return m_host_indices.defined();
}
bool on_gpu_only() const {
return on_gpu() && !on_cpu();
}
bool on_cpu_only() const {
return !on_gpu() && on_cpu();
}
bool on_both() const {
return on_gpu() && on_cpu();
}
std::size_t length() const {
return m_tokens.size();
}
bool is_leaf() const {
return m_children.empty();
}
bool is_leaf_device() const {
for (const auto& [_, child] : m_children)
if (child->on_gpu()) return false; // at least one child is on the device
return true;
}
void add_child(const token_vec_t& v, std::unique_ptr<TreeNode>&& child) {
child->m_parent = this;
m_children[v] = std::move(child);
}
void add_child(iterator_t it, std::unique_ptr<TreeNode>&& child) {
child->m_parent = this;
it->second = std::move(child);
}
void erase_child(const token_vec_t& v) {
_assert(m_children.erase(v) > 0, "Child node not found");
}
iterator_t find_child(const token_vec_t& v) {
return m_children.find(v);
}
iterator_t begin() {
return m_children.begin();
}
iterator_t end() {
return m_children.end();
}
const_iterator_t begin() const {
return m_children.begin();
}
const_iterator_t end() const {
return m_children.end();
}
TreeNode* parent() {
return m_parent;
}
// set up all data structures except for parent-child relationship
friend void split_prefix(TreeNode* new_node, TreeNode* old_node, std::size_t prefix_length) {
auto tokens = std::move(old_node->m_tokens);
_assert(0 < prefix_length && prefix_length < tokens.size(), "Invalid prefix size for split");
// set up tokens
old_node->m_tokens = token_vec_t(tokens.begin() + prefix_length, tokens.end());
new_node->m_tokens = std::move(tokens);
new_node->m_tokens.resize(prefix_length);
// set up values
const int64_t new_size = new_node->length();
const int64_t old_size = old_node->length();
if (old_node->m_device_indices.defined()) {
auto new_indices = old_node->m_device_indices.split_with_sizes({new_size, old_size});
new_node->m_device_indices = std::move(new_indices[0]);
old_node->m_device_indices = std::move(new_indices[1]);
}
if (old_node->m_host_indices.defined()) {
auto new_indices = old_node->m_host_indices.split_with_sizes({new_size, old_size});
new_node->m_host_indices = std::move(new_indices[0]);
old_node->m_host_indices = std::move(new_indices[1]);
}
// set up ref counts and hit counts
new_node->ref_count = old_node->ref_count;
new_node->hit_count = old_node->hit_count;
// If the old node (child) was locked for IO, the new node (parent) does not need
// to be locked, since it is naturally protected by the child node's lock.
if (old_node->m_io_locked.has_value()) {
new_node->m_io_locked = false;
new_node->m_io_status = old_node->m_io_status;
new_node->m_io_ticket = old_node->m_io_ticket;
}
}
/// @return The first index in `m_tokens` that differs from `key`.
std::size_t diff_key(token_slice key, std::size_t offset) const {
const auto a = token_slice{key}.subspan(offset);
const auto b = token_slice{m_tokens}.subspan(offset);
const auto [it_a, it_b] = std::ranges::mismatch(a, b);
return it_a - a.begin(); // return the index of the first differing token
}
at::Tensor device_indices() const {
return m_device_indices;
}
at::Tensor host_indices() const {
return m_host_indices;
}
// visiting tokens are always unsafe (use `diff_key` instead)
token_vec_t& _unsafe_tokens() {
return m_tokens;
}
at::Tensor& _unsafe_device_indices() {
return m_device_indices;
}
at::Tensor& _unsafe_host_indices() {
return m_host_indices;
}
bool is_io_free() const {
return m_io_status == IOStatus::None;
}
bool is_io_device_to_host() const {
return m_io_status == IOStatus::DeviceToHost;
}
bool is_io_host_to_device() const {
return m_io_status == IOStatus::HostToDevice;
}
void root_reset() {
_assert(is_root(), "Only root node can call root_reset");
_assert(
m_io_status == IOStatus::None && m_io_locked == std::nullopt,
"IO operation in progress, cannot reset root node");
_assert(this->m_tokens.empty(), "Root node tokens should be empty on reset");
_assert(
!this->m_device_indices.defined() && !this->m_host_indices.defined(),
"Root node indices should be always be empty and never assigned");
m_children.clear();
this->access();
}
public:
std::size_t ref_count;
std::size_t hit_count;
private:
enum class IOStatus : std::uint8_t {
None,
HostToDevice,
DeviceToHost,
};
std::optional<bool> m_io_locked; // whether the node is locked in IO operation
IOStatus m_io_status;
IOTicket m_io_ticket;
token_vec_t m_tokens;
at::Tensor m_device_indices; // indices of device value
at::Tensor m_host_indices; // indices of host value
TreeNode* m_parent;
childern_map_t m_children;
timestamp_t m_last_access_time;
public:
const std::size_t node_id; // unique ID for the node
};
template <typename F>
inline TreeNode* walk_to_root(TreeNode* t, const F& f) {
while (!t->is_root()) {
f(t);
t = t->parent();
}
return t; // return the root node
}
} // namespace radix_tree_v2
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, List, Set
import torch
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
from sglang.srt.mem_cache.cpp_radix_tree.radix_tree import (
IOHandle,
RadixTreeCpp,
TreeNodeCpp,
)
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
logger = logging.getLogger(__name__)
class RadixCacheCpp(BasePrefixCache):
def _merge_tensor(self, l: List[torch.Tensor]) -> torch.Tensor:
"""
Merge a list of tensors into a single tensor.
Args:
l (List[torch.Tensor]): List of tensors to merge.
Returns:
torch.Tensor: Merged tensor.
"""
if len(l) == 0:
return torch.empty(0, dtype=torch.int64, device=self.device)
elif len(l) == 1:
return l[0]
else:
return torch.cat(l)
def __init__(
self,
disable: bool,
use_hicache: bool,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool: BaseTokenToKVPoolAllocator,
tp_cache_group: torch.distributed.ProcessGroup,
page_size: int,
hicache_ratio: float,
hicache_size: int,
hicache_write_policy: str,
enable_kv_cache_events: bool = False,
hicache_oracle: bool = False,
enable_write_cancel: bool = False,
):
self.disable = disable
self.enable_write_cancel = enable_write_cancel
assert (
enable_kv_cache_events is False
), "HiRadixCache does not support kv cache events yet"
self.kv_cache = token_to_kv_pool.get_kvcache()
# record the nodes with ongoing write through
self.ongoing_write_through: Set[IOHandle] = set()
# record the node segments with ongoing load back
self.ongoing_load_back: Set[IOHandle] = set()
# todo: dynamically adjust the threshold
self.write_through_threshold = (
1 if hicache_write_policy == "write_through" else 2
)
self.device = token_to_kv_pool.device
self.token_to_kv_pool = token_to_kv_pool
self.req_to_token_pool = req_to_token_pool
self.page_size = page_size
self.tp_group = tp_cache_group
if not use_hicache:
self.tree = RadixTreeCpp(
disabled=self.disable,
page_size=page_size,
host_size=None, # no host cache, this should be removed in the future
write_through_threshold=self.write_through_threshold,
)
self.cache_controller = None
return # early return if hicache is not used
raise NotImplementedError("Host cache is not supported yet")
def reset(self):
if self.cache_controller is not None:
# need to clear the acks before resetting the cache controller
raise NotImplementedError("Host cache is not supported yet")
self.tree.reset()
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
device_indices_vec, host_indices_length, node_gpu, node_cpu = (
self.tree.match_prefix(key)
)
return MatchResult(
device_indices=self._merge_tensor(device_indices_vec),
last_device_node=node_gpu,
last_host_node=node_cpu,
host_hit_length=host_indices_length,
)
def _insert(self, key: List[int], value: torch.Tensor) -> int:
"""
Insert a key-value pair into the radix tree.
Args:
key (List[int]): The key to insert, represented as a list of integers.
value (torch.Tensor): The value to associate with the key.
Returns:
int: Number of device indices that were already present in the tree before the insertion.
"""
ongoing_write, length = self.tree.writing_through(key, value)
if self.cache_controller is None:
assert len(ongoing_write) == 0, "Implementation error"
return length
raise NotImplementedError("Host cache is not supported yet")
def dec_lock_ref(self, node: TreeNodeCpp):
"""
Decrement the reference count of a node to root of the radix tree.
Args:
node (TreeNodeCpp): The handle of the node to decrement the reference count for.
"""
self.tree.lock_ref(node, False) # do not increment
def inc_lock_ref(self, node: TreeNodeCpp):
"""
Increment the reference count of from a node to root of the radix tree.
Args:
node (TreeNodeCpp): The handle of the node to increment the reference count for.
"""
self.tree.lock_ref(node, True)
def evict(self, num_tokens: int):
evicted_device_indices = self.tree.evict(num_tokens)
for indice in evicted_device_indices:
self.token_to_kv_pool.free(indice)
def evictable_size(self):
return self.tree.evictable_size()
def protected_size(self):
return self.tree.protected_size()
def total_size(self):
return self.tree.total_size()
def cache_finished_req(self, req: Req):
"""Cache request when it finishes."""
assert req.req_pool_idx is not None
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
overall_len = len(token_ids) # prefill + decode
kv_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, :overall_len]
# NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
# it will automatically align them, but length of them should be equal
old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
new_prefix_len = self._insert(token_ids, kv_indices)
# NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
# KVCache between old & new is newly generated, but already exists in the pool
# we need to free this newly generated kv indices
if old_prefix_len < new_prefix_len:
self.token_to_kv_pool.free(kv_indices[old_prefix_len:new_prefix_len])
# need to free the unaligned part, since it cannot be inserted into the radix tree
if self.page_size != 1 and ( # unaligned tail only exists when page_size > 1
(unaligned_len := overall_len % self.page_size) > 0
):
# NOTE: sglang PagedAllocator support unaligned free (which will automatically align it)
self.token_to_kv_pool.free(kv_indices[overall_len - unaligned_len :])
# Remove req slot release the cache lock
self.dec_lock_ref(req.last_node)
self.req_to_token_pool.free(req.req_pool_idx)
def cache_unfinished_req(self, req: Req):
"""Cache request when it is unfinished."""
assert req.req_pool_idx is not None
token_ids = req.fill_ids
prefill_len = len(token_ids) # prefill only (maybe chunked)
kv_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, :prefill_len]
# NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
# it will automatically align them, but length of them should be equal
old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
new_prefix_len = self._insert(token_ids, kv_indices)
# NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
# TODO(dark): optimize the `insert` and `match` (e.g. merge into 1 function)
# The prefix indices need to updated to reuse the kv indices in the pool
new_indices_vec, _, new_last_node, _ = self.tree.match_prefix(token_ids)
new_indices = self._merge_tensor(new_indices_vec)
assert new_prefix_len <= len(new_indices)
# KVCache between old & new is newly generated, but already exists in the pool
# we need to free this newly generated kv indices and reuse the indices in the pool
if old_prefix_len < new_prefix_len:
self.token_to_kv_pool.free(kv_indices[old_prefix_len:new_prefix_len])
reused_indices = new_indices[old_prefix_len:new_prefix_len]
self.req_to_token_pool.req_to_token[
req.req_pool_idx, old_prefix_len:new_prefix_len
] = reused_indices
if req.last_node != new_last_node:
self.dec_lock_ref(req.last_node)
self.inc_lock_ref(new_last_node)
# NOTE: there might be unaligned tail, so we may need to append it
assert len(new_indices) <= prefill_len < len(new_indices) + self.page_size
if self.page_size != 1 and len(new_indices) < prefill_len:
req.prefix_indices = torch.cat(
[new_indices, kv_indices[len(new_indices) :]]
)
else:
req.prefix_indices = new_indices
req.last_node = new_last_node
def pretty_print(self):
return self.tree.debug_print()
import os
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestCppRadixCache(CustomTestCase):
@classmethod
def setUpClass(cls):
os.environ["SGLANG_EXPERIMENTAL_CPP_RADIX_TREE"] = "1"
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
print(metrics)
self.assertGreaterEqual(metrics["score"], 0.65)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment