Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
852a49c5
Commit
852a49c5
authored
Sep 30, 2025
by
maxiao
Browse files
adapt to dsv32 on dcu
parent
8f7453e3
Changes
159
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
279 additions
and
397 deletions
+279
-397
python/sglang/srt/layers/quantization/w8a8_int8.py
python/sglang/srt/layers/quantization/w8a8_int8.py
+3
-15
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+17
-54
python/sglang/srt/layers/utils.py
python/sglang/srt/layers/utils.py
+0
-23
python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py
python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py
+1
-1
python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py
python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py
+1
-1
python/sglang/srt/managers/cache_controller.py
python/sglang/srt/managers/cache_controller.py
+126
-34
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+13
-27
python/sglang/srt/managers/mm_utils.py
python/sglang/srt/managers/mm_utils.py
+6
-43
python/sglang/srt/managers/multimodal_processor.py
python/sglang/srt/managers/multimodal_processor.py
+2
-1
python/sglang/srt/managers/overlap_utils.py
python/sglang/srt/managers/overlap_utils.py
+0
-53
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+13
-21
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+8
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+12
-32
python/sglang/srt/managers/scheduler_output_processor_mixin.py
...n/sglang/srt/managers/scheduler_output_processor_mixin.py
+1
-7
python/sglang/srt/managers/scheduler_profiler_mixin.py
python/sglang/srt/managers/scheduler_profiler_mixin.py
+4
-4
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+6
-7
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+1
-0
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+33
-14
python/sglang/srt/mem_cache/allocator_ascend.py
python/sglang/srt/mem_cache/allocator_ascend.py
+31
-20
python/sglang/srt/mem_cache/hicache_storage.py
python/sglang/srt/mem_cache/hicache_storage.py
+1
-40
No files found.
python/sglang/srt/layers/quantization/w8a8_int8.py
View file @
852a49c5
...
...
@@ -393,23 +393,13 @@ class W8A8Int8LinearMethod(LinearMethodBase):
x
.
dtype
,
True
,
# is_vnni
)
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
x_q_2d
=
x_q
.
view
(
-
1
,
x_q
.
shape
[
-
1
])
x_scale_2d
=
x_scale
.
view
(
-
1
,
x_scale
.
shape
[
-
1
])
output_shape
=
[
*
x_q
.
shape
[:
-
1
],
layer
.
weight
.
shape
[
1
]]
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
output
=
int8_scaled_mm
(
x_q_2d
,
layer
.
weight
,
x_scale_2d
,
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
,
return
int8_scaled_mm
(
x_q
,
layer
.
weight
,
x_scale
,
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
return
output
.
view
(
output_shape
)
class
W8A8Int8MoEMethod
(
FusedMoEMethodBase
):
"""MoE method for INT8.
...
...
@@ -648,7 +638,6 @@ class NPU_W8A8LinearMethodImpl:
layer
.
weight
.
data
=
layer
.
weight
.
data
.
transpose
(
0
,
1
).
contiguous
()
layer
.
weight_scale
.
data
=
torch
.
flatten
(
layer
.
weight_scale
.
data
)
layer
.
weight_offset
.
data
=
torch
.
flatten
(
layer
.
weight_offset
.
data
)
layer
.
weight
.
data
=
torch_npu
.
npu_format_cast
(
layer
.
weight
.
data
,
29
)
class
NPU_W8A8LinearMethodMTImpl
:
...
...
@@ -841,7 +830,6 @@ class NPU_W8A8DynamicLinearMethodImpl:
layer
.
weight_scale
.
data
=
layer
.
weight_scale
.
data
.
flatten
()
layer
.
weight_scale_fp32
=
layer
.
weight_scale
.
data
.
to
(
torch
.
float32
)
layer
.
weight_offset
.
data
=
layer
.
weight_offset
.
data
.
flatten
()
layer
.
weight
.
data
=
torch_npu
.
npu_format_cast
(
layer
.
weight
.
data
,
29
)
class
NPU_W8A8DynamicLinearMethod
(
LinearMethodBase
):
...
...
python/sglang/srt/layers/rotary_embedding.py
View file @
852a49c5
...
...
@@ -12,7 +12,6 @@ from sglang.srt.custom_op import CustomOp
from
sglang.srt.utils
import
(
cpu_has_amx_support
,
get_bool_env_var
,
get_compiler_backend
,
is_cpu
,
is_cuda
,
is_hip
,
...
...
@@ -27,19 +26,13 @@ _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu
=
is_cpu
()
if
_is_cuda
:
from
sgl_kernel
import
FusedSetKVBufferArg
,
apply_rope_with_cos_sin_cache_inplace
else
:
FusedSetKVBufferArg
=
None
from
sgl_kernel
import
apply_rope_with_cos_sin_cache_inplace
if
_use_aiter
:
from
aiter.rotary_embedding
import
get_rope
as
aiter_get_rope
if
is_npu
():
import
torch_npu
NPU_ROTARY_MUL_MAX_NUM_HEADS
=
1000
NPU_ROTARY_MUL_MAX_HEAD_SIZE
=
896
def
_rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
...
...
@@ -149,13 +142,8 @@ class RotaryEmbedding(CustomOp):
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
fused_set_kv_buffer_arg
:
Optional
[
FusedSetKVBufferArg
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""A PyTorch-native implementation of forward()."""
assert
(
fused_set_kv_buffer_arg
is
None
),
"fused_set_kv_buffer_arg is not supported for native implementation"
if
offsets
is
not
None
:
positions
=
positions
+
offsets
positions
=
positions
.
flatten
()
...
...
@@ -184,17 +172,12 @@ class RotaryEmbedding(CustomOp):
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
fused_set_kv_buffer_arg
:
Optional
[
FusedSetKVBufferArg
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""A PyTorch-npu implementation of forward()."""
assert
(
fused_set_kv_buffer_arg
is
None
),
"fused_set_kv_buffer_arg is not supported for npu implementation"
import
os
if
get_bool_env_var
(
"SGLANG_ENABLE_TORCH_COMPILE"
):
return
self
.
forward_native
(
positions
,
query
,
key
,
offsets
,
fused_set_kv_buffer_arg
)
return
self
.
forward_native
(
positions
,
query
,
key
,
offsets
)
else
:
rotary_mode
=
"half"
if
self
.
is_neox_style
:
...
...
@@ -219,12 +202,7 @@ class RotaryEmbedding(CustomOp):
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
fused_set_kv_buffer_arg
:
Optional
[
FusedSetKVBufferArg
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
(
fused_set_kv_buffer_arg
is
None
),
"fused_set_kv_buffer_arg is not supported for cpu implementation"
positions
=
torch
.
add
(
positions
,
offsets
)
if
offsets
is
not
None
else
positions
if
_is_cpu_amx_available
:
return
torch
.
ops
.
sgl_kernel
.
rotary_embedding_cpu
(
...
...
@@ -236,9 +214,7 @@ class RotaryEmbedding(CustomOp):
self
.
is_neox_style
,
)
else
:
return
self
.
forward_native
(
positions
,
query
,
key
,
offsets
,
fused_set_kv_buffer_arg
)
return
self
.
forward_native
(
positions
,
query
,
key
,
offsets
)
def
forward_cuda
(
self
,
...
...
@@ -246,7 +222,7 @@ class RotaryEmbedding(CustomOp):
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
fused_set_kv_buffer_arg
:
Optional
[
FusedSetKVBufferArg
]
=
None
,
fused_set_kv_buffer_arg
=
None
,
#
Optional[FusedSetKVBufferArg]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
_is_cuda
and
(
self
.
head_size
in
[
64
,
128
,
256
,
512
]):
apply_rope_with_cos_sin_cache_inplace
(
...
...
@@ -1059,7 +1035,7 @@ class MRotaryEmbedding(RotaryEmbedding):
f
"Corrected mrope_section:
{
self
.
mrope_section
}
(sum=
{
sum
(
self
.
mrope_section
)
}
)"
)
@
torch
.
compile
(
dynamic
=
True
,
backend
=
get_compiler_backend
()
)
@
torch
.
compile
(
dynamic
=
True
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
...
...
@@ -1207,7 +1183,7 @@ class MRotaryEmbedding(RotaryEmbedding):
time_tensor_long
=
time_tensor
.
long
()
t_index
=
time_tensor_long
.
flatten
()
elif
model_type
in
(
"qwen2_vl"
,
"qwen3_vl"
,
"qwen3_vl_moe"
)
:
elif
model_type
==
"qwen2_vl"
:
t_index
=
(
torch
.
arange
(
llm_grid_t
)
.
view
(
-
1
,
1
)
...
...
@@ -1918,30 +1894,17 @@ def apply_rotary_pos_emb_npu(
sin
:
torch
.
Tensor
,
unsqueeze_dim
=
1
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Ascend implementation equivalent to apply_rotary_pos_emb_native.
Args:
q: [num_tokens, num_heads, head_size]
k: [num_tokens, num_kv_heads, head_size]
cos: [num_tokens, head_size]
sin: [num_tokens, head_size]
"""
if
(
cos
.
dim
()
!=
2
or
q
.
dim
()
!=
3
or
q
.
shape
[
1
]
>=
NPU_ROTARY_MUL_MAX_NUM_HEADS
or
q
.
shape
[
2
]
>=
NPU_ROTARY_MUL_MAX_HEAD_SIZE
):
# Note: num_heads and head_size of q must be less than 1000 and 896, respectively
if
q
.
shape
[
1
]
!=
128
:
return
apply_rotary_pos_emb_native
(
q
,
k
,
cos
,
sin
,
unsqueeze_dim
)
cos
=
cos
.
unsqueeze
(
unsqueeze_dim
).
unsqueeze
(
0
)
sin
=
sin
.
unsqueeze
(
unsqueeze_dim
).
unsqueeze
(
0
)
q
=
q
.
unsqueeze
(
0
)
k
=
k
.
unsqueeze
(
0
)
q_embed
=
torch_npu
.
npu_rotary_mul
(
q
,
cos
,
sin
)
k_embed
=
torch_npu
.
npu_rotary_mul
(
k
,
cos
,
sin
)
q_embed
=
q_embed
.
squeeze
(
0
)
k_embed
=
k_embed
.
squeeze
(
0
)
cos
=
cos
.
unsqueeze
(
unsqueeze_dim
)
cos
=
torch
.
transpose
(
cos
,
1
,
2
)
sin
=
sin
.
unsqueeze
(
unsqueeze_dim
)
sin
=
torch
.
transpose
(
sin
,
1
,
2
)
q
=
torch
.
transpose
(
q
,
1
,
2
)
k
=
torch
.
transpose
(
k
,
1
,
2
)
q_embed
,
k_embed
=
torch_npu
.
npu_apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
)
q_embed
=
torch
.
transpose
(
q_embed
,
1
,
2
)
k_embed
=
torch
.
transpose
(
k_embed
,
1
,
2
)
return
q_embed
,
k_embed
...
...
python/sglang/srt/layers/utils.py
View file @
852a49c5
...
...
@@ -15,29 +15,6 @@ def get_layer_id(weight_name):
return
None
def
pad_or_narrow_weight
(
loaded_weight
:
torch
.
Tensor
,
input_dim
:
int
,
start_idx
:
int
,
shard_size
:
int
)
->
torch
.
Tensor
:
# Padding with zeros for special case such as qwen2_5_VL's mlp which is not 8-aligned
valid_size
=
max
(
loaded_weight
.
shape
[
input_dim
]
-
start_idx
,
0
)
if
valid_size
>
0
:
loaded_slice
=
loaded_weight
.
narrow
(
input_dim
,
start_idx
,
valid_size
)
pad_shape
=
list
(
loaded_weight
.
shape
)
pad_shape
[
input_dim
]
=
shard_size
-
valid_size
pad
=
torch
.
zeros
(
pad_shape
,
dtype
=
loaded_weight
.
dtype
,
device
=
loaded_weight
.
device
)
return
torch
.
cat
([
loaded_slice
,
pad
],
dim
=
input_dim
)
# All padding
pad_shape
=
list
(
loaded_weight
.
shape
)
pad_shape
[
input_dim
]
=
shard_size
return
torch
.
zeros
(
pad_shape
,
dtype
=
loaded_weight
.
dtype
,
device
=
loaded_weight
.
device
)
class
PPMissingLayer
(
torch
.
nn
.
Identity
):
# Adapted from
# https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1
...
...
python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py
View file @
852a49c5
...
...
@@ -5,7 +5,7 @@ import triton
import
triton.language
as
tl
from
sglang.srt.lora.utils
import
LoRABatchInfo
from
sglang.
srt.
utils
import
cached_triton_kernel
from
sglang.utils
import
cached_triton_kernel
@
cached_triton_kernel
(
lambda
_
,
kwargs
:
(
kwargs
[
"NUM_SLICES"
],
kwargs
[
"BLOCK_M"
]))
...
...
python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py
View file @
852a49c5
...
...
@@ -3,7 +3,7 @@ import triton
import
triton.language
as
tl
from
sglang.srt.lora.utils
import
LoRABatchInfo
from
sglang.
srt.
utils
import
cached_triton_kernel
from
sglang.utils
import
cached_triton_kernel
@
cached_triton_kernel
(
lambda
_
,
kwargs
:
(
kwargs
[
"NUM_SLICES"
],
kwargs
[
"BLOCK_M"
]))
...
...
python/sglang/srt/managers/cache_controller.py
View file @
852a49c5
...
...
@@ -275,17 +275,43 @@ class HiCacheController:
and
self
.
storage_config
.
tp_rank
!=
0
)
# Use
storage
backend
factory for dynamic backend creation
from
sglang.srt.mem_cache.storage
import
StorageBackendFactory
if
storage
_
backend
==
"file"
:
from
sglang.srt.mem_cache.
hicache_
storage
import
HiCacheFile
try
:
self
.
storage_backend
=
StorageBackendFactory
.
create_backend
(
storage_backend
,
self
.
storage_config
,
self
.
mem_pool_host
self
.
storage_backend
=
HiCacheFile
(
self
.
storage_config
)
elif
storage_backend
==
"nixl"
:
from
sglang.srt.mem_cache.storage.nixl.hicache_nixl
import
HiCacheNixl
self
.
storage_backend
=
HiCacheNixl
()
elif
storage_backend
==
"mooncake"
:
from
sglang.srt.mem_cache.storage.mooncake_store.mooncake_store
import
(
MooncakeStore
,
)
self
.
storage_backend
=
MooncakeStore
(
self
.
storage_config
)
self
.
storage_backend
.
register_buffer
(
self
.
mem_pool_host
.
kv_buffer
)
assert
self
.
mem_pool_host
.
layout
==
"page_first"
elif
storage_backend
==
"hf3fs"
:
from
sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs
import
(
HiCacheHF3FS
,
)
except
ValueError
as
e
:
raise
ValueError
(
f
"Failed to create storage backend:
{
e
}
"
)
from
e
self
.
storage_backend
.
register_mem_pool_host
(
self
.
mem_pool_host
)
if
self
.
mem_pool_host
.
layout
==
"page_first"
:
bytes_per_page
=
(
mem_pool_host
.
get_ksize_per_token
()
*
mem_pool_host
.
page_size
)
elif
self
.
mem_pool_host
.
layout
==
"layer_first"
:
bytes_per_page
=
(
mem_pool_host
.
get_size_per_token
()
*
mem_pool_host
.
page_size
)
dtype
=
mem_pool_host
.
dtype
self
.
storage_backend
=
HiCacheHF3FS
.
from_env_config
(
bytes_per_page
,
dtype
,
self
.
storage_config
)
else
:
raise
NotImplementedError
(
f
"Unsupported storage backend:
{
storage_backend
}
"
)
self
.
enable_storage
=
True
# todo: threshold policy for prefetching
...
...
@@ -309,10 +335,18 @@ class HiCacheController:
# Select the get and set functions
self
.
page_get_func
=
self
.
_generic_page_get
self
.
page_set_func
=
self
.
_generic_page_set
if
self
.
storage_backend_type
in
[
"hf3fs"
,
"mooncake"
]:
self
.
page_get_func
=
self
.
_page_get_zero_copy
self
.
page_set_func
=
self
.
_page_set_zero_copy
self
.
batch_exists_func
=
self
.
storage_backend
.
batch_exists
self
.
is_3fs_zerocopy
=
(
self
.
storage_backend_type
==
"hf3fs"
and
self
.
mem_pool_host
.
layout
==
"page_first"
)
if
self
.
storage_backend_type
==
"mooncake"
:
self
.
page_get_func
=
self
.
_mooncake_page_get
self
.
page_set_func
=
self
.
_mooncake_page_set
elif
self
.
is_3fs_zerocopy
:
self
.
page_get_func
=
self
.
_3fs_zero_copy_page_get
self
.
page_set_func
=
self
.
_3fs_zero_copy_page_set
self
.
batch_exists_func
=
self
.
_3fs_zero_copy_batch_exists
self
.
device
=
self
.
mem_pool_device
.
device
self
.
layer_num
=
self
.
mem_pool_device
.
layer_num
...
...
@@ -436,6 +470,7 @@ class HiCacheController:
host_indices
=
self
.
mem_pool_host
.
alloc
(
len
(
device_indices
))
if
host_indices
is
None
:
return
None
self
.
mem_pool_host
.
protect_write
(
host_indices
)
self
.
write_queue
.
append
(
CacheOperation
(
host_indices
,
device_indices
,
node_id
,
priority
)
)
...
...
@@ -459,6 +494,7 @@ class HiCacheController:
self
.
mem_pool_host
.
backup_from_device_all_layer
(
self
.
mem_pool_device
,
host_indices
,
device_indices
,
self
.
io_backend
)
self
.
mem_pool_host
.
complete_io
(
op
.
host_indices
)
finish_event
.
record
()
# NOTE: We must save the host indices and device indices here,
# this is because we need to guarantee that these tensors are
...
...
@@ -482,6 +518,7 @@ class HiCacheController:
device_indices
=
self
.
mem_pool_device_allocator
.
alloc
(
len
(
host_indices
))
if
device_indices
is
None
:
return
None
self
.
mem_pool_host
.
protect_load
(
host_indices
)
self
.
load_queue
.
append
(
CacheOperation
(
host_indices
,
device_indices
,
node_id
,
priority
)
)
...
...
@@ -526,6 +563,7 @@ class HiCacheController:
self
.
io_backend
,
)
producer_event
.
complete
(
i
)
self
.
mem_pool_host
.
complete_io
(
op
.
host_indices
)
# NOTE: We must save the host indices and device indices here,
# this is because we need to guarantee that these tensors are
# still alive when the load stream is executing.
...
...
@@ -543,16 +581,29 @@ class HiCacheController:
)
return
producer_id
def
evict_device
(
self
,
device_indices
:
torch
.
Tensor
)
->
int
:
self
.
mem_pool_device_allocator
.
free
(
device_indices
)
return
len
(
device_indices
)
def
evict_device
(
self
,
device_indices
:
torch
.
Tensor
,
host_indices
:
torch
.
Tensor
)
->
int
:
if
self
.
mem_pool_host
.
is_synced
(
host_indices
):
self
.
mem_pool_device_allocator
.
free
(
device_indices
)
self
.
mem_pool_host
.
update_backup
(
host_indices
)
return
len
(
device_indices
)
else
:
raise
ValueError
(
f
"Inconsistent states:
{
self
.
mem_pool_host
.
get_state
(
host_indices
)
}
"
)
def
evict_host
(
self
,
host_indices
:
torch
.
Tensor
,
backup_only
:
bool
=
True
)
->
int
:
if
not
backup_only
:
raise
ValueError
(
"Other eviction policies are not supported yet."
)
self
.
mem_pool_host
.
free
(
host_indices
)
return
len
(
host_indices
)
if
self
.
mem_pool_host
.
is_backup
(
host_indices
):
self
.
mem_pool_host
.
free
(
host_indices
)
return
len
(
host_indices
)
else
:
raise
ValueError
(
f
"Inconsistent states:
{
self
.
mem_pool_host
.
get_state
(
host_indices
)
}
"
)
def
prefetch
(
self
,
...
...
@@ -579,19 +630,42 @@ class HiCacheController:
for
chunk
in
chunks
:
self
.
host_mem_release_queue
.
put
(
chunk
)
def
_page_get_zero_copy
(
self
,
operation
,
hash_values
,
host_indices
):
results
=
self
.
storage_backend
.
batch_get_v1
(
hash_values
,
host_indices
)
inc
=
0
for
i
in
range
(
len
(
hash_values
)):
if
not
results
[
i
]:
logger
.
warning
(
f
"Prefetch operation
{
operation
.
request_id
}
failed to retrieve page
{
hash_values
[
i
]
}
."
)
break
inc
+=
self
.
page_size
operation
.
increment
(
inc
)
def
_3fs_zero_copy_batch_exists
(
self
,
batch_hashes
):
_batch_hashes
,
_
,
factor
=
self
.
mem_pool_host
.
get_buffer_with_hash
(
batch_hashes
)
hit_page_num
=
self
.
storage_backend
.
batch_exists
(
_batch_hashes
)
//
factor
return
hit_page_num
def
_3fs_zero_copy_page_get
(
self
,
operation
,
hash_values
,
host_indices
):
hashes
,
dsts
,
factor
=
self
.
mem_pool_host
.
get_buffer_with_hash
(
hash_values
,
host_indices
)
page_data
=
self
.
storage_backend
.
batch_get
(
hashes
,
dsts
)
if
page_data
:
inc
=
self
.
page_size
*
len
(
hashes
)
//
factor
operation
.
increment
(
inc
)
else
:
logger
.
warning
(
f
"Prefetch operation
{
operation
.
request_id
}
failed to retrieve page
{
hashes
}
."
)
def
_mooncake_page_get
(
self
,
operation
,
hash_values
,
host_indices
):
key_strs
,
buffer_ptrs
,
buffer_sizes
=
self
.
mem_pool_host
.
get_buffer_meta
(
hash_values
,
host_indices
,
self
.
storage_config
.
tp_rank
,
)
get_result
=
self
.
storage_backend
.
batch_get
(
key_strs
,
target_locations
=
buffer_ptrs
,
target_sizes
=
buffer_sizes
,
)
if
get_result
!=
len
(
hash_values
):
logger
.
warning
(
f
"Prefetch operation
{
operation
.
request_id
}
failed or partially failed."
)
if
get_result
!=
0
:
operation
.
increment
(
get_result
*
self
.
page_size
)
# todo: deprecate
def
_generic_page_get
(
self
,
operation
,
hash_values
,
host_indices
):
dummy_page_dst
=
[
self
.
mem_pool_host
.
get_dummy_flat_data_page
()
for
_
in
hash_values
...
...
@@ -681,7 +755,7 @@ class HiCacheController:
batch_tokens
[
i
:
i
+
self
.
page_size
],
last_hash
)
batch_hashes
.
append
(
last_hash
)
hit_page_num
=
self
.
storage_backend
.
batch_exists
(
batch_hashes
)
hit_page_num
=
self
.
batch_exists
_func
(
batch_hashes
)
hash_value
.
extend
(
batch_hashes
[:
hit_page_num
])
storage_query_count
+=
hit_page_num
*
self
.
page_size
if
hit_page_num
<
len
(
batch_hashes
):
...
...
@@ -750,16 +824,34 @@ class HiCacheController:
self
.
backup_queue
.
put
(
operation
)
return
operation
.
id
#
todo: deprecate
#
non-zero copy
def
_generic_page_set
(
self
,
hash_values
,
host_indices
)
->
bool
:
data
=
[
self
.
mem_pool_host
.
get_data_page
(
host_indices
[
i
*
self
.
page_size
])
self
.
mem_pool_host
.
get_
flat_
data_page
(
host_indices
[
i
*
self
.
page_size
])
for
i
in
range
(
len
(
hash_values
))
]
return
self
.
storage_backend
.
batch_set
(
hash_values
,
data
)
def
_page_set_zero_copy
(
self
,
hash_values
,
host_indices
)
->
bool
:
return
all
(
self
.
storage_backend
.
batch_set_v1
(
hash_values
,
host_indices
))
# zero copy
def
_mooncake_page_set
(
self
,
hash_values
,
host_indices
)
->
bool
:
key_strs
,
buffer_ptrs
,
buffer_sizes
=
self
.
mem_pool_host
.
get_buffer_meta
(
hash_values
,
host_indices
,
self
.
storage_config
.
tp_rank
,
)
success
=
self
.
storage_backend
.
batch_set
(
key_strs
,
target_locations
=
buffer_ptrs
,
target_sizes
=
buffer_sizes
,
)
return
success
# zero copy
def
_3fs_zero_copy_page_set
(
self
,
hash_values
,
host_indices
)
->
bool
:
hashes
,
dsts
,
_
=
self
.
mem_pool_host
.
get_buffer_with_hash
(
hash_values
,
host_indices
)
return
self
.
storage_backend
.
batch_set
(
hashes
,
dsts
)
# Backup batch by batch
def
_page_backup
(
self
,
operation
):
...
...
python/sglang/srt/managers/io_struct.py
View file @
852a49c5
...
...
@@ -35,7 +35,6 @@ else:
Image
=
Any
# Parameters for a session
@
dataclass
class
SessionParams
:
id
:
Optional
[
str
]
=
None
...
...
@@ -133,23 +132,18 @@ class GenerateReqInput:
# Conversation id used for tracking requests
conversation_id
:
Optional
[
str
]
=
None
# Label for the request
label
:
Optional
[
str
]
=
None
# Priority for the request
priority
:
Optional
[
int
]
=
None
# Extra key for classifying the request (e.g. cache_salt)
extra_key
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# Whether to disallow logging for this request (e.g. due to ZDR)
no_logs
:
bool
=
False
# For custom metric labels
custom_labels
:
Optional
[
Dict
[
str
,
str
]]
=
None
# (Deprecated, please use custom_labels) Label for the request
label
:
Optional
[
str
]
=
None
# (Internal) Whether to return bytes for image generation
# Image gen grpc migration
return_bytes
:
bool
=
False
# For customer metric labels
customer_labels
:
Optional
[
Dict
[
str
,
str
]]
=
None
def
contains_mm_input
(
self
)
->
bool
:
return
(
has_valid_data
(
self
.
image_data
)
...
...
@@ -548,11 +542,8 @@ class GenerateReqInput:
self
.
data_parallel_rank
if
self
.
data_parallel_rank
is
not
None
else
None
),
conversation_id
=
self
.
conversation_id
,
priority
=
self
.
priority
,
extra_key
=
self
.
extra_key
,
no_logs
=
self
.
no_logs
,
custom_labels
=
self
.
custom_labels
,
label
=
self
.
label
,
priority
=
self
.
priority
,
return_bytes
=
self
.
return_bytes
,
)
...
...
@@ -609,23 +600,18 @@ class TokenizedGenerateReqInput:
# For dp balance
dp_balance_id
:
int
=
-
1
# Label for the request
label
:
Optional
[
str
]
=
None
# Priority for the request
priority
:
Optional
[
int
]
=
None
# Extra key for classifying the request (e.g. cache_salt)
extra_key
:
Optional
[
str
]
=
None
# Whether to disallow logging for this request (e.g. due to ZDR)
no_logs
:
bool
=
False
# Image gen grpc migration
return_bytes
:
bool
=
False
# tracing context
trace_context
:
Optional
[
Dict
]
=
None
# (Deprecated, please use custom_labels) Label for the request
label
:
Optional
[
str
]
=
None
# (Internal) Whether to return bytes for image generation
return_bytes
:
bool
=
False
@
dataclass
class
BatchTokenizedGenerateReqInput
:
...
...
python/sglang/srt/managers/mm_utils.py
View file @
852a49c5
...
...
@@ -507,7 +507,6 @@ def embed_mm_inputs(
Modality
,
Callable
[[
List
[
MultimodalDataItem
]],
torch
.
Tensor
]
]
=
None
,
placeholder_tokens
:
dict
[
Modality
,
List
[
int
]]
=
None
,
use_deepstack
:
bool
=
False
,
)
->
Optional
[
torch
.
Tensor
]:
"""
Embed multimodal inputs and integrate them with text token embeddings.
...
...
@@ -523,7 +522,7 @@ def embed_mm_inputs(
Returns:
Combined embedding tensor with multimodal content integrated
"""
other_info
=
{}
if
mm_inputs_list
is
None
:
return
None
...
...
@@ -533,7 +532,7 @@ def embed_mm_inputs(
for
mm_inputs
in
mm_inputs_list
:
item_flatten_list
+=
[
item
for
item
in
mm_inputs
.
mm_items
if
item
is
not
None
]
embeddings
,
masks
,
deepstack_embeddings
=
[],
[],
[]
embeddings
,
masks
=
[],
[]
# 2. Get multimodal embedding separately
# Try get mm embedding if any
for
modality
in
Modality
.
all
():
...
...
@@ -579,12 +578,6 @@ def embed_mm_inputs(
extend_length
=
extend_seq_lens
,
items_offset_list
=
items_offsets
,
)
if
use_deepstack
and
embedding
is
not
None
:
embedding
,
deepstack_embedding
=
(
multimodal_model
.
separate_deepstack_embeds
(
embedding
)
)
deepstack_embeddings
+=
[
deepstack_embedding
]
embeddings
+=
[
embedding
]
masks
+=
[
mask
]
...
...
@@ -598,37 +591,13 @@ def embed_mm_inputs(
inputs_embeds
=
input_embedding
(
input_ids
)
# 4. scatter embeddings into input embedding
# deepstack embedding
if
use_deepstack
:
num_deepstack_embeddings
=
(
len
(
multimodal_model
.
deepstack_visual_indexes
)
if
use_deepstack
else
0
)
deepstack_embedding_shape
=
inputs_embeds
.
shape
[:
-
1
]
+
(
inputs_embeds
.
shape
[
-
1
]
*
num_deepstack_embeddings
,
)
input_deepstack_embeds
=
torch
.
zeros
(
deepstack_embedding_shape
,
device
=
inputs_embeds
.
device
,
dtype
=
inputs_embeds
.
dtype
,
)
other_info
[
"input_deepstack_embeds"
]
=
input_deepstack_embeds
for
i
,
embedding
,
mask
in
zip
(
range
(
len
(
embeddings
)),
embeddings
,
masks
):
for
embedding
,
mask
in
zip
(
embeddings
,
masks
):
if
embedding
is
None
or
mask
is
None
:
continue
# in-place update
indices
=
torch
.
where
(
mask
.
squeeze
(
dim
=-
1
))[
0
]
inputs_embeds
[
indices
]
=
embedding
.
to
(
inputs_embeds
.
device
,
inputs_embeds
.
dtype
)
if
use_deepstack
:
input_deepstack_embeds
[
indices
]
=
deepstack_embeddings
[
i
].
to
(
inputs_embeds
.
device
,
inputs_embeds
.
dtype
)
return
inputs_embeds
,
other_info
return
inputs_embeds
def
general_mm_embed_routine
(
...
...
@@ -640,7 +609,6 @@ def general_mm_embed_routine(
Modality
,
Callable
[[
List
[
MultimodalDataItem
]],
torch
.
Tensor
]
]
=
None
,
placeholder_tokens
:
Optional
[
dict
[
Modality
,
List
[
int
]]]
=
None
,
use_deepstack
:
bool
=
False
,
**
kwargs
,
)
->
torch
.
Tensor
:
"""
...
...
@@ -652,7 +620,6 @@ def general_mm_embed_routine(
language_model: Base language model to use
data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
placeholder_tokens: Token IDs for multimodal placeholders
use_deepstack: Whether to use deepstack embeddings
**kwargs: Additional arguments passed to language model
Returns:
...
...
@@ -678,20 +645,16 @@ def general_mm_embed_routine(
for
i
,
seq_len
in
enumerate
(
forward_batch
.
extend_seq_lens_cpu
)
if
forward_batch
.
mm_inputs
[
i
]
is
not
None
]
inputs_embeds
,
other_info
=
embed_mm_inputs
(
inputs_embeds
=
embed_mm_inputs
(
mm_inputs_list
=
mm_inputs_list
,
extend_prefix_lens
=
extend_prefix_lens
,
extend_seq_lens
=
extend_seq_lens
,
input_ids
=
input_ids
,
multimodal_model
=
multimodal_model
,
input_embedding
=
embed_tokens
,
multimodal_model
=
multimodal_model
,
data_embedding_func_mapping
=
data_embedding_funcs
,
placeholder_tokens
=
placeholder_tokens
,
use_deepstack
=
use_deepstack
,
)
# add for qwen3_vl deepstack
if
use_deepstack
:
kwargs
[
"input_deepstack_embeds"
]
=
other_info
[
"input_deepstack_embeds"
]
# once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
# just being defensive here
forward_batch
.
mm_inputs
=
None
...
...
python/sglang/srt/managers/multimodal_processor.py
View file @
852a49c5
...
...
@@ -12,7 +12,8 @@ logger = logging.getLogger(__name__)
PROCESSOR_MAPPING
=
{}
def
import_processors
(
package_name
:
str
):
def
import_processors
():
package_name
=
"sglang.srt.multimodal.processors"
package
=
importlib
.
import_module
(
package_name
)
for
_
,
name
,
ispkg
in
pkgutil
.
iter_modules
(
package
.
__path__
,
package_name
+
"."
):
if
not
ispkg
:
...
...
python/sglang/srt/managers/overlap_utils.py
deleted
100644 → 0
View file @
8f7453e3
import
torch
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.utils
import
get_compiler_backend
@
torch
.
compile
(
dynamic
=
True
,
backend
=
get_compiler_backend
())
def
_resolve_future_token_ids
(
input_ids
,
future_token_ids_map
):
input_ids
[:]
=
torch
.
where
(
input_ids
<
0
,
future_token_ids_map
[
torch
.
clamp
(
-
input_ids
,
min
=
0
)],
input_ids
,
)
class
FutureMap
:
def
__init__
(
self
,
max_running_requests
:
int
,
device
:
torch
.
device
,
):
self
.
future_ct
=
0
# A factor of 3 is used to avoid collision in the circular buffer.
self
.
future_limit
=
max_running_requests
*
3
# A factor of 5 is used to ensure the buffer is large enough.
self
.
future_buffer_len
=
max_running_requests
*
5
self
.
device
=
device
self
.
token_ids_buf
=
torch
.
empty
(
(
self
.
future_buffer_len
,),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
def
update_ct
(
self
,
bs
:
int
)
->
int
:
"""Update the circular buffer pointer and return the current pointer."""
cur_future_ct
=
self
.
future_ct
self
.
future_ct
=
(
cur_future_ct
+
bs
)
%
self
.
future_limit
return
cur_future_ct
def
resolve_future
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
input_ids
=
model_worker_batch
.
input_ids
_resolve_future_token_ids
(
input_ids
,
self
.
token_ids_buf
)
def
update_next_future
(
self
,
future_ct
:
int
,
bs
:
int
):
return
torch
.
arange
(
-
(
future_ct
+
1
),
-
(
future_ct
+
1
+
bs
),
-
1
,
dtype
=
torch
.
int64
,
device
=
self
.
device
,
)
def
store_to_map
(
self
,
future_ct
:
int
,
bs
:
int
,
next_token_ids
:
torch
.
Tensor
):
self
.
token_ids_buf
[
future_ct
+
1
:
future_ct
+
bs
+
1
]
=
next_token_ids
python/sglang/srt/managers/schedule_batch.py
View file @
852a49c5
...
...
@@ -67,14 +67,14 @@ from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from
sglang.srt.metrics.collector
import
SchedulerMetricsCollector
,
TimeStats
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
,
ForwardMode
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.sampling.sampling_params
import
DEFAULT_SAMPLING_SEED
,
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
flatten_nested_list
,
support_triton
if
TYPE_CHECKING
:
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.
ngram
_utils
import
Ngram
VerifyInput
from
sglang.srt.speculative.
lookahead
_utils
import
Lookahead
VerifyInput
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
...
...
@@ -90,7 +90,6 @@ GLOBAL_SERVER_ARGS_KEYS = [
"disable_flashinfer_cutlass_moe_fp4_allgather"
,
"disable_radix_cache"
,
"enable_dp_lm_head"
,
"enable_fp32_lm_head"
,
"flashinfer_mxfp4_moe_precision"
,
"enable_flashinfer_allreduce_fusion"
,
"moe_dense_tp_size"
,
...
...
@@ -113,6 +112,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
"enable_custom_logit_processor"
,
"disaggregation_mode"
,
"enable_deterministic_inference"
,
"nsa_prefill"
,
"nsa_decode"
,
]
# Put some global args for easy access
...
...
@@ -492,7 +493,7 @@ class Req:
self
.
custom_logit_processor
=
custom_logit_processor
self
.
return_hidden_states
=
return_hidden_states
# extra key for classifying the request (e.g. cache_salt)
# extra key for classifying the request (e.g.
lora_id,
cache_salt)
if
lora_id
is
not
None
:
extra_key
=
(
extra_key
or
""
...
...
@@ -608,8 +609,6 @@ class Req:
)
=
None
self
.
hidden_states
:
List
[
List
[
float
]]
=
[]
self
.
hidden_states_tensor
=
None
# Note: use tensor instead of list to transfer hidden_states when PD + MTP
self
.
output_topk_p
=
None
self
.
output_topk_index
=
None
# Embedding (return values)
self
.
embedding
=
None
...
...
@@ -954,9 +953,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Speculative decoding
spec_algorithm
:
SpeculativeAlgorithm
=
None
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerifyInput
]]
=
(
None
)
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
LookaheadVerifyInput
]
]
=
None
# Whether to return hidden states
return_hidden_states
:
bool
=
False
...
...
@@ -1609,7 +1608,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if
(
self
.
spec_algorithm
.
is_eagle
()
or
self
.
spec_algorithm
.
is_standalone
()
or
self
.
spec_algorithm
.
is_
ngram
()
or
self
.
spec_algorithm
.
is_
lookahead
()
):
# if spec decoding is used, the decode batch is prepared inside
# `forward_batch_speculative_generation` after running draft models.
...
...
@@ -1736,14 +1735,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
sampling_info
.
filter_batch
(
keep_indices
,
keep_indices_device
)
if
self
.
spec_info
:
if
chunked_req_to_exclude
is
not
None
and
len
(
chunked_req_to_exclude
)
>
0
:
has_been_filtered
=
False
else
:
has_been_filtered
=
True
self
.
spec_info
.
filter_batch
(
new_indices
=
keep_indices_device
,
has_been_filtered
=
has_been_filtered
,
)
self
.
spec_info
.
filter_batch
(
keep_indices_device
)
def
merge_batch
(
self
,
other
:
"ScheduleBatch"
):
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
...
...
@@ -1992,9 +1984,9 @@ class ModelWorkerBatch:
# Speculative decoding
spec_algorithm
:
SpeculativeAlgorithm
=
None
spec_info
:
Optional
[
Union
[
EagleVerifyInput
,
EagleDraftInput
,
NgramVerifyInput
]]
=
(
None
)
spec_info
:
Optional
[
Union
[
EagleVerifyInput
,
EagleDraftInput
,
LookaheadVerifyInput
]
]
=
None
# If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode
:
CaptureHiddenMode
=
None
hicache_consumer_index
:
int
=
-
1
...
...
python/sglang/srt/managers/schedule_policy.py
View file @
852a49c5
...
...
@@ -318,6 +318,7 @@ class PrefillAdder:
new_token_ratio
:
float
,
rem_input_tokens
:
int
,
rem_chunk_tokens
:
Optional
[
int
],
max_prefill_bs
:
Optional
[
int
],
mixed_with_decode_tokens
:
int
=
0
,
priority_scheduling_preemption_threshold
:
int
=
0
,
):
...
...
@@ -358,6 +359,10 @@ class PrefillAdder:
priority_scheduling_preemption_threshold
)
self
.
max_prefill_bs
=
(
max_prefill_bs
if
max_prefill_bs
is
not
None
else
2147483647
)
def
_get_running_request_total_token_offset
(
self
,
req
:
Req
)
->
int
:
return
(
min
(
...
...
@@ -549,6 +554,9 @@ class PrefillAdder:
def
add_one_req
(
self
,
req
:
Req
,
has_chunked_req
:
bool
,
truncation_align_size
:
Optional
[
int
]
):
if
len
(
self
.
can_run_list
)
>=
self
.
max_prefill_bs
:
return
AddReqResult
.
OTHER
if
req
.
sampling_params
.
ignore_eos
and
getattr
(
self
.
tree_cache
,
"disable"
,
True
):
return
self
.
add_one_req_ignore_eos
(
req
,
has_chunked_req
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
852a49c5
...
...
@@ -44,9 +44,6 @@ from sglang.srt.disaggregation.decode import (
DecodeTransferQueue
,
SchedulerDisaggregationDecodeMixin
,
)
from
sglang.srt.disaggregation.decode_kvcache_offload_manager
import
(
DecodeKVCacheOffloadManager
,
)
from
sglang.srt.disaggregation.prefill
import
(
PrefillBootstrapQueue
,
SchedulerDisaggregationPrefillMixin
,
...
...
@@ -262,7 +259,7 @@ class Scheduler(
self
.
enable_metrics_for_all_schedulers
=
(
server_args
.
enable_metrics_for_all_schedulers
)
self
.
enable_kv_cache_events
=
server_args
.
kv_events_config
and
tp_rank
==
0
self
.
enable_kv_cache_events
=
server_args
.
kv_events_config
is
not
None
self
.
stream_interval
=
server_args
.
stream_interval
self
.
spec_algorithm
=
SpeculativeAlgorithm
.
from_string
(
server_args
.
speculative_algorithm
...
...
@@ -388,10 +385,10 @@ class Scheduler(
target_worker
=
self
.
tp_worker
,
dp_rank
=
dp_rank
,
)
elif
self
.
spec_algorithm
.
is_
ngram
():
from
sglang.srt.speculative.
ngram
_worker
import
NGRAM
Worker
elif
self
.
spec_algorithm
.
is_
lookahead
():
from
sglang.srt.speculative.
lookahead
_worker
import
LOOKAHEAD
Worker
self
.
draft_worker
=
NGRAM
Worker
(
self
.
draft_worker
=
LOOKAHEAD
Worker
(
gpu_id
=
gpu_id
,
tp_rank
=
tp_rank
,
moe_ep_rank
=
moe_ep_rank
,
...
...
@@ -556,11 +553,9 @@ class Scheduler(
# Init metrics stats
self
.
init_metrics
(
tp_rank
,
pp_rank
,
dp_rank
)
self
.
init_kv_events
(
server_args
.
kv_events_config
)
self
.
init_dp_balance
(
dp_balance_meta
)
if
self
.
enable_kv_cache_events
:
self
.
init_kv_events
(
server_args
.
kv_events_config
)
# Init disaggregation
self
.
disaggregation_mode
=
DisaggregationMode
(
self
.
server_args
.
disaggregation_mode
...
...
@@ -618,6 +613,8 @@ class Scheduler(
]
)
self
.
max_prefill_bs
=
server_args
.
max_prefill_bs
def
init_deterministic_inference_config
(
self
):
"""Initialize deterministic inference configuration for different attention backends."""
if
not
self
.
server_args
.
enable_deterministic_inference
:
...
...
@@ -758,24 +755,6 @@ class Scheduler(
eviction_policy
=
server_args
.
radix_eviction_policy
,
)
if
(
server_args
.
disaggregation_mode
==
"decode"
and
server_args
.
disaggregation_decode_enable_offload_kvcache
):
self
.
decode_offload_manager
=
DecodeKVCacheOffloadManager
(
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool_allocator
=
self
.
token_to_kv_pool_allocator
,
tp_group
=
(
self
.
attn_tp_cpu_group
if
self
.
server_args
.
enable_dp_attention
else
self
.
tp_cpu_group
),
tree_cache
=
self
.
tree_cache
,
server_args
=
self
.
server_args
,
)
else
:
self
.
decode_offload_manager
=
None
self
.
decode_mem_cache_buf_multiplier
=
(
1
if
self
.
spec_algorithm
.
is_none
()
...
...
@@ -806,7 +785,7 @@ class Scheduler(
self
.
disagg_metadata_buffers
=
MetadataBuffers
(
buffer_size
,
hidden_size
=
self
.
model_config
.
hf_text_config
.
hidden_size
,
hidden_states_
dtype
=
self
.
model_config
.
dtype
,
dtype
=
self
.
model_config
.
dtype
,
custom_mem_pool
=
self
.
token_to_kv_pool_allocator
.
get_kvcache
().
maybe_get_custom_mem_pool
(),
)
...
...
@@ -826,7 +805,7 @@ class Scheduler(
token_to_kv_pool_allocator
=
self
.
token_to_kv_pool_allocator
,
draft_token_to_kv_pool
=
(
None
if
self
.
draft_worker
is
None
or
self
.
spec_algorithm
.
is_
ngram
()
if
self
.
draft_worker
is
None
or
self
.
spec_algorithm
.
is_
lookahead
()
else
self
.
draft_worker
.
model_runner
.
token_to_kv_pool
),
req_to_metadata_buffer_idx_allocator
=
self
.
req_to_metadata_buffer_idx_allocator
,
...
...
@@ -855,7 +834,7 @@ class Scheduler(
self
.
disagg_metadata_buffers
=
MetadataBuffers
(
buffer_size
,
hidden_size
=
self
.
model_config
.
hf_text_config
.
hidden_size
,
hidden_states_
dtype
=
self
.
model_config
.
dtype
,
dtype
=
self
.
model_config
.
dtype
,
custom_mem_pool
=
self
.
token_to_kv_pool_allocator
.
get_kvcache
().
maybe_get_custom_mem_pool
(),
)
...
...
@@ -863,7 +842,7 @@ class Scheduler(
token_to_kv_pool
=
self
.
token_to_kv_pool_allocator
.
get_kvcache
(),
draft_token_to_kv_pool
=
(
None
if
self
.
draft_worker
is
None
or
self
.
spec_algorithm
.
is_
ngram
()
if
self
.
draft_worker
is
None
or
self
.
spec_algorithm
.
is_
lookahead
()
else
self
.
draft_worker
.
model_runner
.
token_to_kv_pool
),
req_to_metadata_buffer_idx_allocator
=
self
.
req_to_metadata_buffer_idx_allocator
,
...
...
@@ -1832,6 +1811,7 @@ class Scheduler(
self
.
new_token_ratio
,
self
.
max_prefill_tokens
,
self
.
chunked_prefill_size
,
self
.
max_prefill_bs
,
running_bs
if
self
.
is_mixed_chunk
else
0
,
self
.
priority_scheduling_preemption_threshold
,
)
...
...
python/sglang/srt/managers/scheduler_output_processor_mixin.py
View file @
852a49c5
...
...
@@ -250,13 +250,7 @@ class SchedulerOutputProcessorMixin:
req
.
check_finished
()
if
req
.
finished
():
if
self
.
server_args
.
disaggregation_decode_enable_offload_kvcache
:
# Asynchronously offload KV cache; cache_finished_req will be called after Device->Host transfer completes
if
not
self
.
decode_offload_manager
.
offload_kv_cache
(
req
):
self
.
tree_cache
.
cache_finished_req
(
req
)
else
:
self
.
tree_cache
.
cache_finished_req
(
req
)
self
.
tree_cache
.
cache_finished_req
(
req
)
req
.
time_stats
.
completion_time
=
time
.
time
()
if
req
.
return_logprob
and
batch
.
spec_algorithm
.
is_none
():
...
...
python/sglang/srt/managers/scheduler_profiler_mixin.py
View file @
852a49c5
...
...
@@ -97,7 +97,7 @@ class SchedulerProfilerMixin:
def
start_profile
(
self
,
stage
:
Optional
[
ForwardMode
]
=
None
)
->
ProfileReqOutput
|
None
:
stage_str
=
f
" for
{
stage
.
name
}
"
if
stage
else
""
stage_str
=
f
" for
{
stage
.
__str__
()
}
"
if
stage
else
""
logger
.
info
(
f
"Profiling starts
{
stage_str
}
. Traces will be saved to:
{
self
.
torch_profiler_output_dir
}
(with profile id:
{
self
.
profile_id
}
)"
,
)
...
...
@@ -181,7 +181,7 @@ class SchedulerProfilerMixin:
if
not
Path
(
self
.
torch_profiler_output_dir
).
exists
():
Path
(
self
.
torch_profiler_output_dir
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
stage_suffix
=
f
"-
{
stage
.
name
}
"
if
stage
else
""
stage_suffix
=
f
"-
{
stage
.
__str__
()
}
"
if
stage
else
""
logger
.
info
(
"Stop profiling"
+
stage_suffix
+
"..."
)
if
self
.
torch_profiler
is
not
None
:
self
.
torch_profiler
.
stop
()
...
...
@@ -247,7 +247,7 @@ class SchedulerProfilerMixin:
if
self
.
profiler_decode_ct
==
0
:
if
self
.
profile_in_progress
:
# force trace flush
self
.
stop_profile
(
stage
=
ForwardMode
.
EXTEND
)
self
.
stop_profile
(
ForwardMode
.
EXTEND
)
self
.
start_profile
(
batch
.
forward_mode
)
self
.
profiler_decode_ct
+=
1
if
self
.
profiler_decode_ct
>
self
.
profiler_target_decode_ct
:
...
...
@@ -294,6 +294,6 @@ class SchedulerProfilerMixin:
recv_req
.
profile_by_stage
,
recv_req
.
profile_id
,
)
return
self
.
start_profile
()
return
self
.
start_profile
(
True
)
else
:
return
self
.
stop_profile
()
python/sglang/srt/managers/tokenizer_manager.py
View file @
852a49c5
...
...
@@ -185,7 +185,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
)
if
self
.
model_config
.
is_multimodal
:
import_processors
(
"sglang.srt.multimodal.processors"
)
import_processors
()
try
:
_processor
=
get_processor
(
server_args
.
tokenizer_path
,
...
...
@@ -320,8 +320,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
"model_name"
:
self
.
server_args
.
served_model_name
,
# TODO: Add lora name/path in the future,
}
if
server_args
.
tokenizer_metrics_allowed_custom_labels
:
for
label
in
server_args
.
tokenizer_metrics_allowed_custom_labels
:
if
server_args
.
tokenizer_metrics_allowed_custom
er
_labels
:
for
label
in
server_args
.
tokenizer_metrics_allowed_custom
er
_labels
:
labels
[
label
]
=
""
self
.
metrics_collector
=
TokenizerMetricsCollector
(
server_args
=
server_args
,
...
...
@@ -750,7 +750,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
return_hidden_states
=
obj
.
return_hidden_states
,
data_parallel_rank
=
obj
.
data_parallel_rank
,
priority
=
obj
.
priority
,
extra_key
=
obj
.
extra_key
,
)
elif
isinstance
(
obj
,
EmbeddingReqInput
):
tokenized_obj
=
TokenizedEmbeddingReqInput
(
...
...
@@ -1633,10 +1632,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
else
0
)
custom_labels
=
getattr
(
state
.
obj
,
"custom_labels"
,
None
)
custom
er
_labels
=
getattr
(
state
.
obj
,
"custom
er
_labels"
,
None
)
labels
=
(
{
**
self
.
metrics_collector
.
labels
,
**
custom_labels
}
if
custom_labels
{
**
self
.
metrics_collector
.
labels
,
**
custom
er
_labels
}
if
custom
er
_labels
else
self
.
metrics_collector
.
labels
)
if
(
...
...
python/sglang/srt/managers/tp_worker.py
View file @
852a49c5
...
...
@@ -91,6 +91,7 @@ class TpModelWorker:
else
server_args
.
speculative_draft_model_revision
),
is_draft_model
=
is_draft_worker
,
tp_rank
=
tp_rank
,
)
self
.
model_runner
=
ModelRunner
(
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
852a49c5
...
...
@@ -36,11 +36,10 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromTensorReqInput
,
)
from
sglang.srt.managers.overlap_utils
import
FutureMap
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
DynamicGradMode
from
sglang.srt.utils
import
DynamicGradMode
,
get_compiler_backend
from
sglang.utils
import
get_exception_traceback
if
TYPE_CHECKING
:
...
...
@@ -49,6 +48,15 @@ if TYPE_CHECKING:
logger
=
logging
.
getLogger
(
__name__
)
@
torch
.
compile
(
dynamic
=
True
,
backend
=
get_compiler_backend
())
def
resolve_future_token_ids
(
input_ids
,
future_token_ids_map
):
input_ids
[:]
=
torch
.
where
(
input_ids
<
0
,
future_token_ids_map
[
torch
.
clamp
(
-
input_ids
,
min
=
0
)],
input_ids
,
)
class
TpModelWorkerClient
:
"""A tensor parallel model worker."""
...
...
@@ -71,7 +79,11 @@ class TpModelWorkerClient:
self
.
gpu_id
=
gpu_id
# Init future mappings
self
.
future_map
=
FutureMap
(
self
.
max_running_requests
,
self
.
device
)
self
.
future_token_ids_ct
=
0
self
.
future_token_ids_limit
=
self
.
max_running_requests
*
3
self
.
future_token_ids_map
=
torch
.
empty
(
(
self
.
max_running_requests
*
5
,),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
# Launch threads
self
.
input_queue
=
Queue
[
Tuple
[
ModelWorkerBatch
,
int
,
torch
.
Event
]]()
...
...
@@ -141,7 +153,7 @@ class TpModelWorkerClient:
batch_lists
:
List
=
[
None
]
*
2
while
True
:
model_worker_batch
,
future_
map
_ct
,
sync_event
=
self
.
input_queue
.
get
()
model_worker_batch
,
future_
token_ids
_ct
,
sync_event
=
self
.
input_queue
.
get
()
if
not
model_worker_batch
:
break
...
...
@@ -157,7 +169,8 @@ class TpModelWorkerClient:
copy_done
=
torch
.
get_device_module
(
self
.
device
).
Event
()
# Resolve future tokens in the input
self
.
future_map
.
resolve_future
(
model_worker_batch
)
input_ids
=
model_worker_batch
.
input_ids
resolve_future_token_ids
(
input_ids
,
self
.
future_token_ids_map
)
# Run forward
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
...
...
@@ -174,9 +187,9 @@ class TpModelWorkerClient:
if
model_worker_batch
.
is_prefill_only
:
# For prefill-only requests, create dummy token IDs on CPU
next_token_ids
=
torch
.
zeros
(
bs
,
dtype
=
torch
.
long
)
# store the future indices into future map
self
.
future_map
.
store_to_map
(
future_map_ct
,
bs
,
next_token_ids
)
self
.
future_token_ids_map
[
future_token_ids_ct
+
1
:
future_token_ids_ct
+
bs
+
1
]
=
next_token_ids
# Copy results to the CPU
if
model_worker_batch
.
return_logprob
:
...
...
@@ -242,14 +255,20 @@ class TpModelWorkerClient:
sync_event
.
record
(
self
.
scheduler_stream
)
# Push a new batch to the queue
bs
=
len
(
model_worker_batch
.
seq_lens
)
cur_future_map_ct
=
self
.
future_map
.
update_ct
(
bs
)
self
.
input_queue
.
put
((
model_worker_batch
,
cur_future_map_ct
,
sync_event
))
self
.
input_queue
.
put
((
model_worker_batch
,
self
.
future_token_ids_ct
,
sync_event
))
# get this forward batch's future token ids
future_next_token_ids
=
self
.
future_map
.
update_next_future
(
cur_future_map_ct
,
bs
# Allocate output future objects
bs
=
len
(
model_worker_batch
.
seq_lens
)
future_next_token_ids
=
torch
.
arange
(
-
(
self
.
future_token_ids_ct
+
1
),
-
(
self
.
future_token_ids_ct
+
1
+
bs
),
-
1
,
dtype
=
torch
.
int64
,
device
=
self
.
device
,
)
self
.
future_token_ids_ct
=
(
self
.
future_token_ids_ct
+
bs
)
%
self
.
future_token_ids_limit
return
None
,
future_next_token_ids
,
False
def
update_weights_from_disk
(
self
,
recv_req
:
UpdateWeightFromDiskReqInput
):
...
...
python/sglang/srt/mem_cache/allocator_ascend.py
View file @
852a49c5
...
...
@@ -79,37 +79,48 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
)
num_new_pages
=
(
(
(
seq_lens
+
self
.
page_size
-
1
)
//
self
.
page_size
-
(
prefix_lens
+
self
.
page_size
-
1
)
//
self
.
page_size
)
.
sum
()
.
item
()
)
if
self
.
need_sort
and
num_new_pages
>
len
(
self
.
free_pages
):
(
seq_lens
+
self
.
page_size
-
1
)
//
self
.
page_size
-
(
prefix_lens
+
self
.
page_size
-
1
)
//
self
.
page_size
).
sum
()
num_new_pages_item
=
num_new_pages
.
item
()
if
self
.
need_sort
and
num_new_pages_item
>
len
(
self
.
free_pages
):
self
.
merge_and_sort_free
()
if
num_new_pages
>
len
(
self
.
free_pages
):
if
num_new_pages
_item
>
len
(
self
.
free_pages
):
return
None
out_indices
=
torch
.
empty
(
(
extend_num_tokens
,),
dtype
=
torch
.
int
32
,
device
=
self
.
device
(
extend_num_tokens
,),
dtype
=
torch
.
int
64
,
device
=
self
.
device
)
alloc_extend_kernel_ascend
(
prefix_lens
,
seq_lens
,
last_loc
,
self
.
free_pages
,
out_indices
,
self
.
page_size
,
self
.
device
,
)
if
num_new_pages_item
<
200
:
import
sgl_kernel_npu
torch
.
ops
.
npu
.
alloc_extend
(
prefix_lens
,
seq_lens
,
last_loc
,
self
.
free_pages
,
self
.
page_size
,
out_indices
,
num_new_pages
,
)
else
:
alloc_extend_kernel_ascend
(
prefix_lens
,
seq_lens
,
last_loc
,
self
.
free_pages
,
out_indices
,
self
.
page_size
,
self
.
device
,
)
if
self
.
debug_mode
:
assert
len
(
torch
.
unique
(
out_indices
))
==
len
(
out_indices
)
self
.
free_pages
=
self
.
free_pages
[
num_new_pages
:]
self
.
free_pages
=
self
.
free_pages
[
num_new_pages
_item
:]
return
out_indices
def
alloc_decode
(
...
...
python/sglang/srt/mem_cache/hicache_storage.py
View file @
852a49c5
...
...
@@ -7,8 +7,6 @@ from typing import Any, List, Optional
import
torch
from
sglang.srt.mem_cache.memory_pool_host
import
HostKVCache
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -34,46 +32,15 @@ class HiCacheStorageConfig:
extra_config
:
Optional
[
dict
]
=
None
@
dataclass
class
HiCacheStorageExtraInfo
:
extra_info
:
Optional
[
dict
]
=
None
class
HiCacheStorage
(
ABC
):
"""
HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache.
It abstracts the underlying storage mechanism, allowing different implementations to be used.
"""
# todo, potentially pass model and TP configs into storage backend
# todo, the page size of storage backend does not have to be the same as the same as host memory pool
def
register_mem_pool_host
(
self
,
mem_pool_host
:
HostKVCache
):
self
.
mem_pool_host
=
mem_pool_host
def
batch_get_v1
(
self
,
keys
:
List
[
str
],
host_indices
:
torch
.
Tensor
,
extra_info
:
Optional
[
HiCacheStorageExtraInfo
]
=
None
,
)
->
List
[
bool
]:
"""
Retrieve values for multiple keys.
Returns a list of tensors or None for each key.
"""
pass
def
batch_set_v1
(
self
,
keys
:
List
[
str
],
host_indices
:
torch
.
Tensor
,
extra_info
:
Optional
[
HiCacheStorageExtraInfo
]
=
None
,
)
->
List
[
bool
]:
"""
Retrieve values for multiple keys.
Returns a list of tensors or None for each key.
"""
pass
@
abstractmethod
def
get
(
self
,
...
...
@@ -87,7 +54,6 @@ class HiCacheStorage(ABC):
"""
pass
# TODO: Deprecate
@
abstractmethod
def
batch_get
(
self
,
...
...
@@ -115,7 +81,6 @@ class HiCacheStorage(ABC):
"""
pass
# TODO: Deprecate
@
abstractmethod
def
batch_set
(
self
,
...
...
@@ -138,7 +103,6 @@ class HiCacheStorage(ABC):
"""
pass
# TODO: Use a finer-grained return type (e.g., List[bool])
def
batch_exists
(
self
,
keys
:
List
[
str
])
->
int
:
"""
Check if the keys exist in the storage.
...
...
@@ -150,9 +114,6 @@ class HiCacheStorage(ABC):
return
i
return
len
(
keys
)
def
clear
(
self
)
->
None
:
pass
def
get_stats
(
self
):
return
None
...
...
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment