Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
05c9bc89
Unverified
Commit
05c9bc89
authored
Jun 22, 2025
by
Liangsheng Yin
Committed by
GitHub
Jun 22, 2025
Browse files
[minor] simplify the `TokenToKVPoolAllocator` (#7414)
parent
b7a2df0a
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
165 additions
and
149 deletions
+165
-149
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+3
-8
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+0
-1
python/sglang/srt/managers/cache_controller.py
python/sglang/srt/managers/cache_controller.py
+5
-4
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+4
-3
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+7
-3
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+0
-1
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+3
-2
python/sglang/srt/mem_cache/allocator.py
python/sglang/srt/mem_cache/allocator.py
+125
-34
python/sglang/srt/mem_cache/chunk_cache.py
python/sglang/srt/mem_cache/chunk_cache.py
+4
-3
python/sglang/srt/mem_cache/hiradix_cache.py
python/sglang/srt/mem_cache/hiradix_cache.py
+2
-2
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+0
-79
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+4
-4
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+6
-3
python/sglang/srt/speculative/eagle_utils.py
python/sglang/srt/speculative/eagle_utils.py
+2
-2
No files found.
python/sglang/srt/disaggregation/decode.py
View file @
05c9bc89
...
@@ -21,13 +21,11 @@ Life cycle of a request in the decode server
...
@@ -21,13 +21,11 @@ Life cycle of a request in the decode server
from
__future__
import
annotations
from
__future__
import
annotations
import
logging
import
logging
import
os
from
collections
import
deque
from
collections
import
deque
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
import
torch
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
...
@@ -47,12 +45,9 @@ from sglang.srt.disaggregation.utils import (
...
@@ -47,12 +45,9 @@ from sglang.srt.disaggregation.utils import (
prepare_abort
,
prepare_abort
,
)
)
from
sglang.srt.managers.schedule_batch
import
FINISH_ABORT
,
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
FINISH_ABORT
,
ScheduleBatch
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.memory_pool
import
(
from
sglang.srt.mem_cache.memory_pool
import
KVCache
,
ReqToTokenPool
KVCache
,
ReqToTokenPool
,
TokenToKVPoolAllocator
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.torch_memory_saver_adapter
import
TorchMemorySaverAdapter
from
sglang.srt.torch_memory_saver_adapter
import
TorchMemorySaverAdapter
from
sglang.srt.utils
import
require_mlp_sync
from
sglang.srt.utils
import
require_mlp_sync
...
@@ -141,7 +136,7 @@ class DecodePreallocQueue:
...
@@ -141,7 +136,7 @@ class DecodePreallocQueue:
def
__init__
(
def
__init__
(
self
,
self
,
req_to_token_pool
:
ReqToTokenPool
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
token_to_kv_pool_allocator
:
Base
TokenToKVPoolAllocator
,
draft_token_to_kv_pool
:
Optional
[
KVCache
],
draft_token_to_kv_pool
:
Optional
[
KVCache
],
req_to_metadata_buffer_idx_allocator
:
ReqToMetadataIdxAllocator
,
req_to_metadata_buffer_idx_allocator
:
ReqToMetadataIdxAllocator
,
metadata_buffers
:
MetadataBuffers
,
metadata_buffers
:
MetadataBuffers
,
...
...
python/sglang/srt/disaggregation/prefill.py
View file @
05c9bc89
...
@@ -25,7 +25,6 @@ from collections import deque
...
@@ -25,7 +25,6 @@ from collections import deque
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
List
,
Optional
import
numpy
as
np
import
torch
import
torch
from
sglang.srt.disaggregation.base
import
BaseKVManager
,
KVPoll
from
sglang.srt.disaggregation.base
import
BaseKVManager
,
KVPoll
...
...
python/sglang/srt/managers/cache_controller.py
View file @
05c9bc89
...
@@ -18,12 +18,13 @@ import logging
...
@@ -18,12 +18,13 @@ import logging
import
math
import
math
import
threading
import
threading
from
queue
import
Empty
,
Full
,
PriorityQueue
,
Queue
from
queue
import
Empty
,
Full
,
PriorityQueue
,
Queue
from
typing
import
List
,
Optional
from
typing
import
TYPE_CHECKING
,
List
,
Optional
import
torch
import
torch
from
sglang.srt.mem_cache.memory_pool
import
TokenToKVPoolAllocator
if
TYPE_CHECKING
:
from
sglang.srt.mem_cache.memory_pool_host
import
HostKVCache
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool_host
import
HostKVCache
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -163,7 +164,7 @@ class HiCacheController:
...
@@ -163,7 +164,7 @@ class HiCacheController:
def
__init__
(
def
__init__
(
self
,
self
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
token_to_kv_pool_allocator
:
Base
TokenToKVPoolAllocator
,
mem_pool_host
:
HostKVCache
,
mem_pool_host
:
HostKVCache
,
page_size
:
int
,
page_size
:
int
,
load_cache_event
:
threading
.
Event
=
None
,
load_cache_event
:
threading
.
Event
=
None
,
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
05c9bc89
...
@@ -54,9 +54,10 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
...
@@ -54,9 +54,10 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
)
)
from
sglang.srt.distributed.parallel_state
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed.parallel_state
import
get_tensor_model_parallel_rank
from
sglang.srt.layers.multimodal
import
gpu_tensor_hash
from
sglang.srt.layers.multimodal
import
gpu_tensor_hash
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
,
TokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
from
sglang.srt.metrics.collector
import
TimeStats
from
sglang.srt.metrics.collector
import
TimeStats
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
,
ForwardMode
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
...
@@ -810,7 +811,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -810,7 +811,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Request, memory pool, and cache
# Request, memory pool, and cache
reqs
:
List
[
Req
]
reqs
:
List
[
Req
]
req_to_token_pool
:
ReqToTokenPool
=
None
req_to_token_pool
:
ReqToTokenPool
=
None
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
=
None
token_to_kv_pool_allocator
:
Base
TokenToKVPoolAllocator
=
None
tree_cache
:
BasePrefixCache
=
None
tree_cache
:
BasePrefixCache
=
None
# Batch configs
# Batch configs
...
@@ -907,7 +908,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -907,7 +908,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
cls
,
cls
,
reqs
:
List
[
Req
],
reqs
:
List
[
Req
],
req_to_token_pool
:
ReqToTokenPool
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
token_to_kv_pool_allocator
:
Base
TokenToKVPoolAllocator
,
tree_cache
:
BasePrefixCache
,
tree_cache
:
BasePrefixCache
,
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
enable_overlap
:
bool
,
enable_overlap
:
bool
,
...
...
python/sglang/srt/managers/schedule_policy.py
View file @
05c9bc89
from
__future__
import
annotations
# Copyright 2023-2024 SGLang Team
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -18,15 +20,17 @@ import random
...
@@ -18,15 +20,17 @@ import random
from
collections
import
defaultdict
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Union
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Set
,
Union
import
torch
import
torch
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.memory_pool
import
TokenToKVPoolAllocator
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
,
TreeNode
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
,
TreeNode
if
TYPE_CHECKING
:
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
# This can prevent the server from being too conservative.
# This can prevent the server from being too conservative.
# Note that this only clips the estimation in the scheduler but does not change the stop
# Note that this only clips the estimation in the scheduler but does not change the stop
...
@@ -265,7 +269,7 @@ class PrefillAdder:
...
@@ -265,7 +269,7 @@ class PrefillAdder:
self
,
self
,
page_size
:
int
,
page_size
:
int
,
tree_cache
:
BasePrefixCache
,
tree_cache
:
BasePrefixCache
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
token_to_kv_pool_allocator
:
Base
TokenToKVPoolAllocator
,
running_batch
:
ScheduleBatch
,
running_batch
:
ScheduleBatch
,
new_token_ratio
:
float
,
new_token_ratio
:
float
,
rem_input_tokens
:
int
,
rem_input_tokens
:
int
,
...
...
python/sglang/srt/managers/scheduler.py
View file @
05c9bc89
...
@@ -23,7 +23,6 @@ import time
...
@@ -23,7 +23,6 @@ import time
from
collections
import
defaultdict
,
deque
from
collections
import
defaultdict
,
deque
from
concurrent
import
futures
from
concurrent
import
futures
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
http
import
HTTPStatus
from
pathlib
import
Path
from
pathlib
import
Path
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
...
...
python/sglang/srt/managers/tp_worker.py
View file @
05c9bc89
...
@@ -35,7 +35,8 @@ from sglang.srt.managers.io_struct import (
...
@@ -35,7 +35,8 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromTensorReqInput
,
UpdateWeightsFromTensorReqInput
,
)
)
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
,
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
,
global_server_args_dict
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
,
TokenToKVPoolAllocator
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
...
@@ -57,7 +58,7 @@ class TpModelWorker:
...
@@ -57,7 +58,7 @@ class TpModelWorker:
nccl_port
:
int
,
nccl_port
:
int
,
is_draft_worker
:
bool
=
False
,
is_draft_worker
:
bool
=
False
,
req_to_token_pool
:
Optional
[
ReqToTokenPool
]
=
None
,
req_to_token_pool
:
Optional
[
ReqToTokenPool
]
=
None
,
token_to_kv_pool_allocator
:
Optional
[
TokenToKVPoolAllocator
]
=
None
,
token_to_kv_pool_allocator
:
Optional
[
Base
TokenToKVPoolAllocator
]
=
None
,
):
):
# Parse args
# Parse args
self
.
tp_size
=
server_args
.
tp_size
self
.
tp_size
=
server_args
.
tp_size
...
...
python/sglang/srt/mem_cache/
paged_
allocator.py
→
python/sglang/srt/mem_cache/allocator.py
View file @
05c9bc89
from
__future__
import
annotations
"""
"""
Copyright 2025 SGLang Team
Copyright 2025 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
...
@@ -17,13 +19,132 @@ limitations under the License.
...
@@ -17,13 +19,132 @@ limitations under the License.
Page-aligned memory pool.
Page-aligned memory pool.
"""
"""
import
abc
from
typing
import
TYPE_CHECKING
import
torch
import
torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.mem_cache.memory_pool
import
KVCache
from
sglang.srt.utils
import
get_bool_env_var
,
next_power_of_2
from
sglang.srt.utils
import
get_bool_env_var
,
next_power_of_2
if
TYPE_CHECKING
:
from
sglang.srt.mem_cache.memory_pool
import
KVCache
class
BaseTokenToKVPoolAllocator
(
abc
.
ABC
):
@
abc
.
abstractmethod
def
__init__
(
self
,
size
:
int
,
page_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
,
kvcache
:
KVCache
,
):
self
.
size
=
size
self
.
page_size
=
page_size
self
.
dtype
=
dtype
self
.
device
=
device
self
.
_kvcache
=
kvcache
self
.
free_pages
=
None
self
.
is_not_in_free_group
=
True
self
.
free_group
=
[]
def
debug_print
(
self
)
->
str
:
return
""
def
available_size
(
self
):
return
len
(
self
.
free_pages
)
*
self
.
page_size
def
get_kvcache
(
self
):
return
self
.
_kvcache
def
restore_state
(
self
,
free_pages
):
self
.
free_pages
=
free_pages
def
backup_state
(
self
):
return
self
.
free_pages
def
free_group_begin
(
self
):
self
.
is_not_in_free_group
=
False
self
.
free_group
=
[]
def
free_group_end
(
self
):
self
.
is_not_in_free_group
=
True
if
self
.
free_group
:
self
.
free
(
torch
.
cat
(
self
.
free_group
))
def
get_cpu_copy
(
self
,
*
args
,
**
kwargs
):
# FIXME: reuse the get_cpu_copy after paged allocator is implemented
raise
NotImplementedError
()
def
load_cpu_copy
(
self
,
*
args
,
**
kwargs
):
# FIXME: reuse the load_cpu_copy after paged allocator is implemented
raise
NotImplementedError
()
def
alloc_extend
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
"alloc_extend is only for paged allocator"
)
def
alloc_decode
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
"alloc_decode is only for paged allocator"
)
@
abc
.
abstractmethod
def
clear
(
self
):
raise
NotImplementedError
()
@
abc
.
abstractmethod
def
alloc
(
self
,
need_size
:
int
):
raise
NotImplementedError
()
@
abc
.
abstractmethod
def
free
(
self
,
free_index
:
torch
.
Tensor
):
raise
NotImplementedError
()
class
TokenToKVPoolAllocator
(
BaseTokenToKVPoolAllocator
):
"""An allocator managing the indices to kv cache data."""
def
__init__
(
self
,
size
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
,
kvcache
:
KVCache
):
super
().
__init__
(
size
,
1
,
dtype
,
device
,
kvcache
)
self
.
clear
()
def
clear
(
self
):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self
.
free_pages
=
torch
.
arange
(
1
,
self
.
size
+
1
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
is_not_in_free_group
=
True
self
.
free_group
=
[]
def
available_size
(
self
):
# To avoid minor "len(free_pages) * 1" overhead
return
len
(
self
.
free_pages
)
def
alloc
(
self
,
need_size
:
int
):
if
need_size
>
len
(
self
.
free_pages
):
return
None
select_index
=
self
.
free_pages
[:
need_size
]
self
.
free_pages
=
self
.
free_pages
[
need_size
:]
return
select_index
def
free
(
self
,
free_index
:
torch
.
Tensor
):
if
free_index
.
numel
()
==
0
:
return
if
self
.
is_not_in_free_group
:
self
.
free_pages
=
torch
.
cat
((
self
.
free_pages
,
free_index
))
else
:
self
.
free_group
.
append
(
free_index
)
def
get_cpu_copy
(
self
,
indices
):
return
self
.
_kvcache
.
get_cpu_copy
(
indices
)
def
load_cpu_copy
(
self
,
kv_cache_cpu
,
indices
):
return
self
.
_kvcache
.
load_cpu_copy
(
kv_cache_cpu
,
indices
)
@
triton
.
jit
@
triton
.
jit
def
alloc_extend_kernel
(
def
alloc_extend_kernel
(
...
@@ -154,7 +275,7 @@ def alloc_decode_kernel(
...
@@ -154,7 +275,7 @@ def alloc_decode_kernel(
tl
.
store
(
out_indices
+
pid
,
page
*
page_size
)
tl
.
store
(
out_indices
+
pid
,
page
*
page_size
)
class
PagedTokenToKVPoolAllocator
:
class
PagedTokenToKVPoolAllocator
(
BaseTokenToKVPoolAllocator
)
:
"""
"""
An allocator managing the indices to kv cache data.
An allocator managing the indices to kv cache data.
...
@@ -172,26 +293,11 @@ class PagedTokenToKVPoolAllocator:
...
@@ -172,26 +293,11 @@ class PagedTokenToKVPoolAllocator:
device
:
str
,
device
:
str
,
kvcache
:
KVCache
,
kvcache
:
KVCache
,
):
):
self
.
size
=
size
super
().
__init__
(
size
,
page_size
,
dtype
,
device
,
kvcache
)
self
.
dtype
=
dtype
self
.
device
=
device
self
.
page_size
=
page_size
self
.
num_pages
=
size
//
page_size
self
.
num_pages
=
size
//
page_size
self
.
free_pages
=
None
self
.
is_not_in_free_group
=
True
self
.
free_group
=
[]
self
.
clear
()
self
.
debug_mode
=
get_bool_env_var
(
"SGLANG_DEBUG_MEMORY_POOL"
)
self
.
debug_mode
=
get_bool_env_var
(
"SGLANG_DEBUG_MEMORY_POOL"
)
self
.
_kvcache
=
kvcache
self
.
ret_values
=
torch
.
empty
((),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
ret_values
=
torch
.
empty
((),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
clear
()
def
available_size
(
self
):
return
len
(
self
.
free_pages
)
*
self
.
page_size
def
get_kvcache
(
self
):
return
self
.
_kvcache
def
alloc
(
self
,
need_size
:
int
):
def
alloc
(
self
,
need_size
:
int
):
# page-aligned allocation, returning contiguous indices of pages
# page-aligned allocation, returning contiguous indices of pages
...
@@ -298,21 +404,6 @@ class PagedTokenToKVPoolAllocator:
...
@@ -298,21 +404,6 @@ class PagedTokenToKVPoolAllocator:
if
self
.
debug_mode
:
if
self
.
debug_mode
:
assert
len
(
torch
.
unique
(
self
.
free_pages
))
==
len
(
self
.
free_pages
)
assert
len
(
torch
.
unique
(
self
.
free_pages
))
==
len
(
self
.
free_pages
)
def
free_group_begin
(
self
):
self
.
is_not_in_free_group
=
False
self
.
free_group
=
[]
def
free_group_end
(
self
):
self
.
is_not_in_free_group
=
True
if
self
.
free_group
:
self
.
free
(
torch
.
cat
(
self
.
free_group
))
def
backup_state
(
self
):
return
self
.
free_pages
def
restore_state
(
self
,
free_pages
):
self
.
free_pages
=
free_pages
def
clear
(
self
):
def
clear
(
self
):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self
.
free_pages
=
torch
.
arange
(
self
.
free_pages
=
torch
.
arange
(
...
...
python/sglang/srt/mem_cache/chunk_cache.py
View file @
05c9bc89
...
@@ -2,12 +2,13 @@ from __future__ import annotations
...
@@ -2,12 +2,13 @@ from __future__ import annotations
"""Cache for chunked prefill, used when RadixCache is disabled."""
"""Cache for chunked prefill, used when RadixCache is disabled."""
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
List
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
import
torch
import
torch
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
,
MatchResult
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
,
MatchResult
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
,
TokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
Req
from
sglang.srt.managers.schedule_batch
import
Req
...
@@ -17,7 +18,7 @@ class ChunkCache(BasePrefixCache):
...
@@ -17,7 +18,7 @@ class ChunkCache(BasePrefixCache):
def
__init__
(
def
__init__
(
self
,
self
,
req_to_token_pool
:
ReqToTokenPool
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
token_to_kv_pool_allocator
:
Base
TokenToKVPoolAllocator
,
page_size
:
int
,
page_size
:
int
,
):
):
self
.
req_to_token_pool
=
req_to_token_pool
self
.
req_to_token_pool
=
req_to_token_pool
...
...
python/sglang/srt/mem_cache/hiradix_cache.py
View file @
05c9bc89
...
@@ -7,12 +7,12 @@ from typing import List, Optional
...
@@ -7,12 +7,12 @@ from typing import List, Optional
import
torch
import
torch
from
sglang.srt.managers.cache_controller
import
HiCacheController
from
sglang.srt.managers.cache_controller
import
HiCacheController
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.base_prefix_cache
import
MatchResult
from
sglang.srt.mem_cache.base_prefix_cache
import
MatchResult
from
sglang.srt.mem_cache.memory_pool
import
(
from
sglang.srt.mem_cache.memory_pool
import
(
MHATokenToKVPool
,
MHATokenToKVPool
,
MLATokenToKVPool
,
MLATokenToKVPool
,
ReqToTokenPool
,
ReqToTokenPool
,
TokenToKVPoolAllocator
,
)
)
from
sglang.srt.mem_cache.memory_pool_host
import
(
from
sglang.srt.mem_cache.memory_pool_host
import
(
MHATokenToKVPoolHost
,
MHATokenToKVPoolHost
,
...
@@ -28,7 +28,7 @@ class HiRadixCache(RadixCache):
...
@@ -28,7 +28,7 @@ class HiRadixCache(RadixCache):
def
__init__
(
def
__init__
(
self
,
self
,
req_to_token_pool
:
ReqToTokenPool
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
token_to_kv_pool_allocator
:
Base
TokenToKVPoolAllocator
,
tp_cache_group
:
torch
.
distributed
.
ProcessGroup
,
tp_cache_group
:
torch
.
distributed
.
ProcessGroup
,
page_size
:
int
,
page_size
:
int
,
hicache_ratio
:
float
,
hicache_ratio
:
float
,
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
05c9bc89
...
@@ -26,7 +26,6 @@ KVCache actually holds the physical kv cache.
...
@@ -26,7 +26,6 @@ KVCache actually holds the physical kv cache.
import
abc
import
abc
import
logging
import
logging
import
os
from
contextlib
import
nullcontext
from
contextlib
import
nullcontext
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
...
@@ -167,84 +166,6 @@ class KVCache(abc.ABC):
...
@@ -167,84 +166,6 @@ class KVCache(abc.ABC):
raise
NotImplementedError
()
raise
NotImplementedError
()
class
TokenToKVPoolAllocator
:
"""An allocator managing the indices to kv cache data."""
def
__init__
(
self
,
size
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
,
kvcache
:
KVCache
,
):
self
.
size
=
size
self
.
dtype
=
dtype
self
.
device
=
device
self
.
page_size
=
1
self
.
free_slots
=
None
self
.
is_not_in_free_group
=
True
self
.
free_group
=
[]
self
.
clear
()
self
.
_kvcache
=
kvcache
def
available_size
(
self
):
return
len
(
self
.
free_slots
)
def
debug_print
(
self
)
->
str
:
return
""
def
get_kvcache
(
self
):
return
self
.
_kvcache
def
alloc
(
self
,
need_size
:
int
):
if
need_size
>
len
(
self
.
free_slots
):
return
None
select_index
=
self
.
free_slots
[:
need_size
]
self
.
free_slots
=
self
.
free_slots
[
need_size
:]
return
select_index
def
free
(
self
,
free_index
:
torch
.
Tensor
):
if
free_index
.
numel
()
==
0
:
return
if
self
.
is_not_in_free_group
:
self
.
free_slots
=
torch
.
cat
((
self
.
free_slots
,
free_index
))
else
:
self
.
free_group
.
append
(
free_index
)
def
free_group_begin
(
self
):
self
.
is_not_in_free_group
=
False
self
.
free_group
=
[]
def
free_group_end
(
self
):
self
.
is_not_in_free_group
=
True
if
self
.
free_group
:
self
.
free
(
torch
.
cat
(
self
.
free_group
))
def
backup_state
(
self
):
return
self
.
free_slots
def
restore_state
(
self
,
free_slots
):
self
.
free_slots
=
free_slots
def
clear
(
self
):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self
.
free_slots
=
torch
.
arange
(
1
,
self
.
size
+
1
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
is_not_in_free_group
=
True
self
.
free_group
=
[]
def
get_cpu_copy
(
self
,
indices
):
return
self
.
_kvcache
.
get_cpu_copy
(
indices
)
def
load_cpu_copy
(
self
,
kv_cache_cpu
,
indices
):
return
self
.
_kvcache
.
load_cpu_copy
(
kv_cache_cpu
,
indices
)
class
MHATokenToKVPool
(
KVCache
):
class
MHATokenToKVPool
(
KVCache
):
def
__init__
(
def
__init__
(
...
...
python/sglang/srt/mem_cache/radix_cache.py
View file @
05c9bc89
...
@@ -23,7 +23,7 @@ import heapq
...
@@ -23,7 +23,7 @@ import heapq
import
time
import
time
from
collections
import
defaultdict
from
collections
import
defaultdict
from
functools
import
partial
from
functools
import
partial
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
from
typing
import
TYPE_CHECKING
,
List
,
Optional
import
torch
import
torch
...
@@ -31,10 +31,10 @@ from sglang.srt.disaggregation.kv_events import (
...
@@ -31,10 +31,10 @@ from sglang.srt.disaggregation.kv_events import (
AllBlocksCleared
,
AllBlocksCleared
,
BlockRemoved
,
BlockRemoved
,
BlockStored
,
BlockStored
,
KVCacheEvent
,
)
)
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
,
MatchResult
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
,
MatchResult
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
,
TokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
Req
from
sglang.srt.managers.schedule_batch
import
Req
...
@@ -98,7 +98,7 @@ class RadixCache(BasePrefixCache):
...
@@ -98,7 +98,7 @@ class RadixCache(BasePrefixCache):
def
__init__
(
def
__init__
(
self
,
self
,
req_to_token_pool
:
ReqToTokenPool
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
token_to_kv_pool_allocator
:
Base
TokenToKVPoolAllocator
,
page_size
:
int
,
page_size
:
int
,
disable
:
bool
=
False
,
disable
:
bool
=
False
,
enable_kv_cache_events
:
bool
=
False
,
enable_kv_cache_events
:
bool
=
False
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
05c9bc89
...
@@ -71,14 +71,17 @@ from sglang.srt.managers.schedule_batch import (
...
@@ -71,14 +71,17 @@ from sglang.srt.managers.schedule_batch import (
GLOBAL_SERVER_ARGS_KEYS
,
GLOBAL_SERVER_ARGS_KEYS
,
global_server_args_dict
,
global_server_args_dict
,
)
)
from
sglang.srt.mem_cache.allocator
import
(
BaseTokenToKVPoolAllocator
,
PagedTokenToKVPoolAllocator
,
TokenToKVPoolAllocator
,
)
from
sglang.srt.mem_cache.memory_pool
import
(
from
sglang.srt.mem_cache.memory_pool
import
(
DoubleSparseTokenToKVPool
,
DoubleSparseTokenToKVPool
,
MHATokenToKVPool
,
MHATokenToKVPool
,
MLATokenToKVPool
,
MLATokenToKVPool
,
ReqToTokenPool
,
ReqToTokenPool
,
TokenToKVPoolAllocator
,
)
)
from
sglang.srt.mem_cache.paged_allocator
import
PagedTokenToKVPoolAllocator
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
from
sglang.srt.model_executor.expert_location_updater
import
ExpertLocationUpdater
from
sglang.srt.model_executor.expert_location_updater
import
ExpertLocationUpdater
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
...
@@ -152,7 +155,7 @@ class ModelRunner:
...
@@ -152,7 +155,7 @@ class ModelRunner:
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
is_draft_worker
:
bool
=
False
,
is_draft_worker
:
bool
=
False
,
req_to_token_pool
:
Optional
[
ReqToTokenPool
]
=
None
,
req_to_token_pool
:
Optional
[
ReqToTokenPool
]
=
None
,
token_to_kv_pool_allocator
:
Optional
[
TokenToKVPoolAllocator
]
=
None
,
token_to_kv_pool_allocator
:
Optional
[
Base
TokenToKVPoolAllocator
]
=
None
,
):
):
# Parse args
# Parse args
self
.
model_config
=
model_config
self
.
model_config
=
model_config
...
...
python/sglang/srt/speculative/eagle_utils.py
View file @
05c9bc89
...
@@ -21,7 +21,7 @@ from sglang.srt.managers.schedule_batch import (
...
@@ -21,7 +21,7 @@ from sglang.srt.managers.schedule_batch import (
get_last_loc
,
get_last_loc
,
global_server_args_dict
,
global_server_args_dict
,
)
)
from
sglang.srt.mem_cache.
memory_pool
import
TokenToKVPoolAllocator
from
sglang.srt.mem_cache.
allocator
import
Base
TokenToKVPoolAllocator
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
,
ForwardMode
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
next_power_of_2
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
next_power_of_2
...
@@ -315,7 +315,7 @@ class EagleVerifyInput:
...
@@ -315,7 +315,7 @@ class EagleVerifyInput:
self
,
self
,
batch
:
ScheduleBatch
,
batch
:
ScheduleBatch
,
logits_output
:
torch
.
Tensor
,
logits_output
:
torch
.
Tensor
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
token_to_kv_pool_allocator
:
Base
TokenToKVPoolAllocator
,
page_size
:
int
,
page_size
:
int
,
vocab_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
# For grammar
vocab_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
# For grammar
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment