Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e12358dc
Unverified
Commit
e12358dc
authored
Oct 20, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 20, 2024
Browse files
Simplify the usage of device (#1734)
parent
554fbf93
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
23 deletions
+29
-23
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+23
-18
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+4
-3
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+2
-2
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
e12358dc
...
...
@@ -425,7 +425,6 @@ class ScheduleBatch:
req_pool_indices
:
torch
.
Tensor
=
None
seq_lens
:
torch
.
Tensor
=
None
out_cache_loc
:
torch
.
Tensor
=
None
output_ids
:
torch
.
Tensor
=
None
# For processing logprobs
...
...
@@ -442,27 +441,23 @@ class ScheduleBatch:
# Stream
has_stream
:
bool
=
False
# device
device
:
str
=
"cuda"
# Has regex
has_regex
:
bool
=
False
# device
device
:
str
=
"cuda"
@
classmethod
def
init_new
(
cls
,
reqs
,
req_to_token_pool
,
token_to_kv_pool
,
tree_cache
):
return_logprob
=
any
(
req
.
return_logprob
for
req
in
reqs
)
has_stream
=
any
(
req
.
stream
for
req
in
reqs
)
has_regex
=
any
(
req
.
regex_fsm
for
req
in
reqs
)
return
cls
(
reqs
=
reqs
,
req_to_token_pool
=
req_to_token_pool
,
token_to_kv_pool
=
token_to_kv_pool
,
tree_cache
=
tree_cache
,
return_logprob
=
return_logprob
,
has_stream
=
has_stream
,
return_logprob
=
any
(
req
.
return_logprob
for
req
in
reqs
),
has_stream
=
any
(
req
.
stream
for
req
in
reqs
),
has_regex
=
any
(
req
.
regex_fsm
for
req
in
reqs
),
device
=
req_to_token_pool
.
device
,
has_regex
=
has_regex
,
)
def
batch_size
(
self
):
...
...
@@ -754,7 +749,7 @@ class ScheduleBatch:
return
jump_forward_reqs
def
prepare_for_decode
(
self
):
def
prepare_for_decode
(
self
,
enable_overlap
:
bool
=
False
):
self
.
forward_mode
=
ForwardMode
.
DECODE
self
.
input_ids
=
self
.
output_ids
...
...
@@ -767,10 +762,19 @@ class ScheduleBatch:
# Alloc mem
bs
=
len
(
self
.
reqs
)
self
.
out_cache_loc
=
self
.
alloc_token_slots
(
bs
)
self
.
req_to_token_pool
.
write
(
(
self
.
req_pool_indices
,
self
.
seq_lens
),
self
.
out_cache_loc
)
self
.
seq_lens
.
add_
(
1
)
if
enable_overlap
:
# Do not use in-place operations in the overlap mode
self
.
req_to_token_pool
.
write
(
(
self
.
req_pool_indices
,
self
.
seq_lens
),
self
.
out_cache_loc
)
self
.
seq_lens
=
self
.
seq_lens
+
1
else
:
# A faster in-place version
self
.
req_to_token_pool
.
write
(
(
self
.
req_pool_indices
,
self
.
seq_lens
),
self
.
out_cache_loc
)
self
.
seq_lens
.
add_
(
1
)
def
filter_batch
(
self
,
...
...
@@ -882,6 +886,7 @@ class ScheduleBatch:
)
def
copy
(
self
):
# Only contain fields that will be used by process_batch_result
return
ScheduleBatch
(
reqs
=
self
.
reqs
,
forward_mode
=
self
.
forward_mode
,
...
...
@@ -940,9 +945,9 @@ class ModelWorkerBatch:
return
ModelWorkerBatch
(
bid
=
self
.
bid
,
forward_mode
=
self
.
forward_mode
,
input_ids
=
self
.
input_ids
.
clone
()
,
input_ids
=
self
.
input_ids
,
req_pool_indices
=
self
.
req_pool_indices
,
seq_lens
=
self
.
seq_lens
.
clone
()
,
seq_lens
=
self
.
seq_lens
,
out_cache_loc
=
self
.
out_cache_loc
,
req_to_token_pool_records
=
self
.
req_to_token_pool_records
,
return_logprob
=
self
.
return_logprob
,
...
...
python/sglang/srt/managers/scheduler.py
View file @
e12358dc
...
...
@@ -103,6 +103,7 @@ class Scheduler:
self
.
disable_regex_jump_forward
=
server_args
.
disable_regex_jump_forward
self
.
lora_paths
=
server_args
.
lora_paths
self
.
max_loras_per_batch
=
server_args
.
max_loras_per_batch
self
.
enable_overlap
=
server_args
.
enable_overlap_schedule
# Init inter-process communication
context
=
zmq
.
Context
(
2
)
...
...
@@ -146,7 +147,7 @@ class Scheduler:
)
# Launch a tensor parallel worker
if
self
.
server_args
.
enable_overlap
_schedule
:
if
self
.
enable_overlap
:
TpWorkerClass
=
TpModelWorkerClient
self
.
resolve_next_token_ids
=
(
lambda
bid
,
x
:
self
.
tp_worker
.
resolve_future_token_ids
(
bid
)
...
...
@@ -670,7 +671,7 @@ class Scheduler:
# Mixed-style chunked prefill
if
self
.
is_mixed_chunk
and
self
.
running_batch
is
not
None
:
self
.
running_batch
.
prepare_for_decode
()
self
.
running_batch
.
prepare_for_decode
(
self
.
enable_overlap
)
new_batch
.
mix_with_running
(
self
.
running_batch
)
new_batch
.
decoding_reqs
=
self
.
running_batch
.
reqs
self
.
running_batch
=
None
...
...
@@ -717,7 +718,7 @@ class Scheduler:
return
# Update batch tensors
batch
.
prepare_for_decode
()
batch
.
prepare_for_decode
(
self
.
enable_overlap
)
def
run_batch
(
self
,
batch
:
ScheduleBatch
):
"""Run a batch."""
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
e12358dc
...
...
@@ -51,7 +51,7 @@ class SamplingBatchInfo:
disable_penalizer
:
bool
,
):
reqs
=
batch
.
reqs
device
=
batch
.
input_ids
.
device
device
=
batch
.
device
temperatures
=
(
torch
.
tensor
(
[
r
.
sampling_params
.
temperature
for
r
in
reqs
],
...
...
@@ -95,7 +95,7 @@ class SamplingBatchInfo:
ret
.
penalizer_orchestrator
=
penaltylib
.
BatchedPenalizerOrchestrator
(
vocab_size
=
vocab_size
,
batch
=
batch
,
device
=
batch
.
input_ids
.
device
,
device
=
batch
.
device
,
Penalizers
=
{
penaltylib
.
BatchedFrequencyPenalizer
,
penaltylib
.
BatchedMinNewTokensPenalizer
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment