Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e12358dc
Unverified
Commit
e12358dc
authored
Oct 20, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 20, 2024
Browse files
Simplify the usage of device (#1734)
parent
554fbf93
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
23 deletions
+29
-23
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+23
-18
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+4
-3
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+2
-2
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
e12358dc
...
@@ -425,7 +425,6 @@ class ScheduleBatch:
...
@@ -425,7 +425,6 @@ class ScheduleBatch:
req_pool_indices
:
torch
.
Tensor
=
None
req_pool_indices
:
torch
.
Tensor
=
None
seq_lens
:
torch
.
Tensor
=
None
seq_lens
:
torch
.
Tensor
=
None
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_loc
:
torch
.
Tensor
=
None
output_ids
:
torch
.
Tensor
=
None
output_ids
:
torch
.
Tensor
=
None
# For processing logprobs
# For processing logprobs
...
@@ -442,27 +441,23 @@ class ScheduleBatch:
...
@@ -442,27 +441,23 @@ class ScheduleBatch:
# Stream
# Stream
has_stream
:
bool
=
False
has_stream
:
bool
=
False
# device
device
:
str
=
"cuda"
# Has regex
# Has regex
has_regex
:
bool
=
False
has_regex
:
bool
=
False
# device
device
:
str
=
"cuda"
@
classmethod
@
classmethod
def
init_new
(
cls
,
reqs
,
req_to_token_pool
,
token_to_kv_pool
,
tree_cache
):
def
init_new
(
cls
,
reqs
,
req_to_token_pool
,
token_to_kv_pool
,
tree_cache
):
return_logprob
=
any
(
req
.
return_logprob
for
req
in
reqs
)
has_stream
=
any
(
req
.
stream
for
req
in
reqs
)
has_regex
=
any
(
req
.
regex_fsm
for
req
in
reqs
)
return
cls
(
return
cls
(
reqs
=
reqs
,
reqs
=
reqs
,
req_to_token_pool
=
req_to_token_pool
,
req_to_token_pool
=
req_to_token_pool
,
token_to_kv_pool
=
token_to_kv_pool
,
token_to_kv_pool
=
token_to_kv_pool
,
tree_cache
=
tree_cache
,
tree_cache
=
tree_cache
,
return_logprob
=
return_logprob
,
return_logprob
=
any
(
req
.
return_logprob
for
req
in
reqs
),
has_stream
=
has_stream
,
has_stream
=
any
(
req
.
stream
for
req
in
reqs
),
has_regex
=
any
(
req
.
regex_fsm
for
req
in
reqs
),
device
=
req_to_token_pool
.
device
,
device
=
req_to_token_pool
.
device
,
has_regex
=
has_regex
,
)
)
def
batch_size
(
self
):
def
batch_size
(
self
):
...
@@ -754,7 +749,7 @@ class ScheduleBatch:
...
@@ -754,7 +749,7 @@ class ScheduleBatch:
return
jump_forward_reqs
return
jump_forward_reqs
def
prepare_for_decode
(
self
):
def
prepare_for_decode
(
self
,
enable_overlap
:
bool
=
False
):
self
.
forward_mode
=
ForwardMode
.
DECODE
self
.
forward_mode
=
ForwardMode
.
DECODE
self
.
input_ids
=
self
.
output_ids
self
.
input_ids
=
self
.
output_ids
...
@@ -767,10 +762,19 @@ class ScheduleBatch:
...
@@ -767,10 +762,19 @@ class ScheduleBatch:
# Alloc mem
# Alloc mem
bs
=
len
(
self
.
reqs
)
bs
=
len
(
self
.
reqs
)
self
.
out_cache_loc
=
self
.
alloc_token_slots
(
bs
)
self
.
out_cache_loc
=
self
.
alloc_token_slots
(
bs
)
self
.
req_to_token_pool
.
write
(
(
self
.
req_pool_indices
,
self
.
seq_lens
),
self
.
out_cache_loc
if
enable_overlap
:
)
# Do not use in-place operations in the overlap mode
self
.
seq_lens
.
add_
(
1
)
self
.
req_to_token_pool
.
write
(
(
self
.
req_pool_indices
,
self
.
seq_lens
),
self
.
out_cache_loc
)
self
.
seq_lens
=
self
.
seq_lens
+
1
else
:
# A faster in-place version
self
.
req_to_token_pool
.
write
(
(
self
.
req_pool_indices
,
self
.
seq_lens
),
self
.
out_cache_loc
)
self
.
seq_lens
.
add_
(
1
)
def
filter_batch
(
def
filter_batch
(
self
,
self
,
...
@@ -882,6 +886,7 @@ class ScheduleBatch:
...
@@ -882,6 +886,7 @@ class ScheduleBatch:
)
)
def
copy
(
self
):
def
copy
(
self
):
# Only contain fields that will be used by process_batch_result
return
ScheduleBatch
(
return
ScheduleBatch
(
reqs
=
self
.
reqs
,
reqs
=
self
.
reqs
,
forward_mode
=
self
.
forward_mode
,
forward_mode
=
self
.
forward_mode
,
...
@@ -940,9 +945,9 @@ class ModelWorkerBatch:
...
@@ -940,9 +945,9 @@ class ModelWorkerBatch:
return
ModelWorkerBatch
(
return
ModelWorkerBatch
(
bid
=
self
.
bid
,
bid
=
self
.
bid
,
forward_mode
=
self
.
forward_mode
,
forward_mode
=
self
.
forward_mode
,
input_ids
=
self
.
input_ids
.
clone
()
,
input_ids
=
self
.
input_ids
,
req_pool_indices
=
self
.
req_pool_indices
,
req_pool_indices
=
self
.
req_pool_indices
,
seq_lens
=
self
.
seq_lens
.
clone
()
,
seq_lens
=
self
.
seq_lens
,
out_cache_loc
=
self
.
out_cache_loc
,
out_cache_loc
=
self
.
out_cache_loc
,
req_to_token_pool_records
=
self
.
req_to_token_pool_records
,
req_to_token_pool_records
=
self
.
req_to_token_pool_records
,
return_logprob
=
self
.
return_logprob
,
return_logprob
=
self
.
return_logprob
,
...
...
python/sglang/srt/managers/scheduler.py
View file @
e12358dc
...
@@ -103,6 +103,7 @@ class Scheduler:
...
@@ -103,6 +103,7 @@ class Scheduler:
self
.
disable_regex_jump_forward
=
server_args
.
disable_regex_jump_forward
self
.
disable_regex_jump_forward
=
server_args
.
disable_regex_jump_forward
self
.
lora_paths
=
server_args
.
lora_paths
self
.
lora_paths
=
server_args
.
lora_paths
self
.
max_loras_per_batch
=
server_args
.
max_loras_per_batch
self
.
max_loras_per_batch
=
server_args
.
max_loras_per_batch
self
.
enable_overlap
=
server_args
.
enable_overlap_schedule
# Init inter-process communication
# Init inter-process communication
context
=
zmq
.
Context
(
2
)
context
=
zmq
.
Context
(
2
)
...
@@ -146,7 +147,7 @@ class Scheduler:
...
@@ -146,7 +147,7 @@ class Scheduler:
)
)
# Launch a tensor parallel worker
# Launch a tensor parallel worker
if
self
.
server_args
.
enable_overlap
_schedule
:
if
self
.
enable_overlap
:
TpWorkerClass
=
TpModelWorkerClient
TpWorkerClass
=
TpModelWorkerClient
self
.
resolve_next_token_ids
=
(
self
.
resolve_next_token_ids
=
(
lambda
bid
,
x
:
self
.
tp_worker
.
resolve_future_token_ids
(
bid
)
lambda
bid
,
x
:
self
.
tp_worker
.
resolve_future_token_ids
(
bid
)
...
@@ -670,7 +671,7 @@ class Scheduler:
...
@@ -670,7 +671,7 @@ class Scheduler:
# Mixed-style chunked prefill
# Mixed-style chunked prefill
if
self
.
is_mixed_chunk
and
self
.
running_batch
is
not
None
:
if
self
.
is_mixed_chunk
and
self
.
running_batch
is
not
None
:
self
.
running_batch
.
prepare_for_decode
()
self
.
running_batch
.
prepare_for_decode
(
self
.
enable_overlap
)
new_batch
.
mix_with_running
(
self
.
running_batch
)
new_batch
.
mix_with_running
(
self
.
running_batch
)
new_batch
.
decoding_reqs
=
self
.
running_batch
.
reqs
new_batch
.
decoding_reqs
=
self
.
running_batch
.
reqs
self
.
running_batch
=
None
self
.
running_batch
=
None
...
@@ -717,7 +718,7 @@ class Scheduler:
...
@@ -717,7 +718,7 @@ class Scheduler:
return
return
# Update batch tensors
# Update batch tensors
batch
.
prepare_for_decode
()
batch
.
prepare_for_decode
(
self
.
enable_overlap
)
def
run_batch
(
self
,
batch
:
ScheduleBatch
):
def
run_batch
(
self
,
batch
:
ScheduleBatch
):
"""Run a batch."""
"""Run a batch."""
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
e12358dc
...
@@ -51,7 +51,7 @@ class SamplingBatchInfo:
...
@@ -51,7 +51,7 @@ class SamplingBatchInfo:
disable_penalizer
:
bool
,
disable_penalizer
:
bool
,
):
):
reqs
=
batch
.
reqs
reqs
=
batch
.
reqs
device
=
batch
.
input_ids
.
device
device
=
batch
.
device
temperatures
=
(
temperatures
=
(
torch
.
tensor
(
torch
.
tensor
(
[
r
.
sampling_params
.
temperature
for
r
in
reqs
],
[
r
.
sampling_params
.
temperature
for
r
in
reqs
],
...
@@ -95,7 +95,7 @@ class SamplingBatchInfo:
...
@@ -95,7 +95,7 @@ class SamplingBatchInfo:
ret
.
penalizer_orchestrator
=
penaltylib
.
BatchedPenalizerOrchestrator
(
ret
.
penalizer_orchestrator
=
penaltylib
.
BatchedPenalizerOrchestrator
(
vocab_size
=
vocab_size
,
vocab_size
=
vocab_size
,
batch
=
batch
,
batch
=
batch
,
device
=
batch
.
input_ids
.
device
,
device
=
batch
.
device
,
Penalizers
=
{
Penalizers
=
{
penaltylib
.
BatchedFrequencyPenalizer
,
penaltylib
.
BatchedFrequencyPenalizer
,
penaltylib
.
BatchedMinNewTokensPenalizer
,
penaltylib
.
BatchedMinNewTokensPenalizer
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment