Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f5a2faf2
Unverified
Commit
f5a2faf2
authored
Sep 23, 2025
by
Liangsheng Yin
Committed by
GitHub
Sep 22, 2025
Browse files
Introduce `FutureMap` (#10715)
parent
1c82d9db
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
67 additions
and
33 deletions
+67
-33
python/sglang/srt/managers/overlap_utils.py
python/sglang/srt/managers/overlap_utils.py
+53
-0
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+14
-33
No files found.
python/sglang/srt/managers/overlap_utils.py
0 → 100644
View file @
f5a2faf2
import
torch
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.utils
import
get_compiler_backend
@
torch
.
compile
(
dynamic
=
True
,
backend
=
get_compiler_backend
())
def
_resolve_future_token_ids
(
input_ids
,
future_token_ids_map
):
input_ids
[:]
=
torch
.
where
(
input_ids
<
0
,
future_token_ids_map
[
torch
.
clamp
(
-
input_ids
,
min
=
0
)],
input_ids
,
)
class
FutureMap
:
def
__init__
(
self
,
max_running_requests
:
int
,
device
:
torch
.
device
,
):
self
.
future_ct
=
0
# A factor of 3 is used to avoid collision in the circular buffer.
self
.
future_limit
=
max_running_requests
*
3
# A factor of 5 is used to ensure the buffer is large enough.
self
.
future_buffer_len
=
max_running_requests
*
5
self
.
device
=
device
self
.
token_ids_buf
=
torch
.
empty
(
(
self
.
future_buffer_len
,),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
def
update_ct
(
self
,
bs
:
int
)
->
int
:
"""Update the circular buffer pointer and return the current pointer."""
cur_future_ct
=
self
.
future_ct
self
.
future_ct
=
(
cur_future_ct
+
bs
)
%
self
.
future_limit
return
cur_future_ct
def
resolve_future
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
input_ids
=
model_worker_batch
.
input_ids
_resolve_future_token_ids
(
input_ids
,
self
.
token_ids_buf
)
def
update_next_future
(
self
,
future_ct
:
int
,
bs
:
int
):
return
torch
.
arange
(
-
(
future_ct
+
1
),
-
(
future_ct
+
1
+
bs
),
-
1
,
dtype
=
torch
.
int64
,
device
=
self
.
device
,
)
def
store_to_map
(
self
,
future_ct
:
int
,
bs
:
int
,
next_token_ids
:
torch
.
Tensor
):
self
.
token_ids_buf
[
future_ct
+
1
:
future_ct
+
bs
+
1
]
=
next_token_ids
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
f5a2faf2
...
@@ -36,10 +36,11 @@ from sglang.srt.managers.io_struct import (
...
@@ -36,10 +36,11 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromTensorReqInput
,
UpdateWeightsFromTensorReqInput
,
)
)
from
sglang.srt.managers.overlap_utils
import
FutureMap
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
DynamicGradMode
,
get_compiler_backend
from
sglang.srt.utils
import
DynamicGradMode
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -48,15 +49,6 @@ if TYPE_CHECKING:
...
@@ -48,15 +49,6 @@ if TYPE_CHECKING:
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
@
torch
.
compile
(
dynamic
=
True
,
backend
=
get_compiler_backend
())
def
resolve_future_token_ids
(
input_ids
,
future_token_ids_map
):
input_ids
[:]
=
torch
.
where
(
input_ids
<
0
,
future_token_ids_map
[
torch
.
clamp
(
-
input_ids
,
min
=
0
)],
input_ids
,
)
class
TpModelWorkerClient
:
class
TpModelWorkerClient
:
"""A tensor parallel model worker."""
"""A tensor parallel model worker."""
...
@@ -79,11 +71,7 @@ class TpModelWorkerClient:
...
@@ -79,11 +71,7 @@ class TpModelWorkerClient:
self
.
gpu_id
=
gpu_id
self
.
gpu_id
=
gpu_id
# Init future mappings
# Init future mappings
self
.
future_token_ids_ct
=
0
self
.
future_map
=
FutureMap
(
self
.
max_running_requests
,
self
.
device
)
self
.
future_token_ids_limit
=
self
.
max_running_requests
*
3
self
.
future_token_ids_map
=
torch
.
empty
(
(
self
.
max_running_requests
*
5
,),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
# Launch threads
# Launch threads
self
.
input_queue
=
Queue
[
Tuple
[
ModelWorkerBatch
,
int
,
torch
.
Event
]]()
self
.
input_queue
=
Queue
[
Tuple
[
ModelWorkerBatch
,
int
,
torch
.
Event
]]()
...
@@ -153,7 +141,7 @@ class TpModelWorkerClient:
...
@@ -153,7 +141,7 @@ class TpModelWorkerClient:
batch_lists
:
List
=
[
None
]
*
2
batch_lists
:
List
=
[
None
]
*
2
while
True
:
while
True
:
model_worker_batch
,
future_
token_ids
_ct
,
sync_event
=
self
.
input_queue
.
get
()
model_worker_batch
,
future_
map
_ct
,
sync_event
=
self
.
input_queue
.
get
()
if
not
model_worker_batch
:
if
not
model_worker_batch
:
break
break
...
@@ -169,8 +157,7 @@ class TpModelWorkerClient:
...
@@ -169,8 +157,7 @@ class TpModelWorkerClient:
copy_done
=
torch
.
get_device_module
(
self
.
device
).
Event
()
copy_done
=
torch
.
get_device_module
(
self
.
device
).
Event
()
# Resolve future tokens in the input
# Resolve future tokens in the input
input_ids
=
model_worker_batch
.
input_ids
self
.
future_map
.
resolve_future
(
model_worker_batch
)
resolve_future_token_ids
(
input_ids
,
self
.
future_token_ids_map
)
# Run forward
# Run forward
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
...
@@ -187,9 +174,9 @@ class TpModelWorkerClient:
...
@@ -187,9 +174,9 @@ class TpModelWorkerClient:
if
model_worker_batch
.
is_prefill_only
:
if
model_worker_batch
.
is_prefill_only
:
# For prefill-only requests, create dummy token IDs on CPU
# For prefill-only requests, create dummy token IDs on CPU
next_token_ids
=
torch
.
zeros
(
bs
,
dtype
=
torch
.
long
)
next_token_ids
=
torch
.
zeros
(
bs
,
dtype
=
torch
.
long
)
self
.
future_token_ids_map
[
future_token_ids_ct
+
1
:
future_token_ids_ct
+
bs
+
1
# store the future indices into future map
]
=
next_token_ids
self
.
future_map
.
store_to_map
(
future_map_ct
,
bs
,
next_token_ids
)
# Copy results to the CPU
# Copy results to the CPU
if
model_worker_batch
.
return_logprob
:
if
model_worker_batch
.
return_logprob
:
...
@@ -255,20 +242,14 @@ class TpModelWorkerClient:
...
@@ -255,20 +242,14 @@ class TpModelWorkerClient:
sync_event
.
record
(
self
.
scheduler_stream
)
sync_event
.
record
(
self
.
scheduler_stream
)
# Push a new batch to the queue
# Push a new batch to the queue
self
.
input_queue
.
put
((
model_worker_batch
,
self
.
future_token_ids_ct
,
sync_event
))
# Allocate output future objects
bs
=
len
(
model_worker_batch
.
seq_lens
)
bs
=
len
(
model_worker_batch
.
seq_lens
)
future_
next_token_ids
=
torch
.
arange
(
cur_
future_
map_ct
=
self
.
future_map
.
update_ct
(
bs
)
-
(
self
.
future_token_ids_ct
+
1
),
self
.
input_queue
.
put
((
model_worker_batch
,
cur_future_map_ct
,
sync_event
))
-
(
self
.
future_token_ids_ct
+
1
+
bs
),
-
1
,
# get this forward batch's future token ids
dtype
=
torch
.
int64
,
future_next_token_ids
=
self
.
future_map
.
update_next_future
(
device
=
self
.
device
,
cur_future_map_ct
,
bs
)
)
self
.
future_token_ids_ct
=
(
self
.
future_token_ids_ct
+
bs
)
%
self
.
future_token_ids_limit
return
None
,
future_next_token_ids
,
False
return
None
,
future_next_token_ids
,
False
def
update_weights_from_disk
(
self
,
recv_req
:
UpdateWeightFromDiskReqInput
):
def
update_weights_from_disk
(
self
,
recv_req
:
UpdateWeightFromDiskReqInput
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment