Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
852a49c5
Commit
852a49c5
authored
Sep 30, 2025
by
maxiao
Browse files
adapt to dsv32 on dcu
parent
8f7453e3
Changes
159
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
279 additions
and
397 deletions
+279
-397
python/sglang/srt/layers/quantization/w8a8_int8.py
python/sglang/srt/layers/quantization/w8a8_int8.py
+3
-15
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+17
-54
python/sglang/srt/layers/utils.py
python/sglang/srt/layers/utils.py
+0
-23
python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py
python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py
+1
-1
python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py
python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py
+1
-1
python/sglang/srt/managers/cache_controller.py
python/sglang/srt/managers/cache_controller.py
+126
-34
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+13
-27
python/sglang/srt/managers/mm_utils.py
python/sglang/srt/managers/mm_utils.py
+6
-43
python/sglang/srt/managers/multimodal_processor.py
python/sglang/srt/managers/multimodal_processor.py
+2
-1
python/sglang/srt/managers/overlap_utils.py
python/sglang/srt/managers/overlap_utils.py
+0
-53
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+13
-21
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+8
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+12
-32
python/sglang/srt/managers/scheduler_output_processor_mixin.py
...n/sglang/srt/managers/scheduler_output_processor_mixin.py
+1
-7
python/sglang/srt/managers/scheduler_profiler_mixin.py
python/sglang/srt/managers/scheduler_profiler_mixin.py
+4
-4
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+6
-7
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+1
-0
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+33
-14
python/sglang/srt/mem_cache/allocator_ascend.py
python/sglang/srt/mem_cache/allocator_ascend.py
+31
-20
python/sglang/srt/mem_cache/hicache_storage.py
python/sglang/srt/mem_cache/hicache_storage.py
+1
-40
No files found.
python/sglang/srt/layers/quantization/w8a8_int8.py
View file @
852a49c5
...
@@ -393,23 +393,13 @@ class W8A8Int8LinearMethod(LinearMethodBase):
...
@@ -393,23 +393,13 @@ class W8A8Int8LinearMethod(LinearMethodBase):
x
.
dtype
,
x
.
dtype
,
True
,
# is_vnni
True
,
# is_vnni
)
)
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
x_q_2d
=
x_q
.
view
(
-
1
,
x_q
.
shape
[
-
1
])
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
x_scale_2d
=
x_scale
.
view
(
-
1
,
x_scale
.
shape
[
-
1
])
output_shape
=
[
*
x_q
.
shape
[:
-
1
],
layer
.
weight
.
shape
[
1
]]
output
=
int8_scaled_mm
(
return
int8_scaled_mm
(
x_q_2d
,
x_q
,
layer
.
weight
,
x_scale
,
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
layer
.
weight
,
x_scale_2d
,
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
,
)
)
return
output
.
view
(
output_shape
)
class
W8A8Int8MoEMethod
(
FusedMoEMethodBase
):
class
W8A8Int8MoEMethod
(
FusedMoEMethodBase
):
"""MoE method for INT8.
"""MoE method for INT8.
...
@@ -648,7 +638,6 @@ class NPU_W8A8LinearMethodImpl:
...
@@ -648,7 +638,6 @@ class NPU_W8A8LinearMethodImpl:
layer
.
weight
.
data
=
layer
.
weight
.
data
.
transpose
(
0
,
1
).
contiguous
()
layer
.
weight
.
data
=
layer
.
weight
.
data
.
transpose
(
0
,
1
).
contiguous
()
layer
.
weight_scale
.
data
=
torch
.
flatten
(
layer
.
weight_scale
.
data
)
layer
.
weight_scale
.
data
=
torch
.
flatten
(
layer
.
weight_scale
.
data
)
layer
.
weight_offset
.
data
=
torch
.
flatten
(
layer
.
weight_offset
.
data
)
layer
.
weight_offset
.
data
=
torch
.
flatten
(
layer
.
weight_offset
.
data
)
layer
.
weight
.
data
=
torch_npu
.
npu_format_cast
(
layer
.
weight
.
data
,
29
)
class
NPU_W8A8LinearMethodMTImpl
:
class
NPU_W8A8LinearMethodMTImpl
:
...
@@ -841,7 +830,6 @@ class NPU_W8A8DynamicLinearMethodImpl:
...
@@ -841,7 +830,6 @@ class NPU_W8A8DynamicLinearMethodImpl:
layer
.
weight_scale
.
data
=
layer
.
weight_scale
.
data
.
flatten
()
layer
.
weight_scale
.
data
=
layer
.
weight_scale
.
data
.
flatten
()
layer
.
weight_scale_fp32
=
layer
.
weight_scale
.
data
.
to
(
torch
.
float32
)
layer
.
weight_scale_fp32
=
layer
.
weight_scale
.
data
.
to
(
torch
.
float32
)
layer
.
weight_offset
.
data
=
layer
.
weight_offset
.
data
.
flatten
()
layer
.
weight_offset
.
data
=
layer
.
weight_offset
.
data
.
flatten
()
layer
.
weight
.
data
=
torch_npu
.
npu_format_cast
(
layer
.
weight
.
data
,
29
)
class
NPU_W8A8DynamicLinearMethod
(
LinearMethodBase
):
class
NPU_W8A8DynamicLinearMethod
(
LinearMethodBase
):
...
...
python/sglang/srt/layers/rotary_embedding.py
View file @
852a49c5
...
@@ -12,7 +12,6 @@ from sglang.srt.custom_op import CustomOp
...
@@ -12,7 +12,6 @@ from sglang.srt.custom_op import CustomOp
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
cpu_has_amx_support
,
cpu_has_amx_support
,
get_bool_env_var
,
get_bool_env_var
,
get_compiler_backend
,
is_cpu
,
is_cpu
,
is_cuda
,
is_cuda
,
is_hip
,
is_hip
,
...
@@ -27,19 +26,13 @@ _is_cpu_amx_available = cpu_has_amx_support()
...
@@ -27,19 +26,13 @@ _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu
=
is_cpu
()
_is_cpu
=
is_cpu
()
if
_is_cuda
:
if
_is_cuda
:
from
sgl_kernel
import
FusedSetKVBufferArg
,
apply_rope_with_cos_sin_cache_inplace
from
sgl_kernel
import
apply_rope_with_cos_sin_cache_inplace
else
:
FusedSetKVBufferArg
=
None
if
_use_aiter
:
if
_use_aiter
:
from
aiter.rotary_embedding
import
get_rope
as
aiter_get_rope
from
aiter.rotary_embedding
import
get_rope
as
aiter_get_rope
if
is_npu
():
if
is_npu
():
import
torch_npu
import
torch_npu
NPU_ROTARY_MUL_MAX_NUM_HEADS
=
1000
NPU_ROTARY_MUL_MAX_HEAD_SIZE
=
896
def
_rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
...
@@ -149,13 +142,8 @@ class RotaryEmbedding(CustomOp):
...
@@ -149,13 +142,8 @@ class RotaryEmbedding(CustomOp):
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
fused_set_kv_buffer_arg
:
Optional
[
FusedSetKVBufferArg
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""A PyTorch-native implementation of forward()."""
"""A PyTorch-native implementation of forward()."""
assert
(
fused_set_kv_buffer_arg
is
None
),
"fused_set_kv_buffer_arg is not supported for native implementation"
if
offsets
is
not
None
:
if
offsets
is
not
None
:
positions
=
positions
+
offsets
positions
=
positions
+
offsets
positions
=
positions
.
flatten
()
positions
=
positions
.
flatten
()
...
@@ -184,17 +172,12 @@ class RotaryEmbedding(CustomOp):
...
@@ -184,17 +172,12 @@ class RotaryEmbedding(CustomOp):
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
fused_set_kv_buffer_arg
:
Optional
[
FusedSetKVBufferArg
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""A PyTorch-npu implementation of forward()."""
"""A PyTorch-npu implementation of forward()."""
assert
(
import
os
fused_set_kv_buffer_arg
is
None
),
"fused_set_kv_buffer_arg is not supported for npu implementation"
if
get_bool_env_var
(
"SGLANG_ENABLE_TORCH_COMPILE"
):
if
get_bool_env_var
(
"SGLANG_ENABLE_TORCH_COMPILE"
):
return
self
.
forward_native
(
return
self
.
forward_native
(
positions
,
query
,
key
,
offsets
)
positions
,
query
,
key
,
offsets
,
fused_set_kv_buffer_arg
)
else
:
else
:
rotary_mode
=
"half"
rotary_mode
=
"half"
if
self
.
is_neox_style
:
if
self
.
is_neox_style
:
...
@@ -219,12 +202,7 @@ class RotaryEmbedding(CustomOp):
...
@@ -219,12 +202,7 @@ class RotaryEmbedding(CustomOp):
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
fused_set_kv_buffer_arg
:
Optional
[
FusedSetKVBufferArg
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
(
fused_set_kv_buffer_arg
is
None
),
"fused_set_kv_buffer_arg is not supported for cpu implementation"
positions
=
torch
.
add
(
positions
,
offsets
)
if
offsets
is
not
None
else
positions
positions
=
torch
.
add
(
positions
,
offsets
)
if
offsets
is
not
None
else
positions
if
_is_cpu_amx_available
:
if
_is_cpu_amx_available
:
return
torch
.
ops
.
sgl_kernel
.
rotary_embedding_cpu
(
return
torch
.
ops
.
sgl_kernel
.
rotary_embedding_cpu
(
...
@@ -236,9 +214,7 @@ class RotaryEmbedding(CustomOp):
...
@@ -236,9 +214,7 @@ class RotaryEmbedding(CustomOp):
self
.
is_neox_style
,
self
.
is_neox_style
,
)
)
else
:
else
:
return
self
.
forward_native
(
return
self
.
forward_native
(
positions
,
query
,
key
,
offsets
)
positions
,
query
,
key
,
offsets
,
fused_set_kv_buffer_arg
)
def
forward_cuda
(
def
forward_cuda
(
self
,
self
,
...
@@ -246,7 +222,7 @@ class RotaryEmbedding(CustomOp):
...
@@ -246,7 +222,7 @@ class RotaryEmbedding(CustomOp):
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
fused_set_kv_buffer_arg
:
Optional
[
FusedSetKVBufferArg
]
=
None
,
fused_set_kv_buffer_arg
=
None
,
#
Optional[FusedSetKVBufferArg]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
_is_cuda
and
(
self
.
head_size
in
[
64
,
128
,
256
,
512
]):
if
_is_cuda
and
(
self
.
head_size
in
[
64
,
128
,
256
,
512
]):
apply_rope_with_cos_sin_cache_inplace
(
apply_rope_with_cos_sin_cache_inplace
(
...
@@ -1059,7 +1035,7 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -1059,7 +1035,7 @@ class MRotaryEmbedding(RotaryEmbedding):
f
"Corrected mrope_section:
{
self
.
mrope_section
}
(sum=
{
sum
(
self
.
mrope_section
)
}
)"
f
"Corrected mrope_section:
{
self
.
mrope_section
}
(sum=
{
sum
(
self
.
mrope_section
)
}
)"
)
)
@
torch
.
compile
(
dynamic
=
True
,
backend
=
get_compiler_backend
()
)
@
torch
.
compile
(
dynamic
=
True
)
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
@@ -1207,7 +1183,7 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -1207,7 +1183,7 @@ class MRotaryEmbedding(RotaryEmbedding):
time_tensor_long
=
time_tensor
.
long
()
time_tensor_long
=
time_tensor
.
long
()
t_index
=
time_tensor_long
.
flatten
()
t_index
=
time_tensor_long
.
flatten
()
elif
model_type
in
(
"qwen2_vl"
,
"qwen3_vl"
,
"qwen3_vl_moe"
)
:
elif
model_type
==
"qwen2_vl"
:
t_index
=
(
t_index
=
(
torch
.
arange
(
llm_grid_t
)
torch
.
arange
(
llm_grid_t
)
.
view
(
-
1
,
1
)
.
view
(
-
1
,
1
)
...
@@ -1918,30 +1894,17 @@ def apply_rotary_pos_emb_npu(
...
@@ -1918,30 +1894,17 @@ def apply_rotary_pos_emb_npu(
sin
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
unsqueeze_dim
=
1
,
unsqueeze_dim
=
1
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Ascend implementation equivalent to apply_rotary_pos_emb_native.
if
q
.
shape
[
1
]
!=
128
:
Args:
q: [num_tokens, num_heads, head_size]
k: [num_tokens, num_kv_heads, head_size]
cos: [num_tokens, head_size]
sin: [num_tokens, head_size]
"""
if
(
cos
.
dim
()
!=
2
or
q
.
dim
()
!=
3
or
q
.
shape
[
1
]
>=
NPU_ROTARY_MUL_MAX_NUM_HEADS
or
q
.
shape
[
2
]
>=
NPU_ROTARY_MUL_MAX_HEAD_SIZE
):
# Note: num_heads and head_size of q must be less than 1000 and 896, respectively
return
apply_rotary_pos_emb_native
(
q
,
k
,
cos
,
sin
,
unsqueeze_dim
)
return
apply_rotary_pos_emb_native
(
q
,
k
,
cos
,
sin
,
unsqueeze_dim
)
cos
=
cos
.
unsqueeze
(
unsqueeze_dim
).
unsqueeze
(
0
)
cos
=
cos
.
unsqueeze
(
unsqueeze_dim
)
sin
=
sin
.
unsqueeze
(
unsqueeze_dim
).
unsqueeze
(
0
)
cos
=
torch
.
transpose
(
cos
,
1
,
2
)
q
=
q
.
unsqueeze
(
0
)
sin
=
sin
.
unsqueeze
(
unsqueeze_dim
)
k
=
k
.
unsqueeze
(
0
)
sin
=
torch
.
transpose
(
sin
,
1
,
2
)
q_embed
=
torch_npu
.
npu_rotary_mul
(
q
,
cos
,
sin
)
q
=
torch
.
transpose
(
q
,
1
,
2
)
k_embed
=
torch_npu
.
npu_rotary_mul
(
k
,
cos
,
sin
)
k
=
torch
.
transpose
(
k
,
1
,
2
)
q_embed
=
q_embed
.
squeeze
(
0
)
q_embed
,
k_embed
=
torch_npu
.
npu_apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
)
k_embed
=
k_embed
.
squeeze
(
0
)
q_embed
=
torch
.
transpose
(
q_embed
,
1
,
2
)
k_embed
=
torch
.
transpose
(
k_embed
,
1
,
2
)
return
q_embed
,
k_embed
return
q_embed
,
k_embed
...
...
python/sglang/srt/layers/utils.py
View file @
852a49c5
...
@@ -15,29 +15,6 @@ def get_layer_id(weight_name):
...
@@ -15,29 +15,6 @@ def get_layer_id(weight_name):
return
None
return
None
def
pad_or_narrow_weight
(
loaded_weight
:
torch
.
Tensor
,
input_dim
:
int
,
start_idx
:
int
,
shard_size
:
int
)
->
torch
.
Tensor
:
# Padding with zeros for special case such as qwen2_5_VL's mlp which is not 8-aligned
valid_size
=
max
(
loaded_weight
.
shape
[
input_dim
]
-
start_idx
,
0
)
if
valid_size
>
0
:
loaded_slice
=
loaded_weight
.
narrow
(
input_dim
,
start_idx
,
valid_size
)
pad_shape
=
list
(
loaded_weight
.
shape
)
pad_shape
[
input_dim
]
=
shard_size
-
valid_size
pad
=
torch
.
zeros
(
pad_shape
,
dtype
=
loaded_weight
.
dtype
,
device
=
loaded_weight
.
device
)
return
torch
.
cat
([
loaded_slice
,
pad
],
dim
=
input_dim
)
# All padding
pad_shape
=
list
(
loaded_weight
.
shape
)
pad_shape
[
input_dim
]
=
shard_size
return
torch
.
zeros
(
pad_shape
,
dtype
=
loaded_weight
.
dtype
,
device
=
loaded_weight
.
device
)
class
PPMissingLayer
(
torch
.
nn
.
Identity
):
class
PPMissingLayer
(
torch
.
nn
.
Identity
):
# Adapted from
# Adapted from
# https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1
# https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1
...
...
python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py
View file @
852a49c5
...
@@ -5,7 +5,7 @@ import triton
...
@@ -5,7 +5,7 @@ import triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.lora.utils
import
LoRABatchInfo
from
sglang.srt.lora.utils
import
LoRABatchInfo
from
sglang.
srt.
utils
import
cached_triton_kernel
from
sglang.utils
import
cached_triton_kernel
@
cached_triton_kernel
(
lambda
_
,
kwargs
:
(
kwargs
[
"NUM_SLICES"
],
kwargs
[
"BLOCK_M"
]))
@
cached_triton_kernel
(
lambda
_
,
kwargs
:
(
kwargs
[
"NUM_SLICES"
],
kwargs
[
"BLOCK_M"
]))
...
...
python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py
View file @
852a49c5
...
@@ -3,7 +3,7 @@ import triton
...
@@ -3,7 +3,7 @@ import triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.lora.utils
import
LoRABatchInfo
from
sglang.srt.lora.utils
import
LoRABatchInfo
from
sglang.
srt.
utils
import
cached_triton_kernel
from
sglang.utils
import
cached_triton_kernel
@
cached_triton_kernel
(
lambda
_
,
kwargs
:
(
kwargs
[
"NUM_SLICES"
],
kwargs
[
"BLOCK_M"
]))
@
cached_triton_kernel
(
lambda
_
,
kwargs
:
(
kwargs
[
"NUM_SLICES"
],
kwargs
[
"BLOCK_M"
]))
...
...
python/sglang/srt/managers/cache_controller.py
View file @
852a49c5
...
@@ -275,17 +275,43 @@ class HiCacheController:
...
@@ -275,17 +275,43 @@ class HiCacheController:
and
self
.
storage_config
.
tp_rank
!=
0
and
self
.
storage_config
.
tp_rank
!=
0
)
)
# Use
storage
backend
factory for dynamic backend creation
if
storage
_
backend
==
"file"
:
from
sglang.srt.mem_cache.storage
import
StorageBackendFactory
from
sglang.srt.mem_cache.
hicache_
storage
import
HiCacheFile
try
:
self
.
storage_backend
=
HiCacheFile
(
self
.
storage_config
)
self
.
storage_backend
=
StorageBackendFactory
.
create_backend
(
elif
storage_backend
==
"nixl"
:
storage_backend
,
self
.
storage_config
,
self
.
mem_pool_host
from
sglang.srt.mem_cache.storage.nixl.hicache_nixl
import
HiCacheNixl
self
.
storage_backend
=
HiCacheNixl
()
elif
storage_backend
==
"mooncake"
:
from
sglang.srt.mem_cache.storage.mooncake_store.mooncake_store
import
(
MooncakeStore
,
)
self
.
storage_backend
=
MooncakeStore
(
self
.
storage_config
)
self
.
storage_backend
.
register_buffer
(
self
.
mem_pool_host
.
kv_buffer
)
assert
self
.
mem_pool_host
.
layout
==
"page_first"
elif
storage_backend
==
"hf3fs"
:
from
sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs
import
(
HiCacheHF3FS
,
)
)
except
ValueError
as
e
:
raise
ValueError
(
f
"Failed to create storage backend:
{
e
}
"
)
from
e
self
.
storage_backend
.
register_mem_pool_host
(
self
.
mem_pool_host
)
if
self
.
mem_pool_host
.
layout
==
"page_first"
:
bytes_per_page
=
(
mem_pool_host
.
get_ksize_per_token
()
*
mem_pool_host
.
page_size
)
elif
self
.
mem_pool_host
.
layout
==
"layer_first"
:
bytes_per_page
=
(
mem_pool_host
.
get_size_per_token
()
*
mem_pool_host
.
page_size
)
dtype
=
mem_pool_host
.
dtype
self
.
storage_backend
=
HiCacheHF3FS
.
from_env_config
(
bytes_per_page
,
dtype
,
self
.
storage_config
)
else
:
raise
NotImplementedError
(
f
"Unsupported storage backend:
{
storage_backend
}
"
)
self
.
enable_storage
=
True
self
.
enable_storage
=
True
# todo: threshold policy for prefetching
# todo: threshold policy for prefetching
...
@@ -309,10 +335,18 @@ class HiCacheController:
...
@@ -309,10 +335,18 @@ class HiCacheController:
# Select the get and set functions
# Select the get and set functions
self
.
page_get_func
=
self
.
_generic_page_get
self
.
page_get_func
=
self
.
_generic_page_get
self
.
page_set_func
=
self
.
_generic_page_set
self
.
page_set_func
=
self
.
_generic_page_set
self
.
batch_exists_func
=
self
.
storage_backend
.
batch_exists
if
self
.
storage_backend_type
in
[
"hf3fs"
,
"mooncake"
]:
self
.
is_3fs_zerocopy
=
(
self
.
page_get_func
=
self
.
_page_get_zero_copy
self
.
storage_backend_type
==
"hf3fs"
self
.
page_set_func
=
self
.
_page_set_zero_copy
and
self
.
mem_pool_host
.
layout
==
"page_first"
)
if
self
.
storage_backend_type
==
"mooncake"
:
self
.
page_get_func
=
self
.
_mooncake_page_get
self
.
page_set_func
=
self
.
_mooncake_page_set
elif
self
.
is_3fs_zerocopy
:
self
.
page_get_func
=
self
.
_3fs_zero_copy_page_get
self
.
page_set_func
=
self
.
_3fs_zero_copy_page_set
self
.
batch_exists_func
=
self
.
_3fs_zero_copy_batch_exists
self
.
device
=
self
.
mem_pool_device
.
device
self
.
device
=
self
.
mem_pool_device
.
device
self
.
layer_num
=
self
.
mem_pool_device
.
layer_num
self
.
layer_num
=
self
.
mem_pool_device
.
layer_num
...
@@ -436,6 +470,7 @@ class HiCacheController:
...
@@ -436,6 +470,7 @@ class HiCacheController:
host_indices
=
self
.
mem_pool_host
.
alloc
(
len
(
device_indices
))
host_indices
=
self
.
mem_pool_host
.
alloc
(
len
(
device_indices
))
if
host_indices
is
None
:
if
host_indices
is
None
:
return
None
return
None
self
.
mem_pool_host
.
protect_write
(
host_indices
)
self
.
write_queue
.
append
(
self
.
write_queue
.
append
(
CacheOperation
(
host_indices
,
device_indices
,
node_id
,
priority
)
CacheOperation
(
host_indices
,
device_indices
,
node_id
,
priority
)
)
)
...
@@ -459,6 +494,7 @@ class HiCacheController:
...
@@ -459,6 +494,7 @@ class HiCacheController:
self
.
mem_pool_host
.
backup_from_device_all_layer
(
self
.
mem_pool_host
.
backup_from_device_all_layer
(
self
.
mem_pool_device
,
host_indices
,
device_indices
,
self
.
io_backend
self
.
mem_pool_device
,
host_indices
,
device_indices
,
self
.
io_backend
)
)
self
.
mem_pool_host
.
complete_io
(
op
.
host_indices
)
finish_event
.
record
()
finish_event
.
record
()
# NOTE: We must save the host indices and device indices here,
# NOTE: We must save the host indices and device indices here,
# this is because we need to guarantee that these tensors are
# this is because we need to guarantee that these tensors are
...
@@ -482,6 +518,7 @@ class HiCacheController:
...
@@ -482,6 +518,7 @@ class HiCacheController:
device_indices
=
self
.
mem_pool_device_allocator
.
alloc
(
len
(
host_indices
))
device_indices
=
self
.
mem_pool_device_allocator
.
alloc
(
len
(
host_indices
))
if
device_indices
is
None
:
if
device_indices
is
None
:
return
None
return
None
self
.
mem_pool_host
.
protect_load
(
host_indices
)
self
.
load_queue
.
append
(
self
.
load_queue
.
append
(
CacheOperation
(
host_indices
,
device_indices
,
node_id
,
priority
)
CacheOperation
(
host_indices
,
device_indices
,
node_id
,
priority
)
)
)
...
@@ -526,6 +563,7 @@ class HiCacheController:
...
@@ -526,6 +563,7 @@ class HiCacheController:
self
.
io_backend
,
self
.
io_backend
,
)
)
producer_event
.
complete
(
i
)
producer_event
.
complete
(
i
)
self
.
mem_pool_host
.
complete_io
(
op
.
host_indices
)
# NOTE: We must save the host indices and device indices here,
# NOTE: We must save the host indices and device indices here,
# this is because we need to guarantee that these tensors are
# this is because we need to guarantee that these tensors are
# still alive when the load stream is executing.
# still alive when the load stream is executing.
...
@@ -543,16 +581,29 @@ class HiCacheController:
...
@@ -543,16 +581,29 @@ class HiCacheController:
)
)
return
producer_id
return
producer_id
def
evict_device
(
self
,
device_indices
:
torch
.
Tensor
)
->
int
:
def
evict_device
(
self
.
mem_pool_device_allocator
.
free
(
device_indices
)
self
,
device_indices
:
torch
.
Tensor
,
host_indices
:
torch
.
Tensor
return
len
(
device_indices
)
)
->
int
:
if
self
.
mem_pool_host
.
is_synced
(
host_indices
):
self
.
mem_pool_device_allocator
.
free
(
device_indices
)
self
.
mem_pool_host
.
update_backup
(
host_indices
)
return
len
(
device_indices
)
else
:
raise
ValueError
(
f
"Inconsistent states:
{
self
.
mem_pool_host
.
get_state
(
host_indices
)
}
"
)
def
evict_host
(
self
,
host_indices
:
torch
.
Tensor
,
backup_only
:
bool
=
True
)
->
int
:
def
evict_host
(
self
,
host_indices
:
torch
.
Tensor
,
backup_only
:
bool
=
True
)
->
int
:
if
not
backup_only
:
if
not
backup_only
:
raise
ValueError
(
"Other eviction policies are not supported yet."
)
raise
ValueError
(
"Other eviction policies are not supported yet."
)
self
.
mem_pool_host
.
free
(
host_indices
)
if
self
.
mem_pool_host
.
is_backup
(
host_indices
):
return
len
(
host_indices
)
self
.
mem_pool_host
.
free
(
host_indices
)
return
len
(
host_indices
)
else
:
raise
ValueError
(
f
"Inconsistent states:
{
self
.
mem_pool_host
.
get_state
(
host_indices
)
}
"
)
def
prefetch
(
def
prefetch
(
self
,
self
,
...
@@ -579,19 +630,42 @@ class HiCacheController:
...
@@ -579,19 +630,42 @@ class HiCacheController:
for
chunk
in
chunks
:
for
chunk
in
chunks
:
self
.
host_mem_release_queue
.
put
(
chunk
)
self
.
host_mem_release_queue
.
put
(
chunk
)
def
_page_get_zero_copy
(
self
,
operation
,
hash_values
,
host_indices
):
def
_3fs_zero_copy_batch_exists
(
self
,
batch_hashes
):
results
=
self
.
storage_backend
.
batch_get_v1
(
hash_values
,
host_indices
)
_batch_hashes
,
_
,
factor
=
self
.
mem_pool_host
.
get_buffer_with_hash
(
batch_hashes
)
inc
=
0
hit_page_num
=
self
.
storage_backend
.
batch_exists
(
_batch_hashes
)
//
factor
for
i
in
range
(
len
(
hash_values
)):
return
hit_page_num
if
not
results
[
i
]:
logger
.
warning
(
def
_3fs_zero_copy_page_get
(
self
,
operation
,
hash_values
,
host_indices
):
f
"Prefetch operation
{
operation
.
request_id
}
failed to retrieve page
{
hash_values
[
i
]
}
."
hashes
,
dsts
,
factor
=
self
.
mem_pool_host
.
get_buffer_with_hash
(
)
hash_values
,
host_indices
break
)
inc
+=
self
.
page_size
page_data
=
self
.
storage_backend
.
batch_get
(
hashes
,
dsts
)
operation
.
increment
(
inc
)
if
page_data
:
inc
=
self
.
page_size
*
len
(
hashes
)
//
factor
operation
.
increment
(
inc
)
else
:
logger
.
warning
(
f
"Prefetch operation
{
operation
.
request_id
}
failed to retrieve page
{
hashes
}
."
)
def
_mooncake_page_get
(
self
,
operation
,
hash_values
,
host_indices
):
key_strs
,
buffer_ptrs
,
buffer_sizes
=
self
.
mem_pool_host
.
get_buffer_meta
(
hash_values
,
host_indices
,
self
.
storage_config
.
tp_rank
,
)
get_result
=
self
.
storage_backend
.
batch_get
(
key_strs
,
target_locations
=
buffer_ptrs
,
target_sizes
=
buffer_sizes
,
)
if
get_result
!=
len
(
hash_values
):
logger
.
warning
(
f
"Prefetch operation
{
operation
.
request_id
}
failed or partially failed."
)
if
get_result
!=
0
:
operation
.
increment
(
get_result
*
self
.
page_size
)
# todo: deprecate
def
_generic_page_get
(
self
,
operation
,
hash_values
,
host_indices
):
def
_generic_page_get
(
self
,
operation
,
hash_values
,
host_indices
):
dummy_page_dst
=
[
dummy_page_dst
=
[
self
.
mem_pool_host
.
get_dummy_flat_data_page
()
for
_
in
hash_values
self
.
mem_pool_host
.
get_dummy_flat_data_page
()
for
_
in
hash_values
...
@@ -681,7 +755,7 @@ class HiCacheController:
...
@@ -681,7 +755,7 @@ class HiCacheController:
batch_tokens
[
i
:
i
+
self
.
page_size
],
last_hash
batch_tokens
[
i
:
i
+
self
.
page_size
],
last_hash
)
)
batch_hashes
.
append
(
last_hash
)
batch_hashes
.
append
(
last_hash
)
hit_page_num
=
self
.
storage_backend
.
batch_exists
(
batch_hashes
)
hit_page_num
=
self
.
batch_exists
_func
(
batch_hashes
)
hash_value
.
extend
(
batch_hashes
[:
hit_page_num
])
hash_value
.
extend
(
batch_hashes
[:
hit_page_num
])
storage_query_count
+=
hit_page_num
*
self
.
page_size
storage_query_count
+=
hit_page_num
*
self
.
page_size
if
hit_page_num
<
len
(
batch_hashes
):
if
hit_page_num
<
len
(
batch_hashes
):
...
@@ -750,16 +824,34 @@ class HiCacheController:
...
@@ -750,16 +824,34 @@ class HiCacheController:
self
.
backup_queue
.
put
(
operation
)
self
.
backup_queue
.
put
(
operation
)
return
operation
.
id
return
operation
.
id
#
todo: deprecate
#
non-zero copy
def
_generic_page_set
(
self
,
hash_values
,
host_indices
)
->
bool
:
def
_generic_page_set
(
self
,
hash_values
,
host_indices
)
->
bool
:
data
=
[
data
=
[
self
.
mem_pool_host
.
get_data_page
(
host_indices
[
i
*
self
.
page_size
])
self
.
mem_pool_host
.
get_
flat_
data_page
(
host_indices
[
i
*
self
.
page_size
])
for
i
in
range
(
len
(
hash_values
))
for
i
in
range
(
len
(
hash_values
))
]
]
return
self
.
storage_backend
.
batch_set
(
hash_values
,
data
)
return
self
.
storage_backend
.
batch_set
(
hash_values
,
data
)
def
_page_set_zero_copy
(
self
,
hash_values
,
host_indices
)
->
bool
:
# zero copy
return
all
(
self
.
storage_backend
.
batch_set_v1
(
hash_values
,
host_indices
))
def
_mooncake_page_set
(
self
,
hash_values
,
host_indices
)
->
bool
:
key_strs
,
buffer_ptrs
,
buffer_sizes
=
self
.
mem_pool_host
.
get_buffer_meta
(
hash_values
,
host_indices
,
self
.
storage_config
.
tp_rank
,
)
success
=
self
.
storage_backend
.
batch_set
(
key_strs
,
target_locations
=
buffer_ptrs
,
target_sizes
=
buffer_sizes
,
)
return
success
# zero copy
def
_3fs_zero_copy_page_set
(
self
,
hash_values
,
host_indices
)
->
bool
:
hashes
,
dsts
,
_
=
self
.
mem_pool_host
.
get_buffer_with_hash
(
hash_values
,
host_indices
)
return
self
.
storage_backend
.
batch_set
(
hashes
,
dsts
)
# Backup batch by batch
# Backup batch by batch
def
_page_backup
(
self
,
operation
):
def
_page_backup
(
self
,
operation
):
...
...
python/sglang/srt/managers/io_struct.py
View file @
852a49c5
...
@@ -35,7 +35,6 @@ else:
...
@@ -35,7 +35,6 @@ else:
Image
=
Any
Image
=
Any
# Parameters for a session
@
dataclass
@
dataclass
class
SessionParams
:
class
SessionParams
:
id
:
Optional
[
str
]
=
None
id
:
Optional
[
str
]
=
None
...
@@ -133,23 +132,18 @@ class GenerateReqInput:
...
@@ -133,23 +132,18 @@ class GenerateReqInput:
# Conversation id used for tracking requests
# Conversation id used for tracking requests
conversation_id
:
Optional
[
str
]
=
None
conversation_id
:
Optional
[
str
]
=
None
# Label for the request
label
:
Optional
[
str
]
=
None
# Priority for the request
# Priority for the request
priority
:
Optional
[
int
]
=
None
priority
:
Optional
[
int
]
=
None
# Extra key for classifying the request (e.g. cache_salt)
# Image gen grpc migration
extra_key
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# Whether to disallow logging for this request (e.g. due to ZDR)
no_logs
:
bool
=
False
# For custom metric labels
custom_labels
:
Optional
[
Dict
[
str
,
str
]]
=
None
# (Deprecated, please use custom_labels) Label for the request
label
:
Optional
[
str
]
=
None
# (Internal) Whether to return bytes for image generation
return_bytes
:
bool
=
False
return_bytes
:
bool
=
False
# For customer metric labels
customer_labels
:
Optional
[
Dict
[
str
,
str
]]
=
None
def
contains_mm_input
(
self
)
->
bool
:
def
contains_mm_input
(
self
)
->
bool
:
return
(
return
(
has_valid_data
(
self
.
image_data
)
has_valid_data
(
self
.
image_data
)
...
@@ -548,11 +542,8 @@ class GenerateReqInput:
...
@@ -548,11 +542,8 @@ class GenerateReqInput:
self
.
data_parallel_rank
if
self
.
data_parallel_rank
is
not
None
else
None
self
.
data_parallel_rank
if
self
.
data_parallel_rank
is
not
None
else
None
),
),
conversation_id
=
self
.
conversation_id
,
conversation_id
=
self
.
conversation_id
,
priority
=
self
.
priority
,
extra_key
=
self
.
extra_key
,
no_logs
=
self
.
no_logs
,
custom_labels
=
self
.
custom_labels
,
label
=
self
.
label
,
label
=
self
.
label
,
priority
=
self
.
priority
,
return_bytes
=
self
.
return_bytes
,
return_bytes
=
self
.
return_bytes
,
)
)
...
@@ -609,23 +600,18 @@ class TokenizedGenerateReqInput:
...
@@ -609,23 +600,18 @@ class TokenizedGenerateReqInput:
# For dp balance
# For dp balance
dp_balance_id
:
int
=
-
1
dp_balance_id
:
int
=
-
1
# Label for the request
label
:
Optional
[
str
]
=
None
# Priority for the request
# Priority for the request
priority
:
Optional
[
int
]
=
None
priority
:
Optional
[
int
]
=
None
# Extra key for classifying the request (e.g. cache_salt)
# Image gen grpc migration
extra_key
:
Optional
[
str
]
=
None
return_bytes
:
bool
=
False
# Whether to disallow logging for this request (e.g. due to ZDR)
no_logs
:
bool
=
False
# tracing context
# tracing context
trace_context
:
Optional
[
Dict
]
=
None
trace_context
:
Optional
[
Dict
]
=
None
# (Deprecated, please use custom_labels) Label for the request
label
:
Optional
[
str
]
=
None
# (Internal) Whether to return bytes for image generation
return_bytes
:
bool
=
False
@
dataclass
@
dataclass
class
BatchTokenizedGenerateReqInput
:
class
BatchTokenizedGenerateReqInput
:
...
...
python/sglang/srt/managers/mm_utils.py
View file @
852a49c5
...
@@ -507,7 +507,6 @@ def embed_mm_inputs(
...
@@ -507,7 +507,6 @@ def embed_mm_inputs(
Modality
,
Callable
[[
List
[
MultimodalDataItem
]],
torch
.
Tensor
]
Modality
,
Callable
[[
List
[
MultimodalDataItem
]],
torch
.
Tensor
]
]
=
None
,
]
=
None
,
placeholder_tokens
:
dict
[
Modality
,
List
[
int
]]
=
None
,
placeholder_tokens
:
dict
[
Modality
,
List
[
int
]]
=
None
,
use_deepstack
:
bool
=
False
,
)
->
Optional
[
torch
.
Tensor
]:
)
->
Optional
[
torch
.
Tensor
]:
"""
"""
Embed multimodal inputs and integrate them with text token embeddings.
Embed multimodal inputs and integrate them with text token embeddings.
...
@@ -523,7 +522,7 @@ def embed_mm_inputs(
...
@@ -523,7 +522,7 @@ def embed_mm_inputs(
Returns:
Returns:
Combined embedding tensor with multimodal content integrated
Combined embedding tensor with multimodal content integrated
"""
"""
other_info
=
{}
if
mm_inputs_list
is
None
:
if
mm_inputs_list
is
None
:
return
None
return
None
...
@@ -533,7 +532,7 @@ def embed_mm_inputs(
...
@@ -533,7 +532,7 @@ def embed_mm_inputs(
for
mm_inputs
in
mm_inputs_list
:
for
mm_inputs
in
mm_inputs_list
:
item_flatten_list
+=
[
item
for
item
in
mm_inputs
.
mm_items
if
item
is
not
None
]
item_flatten_list
+=
[
item
for
item
in
mm_inputs
.
mm_items
if
item
is
not
None
]
embeddings
,
masks
,
deepstack_embeddings
=
[],
[],
[]
embeddings
,
masks
=
[],
[]
# 2. Get multimodal embedding separately
# 2. Get multimodal embedding separately
# Try get mm embedding if any
# Try get mm embedding if any
for
modality
in
Modality
.
all
():
for
modality
in
Modality
.
all
():
...
@@ -579,12 +578,6 @@ def embed_mm_inputs(
...
@@ -579,12 +578,6 @@ def embed_mm_inputs(
extend_length
=
extend_seq_lens
,
extend_length
=
extend_seq_lens
,
items_offset_list
=
items_offsets
,
items_offset_list
=
items_offsets
,
)
)
if
use_deepstack
and
embedding
is
not
None
:
embedding
,
deepstack_embedding
=
(
multimodal_model
.
separate_deepstack_embeds
(
embedding
)
)
deepstack_embeddings
+=
[
deepstack_embedding
]
embeddings
+=
[
embedding
]
embeddings
+=
[
embedding
]
masks
+=
[
mask
]
masks
+=
[
mask
]
...
@@ -598,37 +591,13 @@ def embed_mm_inputs(
...
@@ -598,37 +591,13 @@ def embed_mm_inputs(
inputs_embeds
=
input_embedding
(
input_ids
)
inputs_embeds
=
input_embedding
(
input_ids
)
# 4. scatter embeddings into input embedding
# 4. scatter embeddings into input embedding
for
embedding
,
mask
in
zip
(
embeddings
,
masks
):
# deepstack embedding
if
use_deepstack
:
num_deepstack_embeddings
=
(
len
(
multimodal_model
.
deepstack_visual_indexes
)
if
use_deepstack
else
0
)
deepstack_embedding_shape
=
inputs_embeds
.
shape
[:
-
1
]
+
(
inputs_embeds
.
shape
[
-
1
]
*
num_deepstack_embeddings
,
)
input_deepstack_embeds
=
torch
.
zeros
(
deepstack_embedding_shape
,
device
=
inputs_embeds
.
device
,
dtype
=
inputs_embeds
.
dtype
,
)
other_info
[
"input_deepstack_embeds"
]
=
input_deepstack_embeds
for
i
,
embedding
,
mask
in
zip
(
range
(
len
(
embeddings
)),
embeddings
,
masks
):
if
embedding
is
None
or
mask
is
None
:
if
embedding
is
None
or
mask
is
None
:
continue
continue
# in-place update
# in-place update
indices
=
torch
.
where
(
mask
.
squeeze
(
dim
=-
1
))[
0
]
indices
=
torch
.
where
(
mask
.
squeeze
(
dim
=-
1
))[
0
]
inputs_embeds
[
indices
]
=
embedding
.
to
(
inputs_embeds
.
device
,
inputs_embeds
.
dtype
)
inputs_embeds
[
indices
]
=
embedding
.
to
(
inputs_embeds
.
device
,
inputs_embeds
.
dtype
)
return
inputs_embeds
if
use_deepstack
:
input_deepstack_embeds
[
indices
]
=
deepstack_embeddings
[
i
].
to
(
inputs_embeds
.
device
,
inputs_embeds
.
dtype
)
return
inputs_embeds
,
other_info
def
general_mm_embed_routine
(
def
general_mm_embed_routine
(
...
@@ -640,7 +609,6 @@ def general_mm_embed_routine(
...
@@ -640,7 +609,6 @@ def general_mm_embed_routine(
Modality
,
Callable
[[
List
[
MultimodalDataItem
]],
torch
.
Tensor
]
Modality
,
Callable
[[
List
[
MultimodalDataItem
]],
torch
.
Tensor
]
]
=
None
,
]
=
None
,
placeholder_tokens
:
Optional
[
dict
[
Modality
,
List
[
int
]]]
=
None
,
placeholder_tokens
:
Optional
[
dict
[
Modality
,
List
[
int
]]]
=
None
,
use_deepstack
:
bool
=
False
,
**
kwargs
,
**
kwargs
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
...
@@ -652,7 +620,6 @@ def general_mm_embed_routine(
...
@@ -652,7 +620,6 @@ def general_mm_embed_routine(
language_model: Base language model to use
language_model: Base language model to use
data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
placeholder_tokens: Token IDs for multimodal placeholders
placeholder_tokens: Token IDs for multimodal placeholders
use_deepstack: Whether to use deepstack embeddings
**kwargs: Additional arguments passed to language model
**kwargs: Additional arguments passed to language model
Returns:
Returns:
...
@@ -678,20 +645,16 @@ def general_mm_embed_routine(
...
@@ -678,20 +645,16 @@ def general_mm_embed_routine(
for
i
,
seq_len
in
enumerate
(
forward_batch
.
extend_seq_lens_cpu
)
for
i
,
seq_len
in
enumerate
(
forward_batch
.
extend_seq_lens_cpu
)
if
forward_batch
.
mm_inputs
[
i
]
is
not
None
if
forward_batch
.
mm_inputs
[
i
]
is
not
None
]
]
inputs_embeds
,
other_info
=
embed_mm_inputs
(
inputs_embeds
=
embed_mm_inputs
(
mm_inputs_list
=
mm_inputs_list
,
mm_inputs_list
=
mm_inputs_list
,
extend_prefix_lens
=
extend_prefix_lens
,
extend_prefix_lens
=
extend_prefix_lens
,
extend_seq_lens
=
extend_seq_lens
,
extend_seq_lens
=
extend_seq_lens
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
multimodal_model
=
multimodal_model
,
input_embedding
=
embed_tokens
,
input_embedding
=
embed_tokens
,
multimodal_model
=
multimodal_model
,
data_embedding_func_mapping
=
data_embedding_funcs
,
data_embedding_func_mapping
=
data_embedding_funcs
,
placeholder_tokens
=
placeholder_tokens
,
placeholder_tokens
=
placeholder_tokens
,
use_deepstack
=
use_deepstack
,
)
)
# add for qwen3_vl deepstack
if
use_deepstack
:
kwargs
[
"input_deepstack_embeds"
]
=
other_info
[
"input_deepstack_embeds"
]
# once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
# once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
# just being defensive here
# just being defensive here
forward_batch
.
mm_inputs
=
None
forward_batch
.
mm_inputs
=
None
...
...
python/sglang/srt/managers/multimodal_processor.py
View file @
852a49c5
...
@@ -12,7 +12,8 @@ logger = logging.getLogger(__name__)
...
@@ -12,7 +12,8 @@ logger = logging.getLogger(__name__)
PROCESSOR_MAPPING
=
{}
PROCESSOR_MAPPING
=
{}
def
import_processors
(
package_name
:
str
):
def
import_processors
():
package_name
=
"sglang.srt.multimodal.processors"
package
=
importlib
.
import_module
(
package_name
)
package
=
importlib
.
import_module
(
package_name
)
for
_
,
name
,
ispkg
in
pkgutil
.
iter_modules
(
package
.
__path__
,
package_name
+
"."
):
for
_
,
name
,
ispkg
in
pkgutil
.
iter_modules
(
package
.
__path__
,
package_name
+
"."
):
if
not
ispkg
:
if
not
ispkg
:
...
...
python/sglang/srt/managers/overlap_utils.py
deleted
100644 → 0
View file @
8f7453e3
import
torch
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.utils
import
get_compiler_backend
@
torch
.
compile
(
dynamic
=
True
,
backend
=
get_compiler_backend
())
def
_resolve_future_token_ids
(
input_ids
,
future_token_ids_map
):
input_ids
[:]
=
torch
.
where
(
input_ids
<
0
,
future_token_ids_map
[
torch
.
clamp
(
-
input_ids
,
min
=
0
)],
input_ids
,
)
class
FutureMap
:
def
__init__
(
self
,
max_running_requests
:
int
,
device
:
torch
.
device
,
):
self
.
future_ct
=
0
# A factor of 3 is used to avoid collision in the circular buffer.
self
.
future_limit
=
max_running_requests
*
3
# A factor of 5 is used to ensure the buffer is large enough.
self
.
future_buffer_len
=
max_running_requests
*
5
self
.
device
=
device
self
.
token_ids_buf
=
torch
.
empty
(
(
self
.
future_buffer_len
,),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
def
update_ct
(
self
,
bs
:
int
)
->
int
:
"""Update the circular buffer pointer and return the current pointer."""
cur_future_ct
=
self
.
future_ct
self
.
future_ct
=
(
cur_future_ct
+
bs
)
%
self
.
future_limit
return
cur_future_ct
def
resolve_future
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
input_ids
=
model_worker_batch
.
input_ids
_resolve_future_token_ids
(
input_ids
,
self
.
token_ids_buf
)
def
update_next_future
(
self
,
future_ct
:
int
,
bs
:
int
):
return
torch
.
arange
(
-
(
future_ct
+
1
),
-
(
future_ct
+
1
+
bs
),
-
1
,
dtype
=
torch
.
int64
,
device
=
self
.
device
,
)
def
store_to_map
(
self
,
future_ct
:
int
,
bs
:
int
,
next_token_ids
:
torch
.
Tensor
):
self
.
token_ids_buf
[
future_ct
+
1
:
future_ct
+
bs
+
1
]
=
next_token_ids
python/sglang/srt/managers/schedule_batch.py
View file @
852a49c5
...
@@ -67,14 +67,14 @@ from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
...
@@ -67,14 +67,14 @@ from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from
sglang.srt.metrics.collector
import
SchedulerMetricsCollector
,
TimeStats
from
sglang.srt.metrics.collector
import
SchedulerMetricsCollector
,
TimeStats
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
,
ForwardMode
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.sampling.sampling_params
import
DEFAULT_SAMPLING_SEED
,
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
flatten_nested_list
,
support_triton
from
sglang.srt.utils
import
flatten_nested_list
,
support_triton
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.
ngram
_utils
import
Ngram
VerifyInput
from
sglang.srt.speculative.
lookahead
_utils
import
Lookahead
VerifyInput
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
...
@@ -90,7 +90,6 @@ GLOBAL_SERVER_ARGS_KEYS = [
...
@@ -90,7 +90,6 @@ GLOBAL_SERVER_ARGS_KEYS = [
"disable_flashinfer_cutlass_moe_fp4_allgather"
,
"disable_flashinfer_cutlass_moe_fp4_allgather"
,
"disable_radix_cache"
,
"disable_radix_cache"
,
"enable_dp_lm_head"
,
"enable_dp_lm_head"
,
"enable_fp32_lm_head"
,
"flashinfer_mxfp4_moe_precision"
,
"flashinfer_mxfp4_moe_precision"
,
"enable_flashinfer_allreduce_fusion"
,
"enable_flashinfer_allreduce_fusion"
,
"moe_dense_tp_size"
,
"moe_dense_tp_size"
,
...
@@ -113,6 +112,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
...
@@ -113,6 +112,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
"enable_custom_logit_processor"
,
"enable_custom_logit_processor"
,
"disaggregation_mode"
,
"disaggregation_mode"
,
"enable_deterministic_inference"
,
"enable_deterministic_inference"
,
"nsa_prefill"
,
"nsa_decode"
,
]
]
# Put some global args for easy access
# Put some global args for easy access
...
@@ -492,7 +493,7 @@ class Req:
...
@@ -492,7 +493,7 @@ class Req:
self
.
custom_logit_processor
=
custom_logit_processor
self
.
custom_logit_processor
=
custom_logit_processor
self
.
return_hidden_states
=
return_hidden_states
self
.
return_hidden_states
=
return_hidden_states
# extra key for classifying the request (e.g. cache_salt)
# extra key for classifying the request (e.g.
lora_id,
cache_salt)
if
lora_id
is
not
None
:
if
lora_id
is
not
None
:
extra_key
=
(
extra_key
=
(
extra_key
or
""
extra_key
or
""
...
@@ -608,8 +609,6 @@ class Req:
...
@@ -608,8 +609,6 @@ class Req:
)
=
None
)
=
None
self
.
hidden_states
:
List
[
List
[
float
]]
=
[]
self
.
hidden_states
:
List
[
List
[
float
]]
=
[]
self
.
hidden_states_tensor
=
None
# Note: use tensor instead of list to transfer hidden_states when PD + MTP
self
.
hidden_states_tensor
=
None
# Note: use tensor instead of list to transfer hidden_states when PD + MTP
self
.
output_topk_p
=
None
self
.
output_topk_index
=
None
# Embedding (return values)
# Embedding (return values)
self
.
embedding
=
None
self
.
embedding
=
None
...
@@ -954,9 +953,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -954,9 +953,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Speculative decoding
# Speculative decoding
spec_algorithm
:
SpeculativeAlgorithm
=
None
spec_algorithm
:
SpeculativeAlgorithm
=
None
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerifyInput
]]
=
(
spec_info
:
Optional
[
None
Union
[
EagleDraftInput
,
EagleVerifyInput
,
LookaheadVerifyInput
]
)
]
=
None
# Whether to return hidden states
# Whether to return hidden states
return_hidden_states
:
bool
=
False
return_hidden_states
:
bool
=
False
...
@@ -1609,7 +1608,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1609,7 +1608,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if
(
if
(
self
.
spec_algorithm
.
is_eagle
()
self
.
spec_algorithm
.
is_eagle
()
or
self
.
spec_algorithm
.
is_standalone
()
or
self
.
spec_algorithm
.
is_standalone
()
or
self
.
spec_algorithm
.
is_
ngram
()
or
self
.
spec_algorithm
.
is_
lookahead
()
):
):
# if spec decoding is used, the decode batch is prepared inside
# if spec decoding is used, the decode batch is prepared inside
# `forward_batch_speculative_generation` after running draft models.
# `forward_batch_speculative_generation` after running draft models.
...
@@ -1736,14 +1735,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1736,14 +1735,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
sampling_info
.
filter_batch
(
keep_indices
,
keep_indices_device
)
self
.
sampling_info
.
filter_batch
(
keep_indices
,
keep_indices_device
)
if
self
.
spec_info
:
if
self
.
spec_info
:
if
chunked_req_to_exclude
is
not
None
and
len
(
chunked_req_to_exclude
)
>
0
:
self
.
spec_info
.
filter_batch
(
keep_indices_device
)
has_been_filtered
=
False
else
:
has_been_filtered
=
True
self
.
spec_info
.
filter_batch
(
new_indices
=
keep_indices_device
,
has_been_filtered
=
has_been_filtered
,
)
def
merge_batch
(
self
,
other
:
"ScheduleBatch"
):
def
merge_batch
(
self
,
other
:
"ScheduleBatch"
):
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
...
@@ -1992,9 +1984,9 @@ class ModelWorkerBatch:
...
@@ -1992,9 +1984,9 @@ class ModelWorkerBatch:
# Speculative decoding
# Speculative decoding
spec_algorithm
:
SpeculativeAlgorithm
=
None
spec_algorithm
:
SpeculativeAlgorithm
=
None
spec_info
:
Optional
[
Union
[
EagleVerifyInput
,
EagleDraftInput
,
NgramVerifyInput
]]
=
(
spec_info
:
Optional
[
None
Union
[
EagleVerifyInput
,
EagleDraftInput
,
LookaheadVerifyInput
]
)
]
=
None
# If set, the output of the batch contains the hidden states of the run.
# If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode
:
CaptureHiddenMode
=
None
capture_hidden_mode
:
CaptureHiddenMode
=
None
hicache_consumer_index
:
int
=
-
1
hicache_consumer_index
:
int
=
-
1
...
...
python/sglang/srt/managers/schedule_policy.py
View file @
852a49c5
...
@@ -318,6 +318,7 @@ class PrefillAdder:
...
@@ -318,6 +318,7 @@ class PrefillAdder:
new_token_ratio
:
float
,
new_token_ratio
:
float
,
rem_input_tokens
:
int
,
rem_input_tokens
:
int
,
rem_chunk_tokens
:
Optional
[
int
],
rem_chunk_tokens
:
Optional
[
int
],
max_prefill_bs
:
Optional
[
int
],
mixed_with_decode_tokens
:
int
=
0
,
mixed_with_decode_tokens
:
int
=
0
,
priority_scheduling_preemption_threshold
:
int
=
0
,
priority_scheduling_preemption_threshold
:
int
=
0
,
):
):
...
@@ -358,6 +359,10 @@ class PrefillAdder:
...
@@ -358,6 +359,10 @@ class PrefillAdder:
priority_scheduling_preemption_threshold
priority_scheduling_preemption_threshold
)
)
self
.
max_prefill_bs
=
(
max_prefill_bs
if
max_prefill_bs
is
not
None
else
2147483647
)
def
_get_running_request_total_token_offset
(
self
,
req
:
Req
)
->
int
:
def
_get_running_request_total_token_offset
(
self
,
req
:
Req
)
->
int
:
return
(
return
(
min
(
min
(
...
@@ -549,6 +554,9 @@ class PrefillAdder:
...
@@ -549,6 +554,9 @@ class PrefillAdder:
def
add_one_req
(
def
add_one_req
(
self
,
req
:
Req
,
has_chunked_req
:
bool
,
truncation_align_size
:
Optional
[
int
]
self
,
req
:
Req
,
has_chunked_req
:
bool
,
truncation_align_size
:
Optional
[
int
]
):
):
if
len
(
self
.
can_run_list
)
>=
self
.
max_prefill_bs
:
return
AddReqResult
.
OTHER
if
req
.
sampling_params
.
ignore_eos
and
getattr
(
self
.
tree_cache
,
"disable"
,
True
):
if
req
.
sampling_params
.
ignore_eos
and
getattr
(
self
.
tree_cache
,
"disable"
,
True
):
return
self
.
add_one_req_ignore_eos
(
req
,
has_chunked_req
)
return
self
.
add_one_req_ignore_eos
(
req
,
has_chunked_req
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
852a49c5
...
@@ -44,9 +44,6 @@ from sglang.srt.disaggregation.decode import (
...
@@ -44,9 +44,6 @@ from sglang.srt.disaggregation.decode import (
DecodeTransferQueue
,
DecodeTransferQueue
,
SchedulerDisaggregationDecodeMixin
,
SchedulerDisaggregationDecodeMixin
,
)
)
from
sglang.srt.disaggregation.decode_kvcache_offload_manager
import
(
DecodeKVCacheOffloadManager
,
)
from
sglang.srt.disaggregation.prefill
import
(
from
sglang.srt.disaggregation.prefill
import
(
PrefillBootstrapQueue
,
PrefillBootstrapQueue
,
SchedulerDisaggregationPrefillMixin
,
SchedulerDisaggregationPrefillMixin
,
...
@@ -262,7 +259,7 @@ class Scheduler(
...
@@ -262,7 +259,7 @@ class Scheduler(
self
.
enable_metrics_for_all_schedulers
=
(
self
.
enable_metrics_for_all_schedulers
=
(
server_args
.
enable_metrics_for_all_schedulers
server_args
.
enable_metrics_for_all_schedulers
)
)
self
.
enable_kv_cache_events
=
server_args
.
kv_events_config
and
tp_rank
==
0
self
.
enable_kv_cache_events
=
server_args
.
kv_events_config
is
not
None
self
.
stream_interval
=
server_args
.
stream_interval
self
.
stream_interval
=
server_args
.
stream_interval
self
.
spec_algorithm
=
SpeculativeAlgorithm
.
from_string
(
self
.
spec_algorithm
=
SpeculativeAlgorithm
.
from_string
(
server_args
.
speculative_algorithm
server_args
.
speculative_algorithm
...
@@ -388,10 +385,10 @@ class Scheduler(
...
@@ -388,10 +385,10 @@ class Scheduler(
target_worker
=
self
.
tp_worker
,
target_worker
=
self
.
tp_worker
,
dp_rank
=
dp_rank
,
dp_rank
=
dp_rank
,
)
)
elif
self
.
spec_algorithm
.
is_
ngram
():
elif
self
.
spec_algorithm
.
is_
lookahead
():
from
sglang.srt.speculative.
ngram
_worker
import
NGRAM
Worker
from
sglang.srt.speculative.
lookahead
_worker
import
LOOKAHEAD
Worker
self
.
draft_worker
=
NGRAM
Worker
(
self
.
draft_worker
=
LOOKAHEAD
Worker
(
gpu_id
=
gpu_id
,
gpu_id
=
gpu_id
,
tp_rank
=
tp_rank
,
tp_rank
=
tp_rank
,
moe_ep_rank
=
moe_ep_rank
,
moe_ep_rank
=
moe_ep_rank
,
...
@@ -556,11 +553,9 @@ class Scheduler(
...
@@ -556,11 +553,9 @@ class Scheduler(
# Init metrics stats
# Init metrics stats
self
.
init_metrics
(
tp_rank
,
pp_rank
,
dp_rank
)
self
.
init_metrics
(
tp_rank
,
pp_rank
,
dp_rank
)
self
.
init_kv_events
(
server_args
.
kv_events_config
)
self
.
init_dp_balance
(
dp_balance_meta
)
self
.
init_dp_balance
(
dp_balance_meta
)
if
self
.
enable_kv_cache_events
:
self
.
init_kv_events
(
server_args
.
kv_events_config
)
# Init disaggregation
# Init disaggregation
self
.
disaggregation_mode
=
DisaggregationMode
(
self
.
disaggregation_mode
=
DisaggregationMode
(
self
.
server_args
.
disaggregation_mode
self
.
server_args
.
disaggregation_mode
...
@@ -618,6 +613,8 @@ class Scheduler(
...
@@ -618,6 +613,8 @@ class Scheduler(
]
]
)
)
self
.
max_prefill_bs
=
server_args
.
max_prefill_bs
def
init_deterministic_inference_config
(
self
):
def
init_deterministic_inference_config
(
self
):
"""Initialize deterministic inference configuration for different attention backends."""
"""Initialize deterministic inference configuration for different attention backends."""
if
not
self
.
server_args
.
enable_deterministic_inference
:
if
not
self
.
server_args
.
enable_deterministic_inference
:
...
@@ -758,24 +755,6 @@ class Scheduler(
...
@@ -758,24 +755,6 @@ class Scheduler(
eviction_policy
=
server_args
.
radix_eviction_policy
,
eviction_policy
=
server_args
.
radix_eviction_policy
,
)
)
if
(
server_args
.
disaggregation_mode
==
"decode"
and
server_args
.
disaggregation_decode_enable_offload_kvcache
):
self
.
decode_offload_manager
=
DecodeKVCacheOffloadManager
(
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool_allocator
=
self
.
token_to_kv_pool_allocator
,
tp_group
=
(
self
.
attn_tp_cpu_group
if
self
.
server_args
.
enable_dp_attention
else
self
.
tp_cpu_group
),
tree_cache
=
self
.
tree_cache
,
server_args
=
self
.
server_args
,
)
else
:
self
.
decode_offload_manager
=
None
self
.
decode_mem_cache_buf_multiplier
=
(
self
.
decode_mem_cache_buf_multiplier
=
(
1
1
if
self
.
spec_algorithm
.
is_none
()
if
self
.
spec_algorithm
.
is_none
()
...
@@ -806,7 +785,7 @@ class Scheduler(
...
@@ -806,7 +785,7 @@ class Scheduler(
self
.
disagg_metadata_buffers
=
MetadataBuffers
(
self
.
disagg_metadata_buffers
=
MetadataBuffers
(
buffer_size
,
buffer_size
,
hidden_size
=
self
.
model_config
.
hf_text_config
.
hidden_size
,
hidden_size
=
self
.
model_config
.
hf_text_config
.
hidden_size
,
hidden_states_
dtype
=
self
.
model_config
.
dtype
,
dtype
=
self
.
model_config
.
dtype
,
custom_mem_pool
=
self
.
token_to_kv_pool_allocator
.
get_kvcache
().
maybe_get_custom_mem_pool
(),
custom_mem_pool
=
self
.
token_to_kv_pool_allocator
.
get_kvcache
().
maybe_get_custom_mem_pool
(),
)
)
...
@@ -826,7 +805,7 @@ class Scheduler(
...
@@ -826,7 +805,7 @@ class Scheduler(
token_to_kv_pool_allocator
=
self
.
token_to_kv_pool_allocator
,
token_to_kv_pool_allocator
=
self
.
token_to_kv_pool_allocator
,
draft_token_to_kv_pool
=
(
draft_token_to_kv_pool
=
(
None
None
if
self
.
draft_worker
is
None
or
self
.
spec_algorithm
.
is_
ngram
()
if
self
.
draft_worker
is
None
or
self
.
spec_algorithm
.
is_
lookahead
()
else
self
.
draft_worker
.
model_runner
.
token_to_kv_pool
else
self
.
draft_worker
.
model_runner
.
token_to_kv_pool
),
),
req_to_metadata_buffer_idx_allocator
=
self
.
req_to_metadata_buffer_idx_allocator
,
req_to_metadata_buffer_idx_allocator
=
self
.
req_to_metadata_buffer_idx_allocator
,
...
@@ -855,7 +834,7 @@ class Scheduler(
...
@@ -855,7 +834,7 @@ class Scheduler(
self
.
disagg_metadata_buffers
=
MetadataBuffers
(
self
.
disagg_metadata_buffers
=
MetadataBuffers
(
buffer_size
,
buffer_size
,
hidden_size
=
self
.
model_config
.
hf_text_config
.
hidden_size
,
hidden_size
=
self
.
model_config
.
hf_text_config
.
hidden_size
,
hidden_states_
dtype
=
self
.
model_config
.
dtype
,
dtype
=
self
.
model_config
.
dtype
,
custom_mem_pool
=
self
.
token_to_kv_pool_allocator
.
get_kvcache
().
maybe_get_custom_mem_pool
(),
custom_mem_pool
=
self
.
token_to_kv_pool_allocator
.
get_kvcache
().
maybe_get_custom_mem_pool
(),
)
)
...
@@ -863,7 +842,7 @@ class Scheduler(
...
@@ -863,7 +842,7 @@ class Scheduler(
token_to_kv_pool
=
self
.
token_to_kv_pool_allocator
.
get_kvcache
(),
token_to_kv_pool
=
self
.
token_to_kv_pool_allocator
.
get_kvcache
(),
draft_token_to_kv_pool
=
(
draft_token_to_kv_pool
=
(
None
None
if
self
.
draft_worker
is
None
or
self
.
spec_algorithm
.
is_
ngram
()
if
self
.
draft_worker
is
None
or
self
.
spec_algorithm
.
is_
lookahead
()
else
self
.
draft_worker
.
model_runner
.
token_to_kv_pool
else
self
.
draft_worker
.
model_runner
.
token_to_kv_pool
),
),
req_to_metadata_buffer_idx_allocator
=
self
.
req_to_metadata_buffer_idx_allocator
,
req_to_metadata_buffer_idx_allocator
=
self
.
req_to_metadata_buffer_idx_allocator
,
...
@@ -1832,6 +1811,7 @@ class Scheduler(
...
@@ -1832,6 +1811,7 @@ class Scheduler(
self
.
new_token_ratio
,
self
.
new_token_ratio
,
self
.
max_prefill_tokens
,
self
.
max_prefill_tokens
,
self
.
chunked_prefill_size
,
self
.
chunked_prefill_size
,
self
.
max_prefill_bs
,
running_bs
if
self
.
is_mixed_chunk
else
0
,
running_bs
if
self
.
is_mixed_chunk
else
0
,
self
.
priority_scheduling_preemption_threshold
,
self
.
priority_scheduling_preemption_threshold
,
)
)
...
...
python/sglang/srt/managers/scheduler_output_processor_mixin.py
View file @
852a49c5
...
@@ -250,13 +250,7 @@ class SchedulerOutputProcessorMixin:
...
@@ -250,13 +250,7 @@ class SchedulerOutputProcessorMixin:
req
.
check_finished
()
req
.
check_finished
()
if
req
.
finished
():
if
req
.
finished
():
if
self
.
server_args
.
disaggregation_decode_enable_offload_kvcache
:
self
.
tree_cache
.
cache_finished_req
(
req
)
# Asynchronously offload KV cache; cache_finished_req will be called after Device->Host transfer completes
if
not
self
.
decode_offload_manager
.
offload_kv_cache
(
req
):
self
.
tree_cache
.
cache_finished_req
(
req
)
else
:
self
.
tree_cache
.
cache_finished_req
(
req
)
req
.
time_stats
.
completion_time
=
time
.
time
()
req
.
time_stats
.
completion_time
=
time
.
time
()
if
req
.
return_logprob
and
batch
.
spec_algorithm
.
is_none
():
if
req
.
return_logprob
and
batch
.
spec_algorithm
.
is_none
():
...
...
python/sglang/srt/managers/scheduler_profiler_mixin.py
View file @
852a49c5
...
@@ -97,7 +97,7 @@ class SchedulerProfilerMixin:
...
@@ -97,7 +97,7 @@ class SchedulerProfilerMixin:
def
start_profile
(
def
start_profile
(
self
,
stage
:
Optional
[
ForwardMode
]
=
None
self
,
stage
:
Optional
[
ForwardMode
]
=
None
)
->
ProfileReqOutput
|
None
:
)
->
ProfileReqOutput
|
None
:
stage_str
=
f
" for
{
stage
.
name
}
"
if
stage
else
""
stage_str
=
f
" for
{
stage
.
__str__
()
}
"
if
stage
else
""
logger
.
info
(
logger
.
info
(
f
"Profiling starts
{
stage_str
}
. Traces will be saved to:
{
self
.
torch_profiler_output_dir
}
(with profile id:
{
self
.
profile_id
}
)"
,
f
"Profiling starts
{
stage_str
}
. Traces will be saved to:
{
self
.
torch_profiler_output_dir
}
(with profile id:
{
self
.
profile_id
}
)"
,
)
)
...
@@ -181,7 +181,7 @@ class SchedulerProfilerMixin:
...
@@ -181,7 +181,7 @@ class SchedulerProfilerMixin:
if
not
Path
(
self
.
torch_profiler_output_dir
).
exists
():
if
not
Path
(
self
.
torch_profiler_output_dir
).
exists
():
Path
(
self
.
torch_profiler_output_dir
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
Path
(
self
.
torch_profiler_output_dir
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
stage_suffix
=
f
"-
{
stage
.
name
}
"
if
stage
else
""
stage_suffix
=
f
"-
{
stage
.
__str__
()
}
"
if
stage
else
""
logger
.
info
(
"Stop profiling"
+
stage_suffix
+
"..."
)
logger
.
info
(
"Stop profiling"
+
stage_suffix
+
"..."
)
if
self
.
torch_profiler
is
not
None
:
if
self
.
torch_profiler
is
not
None
:
self
.
torch_profiler
.
stop
()
self
.
torch_profiler
.
stop
()
...
@@ -247,7 +247,7 @@ class SchedulerProfilerMixin:
...
@@ -247,7 +247,7 @@ class SchedulerProfilerMixin:
if
self
.
profiler_decode_ct
==
0
:
if
self
.
profiler_decode_ct
==
0
:
if
self
.
profile_in_progress
:
if
self
.
profile_in_progress
:
# force trace flush
# force trace flush
self
.
stop_profile
(
stage
=
ForwardMode
.
EXTEND
)
self
.
stop_profile
(
ForwardMode
.
EXTEND
)
self
.
start_profile
(
batch
.
forward_mode
)
self
.
start_profile
(
batch
.
forward_mode
)
self
.
profiler_decode_ct
+=
1
self
.
profiler_decode_ct
+=
1
if
self
.
profiler_decode_ct
>
self
.
profiler_target_decode_ct
:
if
self
.
profiler_decode_ct
>
self
.
profiler_target_decode_ct
:
...
@@ -294,6 +294,6 @@ class SchedulerProfilerMixin:
...
@@ -294,6 +294,6 @@ class SchedulerProfilerMixin:
recv_req
.
profile_by_stage
,
recv_req
.
profile_by_stage
,
recv_req
.
profile_id
,
recv_req
.
profile_id
,
)
)
return
self
.
start_profile
()
return
self
.
start_profile
(
True
)
else
:
else
:
return
self
.
stop_profile
()
return
self
.
stop_profile
()
python/sglang/srt/managers/tokenizer_manager.py
View file @
852a49c5
...
@@ -185,7 +185,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
...
@@ -185,7 +185,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
)
)
if
self
.
model_config
.
is_multimodal
:
if
self
.
model_config
.
is_multimodal
:
import_processors
(
"sglang.srt.multimodal.processors"
)
import_processors
()
try
:
try
:
_processor
=
get_processor
(
_processor
=
get_processor
(
server_args
.
tokenizer_path
,
server_args
.
tokenizer_path
,
...
@@ -320,8 +320,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
...
@@ -320,8 +320,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
"model_name"
:
self
.
server_args
.
served_model_name
,
"model_name"
:
self
.
server_args
.
served_model_name
,
# TODO: Add lora name/path in the future,
# TODO: Add lora name/path in the future,
}
}
if
server_args
.
tokenizer_metrics_allowed_custom_labels
:
if
server_args
.
tokenizer_metrics_allowed_custom
er
_labels
:
for
label
in
server_args
.
tokenizer_metrics_allowed_custom_labels
:
for
label
in
server_args
.
tokenizer_metrics_allowed_custom
er
_labels
:
labels
[
label
]
=
""
labels
[
label
]
=
""
self
.
metrics_collector
=
TokenizerMetricsCollector
(
self
.
metrics_collector
=
TokenizerMetricsCollector
(
server_args
=
server_args
,
server_args
=
server_args
,
...
@@ -750,7 +750,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
...
@@ -750,7 +750,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
return_hidden_states
=
obj
.
return_hidden_states
,
return_hidden_states
=
obj
.
return_hidden_states
,
data_parallel_rank
=
obj
.
data_parallel_rank
,
data_parallel_rank
=
obj
.
data_parallel_rank
,
priority
=
obj
.
priority
,
priority
=
obj
.
priority
,
extra_key
=
obj
.
extra_key
,
)
)
elif
isinstance
(
obj
,
EmbeddingReqInput
):
elif
isinstance
(
obj
,
EmbeddingReqInput
):
tokenized_obj
=
TokenizedEmbeddingReqInput
(
tokenized_obj
=
TokenizedEmbeddingReqInput
(
...
@@ -1633,10 +1632,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
...
@@ -1633,10 +1632,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
else
0
else
0
)
)
custom_labels
=
getattr
(
state
.
obj
,
"custom_labels"
,
None
)
custom
er
_labels
=
getattr
(
state
.
obj
,
"custom
er
_labels"
,
None
)
labels
=
(
labels
=
(
{
**
self
.
metrics_collector
.
labels
,
**
custom_labels
}
{
**
self
.
metrics_collector
.
labels
,
**
custom
er
_labels
}
if
custom_labels
if
custom
er
_labels
else
self
.
metrics_collector
.
labels
else
self
.
metrics_collector
.
labels
)
)
if
(
if
(
...
...
python/sglang/srt/managers/tp_worker.py
View file @
852a49c5
...
@@ -91,6 +91,7 @@ class TpModelWorker:
...
@@ -91,6 +91,7 @@ class TpModelWorker:
else
server_args
.
speculative_draft_model_revision
else
server_args
.
speculative_draft_model_revision
),
),
is_draft_model
=
is_draft_worker
,
is_draft_model
=
is_draft_worker
,
tp_rank
=
tp_rank
,
)
)
self
.
model_runner
=
ModelRunner
(
self
.
model_runner
=
ModelRunner
(
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
852a49c5
...
@@ -36,11 +36,10 @@ from sglang.srt.managers.io_struct import (
...
@@ -36,11 +36,10 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromTensorReqInput
,
UpdateWeightsFromTensorReqInput
,
)
)
from
sglang.srt.managers.overlap_utils
import
FutureMap
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
DynamicGradMode
from
sglang.srt.utils
import
DynamicGradMode
,
get_compiler_backend
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -49,6 +48,15 @@ if TYPE_CHECKING:
...
@@ -49,6 +48,15 @@ if TYPE_CHECKING:
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
@
torch
.
compile
(
dynamic
=
True
,
backend
=
get_compiler_backend
())
def
resolve_future_token_ids
(
input_ids
,
future_token_ids_map
):
input_ids
[:]
=
torch
.
where
(
input_ids
<
0
,
future_token_ids_map
[
torch
.
clamp
(
-
input_ids
,
min
=
0
)],
input_ids
,
)
class
TpModelWorkerClient
:
class
TpModelWorkerClient
:
"""A tensor parallel model worker."""
"""A tensor parallel model worker."""
...
@@ -71,7 +79,11 @@ class TpModelWorkerClient:
...
@@ -71,7 +79,11 @@ class TpModelWorkerClient:
self
.
gpu_id
=
gpu_id
self
.
gpu_id
=
gpu_id
# Init future mappings
# Init future mappings
self
.
future_map
=
FutureMap
(
self
.
max_running_requests
,
self
.
device
)
self
.
future_token_ids_ct
=
0
self
.
future_token_ids_limit
=
self
.
max_running_requests
*
3
self
.
future_token_ids_map
=
torch
.
empty
(
(
self
.
max_running_requests
*
5
,),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
# Launch threads
# Launch threads
self
.
input_queue
=
Queue
[
Tuple
[
ModelWorkerBatch
,
int
,
torch
.
Event
]]()
self
.
input_queue
=
Queue
[
Tuple
[
ModelWorkerBatch
,
int
,
torch
.
Event
]]()
...
@@ -141,7 +153,7 @@ class TpModelWorkerClient:
...
@@ -141,7 +153,7 @@ class TpModelWorkerClient:
batch_lists
:
List
=
[
None
]
*
2
batch_lists
:
List
=
[
None
]
*
2
while
True
:
while
True
:
model_worker_batch
,
future_
map
_ct
,
sync_event
=
self
.
input_queue
.
get
()
model_worker_batch
,
future_
token_ids
_ct
,
sync_event
=
self
.
input_queue
.
get
()
if
not
model_worker_batch
:
if
not
model_worker_batch
:
break
break
...
@@ -157,7 +169,8 @@ class TpModelWorkerClient:
...
@@ -157,7 +169,8 @@ class TpModelWorkerClient:
copy_done
=
torch
.
get_device_module
(
self
.
device
).
Event
()
copy_done
=
torch
.
get_device_module
(
self
.
device
).
Event
()
# Resolve future tokens in the input
# Resolve future tokens in the input
self
.
future_map
.
resolve_future
(
model_worker_batch
)
input_ids
=
model_worker_batch
.
input_ids
resolve_future_token_ids
(
input_ids
,
self
.
future_token_ids_map
)
# Run forward
# Run forward
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
...
@@ -174,9 +187,9 @@ class TpModelWorkerClient:
...
@@ -174,9 +187,9 @@ class TpModelWorkerClient:
if
model_worker_batch
.
is_prefill_only
:
if
model_worker_batch
.
is_prefill_only
:
# For prefill-only requests, create dummy token IDs on CPU
# For prefill-only requests, create dummy token IDs on CPU
next_token_ids
=
torch
.
zeros
(
bs
,
dtype
=
torch
.
long
)
next_token_ids
=
torch
.
zeros
(
bs
,
dtype
=
torch
.
long
)
self
.
future_token_ids_map
[
# store the future indices into future map
future_token_ids_ct
+
1
:
future_token_ids_ct
+
bs
+
1
self
.
future_map
.
store_to_map
(
future_map_ct
,
bs
,
next_token_ids
)
]
=
next_token_ids
# Copy results to the CPU
# Copy results to the CPU
if
model_worker_batch
.
return_logprob
:
if
model_worker_batch
.
return_logprob
:
...
@@ -242,14 +255,20 @@ class TpModelWorkerClient:
...
@@ -242,14 +255,20 @@ class TpModelWorkerClient:
sync_event
.
record
(
self
.
scheduler_stream
)
sync_event
.
record
(
self
.
scheduler_stream
)
# Push a new batch to the queue
# Push a new batch to the queue
bs
=
len
(
model_worker_batch
.
seq_lens
)
self
.
input_queue
.
put
((
model_worker_batch
,
self
.
future_token_ids_ct
,
sync_event
))
cur_future_map_ct
=
self
.
future_map
.
update_ct
(
bs
)
self
.
input_queue
.
put
((
model_worker_batch
,
cur_future_map_ct
,
sync_event
))
# get this forward batch's future token ids
# Allocate output future objects
future_next_token_ids
=
self
.
future_map
.
update_next_future
(
bs
=
len
(
model_worker_batch
.
seq_lens
)
cur_future_map_ct
,
bs
future_next_token_ids
=
torch
.
arange
(
-
(
self
.
future_token_ids_ct
+
1
),
-
(
self
.
future_token_ids_ct
+
1
+
bs
),
-
1
,
dtype
=
torch
.
int64
,
device
=
self
.
device
,
)
)
self
.
future_token_ids_ct
=
(
self
.
future_token_ids_ct
+
bs
)
%
self
.
future_token_ids_limit
return
None
,
future_next_token_ids
,
False
return
None
,
future_next_token_ids
,
False
def
update_weights_from_disk
(
self
,
recv_req
:
UpdateWeightFromDiskReqInput
):
def
update_weights_from_disk
(
self
,
recv_req
:
UpdateWeightFromDiskReqInput
):
...
...
python/sglang/srt/mem_cache/allocator_ascend.py
View file @
852a49c5
...
@@ -79,37 +79,48 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
...
@@ -79,37 +79,48 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
)
)
num_new_pages
=
(
num_new_pages
=
(
(
(
seq_lens
+
self
.
page_size
-
1
)
//
self
.
page_size
(
seq_lens
+
self
.
page_size
-
1
)
//
self
.
page_size
-
(
prefix_lens
+
self
.
page_size
-
1
)
//
self
.
page_size
-
(
prefix_lens
+
self
.
page_size
-
1
)
//
self
.
page_size
).
sum
()
)
num_new_pages_item
=
num_new_pages
.
item
()
.
sum
()
if
self
.
need_sort
and
num_new_pages_item
>
len
(
self
.
free_pages
):
.
item
()
)
if
self
.
need_sort
and
num_new_pages
>
len
(
self
.
free_pages
):
self
.
merge_and_sort_free
()
self
.
merge_and_sort_free
()
if
num_new_pages
>
len
(
self
.
free_pages
):
if
num_new_pages
_item
>
len
(
self
.
free_pages
):
return
None
return
None
out_indices
=
torch
.
empty
(
out_indices
=
torch
.
empty
(
(
extend_num_tokens
,),
dtype
=
torch
.
int
32
,
device
=
self
.
device
(
extend_num_tokens
,),
dtype
=
torch
.
int
64
,
device
=
self
.
device
)
)
alloc_extend_kernel_ascend
(
if
num_new_pages_item
<
200
:
prefix_lens
,
import
sgl_kernel_npu
seq_lens
,
last_loc
,
torch
.
ops
.
npu
.
alloc_extend
(
self
.
free_pages
,
prefix_lens
,
out_indices
,
seq_lens
,
self
.
page_size
,
last_loc
,
self
.
device
,
self
.
free_pages
,
)
self
.
page_size
,
out_indices
,
num_new_pages
,
)
else
:
alloc_extend_kernel_ascend
(
prefix_lens
,
seq_lens
,
last_loc
,
self
.
free_pages
,
out_indices
,
self
.
page_size
,
self
.
device
,
)
if
self
.
debug_mode
:
if
self
.
debug_mode
:
assert
len
(
torch
.
unique
(
out_indices
))
==
len
(
out_indices
)
assert
len
(
torch
.
unique
(
out_indices
))
==
len
(
out_indices
)
self
.
free_pages
=
self
.
free_pages
[
num_new_pages
:]
self
.
free_pages
=
self
.
free_pages
[
num_new_pages
_item
:]
return
out_indices
return
out_indices
def
alloc_decode
(
def
alloc_decode
(
...
...
python/sglang/srt/mem_cache/hicache_storage.py
View file @
852a49c5
...
@@ -7,8 +7,6 @@ from typing import Any, List, Optional
...
@@ -7,8 +7,6 @@ from typing import Any, List, Optional
import
torch
import
torch
from
sglang.srt.mem_cache.memory_pool_host
import
HostKVCache
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -34,46 +32,15 @@ class HiCacheStorageConfig:
...
@@ -34,46 +32,15 @@ class HiCacheStorageConfig:
extra_config
:
Optional
[
dict
]
=
None
extra_config
:
Optional
[
dict
]
=
None
@
dataclass
class
HiCacheStorageExtraInfo
:
extra_info
:
Optional
[
dict
]
=
None
class
HiCacheStorage
(
ABC
):
class
HiCacheStorage
(
ABC
):
"""
"""
HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache.
HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache.
It abstracts the underlying storage mechanism, allowing different implementations to be used.
It abstracts the underlying storage mechanism, allowing different implementations to be used.
"""
"""
# todo, potentially pass model and TP configs into storage backend
# todo, the page size of storage backend does not have to be the same as the same as host memory pool
# todo, the page size of storage backend does not have to be the same as the same as host memory pool
def
register_mem_pool_host
(
self
,
mem_pool_host
:
HostKVCache
):
self
.
mem_pool_host
=
mem_pool_host
def
batch_get_v1
(
self
,
keys
:
List
[
str
],
host_indices
:
torch
.
Tensor
,
extra_info
:
Optional
[
HiCacheStorageExtraInfo
]
=
None
,
)
->
List
[
bool
]:
"""
Retrieve values for multiple keys.
Returns a list of tensors or None for each key.
"""
pass
def
batch_set_v1
(
self
,
keys
:
List
[
str
],
host_indices
:
torch
.
Tensor
,
extra_info
:
Optional
[
HiCacheStorageExtraInfo
]
=
None
,
)
->
List
[
bool
]:
"""
Retrieve values for multiple keys.
Returns a list of tensors or None for each key.
"""
pass
@
abstractmethod
@
abstractmethod
def
get
(
def
get
(
self
,
self
,
...
@@ -87,7 +54,6 @@ class HiCacheStorage(ABC):
...
@@ -87,7 +54,6 @@ class HiCacheStorage(ABC):
"""
"""
pass
pass
# TODO: Deprecate
@
abstractmethod
@
abstractmethod
def
batch_get
(
def
batch_get
(
self
,
self
,
...
@@ -115,7 +81,6 @@ class HiCacheStorage(ABC):
...
@@ -115,7 +81,6 @@ class HiCacheStorage(ABC):
"""
"""
pass
pass
# TODO: Deprecate
@
abstractmethod
@
abstractmethod
def
batch_set
(
def
batch_set
(
self
,
self
,
...
@@ -138,7 +103,6 @@ class HiCacheStorage(ABC):
...
@@ -138,7 +103,6 @@ class HiCacheStorage(ABC):
"""
"""
pass
pass
# TODO: Use a finer-grained return type (e.g., List[bool])
def
batch_exists
(
self
,
keys
:
List
[
str
])
->
int
:
def
batch_exists
(
self
,
keys
:
List
[
str
])
->
int
:
"""
"""
Check if the keys exist in the storage.
Check if the keys exist in the storage.
...
@@ -150,9 +114,6 @@ class HiCacheStorage(ABC):
...
@@ -150,9 +114,6 @@ class HiCacheStorage(ABC):
return
i
return
i
return
len
(
keys
)
return
len
(
keys
)
def
clear
(
self
)
->
None
:
pass
def
get_stats
(
self
):
def
get_stats
(
self
):
return
None
return
None
...
...
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment