Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
05c9bc89
Unverified
Commit
05c9bc89
authored
Jun 22, 2025
by
Liangsheng Yin
Committed by
GitHub
Jun 22, 2025
Browse files
[minor] simplify the `TokenToKVPoolAllocator` (#7414)
parent
b7a2df0a
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
165 additions
and
149 deletions
+165
-149
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+3
-8
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+0
-1
python/sglang/srt/managers/cache_controller.py
python/sglang/srt/managers/cache_controller.py
+5
-4
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+4
-3
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+7
-3
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+0
-1
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+3
-2
python/sglang/srt/mem_cache/allocator.py
python/sglang/srt/mem_cache/allocator.py
+125
-34
python/sglang/srt/mem_cache/chunk_cache.py
python/sglang/srt/mem_cache/chunk_cache.py
+4
-3
python/sglang/srt/mem_cache/hiradix_cache.py
python/sglang/srt/mem_cache/hiradix_cache.py
+2
-2
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+0
-79
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+4
-4
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+6
-3
python/sglang/srt/speculative/eagle_utils.py
python/sglang/srt/speculative/eagle_utils.py
+2
-2
No files found.
python/sglang/srt/disaggregation/decode.py
View file @
05c9bc89
...
...
@@ -21,13 +21,11 @@ Life cycle of a request in the decode server
from
__future__
import
annotations
import
logging
import
os
from
collections
import
deque
from
dataclasses
import
dataclass
from
http
import
HTTPStatus
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
from
torch.distributed
import
ProcessGroup
...
...
@@ -47,12 +45,9 @@ from sglang.srt.disaggregation.utils import (
prepare_abort
,
)
from
sglang.srt.managers.schedule_batch
import
FINISH_ABORT
,
ScheduleBatch
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.memory_pool
import
(
KVCache
,
ReqToTokenPool
,
TokenToKVPoolAllocator
,
)
from
sglang.srt.mem_cache.memory_pool
import
KVCache
,
ReqToTokenPool
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.torch_memory_saver_adapter
import
TorchMemorySaverAdapter
from
sglang.srt.utils
import
require_mlp_sync
...
...
@@ -141,7 +136,7 @@ class DecodePreallocQueue:
def
__init__
(
self
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
token_to_kv_pool_allocator
:
Base
TokenToKVPoolAllocator
,
draft_token_to_kv_pool
:
Optional
[
KVCache
],
req_to_metadata_buffer_idx_allocator
:
ReqToMetadataIdxAllocator
,
metadata_buffers
:
MetadataBuffers
,
...
...
python/sglang/srt/disaggregation/prefill.py
View file @
05c9bc89
...
...
@@ -25,7 +25,6 @@ from collections import deque
from
http
import
HTTPStatus
from
typing
import
TYPE_CHECKING
,
List
,
Optional
import
numpy
as
np
import
torch
from
sglang.srt.disaggregation.base
import
BaseKVManager
,
KVPoll
...
...
python/sglang/srt/managers/cache_controller.py
View file @
05c9bc89
...
...
@@ -18,12 +18,13 @@ import logging
import
math
import
threading
from
queue
import
Empty
,
Full
,
PriorityQueue
,
Queue
from
typing
import
List
,
Optional
from
typing
import
TYPE_CHECKING
,
List
,
Optional
import
torch
from
sglang.srt.mem_cache.memory_pool
import
TokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool_host
import
HostKVCache
if
TYPE_CHECKING
:
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool_host
import
HostKVCache
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -163,7 +164,7 @@ class HiCacheController:
def
__init__
(
self
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
token_to_kv_pool_allocator
:
Base
TokenToKVPoolAllocator
,
mem_pool_host
:
HostKVCache
,
page_size
:
int
,
load_cache_event
:
threading
.
Event
=
None
,
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
05c9bc89
...
...
@@ -54,9 +54,10 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
)
from
sglang.srt.distributed.parallel_state
import
get_tensor_model_parallel_rank
from
sglang.srt.layers.multimodal
import
gpu_tensor_hash
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
,
TokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
from
sglang.srt.metrics.collector
import
TimeStats
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
,
ForwardMode
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
...
...
@@ -810,7 +811,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Request, memory pool, and cache
reqs
:
List
[
Req
]
req_to_token_pool
:
ReqToTokenPool
=
None
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
=
None
token_to_kv_pool_allocator
:
Base
TokenToKVPoolAllocator
=
None
tree_cache
:
BasePrefixCache
=
None
# Batch configs
...
...
@@ -907,7 +908,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
cls
,
reqs
:
List
[
Req
],
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
token_to_kv_pool_allocator
:
Base
TokenToKVPoolAllocator
,
tree_cache
:
BasePrefixCache
,
model_config
:
ModelConfig
,
enable_overlap
:
bool
,
...
...
python/sglang/srt/managers/schedule_policy.py
View file @
05c9bc89
from
__future__
import
annotations
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -18,15 +20,17 @@ import random
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
enum
import
Enum
,
auto
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Union
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Set
,
Union
import
torch
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.memory_pool
import
TokenToKVPoolAllocator
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
,
TreeNode
if
TYPE_CHECKING
:
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
# This can prevent the server from being too conservative.
# Note that this only clips the estimation in the scheduler but does not change the stop
...
...
@@ -265,7 +269,7 @@ class PrefillAdder:
self
,
page_size
:
int
,
tree_cache
:
BasePrefixCache
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
token_to_kv_pool_allocator
:
Base
TokenToKVPoolAllocator
,
running_batch
:
ScheduleBatch
,
new_token_ratio
:
float
,
rem_input_tokens
:
int
,
...
...
python/sglang/srt/managers/scheduler.py
View file @
05c9bc89
...
...
@@ -23,7 +23,6 @@ import time
from
collections
import
defaultdict
,
deque
from
concurrent
import
futures
from
dataclasses
import
dataclass
from
http
import
HTTPStatus
from
pathlib
import
Path
from
types
import
SimpleNamespace
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
...
...
python/sglang/srt/managers/tp_worker.py
View file @
05c9bc89
...
...
@@ -35,7 +35,8 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromTensorReqInput
,
)
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
,
global_server_args_dict
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
,
TokenToKVPoolAllocator
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.server_args
import
ServerArgs
...
...
@@ -57,7 +58,7 @@ class TpModelWorker:
nccl_port
:
int
,
is_draft_worker
:
bool
=
False
,
req_to_token_pool
:
Optional
[
ReqToTokenPool
]
=
None
,
token_to_kv_pool_allocator
:
Optional
[
TokenToKVPoolAllocator
]
=
None
,
token_to_kv_pool_allocator
:
Optional
[
Base
TokenToKVPoolAllocator
]
=
None
,
):
# Parse args
self
.
tp_size
=
server_args
.
tp_size
...
...
python/sglang/srt/mem_cache/
paged_
allocator.py
→
python/sglang/srt/mem_cache/allocator.py
View file @
05c9bc89
from
__future__
import
annotations
"""
Copyright 2025 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -17,13 +19,132 @@ limitations under the License.
Page-aligned memory pool.
"""
import
abc
from
typing
import
TYPE_CHECKING
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.mem_cache.memory_pool
import
KVCache
from
sglang.srt.utils
import
get_bool_env_var
,
next_power_of_2
if
TYPE_CHECKING
:
from
sglang.srt.mem_cache.memory_pool
import
KVCache
class
BaseTokenToKVPoolAllocator
(
abc
.
ABC
):
@
abc
.
abstractmethod
def
__init__
(
self
,
size
:
int
,
page_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
,
kvcache
:
KVCache
,
):
self
.
size
=
size
self
.
page_size
=
page_size
self
.
dtype
=
dtype
self
.
device
=
device
self
.
_kvcache
=
kvcache
self
.
free_pages
=
None
self
.
is_not_in_free_group
=
True
self
.
free_group
=
[]
def
debug_print
(
self
)
->
str
:
return
""
def
available_size
(
self
):
return
len
(
self
.
free_pages
)
*
self
.
page_size
def
get_kvcache
(
self
):
return
self
.
_kvcache
def
restore_state
(
self
,
free_pages
):
self
.
free_pages
=
free_pages
def
backup_state
(
self
):
return
self
.
free_pages
def
free_group_begin
(
self
):
self
.
is_not_in_free_group
=
False
self
.
free_group
=
[]
def
free_group_end
(
self
):
self
.
is_not_in_free_group
=
True
if
self
.
free_group
:
self
.
free
(
torch
.
cat
(
self
.
free_group
))
def
get_cpu_copy
(
self
,
*
args
,
**
kwargs
):
# FIXME: reuse the get_cpu_copy after paged allocator is implemented
raise
NotImplementedError
()
def
load_cpu_copy
(
self
,
*
args
,
**
kwargs
):
# FIXME: reuse the load_cpu_copy after paged allocator is implemented
raise
NotImplementedError
()
def
alloc_extend
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
"alloc_extend is only for paged allocator"
)
def
alloc_decode
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
"alloc_decode is only for paged allocator"
)
@
abc
.
abstractmethod
def
clear
(
self
):
raise
NotImplementedError
()
@
abc
.
abstractmethod
def
alloc
(
self
,
need_size
:
int
):
raise
NotImplementedError
()
@
abc
.
abstractmethod
def
free
(
self
,
free_index
:
torch
.
Tensor
):
raise
NotImplementedError
()
class
TokenToKVPoolAllocator
(
BaseTokenToKVPoolAllocator
):
"""An allocator managing the indices to kv cache data."""
def
__init__
(
self
,
size
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
,
kvcache
:
KVCache
):
super
().
__init__
(
size
,
1
,
dtype
,
device
,
kvcache
)
self
.
clear
()
def
clear
(
self
):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self
.
free_pages
=
torch
.
arange
(
1
,
self
.
size
+
1
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
is_not_in_free_group
=
True
self
.
free_group
=
[]
def
available_size
(
self
):
# To avoid minor "len(free_pages) * 1" overhead
return
len
(
self
.
free_pages
)
def
alloc
(
self
,
need_size
:
int
):
if
need_size
>
len
(
self
.
free_pages
):
return
None
select_index
=
self
.
free_pages
[:
need_size
]
self
.
free_pages
=
self
.
free_pages
[
need_size
:]
return
select_index
def
free
(
self
,
free_index
:
torch
.
Tensor
):
if
free_index
.
numel
()
==
0
:
return
if
self
.
is_not_in_free_group
:
self
.
free_pages
=
torch
.
cat
((
self
.
free_pages
,
free_index
))
else
:
self
.
free_group
.
append
(
free_index
)
def
get_cpu_copy
(
self
,
indices
):
return
self
.
_kvcache
.
get_cpu_copy
(
indices
)
def
load_cpu_copy
(
self
,
kv_cache_cpu
,
indices
):
return
self
.
_kvcache
.
load_cpu_copy
(
kv_cache_cpu
,
indices
)
@
triton
.
jit
def
alloc_extend_kernel
(
...
...
@@ -154,7 +275,7 @@ def alloc_decode_kernel(
tl
.
store
(
out_indices
+
pid
,
page
*
page_size
)
class
PagedTokenToKVPoolAllocator
:
class
PagedTokenToKVPoolAllocator
(
BaseTokenToKVPoolAllocator
)
:
"""
An allocator managing the indices to kv cache data.
...
...
@@ -172,26 +293,11 @@ class PagedTokenToKVPoolAllocator:
device
:
str
,
kvcache
:
KVCache
,
):
self
.
size
=
size
self
.
dtype
=
dtype
self
.
device
=
device
self
.
page_size
=
page_size
super
().
__init__
(
size
,
page_size
,
dtype
,
device
,
kvcache
)
self
.
num_pages
=
size
//
page_size
self
.
free_pages
=
None
self
.
is_not_in_free_group
=
True
self
.
free_group
=
[]
self
.
clear
()
self
.
debug_mode
=
get_bool_env_var
(
"SGLANG_DEBUG_MEMORY_POOL"
)
self
.
_kvcache
=
kvcache
self
.
ret_values
=
torch
.
empty
((),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
def
available_size
(
self
):
return
len
(
self
.
free_pages
)
*
self
.
page_size
def
get_kvcache
(
self
):
return
self
.
_kvcache
self
.
clear
()
def
alloc
(
self
,
need_size
:
int
):
# page-aligned allocation, returning contiguous indices of pages
...
...
@@ -298,21 +404,6 @@ class PagedTokenToKVPoolAllocator:
if
self
.
debug_mode
:
assert
len
(
torch
.
unique
(
self
.
free_pages
))
==
len
(
self
.
free_pages
)
def
free_group_begin
(
self
):
self
.
is_not_in_free_group
=
False
self
.
free_group
=
[]
def
free_group_end
(
self
):
self
.
is_not_in_free_group
=
True
if
self
.
free_group
:
self
.
free
(
torch
.
cat
(
self
.
free_group
))
def
backup_state
(
self
):
return
self
.
free_pages
def
restore_state
(
self
,
free_pages
):
self
.
free_pages
=
free_pages
def
clear
(
self
):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self
.
free_pages
=
torch
.
arange
(
...
...
python/sglang/srt/mem_cache/chunk_cache.py
View file @
05c9bc89
...
...
@@ -2,12 +2,13 @@ from __future__ import annotations
"""Cache for chunked prefill, used when RadixCache is disabled."""
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
List
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
import
torch
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
,
MatchResult
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
,
TokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
Req
...
...
@@ -17,7 +18,7 @@ class ChunkCache(BasePrefixCache):
def
__init__
(
self
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
token_to_kv_pool_allocator
:
Base
TokenToKVPoolAllocator
,
page_size
:
int
,
):
self
.
req_to_token_pool
=
req_to_token_pool
...
...
python/sglang/srt/mem_cache/hiradix_cache.py
View file @
05c9bc89
...
...
@@ -7,12 +7,12 @@ from typing import List, Optional
import
torch
from
sglang.srt.managers.cache_controller
import
HiCacheController
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.base_prefix_cache
import
MatchResult
from
sglang.srt.mem_cache.memory_pool
import
(
MHATokenToKVPool
,
MLATokenToKVPool
,
ReqToTokenPool
,
TokenToKVPoolAllocator
,
)
from
sglang.srt.mem_cache.memory_pool_host
import
(
MHATokenToKVPoolHost
,
...
...
@@ -28,7 +28,7 @@ class HiRadixCache(RadixCache):
def
__init__
(
self
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
token_to_kv_pool_allocator
:
Base
TokenToKVPoolAllocator
,
tp_cache_group
:
torch
.
distributed
.
ProcessGroup
,
page_size
:
int
,
hicache_ratio
:
float
,
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
05c9bc89
...
...
@@ -26,7 +26,6 @@ KVCache actually holds the physical kv cache.
import
abc
import
logging
import
os
from
contextlib
import
nullcontext
from
typing
import
List
,
Optional
,
Tuple
,
Union
...
...
@@ -167,84 +166,6 @@ class KVCache(abc.ABC):
raise
NotImplementedError
()
class
TokenToKVPoolAllocator
:
"""An allocator managing the indices to kv cache data."""
def
__init__
(
self
,
size
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
,
kvcache
:
KVCache
,
):
self
.
size
=
size
self
.
dtype
=
dtype
self
.
device
=
device
self
.
page_size
=
1
self
.
free_slots
=
None
self
.
is_not_in_free_group
=
True
self
.
free_group
=
[]
self
.
clear
()
self
.
_kvcache
=
kvcache
def
available_size
(
self
):
return
len
(
self
.
free_slots
)
def
debug_print
(
self
)
->
str
:
return
""
def
get_kvcache
(
self
):
return
self
.
_kvcache
def
alloc
(
self
,
need_size
:
int
):
if
need_size
>
len
(
self
.
free_slots
):
return
None
select_index
=
self
.
free_slots
[:
need_size
]
self
.
free_slots
=
self
.
free_slots
[
need_size
:]
return
select_index
def
free
(
self
,
free_index
:
torch
.
Tensor
):
if
free_index
.
numel
()
==
0
:
return
if
self
.
is_not_in_free_group
:
self
.
free_slots
=
torch
.
cat
((
self
.
free_slots
,
free_index
))
else
:
self
.
free_group
.
append
(
free_index
)
def
free_group_begin
(
self
):
self
.
is_not_in_free_group
=
False
self
.
free_group
=
[]
def
free_group_end
(
self
):
self
.
is_not_in_free_group
=
True
if
self
.
free_group
:
self
.
free
(
torch
.
cat
(
self
.
free_group
))
def
backup_state
(
self
):
return
self
.
free_slots
def
restore_state
(
self
,
free_slots
):
self
.
free_slots
=
free_slots
def
clear
(
self
):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self
.
free_slots
=
torch
.
arange
(
1
,
self
.
size
+
1
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
is_not_in_free_group
=
True
self
.
free_group
=
[]
def
get_cpu_copy
(
self
,
indices
):
return
self
.
_kvcache
.
get_cpu_copy
(
indices
)
def
load_cpu_copy
(
self
,
kv_cache_cpu
,
indices
):
return
self
.
_kvcache
.
load_cpu_copy
(
kv_cache_cpu
,
indices
)
class
MHATokenToKVPool
(
KVCache
):
def
__init__
(
...
...
python/sglang/srt/mem_cache/radix_cache.py
View file @
05c9bc89
...
...
@@ -23,7 +23,7 @@ import heapq
import
time
from
collections
import
defaultdict
from
functools
import
partial
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
from
typing
import
TYPE_CHECKING
,
List
,
Optional
import
torch
...
...
@@ -31,10 +31,10 @@ from sglang.srt.disaggregation.kv_events import (
AllBlocksCleared
,
BlockRemoved
,
BlockStored
,
KVCacheEvent
,
)
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
,
MatchResult
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
,
TokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
Req
...
...
@@ -98,7 +98,7 @@ class RadixCache(BasePrefixCache):
def
__init__
(
self
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
token_to_kv_pool_allocator
:
Base
TokenToKVPoolAllocator
,
page_size
:
int
,
disable
:
bool
=
False
,
enable_kv_cache_events
:
bool
=
False
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
05c9bc89
...
...
@@ -71,14 +71,17 @@ from sglang.srt.managers.schedule_batch import (
GLOBAL_SERVER_ARGS_KEYS
,
global_server_args_dict
,
)
from
sglang.srt.mem_cache.allocator
import
(
BaseTokenToKVPoolAllocator
,
PagedTokenToKVPoolAllocator
,
TokenToKVPoolAllocator
,
)
from
sglang.srt.mem_cache.memory_pool
import
(
DoubleSparseTokenToKVPool
,
MHATokenToKVPool
,
MLATokenToKVPool
,
ReqToTokenPool
,
TokenToKVPoolAllocator
,
)
from
sglang.srt.mem_cache.paged_allocator
import
PagedTokenToKVPoolAllocator
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
from
sglang.srt.model_executor.expert_location_updater
import
ExpertLocationUpdater
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
...
...
@@ -152,7 +155,7 @@ class ModelRunner:
server_args
:
ServerArgs
,
is_draft_worker
:
bool
=
False
,
req_to_token_pool
:
Optional
[
ReqToTokenPool
]
=
None
,
token_to_kv_pool_allocator
:
Optional
[
TokenToKVPoolAllocator
]
=
None
,
token_to_kv_pool_allocator
:
Optional
[
Base
TokenToKVPoolAllocator
]
=
None
,
):
# Parse args
self
.
model_config
=
model_config
...
...
python/sglang/srt/speculative/eagle_utils.py
View file @
05c9bc89
...
...
@@ -21,7 +21,7 @@ from sglang.srt.managers.schedule_batch import (
get_last_loc
,
global_server_args_dict
,
)
from
sglang.srt.mem_cache.
memory_pool
import
TokenToKVPoolAllocator
from
sglang.srt.mem_cache.
allocator
import
Base
TokenToKVPoolAllocator
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
,
ForwardMode
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
next_power_of_2
...
...
@@ -315,7 +315,7 @@ class EagleVerifyInput:
self
,
batch
:
ScheduleBatch
,
logits_output
:
torch
.
Tensor
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
token_to_kv_pool_allocator
:
Base
TokenToKVPoolAllocator
,
page_size
:
int
,
vocab_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
# For grammar
)
->
torch
.
Tensor
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment