Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e273aa6d
Unverified
Commit
e273aa6d
authored
Aug 02, 2025
by
DarkSharpness
Committed by
GitHub
Aug 02, 2025
Browse files
[Feature] Radix Tree in C++ (#7369)
parent
828a4fe9
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
1466 additions
and
1 deletion
+1466
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+17
-1
python/sglang/srt/mem_cache/cpp_radix_tree/.clang-format
python/sglang/srt/mem_cache/cpp_radix_tree/.clang-format
+1
-0
python/sglang/srt/mem_cache/cpp_radix_tree/common.h
python/sglang/srt/mem_cache/cpp_radix_tree/common.h
+29
-0
python/sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py
python/sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py
+182
-0
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2.cpp
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2.cpp
+143
-0
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2.h
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2.h
+59
-0
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_binding.cpp
...n/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_binding.cpp
+32
-0
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_debug.cpp
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_debug.cpp
+194
-0
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_impl.h
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_impl.h
+276
-0
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_node.h
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_node.h
+257
-0
python/sglang/srt/mem_cache/radix_cache_cpp.py
python/sglang/srt/mem_cache/radix_cache_cpp.py
+229
-0
test/srt/test_cpp_radix_cache.py
test/srt/test_cpp_radix_cache.py
+47
-0
No files found.
python/sglang/srt/managers/scheduler.py
View file @
e273aa6d
...
...
@@ -569,7 +569,23 @@ class Scheduler(
page_size
=
self
.
page_size
,
)
else
:
if
self
.
enable_hierarchical_cache
:
if
os
.
environ
.
get
(
"SGLANG_EXPERIMENTAL_CPP_RADIX_TREE"
)
==
"1"
:
# lazy import to avoid JIT overhead
from
sglang.srt.mem_cache.radix_cache_cpp
import
RadixCacheCpp
self
.
tree_cache
=
RadixCacheCpp
(
disable
=
False
,
use_hicache
=
self
.
enable_hierarchical_cache
,
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
token_to_kv_pool_allocator
,
tp_cache_group
=
self
.
tp_cpu_group
,
page_size
=
self
.
page_size
,
hicache_ratio
=
server_args
.
hicache_ratio
,
hicache_size
=
server_args
.
hicache_size
,
hicache_write_policy
=
server_args
.
hicache_write_policy
,
enable_kv_cache_events
=
self
.
enable_kv_cache_events
,
)
elif
self
.
enable_hierarchical_cache
:
self
.
tree_cache
=
HiRadixCache
(
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool_allocator
=
self
.
token_to_kv_pool_allocator
,
...
...
python/sglang/srt/mem_cache/cpp_radix_tree/.clang-format
0 → 120000
View file @
e273aa6d
../../../../../sgl-kernel/.clang-format
\ No newline at end of file
python/sglang/srt/mem_cache/cpp_radix_tree/common.h
0 → 100644
View file @
e273aa6d
#pragma once
#include <cstddef>
#include <cstdint>
#include <source_location>
#include <span>
#include <stdexcept>
#include <string>
#include <vector>
namespace
radix_tree_v2
{
using
token_t
=
std
::
int32_t
;
using
token_vec_t
=
std
::
vector
<
token_t
>
;
using
token_slice
=
std
::
span
<
const
token_t
>
;
using
NodeHandle
=
std
::
size_t
;
using
IOTicket
=
std
::
uint32_t
;
inline
void
_assert
(
bool
condition
,
const
char
*
message
=
"Assertion failed"
,
std
::
source_location
loc
=
std
::
source_location
::
current
())
{
if
(
!
condition
)
[[
unlikely
]]
{
std
::
string
msg
=
message
;
msg
=
msg
+
" at "
+
loc
.
file_name
()
+
":"
+
std
::
to_string
(
loc
.
line
())
+
" in "
+
loc
.
function_name
();
throw
std
::
runtime_error
(
msg
);
}
}
}
// namespace radix_tree_v2
python/sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py
0 → 100644
View file @
e273aa6d
from
__future__
import
annotations
import
os
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
import
torch
from
torch.utils.cpp_extension
import
load
_abs_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
radix_tree_cpp
=
load
(
name
=
"radix_tree_cpp"
,
sources
=
[
f
"
{
_abs_path
}
/tree_v2_binding.cpp"
,
f
"
{
_abs_path
}
/tree_v2_debug.cpp"
,
f
"
{
_abs_path
}
/tree_v2.cpp"
,
],
extra_cflags
=
[
"-O3"
,
"-std=c++20"
],
)
if
TYPE_CHECKING
:
class
TreeNodeCpp
:
"""
A placeholder for the TreeNode class. Cannot be constructed elsewhere.
"""
class
IOHandle
:
"""
A placeholder for the IOHandle class. Cannot be constructed elsewhere.
"""
class
RadixTreeCpp
:
def
__init__
(
self
,
disabled
:
bool
,
host_size
:
Optional
[
int
],
page_size
:
int
,
write_through_threshold
:
int
,
):
"""
Initializes the RadixTreeCpp instance.
Args:
disabled (bool): If True, the radix tree is disabled.
host_size (Optional[int]): Size of the radix tree on the CPU. None means no CPU tree.
page_size (int): Size of the page for the radix tree.
write_through_threshold (int): Threshold for writing through from GPU to CPU.
"""
self
.
tree
=
radix_tree_cpp
.
RadixTree
(
# type: ignore
disabled
,
host_size
,
page_size
,
write_through_threshold
)
def
match_prefix
(
self
,
prefix
:
List
[
int
]
)
->
Tuple
[
List
[
torch
.
Tensor
],
int
,
TreeNodeCpp
,
TreeNodeCpp
]:
"""
Matches a prefix in the radix tree.
Args:
prefix (List[int]): The prefix to match.
Returns:
Tuple[List[torch.Tensor], TreeNodeCpp, TreeNodeCpp]:
0. A list of indices that is matched by the prefix on the GPU.
1. Sum length of the indices matched on the CPU.
2. The last node of the prefix matched on the GPU.
3. The last node of the prefix matched on the CPU.
"""
return
self
.
tree
.
match_prefix
(
prefix
)
def
evict
(
self
,
num_tokens
:
int
)
->
List
[
torch
.
Tensor
]:
"""
Evicts a number of tokens from the radix tree.
Args:
num_tokens (int): The number of tokens to evict.
Returns:
List[torch.Tensor]: A list of indices that were evicted.
"""
return
self
.
tree
.
evict
(
num_tokens
)
def
lock_ref
(
self
,
handle
:
TreeNodeCpp
,
lock
:
bool
)
->
None
:
"""
Locks or unlocks a reference to a tree node.
After locking, the node will not be evicted from the radix tree.
Args:
handle (TreeNodeCpp): The tree node to lock or unlock.
lock (bool): If True, locks the node; if False, unlocks it.
"""
return
self
.
tree
.
lock_ref
(
handle
,
lock
)
def
writing_through
(
self
,
key
:
List
[
int
],
indices
:
torch
.
Tensor
)
->
Tuple
[
List
[
Tuple
[
IOHandle
,
torch
.
Tensor
,
torch
.
Tensor
]],
int
]:
"""
Inserts a key-value pair into the radix tree and perform write-through check.
Args:
key (List[int]): The key to insert.
indices (torch.Tensor): The value associated with the key.
Returns:
Tuple[List[Tuple[IOHandle, torch.Tensor, torch.Tensor]], int]:
0. A list of (IOHandle, device indices, host indices) tuples.
These IOhandles require write-through to the CPU in python side.
1. The number of indices that are matched on device.
"""
return
self
.
tree
.
writing_through
(
key
,
indices
)
def
loading_onboard
(
self
,
host_node
:
TreeNodeCpp
,
new_device_indices
:
torch
.
Tensor
,
)
->
Tuple
[
IOHandle
,
List
[
torch
.
Tensor
]]:
"""
Updates the device indices of tree nodes within a range on the tree.
Args:
host_node (TreeNodeCpp): The tree node on the host, must be descendant of device_node.
new_device_indices (torch.Tensor): The new device indices to set.
The length of this tensor must be exactly host indices length.
Returns:
Tuple[IOHandle, List[torch.Tensor]]:
0. An IOHandle that requires loading to the CPU in python side.
1. A list of host indices corresponding to the new device indices.
"""
return
self
.
tree
.
loading_onboard
(
host_node
,
new_device_indices
)
def
commit_writing_through
(
self
,
handle
:
IOHandle
,
success
:
bool
)
->
None
:
"""
Commits the write-through process for a tree node.
Args:
handle (IOHandle): The IOHandle to commit.
success (bool): If True, commits the write-through; if False, just indicates failure.
"""
return
self
.
tree
.
commit_writing_through
(
handle
,
success
)
def
commit_loading_onboard
(
self
,
handle
:
IOHandle
,
success
:
bool
)
->
None
:
"""
Commits the load onboard process for tree nodes within a range on the tree.
Args:
handle (IOHandle): The IOHandle to commit.
success (bool): If True, commits the load-onboard; if False, just indicates failure.
"""
return
self
.
tree
.
commit_loading_onboard
(
handle
,
success
)
def
evictable_size
(
self
)
->
int
:
"""
Returns the size of the evictable part of the radix tree.
This is the size of the part that can be evicted from the GPU (ref_count = 0).
Returns:
int: The size of the evictable part.
"""
return
self
.
tree
.
evictable_size
()
def
protected_size
(
self
)
->
int
:
"""
Returns the size of the protected part of the radix tree.
This is the size of the part that cannot be evicted from the GPU (ref_count > 0).
Returns:
int: The size of the protected part.
"""
return
self
.
tree
.
protected_size
()
def
total_size
(
self
)
->
int
:
"""
Returns the total size of the radix tree (including CPU nodes).
Returns:
int: The total size of the radix tree.
"""
return
self
.
tree
.
total_size
()
def
reset
(
self
)
->
None
:
"""
Resets the radix tree, clearing all nodes and indices.
"""
return
self
.
tree
.
reset
()
def
debug_print
(
self
)
->
None
:
"""
Prints the internal state of the radix tree for debugging purposes.
"""
return
self
.
tree
.
debug_print
()
else
:
# Real implementation of the classes for runtime
RadixTreeCpp
=
radix_tree_cpp
.
RadixTree
TreeNodeCpp
=
object
IOHandle
=
object
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2.cpp
0 → 100644
View file @
e273aa6d
#include "tree_v2.h"
#include <ATen/core/TensorBody.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/tensor.h>
#include <ATen/ops/zeros.h>
#include <c10/util/irange.h>
#include <cstddef>
#include <memory>
#include <queue>
#include <stdexcept>
#include <utility>
#include <vector>
#include "common.h"
#include "tree_v2_impl.h"
#include "tree_v2_node.h"
namespace
radix_tree_v2
{
static
NodeHandle
node2id
(
TreeNode
*
node
)
{
return
node
->
node_id
;
}
// compare function for the TreeNode pointers based on their time
// we use LRU, so we want to evict the least recently used nodes
// since std::priority_queue is a max-heap, we need to reverse the comparison
static
constexpr
auto
cmp
=
[](
TreeNode
*
lhs
,
TreeNode
*
rhs
)
{
return
lhs
->
time
()
>
rhs
->
time
();
};
RadixTree
::
RadixTree
(
bool
disabled
,
std
::
optional
<
std
::
size_t
>
host_size
,
std
::
size_t
page_size
,
std
::
size_t
threshold
)
:
m_impl
(
std
::
make_unique
<
Impl
>
(
disabled
,
host_size
.
has_value
(),
page_size
,
host_size
.
value_or
(
0
),
threshold
))
{}
RadixTree
::~
RadixTree
()
=
default
;
std
::
tuple
<
std
::
vector
<
at
::
Tensor
>
,
std
::
size_t
,
NodeHandle
,
NodeHandle
>
RadixTree
::
match_prefix
(
const
token_vec_t
&
_key
)
{
if
(
m_impl
->
disabled
)
return
{};
const
auto
key
=
token_slice
{
_key
.
data
(),
m_impl
->
align
(
_key
.
size
())};
const
auto
[
host_node
,
_
]
=
m_impl
->
tree_walk
(
key
);
// walk up to the first non-evicted node
std
::
size_t
host_hit_length
=
0
;
const
auto
device_node
=
host_node
;
// collect all the device indices
std
::
vector
<
at
::
Tensor
>
indices
{};
walk_to_root
(
device_node
,
[
&
](
TreeNode
*
n
)
{
indices
.
push_back
(
n
->
device_indices
());
});
std
::
reverse
(
indices
.
begin
(),
indices
.
end
());
return
{
std
::
move
(
indices
),
host_hit_length
,
node2id
(
device_node
),
node2id
(
host_node
)};
}
std
::
vector
<
at
::
Tensor
>
RadixTree
::
evict
(
std
::
size_t
num_tokens
)
{
if
(
m_impl
->
disabled
||
num_tokens
==
0
)
return
{};
auto
heap
=
std
::
priority_queue
{
cmp
,
m_impl
->
collect_leaves_device
()};
std
::
vector
<
at
::
Tensor
>
evicted_values
;
// evict nodes until we reach the desired number of tokens
std
::
size_t
num_evict
=
0
;
while
(
num_evict
<
num_tokens
&&
!
heap
.
empty
())
{
const
auto
node
=
heap
.
top
();
heap
.
pop
();
// when ref_count == 0, can't be writing through
_assert
(
node
->
on_gpu
()
&&
node
->
ref_count
==
0
);
if
(
!
node
->
is_io_free
())
continue
;
// skip nodes that are undergoing IO (i.e. indices protected)
evicted_values
.
push_back
(
node
->
device_indices
());
num_evict
+=
node
->
length
();
const
auto
parent
=
node
->
parent
();
m_impl
->
remove_device_node
(
node
);
if
(
parent
->
is_leaf_device
()
&&
parent
->
ref_count
==
0
)
heap
.
push
(
parent
);
// push parent to the heap if it is now a free leaf
}
return
evicted_values
;
}
std
::
tuple
<
std
::
vector
<
std
::
tuple
<
IOTicket
,
at
::
Tensor
,
at
::
Tensor
>>
,
std
::
size_t
>
RadixTree
::
writing_through
(
const
token_vec_t
&
_key
,
at
::
Tensor
value
)
{
if
(
m_impl
->
disabled
)
return
{};
_assert
(
_key
.
size
()
==
std
::
size_t
(
value
.
size
(
0
)),
"Key and value must have the same size"
);
// just align the key to the page size, clip the unaligned tail
const
auto
key
=
token_slice
{
_key
.
data
(),
m_impl
->
align
(
_key
.
size
())};
// walk the tree to find the right place to insert
const
auto
[
host_node
,
host_prefix_length
]
=
m_impl
->
tree_walk
(
key
);
// insert and create a new node if the remaining part of the key is not empty
if
(
host_prefix_length
!=
key
.
size
())
{
m_impl
->
create_device_node
(
host_node
,
{
key
.
begin
()
+
host_prefix_length
,
key
.
end
()},
value
.
slice
(
/*dim=*/
0
,
host_prefix_length
,
key
.
size
()));
}
// add the hit count for the device node
walk_to_root
(
host_node
,
[
&
](
TreeNode
*
n
)
{
n
->
hit_count
++
;
});
std
::
vector
<
std
::
tuple
<
IOTicket
,
at
::
Tensor
,
at
::
Tensor
>>
result
;
// don't write through if hicache is disabled (no host memory), fast path
if
(
!
m_impl
->
use_hicache
)
return
{
std
::
move
(
result
),
host_prefix_length
};
throw
std
::
runtime_error
(
"Not implemented yet"
);
}
std
::
tuple
<
IOTicket
,
std
::
vector
<
at
::
Tensor
>>
RadixTree
::
loading_onboard
(
NodeHandle
,
at
::
Tensor
)
{
if
(
m_impl
->
disabled
)
return
{};
throw
std
::
runtime_error
(
"Not implemented yet"
);
}
void
RadixTree
::
commit_writing_through
(
IOTicket
,
bool
)
{
if
(
m_impl
->
disabled
)
return
;
throw
std
::
runtime_error
(
"Not implemented yet"
);
}
void
RadixTree
::
commit_loading_onboard
(
IOTicket
,
bool
)
{
if
(
m_impl
->
disabled
)
return
;
throw
std
::
runtime_error
(
"Not implemented yet"
);
}
void
RadixTree
::
reset
()
{
m_impl
->
reset
();
}
void
RadixTree
::
lock_ref
(
NodeHandle
node_id
,
bool
increment
)
{
if
(
m_impl
->
disabled
)
return
;
m_impl
->
lock_ref
(
node_id
,
increment
);
}
std
::
size_t
RadixTree
::
evictable_size
()
const
{
return
m_impl
->
evictable_size
();
}
std
::
size_t
RadixTree
::
protected_size
()
const
{
return
m_impl
->
protected_size
();
}
std
::
size_t
RadixTree
::
total_size
()
const
{
return
m_impl
->
total_size
();
}
}
// namespace radix_tree_v2
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2.h
0 → 100644
View file @
e273aa6d
#pragma once
#include <ATen/core/TensorBody.h>
#include <c10/core/Device.h>
#include <cstddef>
#include <memory>
#include <optional>
#include <tuple>
#include <vector>
#include "common.h"
namespace
radix_tree_v2
{
struct
RadixTree
{
public:
RadixTree
(
bool
disabled
,
std
::
optional
<
std
::
size_t
>
host_size
,
std
::
size_t
page_size
,
std
::
size_t
threshold
);
~
RadixTree
();
// Trees should not be copied or moved, as they manage their own memory and state.
RadixTree
(
const
RadixTree
&
)
=
delete
;
RadixTree
(
RadixTree
&&
)
=
delete
;
RadixTree
&
operator
=
(
const
RadixTree
&
)
=
delete
;
RadixTree
&
operator
=
(
RadixTree
&&
)
=
delete
;
/// @return (device indices that are matched, host indices length, device node, host node)
std
::
tuple
<
std
::
vector
<
at
::
Tensor
>
,
std
::
size_t
,
NodeHandle
,
NodeHandle
>
match_prefix
(
const
token_vec_t
&
key
);
/// @return Device indices that need to be evicted (on python side).
std
::
vector
<
at
::
Tensor
>
evict
(
std
::
size_t
num_tokens
);
/// @brief (Un-)Lock a node.
void
lock_ref
(
NodeHandle
node_id
,
bool
increment
/* increment or decrement */
);
/// @brief Update new key-value pair and try to perform write-through.
std
::
tuple
<
std
::
vector
<
std
::
tuple
<
IOTicket
,
at
::
Tensor
,
at
::
Tensor
>>
,
std
::
size_t
>
writing_through
(
const
token_vec_t
&
key
,
at
::
Tensor
value
);
/// @brief Load to device from host within a range of nodes.
std
::
tuple
<
IOTicket
,
std
::
vector
<
at
::
Tensor
>>
loading_onboard
(
NodeHandle
host_id
,
at
::
Tensor
indices
);
/// @brief Commit a transaction of write-through.
void
commit_writing_through
(
IOTicket
ticket
,
bool
success
);
/// @brief Commit a transaction of load onboard.
void
commit_loading_onboard
(
IOTicket
ticket
,
bool
success
);
/// @brief Clear and reset the tree.
void
reset
();
/// @return How many size are still evictable (on device + not locked).
std
::
size_t
evictable_size
()
const
;
/// @return How many size are protected (locked).
std
::
size_t
protected_size
()
const
;
/// @return How many size are used on device.
std
::
size_t
total_size
()
const
;
/// @brief Print debug information of the tree.
void
debug_print
()
const
;
private:
struct
Impl
;
std
::
unique_ptr
<
Impl
>
m_impl
;
};
}
// namespace radix_tree_v2
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_binding.cpp
0 → 100644
View file @
e273aa6d
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
#include <cstddef>
#include <optional>
#include "tree_v2.h"
PYBIND11_MODULE
(
radix_tree_cpp
,
m
)
{
using
namespace
radix_tree_v2
;
namespace
py
=
pybind11
;
py
::
class_
<
RadixTree
>
(
m
,
"RadixTree"
)
.
def
(
py
::
init
<
bool
,
std
::
optional
<
std
::
size_t
>
,
std
::
size_t
,
std
::
size_t
>
(),
py
::
arg
(
"disabled"
),
py
::
arg
(
"host_size"
),
py
::
arg
(
"page_size"
),
py
::
arg
(
"write_through_threshold"
))
.
def
(
"match_prefix"
,
&
RadixTree
::
match_prefix
)
.
def
(
"evict"
,
&
RadixTree
::
evict
)
.
def
(
"lock_ref"
,
&
RadixTree
::
lock_ref
)
.
def
(
"evictable_size"
,
&
RadixTree
::
evictable_size
)
.
def
(
"protected_size"
,
&
RadixTree
::
protected_size
)
.
def
(
"total_size"
,
&
RadixTree
::
total_size
)
.
def
(
"writing_through"
,
&
RadixTree
::
writing_through
)
.
def
(
"loading_onboard"
,
&
RadixTree
::
loading_onboard
)
.
def
(
"commit_writing_through"
,
&
RadixTree
::
commit_writing_through
)
.
def
(
"commit_loading_onboard"
,
&
RadixTree
::
commit_loading_onboard
)
.
def
(
"reset"
,
&
RadixTree
::
reset
)
.
def
(
"debug_print"
,
&
RadixTree
::
debug_print
);
}
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_debug.cpp
0 → 100644
View file @
e273aa6d
#include <c10/core/DeviceType.h>
#include <c10/core/MemoryFormat.h>
#include <c10/core/ScalarType.h>
#include <cstddef>
#include <cstdlib>
#include <iostream>
#include <stdexcept>
#include <string>
#include "tree_v2.h"
#include "tree_v2_impl.h"
namespace
radix_tree_v2
{
void
RadixTree
::
debug_print
()
const
{
m_impl
->
debug_print
(
std
::
clog
);
}
static
constexpr
auto
npos
=
std
::
size_t
(
-
1
);
void
RadixTree
::
Impl
::
debug_print
(
std
::
ostream
&
os
)
const
{
static
constexpr
auto
_check
=
[](
bool
condition
,
auto
msg
,
std
::
size_t
id
=
npos
)
{
if
(
!
condition
)
{
std
::
string
suffix
=
id
==
npos
?
""
:
" [id = "
+
std
::
to_string
(
id
)
+
"]"
;
throw
std
::
runtime_error
(
std
::
string
(
"RadixTree::debug_print failed: "
)
+
msg
+
suffix
);
}
};
static
constexpr
auto
_print_node
=
[](
TreeNode
*
node
,
std
::
size_t
depth
,
std
::
ostream
&
os
)
{
const
auto
length
=
node
->
length
();
os
<<
node
->
node_id
<<
" [depth = "
<<
depth
<<
"] [len = "
<<
length
<<
"]"
;
// placement status
if
(
node
->
on_both
())
{
os
<<
" [cpu + gpu]"
;
}
else
if
(
node
->
on_gpu
())
{
os
<<
" [gpu]"
;
}
else
if
(
node
->
on_cpu
())
{
os
<<
" [cpu]"
;
}
else
{
_check
(
false
,
"Node is not on GPU or CPU"
,
node
->
node_id
);
}
// IO status
if
(
node
->
is_io_free
())
{
os
<<
" [io = free]"
;
}
else
if
(
node
->
is_io_device_to_host
())
{
os
<<
" [io = gpu -> cpu]"
;
}
else
if
(
node
->
is_io_host_to_device
())
{
os
<<
" [io = cpu -> gpu]"
;
}
else
{
_check
(
false
,
"Node is in unknown IO state"
,
node
->
node_id
);
}
os
<<
" [rc = "
<<
node
->
ref_count
<<
"]"
;
os
<<
" [hit = "
<<
node
->
hit_count
<<
"]"
;
};
static
constexpr
auto
_print_indices
=
[](
at
::
Tensor
indices
,
std
::
ostream
&
os
)
{
if
(
!
indices
.
defined
())
{
os
<<
"[[N/A]]"
;
return
indices
;
}
indices
=
indices
.
to
(
c10
::
kCPU
,
c10
::
kLong
,
false
,
false
,
c10
::
MemoryFormat
::
Contiguous
);
const
auto
length
=
indices
.
numel
();
os
<<
"["
;
auto
*
data_ptr
=
indices
.
data_ptr
<
int64_t
>
();
for
(
const
auto
i
:
c10
::
irange
(
indices
.
size
(
0
)))
{
os
<<
data_ptr
[
i
];
if
(
i
!=
length
-
1
)
os
<<
", "
;
}
os
<<
"]"
;
return
indices
;
};
os
<<
"Evictable size: "
<<
evictable_size
()
<<
std
::
endl
;
os
<<
"Protected size: "
<<
protected_size
()
<<
std
::
endl
;
os
<<
"Total size: "
<<
const_cast
<
Impl
*>
(
this
)
->
total_size
()
<<
std
::
endl
;
std
::
vector
<
std
::
tuple
<
TreeNode
*
,
TreeNode
*
,
token_slice
>>
stack
;
auto
root
=
const_cast
<
TreeNode
*>
(
&
m_root
);
os
<<
root
->
node_id
<<
" [root]"
<<
std
::
endl
;
for
(
const
auto
&
[
key
,
child
]
:
*
root
)
{
stack
.
push_back
({
child
.
get
(),
root
,
key
});
}
std
::
unordered_map
<
TreeNode
*
,
std
::
size_t
>
depth_map
;
std
::
string
indent_buffer
;
depth_map
[
root
]
=
0
;
std
::
vector
<
NodeHandle
>
visited_id
;
std
::
size_t
evictable_size_real
=
0
;
while
(
!
stack
.
empty
())
{
const
auto
[
node
,
parent
,
key
]
=
stack
.
back
();
stack
.
pop_back
();
visited_id
.
push_back
(
node
->
node_id
);
const
auto
nid
=
node
->
node_id
;
_check
(
node
!=
nullptr
,
"Node is null"
,
nid
);
_check
(
node
->
on_gpu
()
||
node
->
on_cpu
(),
"Node is not on GPU or CPU"
,
nid
);
_check
(
node
->
parent
()
==
parent
,
"Parent is not correct"
,
nid
);
_check
(
key
.
size
()
==
page_size
&&
node
->
diff_key
(
key
,
0
)
==
page_size
,
"Key is not correct"
,
nid
);
_check
(
depth_map
.
count
(
node
)
==
0
,
"Node is visited twice"
,
nid
);
_check
(
m_node_map
.
count
(
nid
)
==
1
,
"Node is not in the map"
,
nid
);
_check
(
m_node_map
.
at
(
nid
)
==
node
,
"Node in the map is not the same as the one in the stack"
,
nid
);
_check
(
!
node
->
on_gpu
()
||
parent
->
is_root
()
||
parent
->
on_gpu
(),
"Node on GPU must have a GPU/root parent"
,
nid
);
if
(
!
node
->
is_io_free
())
{
_check
(
node
->
ref_count
>
0
,
"Node is in IO state but not protected"
,
nid
);
_check
(
node
->
on_both
(),
"Node in IO state must be on both CPU and GPU"
,
nid
);
}
if
(
node
->
on_gpu
()
&&
node
->
ref_count
==
0
)
{
evictable_size_real
+=
node
->
length
();
}
const
auto
depth
=
(
depth_map
[
node
]
=
depth_map
[
parent
]
+
1
);
indent_buffer
.
resize
(
depth
*
2
,
' '
);
os
<<
indent_buffer
;
_print_node
(
node
,
depth
,
os
);
os
<<
std
::
endl
;
for
(
const
auto
&
[
key
,
child
]
:
*
node
)
{
stack
.
push_back
({
child
.
get
(),
node
,
key
});
}
}
_check
(
evictable_size_real
==
evictable_size
(),
"Evictable size is wrong"
);
_check
(
m_node_map
.
count
(
root
->
node_id
)
==
1
,
"Root node is not in the map"
);
_check
(
m_node_map
.
at
(
root
->
node_id
)
==
root
,
"Root node in the map is not correct"
);
std
::
sort
(
visited_id
.
begin
(),
visited_id
.
end
());
if
(
visited_id
.
size
()
!=
m_node_map
.
size
()
-
1
)
{
// Some error in the tree, not all nodes are visited
std
::
string
id_list
;
id_list
+=
"(visited: "
;
id_list
+=
std
::
to_string
(
root
->
node_id
)
+
" "
;
for
(
const
auto
&
id
:
visited_id
)
{
id_list
+=
std
::
to_string
(
id
)
+
" "
;
}
id_list
+=
"), (in map: "
;
for
(
const
auto
&
[
id
,
_
]
:
m_node_map
)
{
id_list
+=
std
::
to_string
(
id
)
+
" "
;
}
id_list
+=
")"
;
_check
(
false
,
"Not all nodes are visited "
+
id_list
);
}
static
const
auto
kSGLANG_RADIX_CPP_DEBUG_LIMIT
=
[]
{
const
char
*
env
=
std
::
getenv
(
"SGLANG_RADIX_CPP_DEBUG_LIMIT"
);
const
std
::
size_t
default_limit
=
16
;
if
(
env
!=
nullptr
)
{
try
{
return
static_cast
<
std
::
size_t
>
(
std
::
stoull
(
env
));
}
catch
(
const
std
::
exception
&
e
)
{
std
::
cerr
<<
"Invalid SGLANG_RADIX_CPP_DEBUG_LIMIT value: "
<<
env
//
<<
". Using default value ="
<<
default_limit
<<
std
::
endl
;
}
}
return
default_limit
;
}();
for
(
const
auto
nid
:
visited_id
)
{
const
auto
node
=
m_node_map
.
at
(
nid
);
// print key and indices
const
auto
&
key
=
node
->
_unsafe_tokens
();
if
(
key
.
size
()
>
kSGLANG_RADIX_CPP_DEBUG_LIMIT
)
{
os
<<
"Node "
<<
nid
<<
": key is too long ("
<<
key
.
size
()
<<
" tokens), skipping..."
<<
std
::
endl
;
continue
;
}
os
<<
"Node "
<<
nid
<<
": key = ["
;
for
(
const
auto
&
i
:
c10
::
irange
(
key
.
size
()))
{
os
<<
key
[
i
];
if
(
i
!=
key
.
size
()
-
1
)
os
<<
", "
;
}
_check
(
key
.
size
()
%
page_size
==
0
,
"Misaligned key"
,
nid
);
os
<<
"] device_indices = "
;
const
auto
device_indices
=
_print_indices
(
node
->
device_indices
(),
os
);
if
(
device_indices
.
defined
())
{
std
::
size_t
length
=
device_indices
.
numel
();
_check
(
device_indices
.
dim
()
==
1
,
"Device indices must be 1D tensor"
,
nid
);
_check
(
length
==
node
->
length
(),
"Wrong device indices size"
,
nid
);
}
os
<<
" host_indices = "
;
const
auto
host_indices
=
_print_indices
(
node
->
host_indices
(),
os
);
if
(
host_indices
.
defined
())
{
std
::
size_t
length
=
host_indices
.
numel
();
_check
(
host_indices
.
dim
()
==
1
,
"Host indices must be 1D tensor"
,
nid
);
_check
(
length
==
node
->
length
(),
"Wrong host indices size"
,
nid
);
}
os
<<
std
::
endl
;
}
}
}
// namespace radix_tree_v2
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_impl.h
0 → 100644
View file @
e273aa6d
#pragma once
#include <c10/util/irange.h>
#include <chrono>
#include <cstddef>
#include <iosfwd>
#include <memory>
#include <unordered_map>
#include <utility>
#include <vector>
#include "common.h"
#include "tree_v2.h"
#include "tree_v2_node.h"
namespace
radix_tree_v2
{
using
node_iterator_t
=
typename
TreeNode
::
iterator_t
;
struct
RadixTree
::
Impl
{
public:
Impl
(
bool
disabled
,
bool
use_hicache
,
std
::
size_t
page_size
,
std
::
size_t
host_size
,
std
::
size_t
threshold
)
:
m_root
(
/*node_id_=*/
0
),
m_evictable_size
(
0
),
m_protected_size
(
0
),
m_cached_vec
(),
m_node_map
(),
m_node_counter
(
1
),
// start from 1 to avoid confusion with root node
disabled
(
disabled
),
use_hicache
(
use_hicache
),
page_size
(
page_size
),
threshold
(
threshold
)
{
_assert
(
page_size
>
0
,
"Page size must be greater than zero"
);
_assert
(
use_hicache
==
(
host_size
>
0
),
"Hierarchical cache is enabled iff host size > 0"
);
m_root
.
ref_count
=
1
;
// root node is always protected
m_cached_vec
.
reserve
(
page_size
);
// to avoid repeated allocations
m_node_map
[
m_root
.
node_id
]
=
&
m_root
;
// add root to the map
}
TreeNode
*
split_node
(
node_iterator_t
iterator
,
std
::
size_t
prefix_length
)
{
// from `parent -> old_node` to `parent-> new_node -> old_node`
// the prefix part of the old node is moved to the new node
auto
old_node_ptr
=
std
::
move
(
iterator
->
second
);
auto
new_node_ptr
=
std
::
make_unique
<
TreeNode
>
(
m_node_counter
++
);
auto
*
old_node
=
old_node_ptr
.
get
();
auto
*
new_node
=
new_node_ptr
.
get
();
auto
*
parent
=
old_node
->
parent
();
// set up data structures
split_prefix
(
new_node
,
old_node
,
prefix_length
);
// set up parent-child relationship
add_child
(
new_node
,
std
::
move
(
old_node_ptr
));
add_child
(
parent
,
std
::
move
(
new_node_ptr
),
iterator
);
m_node_map
[
new_node
->
node_id
]
=
new_node
;
// add to the map
return
new_node
;
}
// node: x -> [GPU]
TreeNode
*
create_device_node
(
TreeNode
*
parent
,
token_vec_t
vec
,
at
::
Tensor
indices
)
{
auto
new_node_ptr
=
std
::
make_unique
<
TreeNode
>
(
m_node_counter
++
);
auto
new_node
=
new_node_ptr
.
get
();
new_node_ptr
->
_unsafe_tokens
()
=
std
::
move
(
vec
);
new_node_ptr
->
_unsafe_device_indices
()
=
std
::
move
(
indices
);
m_evictable_size
+=
new_node_ptr
->
length
();
add_child
(
parent
,
std
::
move
(
new_node_ptr
));
m_node_map
[
new_node
->
node_id
]
=
new_node
;
// add to the map
return
new_node
;
}
// node: [GPU] -> x
void
remove_device_node
(
TreeNode
*
node
)
{
_assert
(
node
->
on_gpu_only
()
&&
node
->
ref_count
==
0
);
m_evictable_size
-=
node
->
length
();
node
->
parent
()
->
erase_child
(
get_key
(
node
));
m_node_map
.
erase
(
node
->
node_id
);
// remove from the map
}
/**
* @brief Walk the tree to find the node that matches the key.
* If the key partially matches a node, it will split that node.
* @return A pair containing the last node that matches the key and
* the total prefix length matched (on gpu and cpu) so far.
*/
std
::
pair
<
TreeNode
*
,
std
::
size_t
>
tree_walk
(
token_slice
key
)
{
_assert
(
key
.
size
()
%
page_size
==
0
,
"Key should be page-aligned"
);
std
::
size_t
total_prefix_length
=
0
;
TreeNode
*
node
=
&
m_root
;
const
auto
now
=
std
::
chrono
::
steady_clock
::
now
();
while
(
key
.
size
()
>
0
)
{
const
auto
iterator
=
node
->
find_child
(
get_key
(
key
));
if
(
iterator
==
node
->
end
())
break
;
// walk to the child node
node
=
iterator
->
second
.
get
();
// at least `page_size` tokens are matched, and there may be more tokens to match
// the return value prefix_length is no less than `page_size`
const
auto
prefix_length
=
align
(
node
->
diff_key
(
key
,
page_size
)
+
page_size
);
total_prefix_length
+=
prefix_length
;
// split the node if the prefix is not the whole token vector
if
(
prefix_length
<
node
->
length
())
{
return
{
split_node
(
iterator
,
prefix_length
),
total_prefix_length
};
}
// we have matched the whole key, continue to the next node
node
->
access
(
now
);
key
=
key
.
subspan
(
prefix_length
);
}
return
{
node
,
total_prefix_length
};
}
std
::
vector
<
TreeNode
*>
collect_leaves
()
const
{
std
::
vector
<
TreeNode
*>
leaves
;
std
::
vector
<
TreeNode
*>
stack
=
{};
for
(
const
auto
&
[
_
,
child
]
:
m_root
)
{
stack
.
push_back
(
child
.
get
());
}
while
(
!
stack
.
empty
())
{
const
auto
node
=
stack
.
back
();
stack
.
pop_back
();
if
(
node
->
is_leaf
())
{
if
(
node
->
ref_count
==
0
)
{
leaves
.
push_back
(
node
);
}
}
else
{
for
(
const
auto
&
[
_
,
child
]
:
*
node
)
{
stack
.
push_back
(
child
.
get
());
}
}
}
return
leaves
;
}
std
::
vector
<
TreeNode
*>
collect_leaves_device
()
const
{
// for non-hicache, every leaf device node is a leaf node (since no backup on host)
if
(
!
use_hicache
)
return
collect_leaves
();
std
::
vector
<
TreeNode
*>
leaves
;
std
::
vector
<
TreeNode
*>
stack
=
{};
for
(
const
auto
&
[
_
,
child
]
:
m_root
)
{
stack
.
push_back
(
child
.
get
());
}
while
(
!
stack
.
empty
())
{
const
auto
node
=
stack
.
back
();
stack
.
pop_back
();
if
(
!
node
->
on_gpu
())
continue
;
// skip nodes that are not on GPU
if
(
node
->
is_leaf_device
())
{
if
(
node
->
ref_count
==
0
)
{
leaves
.
push_back
(
node
);
}
}
else
{
for
(
const
auto
&
[
_
,
child
]
:
*
node
)
{
stack
.
push_back
(
child
.
get
());
}
}
}
return
leaves
;
}
void
lock_ref
(
TreeNode
*
node
,
bool
increment
)
{
if
(
node
->
is_root
())
return
;
// skip root node
_assert
(
node
->
on_gpu
(),
"Cannot lock reference on an evicted node"
);
if
(
increment
)
walk_to_root
(
node
,
[
this
](
TreeNode
*
n
)
{
if
(
n
->
ref_count
==
0
)
{
m_evictable_size
-=
n
->
length
();
m_protected_size
+=
n
->
length
();
}
n
->
ref_count
++
;
});
else
walk_to_root
(
node
,
[
this
](
TreeNode
*
n
)
{
_assert
(
n
->
ref_count
!=
0
,
"Cannot decrement reference count = zero"
);
n
->
ref_count
--
;
if
(
n
->
ref_count
==
0
)
{
m_protected_size
-=
n
->
length
();
m_evictable_size
+=
n
->
length
();
}
});
}
void
lock_ref
(
NodeHandle
node_ptr
,
bool
increment
)
{
return
lock_ref
(
id2node
(
node_ptr
),
increment
);
}
void
lock
(
TreeNode
*
node
)
{
return
lock_ref
(
node
,
/*increment=*/
true
);
}
void
unlock
(
TreeNode
*
node
)
{
return
lock_ref
(
node
,
/*increment=*/
false
);
}
std
::
size_t
total_size
()
const
{
std
::
size_t
size
=
0
;
std
::
vector
<
const
TreeNode
*>
stack
=
{
&
m_root
};
while
(
!
stack
.
empty
())
{
auto
*
node
=
stack
.
back
();
stack
.
pop_back
();
size
+=
node
->
length
();
for
(
const
auto
&
[
_
,
child
]
:
*
node
)
stack
.
push_back
(
child
.
get
());
}
return
size
;
}
std
::
size_t
evictable_size
()
const
{
return
m_evictable_size
;
}
std
::
size_t
protected_size
()
const
{
return
m_protected_size
;
}
std
::
size_t
align
(
std
::
size_t
size
)
const
{
return
(
size
/
page_size
)
*
page_size
;
// align to page size
}
TreeNode
*
id2node
(
NodeHandle
node_id
)
const
{
const
auto
iterator
=
m_node_map
.
find
(
node_id
);
_assert
(
iterator
!=
m_node_map
.
end
(),
"Node not found in the map"
);
return
iterator
->
second
;
}
void
reset
()
{
_assert
(
m_root
.
ref_count
==
1
,
"Root node must be protected during reset"
);
m_node_counter
=
1
;
// reset node counter
m_root
.
root_reset
();
m_evictable_size
=
0
;
m_protected_size
=
0
;
m_node_map
.
clear
();
m_node_map
[
m_root
.
node_id
]
=
&
m_root
;
// re-add root to the map
}
void
debug_print
(
std
::
ostream
&
os
)
const
;
private:
// some auxiliary functions
token_vec_t
&
get_key
(
token_slice
tokens
)
{
_assert
(
tokens
.
size
()
>=
page_size
,
"Key should be at least page-sized"
);
tokens
=
tokens
.
subspan
(
0
,
page_size
);
m_cached_vec
.
assign
(
tokens
.
begin
(),
tokens
.
end
());
return
m_cached_vec
;
}
// justify for _unsafe call: we need to read the key part of the tokens
token_vec_t
&
get_key
(
TreeNode
*
node
)
{
return
get_key
(
node
->
_unsafe_tokens
());
}
void
add_child
(
TreeNode
*
parent
,
std
::
unique_ptr
<
TreeNode
>&&
child
)
{
parent
->
add_child
(
get_key
(
child
.
get
()),
std
::
move
(
child
));
}
void
add_child
(
TreeNode
*
parent
,
std
::
unique_ptr
<
TreeNode
>&&
child
,
node_iterator_t
it
)
{
parent
->
add_child
(
it
,
std
::
move
(
child
));
}
TreeNode
m_root
;
// root node of the tree
std
::
size_t
m_evictable_size
;
// number of evictable tokens on GPU (lock ref = 0)
std
::
size_t
m_protected_size
;
// number of protected tokens on GPU (lock ref > 0)
token_vec_t
m_cached_vec
;
// cached vector of tokens for the current operation
std
::
unordered_map
<
std
::
size_t
,
TreeNode
*>
m_node_map
;
// map of node keys to nodes
std
::
size_t
m_node_counter
;
// counter for node IDs
public:
// some public constant configurations (without m_ prefix)
const
bool
disabled
;
// whether the cache is enabled, or just a temporary cache
const
bool
use_hicache
;
// whether to use the HiCache for this tree
const
std
::
size_t
page_size
;
// size of each page in the cache
const
std
::
size_t
threshold
;
// threshold for write_through
};
}
// namespace radix_tree_v2
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_node.h
0 → 100644
View file @
e273aa6d
#pragma once
#include <ATen/core/TensorBody.h>
#include <algorithm>
#include <array>
#include <chrono>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <optional>
#include <ranges>
#include <unordered_map>
#include "common.h"
namespace
radix_tree_v2
{
struct
std_vector_hash
{
// see https://stackoverflow.com/questions/20511347/a-good-hash-function-for-a-vector
std
::
size_t
operator
()(
const
token_vec_t
&
vec
)
const
{
std
::
size_t
hash
=
0
;
for
(
const
auto
&
token
:
vec
)
{
hash
^=
token
+
0x9e3779b9
+
(
hash
<<
6
)
+
(
hash
>>
2
);
}
return
hash
;
}
};
struct
TreeNode
{
public:
using
childern_map_t
=
std
::
unordered_map
<
token_vec_t
,
std
::
unique_ptr
<
TreeNode
>
,
std_vector_hash
>
;
using
iterator_t
=
typename
childern_map_t
::
iterator
;
using
const_iterator_t
=
typename
childern_map_t
::
const_iterator
;
using
timestamp_t
=
std
::
chrono
::
steady_clock
::
time_point
;
TreeNode
(
std
::
size_t
node_id_
)
:
ref_count
(
0
),
hit_count
(
0
),
m_io_locked
(
std
::
nullopt
),
m_io_status
(
IOStatus
::
None
),
m_io_ticket
(),
m_tokens
(),
m_device_indices
(),
m_host_indices
(),
m_parent
(),
m_children
(),
m_last_access_time
(
std
::
chrono
::
steady_clock
::
now
()),
node_id
(
node_id_
)
{}
void
access
(
timestamp_t
time
=
std
::
chrono
::
steady_clock
::
now
())
{
m_last_access_time
=
time
;
}
bool
is_root
()
const
{
return
m_parent
==
nullptr
;
}
timestamp_t
time
()
const
{
return
m_last_access_time
;
}
bool
on_gpu
()
const
{
return
m_device_indices
.
defined
();
}
bool
on_cpu
()
const
{
return
m_host_indices
.
defined
();
}
bool
on_gpu_only
()
const
{
return
on_gpu
()
&&
!
on_cpu
();
}
bool
on_cpu_only
()
const
{
return
!
on_gpu
()
&&
on_cpu
();
}
bool
on_both
()
const
{
return
on_gpu
()
&&
on_cpu
();
}
std
::
size_t
length
()
const
{
return
m_tokens
.
size
();
}
bool
is_leaf
()
const
{
return
m_children
.
empty
();
}
bool
is_leaf_device
()
const
{
for
(
const
auto
&
[
_
,
child
]
:
m_children
)
if
(
child
->
on_gpu
())
return
false
;
// at least one child is on the device
return
true
;
}
void
add_child
(
const
token_vec_t
&
v
,
std
::
unique_ptr
<
TreeNode
>&&
child
)
{
child
->
m_parent
=
this
;
m_children
[
v
]
=
std
::
move
(
child
);
}
void
add_child
(
iterator_t
it
,
std
::
unique_ptr
<
TreeNode
>&&
child
)
{
child
->
m_parent
=
this
;
it
->
second
=
std
::
move
(
child
);
}
void
erase_child
(
const
token_vec_t
&
v
)
{
_assert
(
m_children
.
erase
(
v
)
>
0
,
"Child node not found"
);
}
iterator_t
find_child
(
const
token_vec_t
&
v
)
{
return
m_children
.
find
(
v
);
}
iterator_t
begin
()
{
return
m_children
.
begin
();
}
iterator_t
end
()
{
return
m_children
.
end
();
}
const_iterator_t
begin
()
const
{
return
m_children
.
begin
();
}
const_iterator_t
end
()
const
{
return
m_children
.
end
();
}
TreeNode
*
parent
()
{
return
m_parent
;
}
// set up all data structures except for parent-child relationship
friend
void
split_prefix
(
TreeNode
*
new_node
,
TreeNode
*
old_node
,
std
::
size_t
prefix_length
)
{
auto
tokens
=
std
::
move
(
old_node
->
m_tokens
);
_assert
(
0
<
prefix_length
&&
prefix_length
<
tokens
.
size
(),
"Invalid prefix size for split"
);
// set up tokens
old_node
->
m_tokens
=
token_vec_t
(
tokens
.
begin
()
+
prefix_length
,
tokens
.
end
());
new_node
->
m_tokens
=
std
::
move
(
tokens
);
new_node
->
m_tokens
.
resize
(
prefix_length
);
// set up values
const
int64_t
new_size
=
new_node
->
length
();
const
int64_t
old_size
=
old_node
->
length
();
if
(
old_node
->
m_device_indices
.
defined
())
{
auto
new_indices
=
old_node
->
m_device_indices
.
split_with_sizes
({
new_size
,
old_size
});
new_node
->
m_device_indices
=
std
::
move
(
new_indices
[
0
]);
old_node
->
m_device_indices
=
std
::
move
(
new_indices
[
1
]);
}
if
(
old_node
->
m_host_indices
.
defined
())
{
auto
new_indices
=
old_node
->
m_host_indices
.
split_with_sizes
({
new_size
,
old_size
});
new_node
->
m_host_indices
=
std
::
move
(
new_indices
[
0
]);
old_node
->
m_host_indices
=
std
::
move
(
new_indices
[
1
]);
}
// set up ref counts and hit counts
new_node
->
ref_count
=
old_node
->
ref_count
;
new_node
->
hit_count
=
old_node
->
hit_count
;
// If the old node (child) was locked for IO, the new node (parent) does not need
// to be locked, since it is naturally protected by the child node's lock.
if
(
old_node
->
m_io_locked
.
has_value
())
{
new_node
->
m_io_locked
=
false
;
new_node
->
m_io_status
=
old_node
->
m_io_status
;
new_node
->
m_io_ticket
=
old_node
->
m_io_ticket
;
}
}
/// @return The first index in `m_tokens` that differs from `key`.
std
::
size_t
diff_key
(
token_slice
key
,
std
::
size_t
offset
)
const
{
const
auto
a
=
token_slice
{
key
}.
subspan
(
offset
);
const
auto
b
=
token_slice
{
m_tokens
}.
subspan
(
offset
);
const
auto
[
it_a
,
it_b
]
=
std
::
ranges
::
mismatch
(
a
,
b
);
return
it_a
-
a
.
begin
();
// return the index of the first differing token
}
at
::
Tensor
device_indices
()
const
{
return
m_device_indices
;
}
at
::
Tensor
host_indices
()
const
{
return
m_host_indices
;
}
// visiting tokens are always unsafe (use `diff_key` instead)
token_vec_t
&
_unsafe_tokens
()
{
return
m_tokens
;
}
at
::
Tensor
&
_unsafe_device_indices
()
{
return
m_device_indices
;
}
at
::
Tensor
&
_unsafe_host_indices
()
{
return
m_host_indices
;
}
bool
is_io_free
()
const
{
return
m_io_status
==
IOStatus
::
None
;
}
bool
is_io_device_to_host
()
const
{
return
m_io_status
==
IOStatus
::
DeviceToHost
;
}
bool
is_io_host_to_device
()
const
{
return
m_io_status
==
IOStatus
::
HostToDevice
;
}
void
root_reset
()
{
_assert
(
is_root
(),
"Only root node can call root_reset"
);
_assert
(
m_io_status
==
IOStatus
::
None
&&
m_io_locked
==
std
::
nullopt
,
"IO operation in progress, cannot reset root node"
);
_assert
(
this
->
m_tokens
.
empty
(),
"Root node tokens should be empty on reset"
);
_assert
(
!
this
->
m_device_indices
.
defined
()
&&
!
this
->
m_host_indices
.
defined
(),
"Root node indices should be always be empty and never assigned"
);
m_children
.
clear
();
this
->
access
();
}
public:
std
::
size_t
ref_count
;
std
::
size_t
hit_count
;
private:
enum
class
IOStatus
:
std
::
uint8_t
{
None
,
HostToDevice
,
DeviceToHost
,
};
std
::
optional
<
bool
>
m_io_locked
;
// whether the node is locked in IO operation
IOStatus
m_io_status
;
IOTicket
m_io_ticket
;
token_vec_t
m_tokens
;
at
::
Tensor
m_device_indices
;
// indices of device value
at
::
Tensor
m_host_indices
;
// indices of host value
TreeNode
*
m_parent
;
childern_map_t
m_children
;
timestamp_t
m_last_access_time
;
public:
const
std
::
size_t
node_id
;
// unique ID for the node
};
template
<
typename
F
>
inline
TreeNode
*
walk_to_root
(
TreeNode
*
t
,
const
F
&
f
)
{
while
(
!
t
->
is_root
())
{
f
(
t
);
t
=
t
->
parent
();
}
return
t
;
// return the root node
}
}
// namespace radix_tree_v2
python/sglang/srt/mem_cache/radix_cache_cpp.py
0 → 100644
View file @
e273aa6d
from
__future__
import
annotations
import
logging
from
typing
import
TYPE_CHECKING
,
List
,
Set
import
torch
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
,
MatchResult
from
sglang.srt.mem_cache.cpp_radix_tree.radix_tree
import
(
IOHandle
,
RadixTreeCpp
,
TreeNodeCpp
,
)
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
Req
logger
=
logging
.
getLogger
(
__name__
)
class
RadixCacheCpp
(
BasePrefixCache
):
def
_merge_tensor
(
self
,
l
:
List
[
torch
.
Tensor
])
->
torch
.
Tensor
:
"""
Merge a list of tensors into a single tensor.
Args:
l (List[torch.Tensor]): List of tensors to merge.
Returns:
torch.Tensor: Merged tensor.
"""
if
len
(
l
)
==
0
:
return
torch
.
empty
(
0
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
elif
len
(
l
)
==
1
:
return
l
[
0
]
else
:
return
torch
.
cat
(
l
)
def
__init__
(
self
,
disable
:
bool
,
use_hicache
:
bool
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool
:
BaseTokenToKVPoolAllocator
,
tp_cache_group
:
torch
.
distributed
.
ProcessGroup
,
page_size
:
int
,
hicache_ratio
:
float
,
hicache_size
:
int
,
hicache_write_policy
:
str
,
enable_kv_cache_events
:
bool
=
False
,
hicache_oracle
:
bool
=
False
,
enable_write_cancel
:
bool
=
False
,
):
self
.
disable
=
disable
self
.
enable_write_cancel
=
enable_write_cancel
assert
(
enable_kv_cache_events
is
False
),
"HiRadixCache does not support kv cache events yet"
self
.
kv_cache
=
token_to_kv_pool
.
get_kvcache
()
# record the nodes with ongoing write through
self
.
ongoing_write_through
:
Set
[
IOHandle
]
=
set
()
# record the node segments with ongoing load back
self
.
ongoing_load_back
:
Set
[
IOHandle
]
=
set
()
# todo: dynamically adjust the threshold
self
.
write_through_threshold
=
(
1
if
hicache_write_policy
==
"write_through"
else
2
)
self
.
device
=
token_to_kv_pool
.
device
self
.
token_to_kv_pool
=
token_to_kv_pool
self
.
req_to_token_pool
=
req_to_token_pool
self
.
page_size
=
page_size
self
.
tp_group
=
tp_cache_group
if
not
use_hicache
:
self
.
tree
=
RadixTreeCpp
(
disabled
=
self
.
disable
,
page_size
=
page_size
,
host_size
=
None
,
# no host cache, this should be removed in the future
write_through_threshold
=
self
.
write_through_threshold
,
)
self
.
cache_controller
=
None
return
# early return if hicache is not used
raise
NotImplementedError
(
"Host cache is not supported yet"
)
def
reset
(
self
):
if
self
.
cache_controller
is
not
None
:
# need to clear the acks before resetting the cache controller
raise
NotImplementedError
(
"Host cache is not supported yet"
)
self
.
tree
.
reset
()
def
match_prefix
(
self
,
key
:
List
[
int
],
**
kwargs
)
->
MatchResult
:
device_indices_vec
,
host_indices_length
,
node_gpu
,
node_cpu
=
(
self
.
tree
.
match_prefix
(
key
)
)
return
MatchResult
(
device_indices
=
self
.
_merge_tensor
(
device_indices_vec
),
last_device_node
=
node_gpu
,
last_host_node
=
node_cpu
,
host_hit_length
=
host_indices_length
,
)
def
_insert
(
self
,
key
:
List
[
int
],
value
:
torch
.
Tensor
)
->
int
:
"""
Insert a key-value pair into the radix tree.
Args:
key (List[int]): The key to insert, represented as a list of integers.
value (torch.Tensor): The value to associate with the key.
Returns:
int: Number of device indices that were already present in the tree before the insertion.
"""
ongoing_write
,
length
=
self
.
tree
.
writing_through
(
key
,
value
)
if
self
.
cache_controller
is
None
:
assert
len
(
ongoing_write
)
==
0
,
"Implementation error"
return
length
raise
NotImplementedError
(
"Host cache is not supported yet"
)
def
dec_lock_ref
(
self
,
node
:
TreeNodeCpp
):
"""
Decrement the reference count of a node to root of the radix tree.
Args:
node (TreeNodeCpp): The handle of the node to decrement the reference count for.
"""
self
.
tree
.
lock_ref
(
node
,
False
)
# do not increment
def
inc_lock_ref
(
self
,
node
:
TreeNodeCpp
):
"""
Increment the reference count of from a node to root of the radix tree.
Args:
node (TreeNodeCpp): The handle of the node to increment the reference count for.
"""
self
.
tree
.
lock_ref
(
node
,
True
)
def
evict
(
self
,
num_tokens
:
int
):
evicted_device_indices
=
self
.
tree
.
evict
(
num_tokens
)
for
indice
in
evicted_device_indices
:
self
.
token_to_kv_pool
.
free
(
indice
)
def
evictable_size
(
self
):
return
self
.
tree
.
evictable_size
()
def
protected_size
(
self
):
return
self
.
tree
.
protected_size
()
def
total_size
(
self
):
return
self
.
tree
.
total_size
()
def
cache_finished_req
(
self
,
req
:
Req
):
"""Cache request when it finishes."""
assert
req
.
req_pool_idx
is
not
None
token_ids
=
(
req
.
origin_input_ids
+
req
.
output_ids
)[:
-
1
]
overall_len
=
len
(
token_ids
)
# prefill + decode
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
overall_len
]
# NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
# it will automatically align them, but length of them should be equal
old_prefix_len
=
len
(
req
.
prefix_indices
)
//
self
.
page_size
*
self
.
page_size
new_prefix_len
=
self
.
_insert
(
token_ids
,
kv_indices
)
# NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
assert
old_prefix_len
<=
new_prefix_len
,
"Wrong prefix indices"
# KVCache between old & new is newly generated, but already exists in the pool
# we need to free this newly generated kv indices
if
old_prefix_len
<
new_prefix_len
:
self
.
token_to_kv_pool
.
free
(
kv_indices
[
old_prefix_len
:
new_prefix_len
])
# need to free the unaligned part, since it cannot be inserted into the radix tree
if
self
.
page_size
!=
1
and
(
# unaligned tail only exists when page_size > 1
(
unaligned_len
:
=
overall_len
%
self
.
page_size
)
>
0
):
# NOTE: sglang PagedAllocator support unaligned free (which will automatically align it)
self
.
token_to_kv_pool
.
free
(
kv_indices
[
overall_len
-
unaligned_len
:])
# Remove req slot release the cache lock
self
.
dec_lock_ref
(
req
.
last_node
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
def
cache_unfinished_req
(
self
,
req
:
Req
):
"""Cache request when it is unfinished."""
assert
req
.
req_pool_idx
is
not
None
token_ids
=
req
.
fill_ids
prefill_len
=
len
(
token_ids
)
# prefill only (maybe chunked)
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
prefill_len
]
# NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
# it will automatically align them, but length of them should be equal
old_prefix_len
=
len
(
req
.
prefix_indices
)
//
self
.
page_size
*
self
.
page_size
new_prefix_len
=
self
.
_insert
(
token_ids
,
kv_indices
)
# NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
assert
old_prefix_len
<=
new_prefix_len
,
"Wrong prefix indices"
# TODO(dark): optimize the `insert` and `match` (e.g. merge into 1 function)
# The prefix indices need to updated to reuse the kv indices in the pool
new_indices_vec
,
_
,
new_last_node
,
_
=
self
.
tree
.
match_prefix
(
token_ids
)
new_indices
=
self
.
_merge_tensor
(
new_indices_vec
)
assert
new_prefix_len
<=
len
(
new_indices
)
# KVCache between old & new is newly generated, but already exists in the pool
# we need to free this newly generated kv indices and reuse the indices in the pool
if
old_prefix_len
<
new_prefix_len
:
self
.
token_to_kv_pool
.
free
(
kv_indices
[
old_prefix_len
:
new_prefix_len
])
reused_indices
=
new_indices
[
old_prefix_len
:
new_prefix_len
]
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
old_prefix_len
:
new_prefix_len
]
=
reused_indices
if
req
.
last_node
!=
new_last_node
:
self
.
dec_lock_ref
(
req
.
last_node
)
self
.
inc_lock_ref
(
new_last_node
)
# NOTE: there might be unaligned tail, so we may need to append it
assert
len
(
new_indices
)
<=
prefill_len
<
len
(
new_indices
)
+
self
.
page_size
if
self
.
page_size
!=
1
and
len
(
new_indices
)
<
prefill_len
:
req
.
prefix_indices
=
torch
.
cat
(
[
new_indices
,
kv_indices
[
len
(
new_indices
)
:]]
)
else
:
req
.
prefix_indices
=
new_indices
req
.
last_node
=
new_last_node
def
pretty_print
(
self
):
return
self
.
tree
.
debug_print
()
test/srt/test_cpp_radix_cache.py
0 → 100644
View file @
e273aa6d
import
os
import
unittest
from
types
import
SimpleNamespace
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
popen_launch_server
,
)
class
TestCppRadixCache
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
os
.
environ
[
"SGLANG_EXPERIMENTAL_CPP_RADIX_TREE"
]
=
"1"
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
)
metrics
=
run_eval
(
args
)
print
(
metrics
)
self
.
assertGreaterEqual
(
metrics
[
"score"
],
0.65
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment