Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e12358dc
"src/array/cuda/spmm.hip" did not exist on "272cb9e29aaa2bb3ee6eb31003530a537c0bee3d"
Unverified
Commit
e12358dc
authored
Oct 20, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 20, 2024
Browse files
Simplify the usage of device (#1734)
parent
554fbf93
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
23 deletions
+29
-23
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+23
-18
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+4
-3
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+2
-2
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
e12358dc
...
@@ -425,7 +425,6 @@ class ScheduleBatch:
...
@@ -425,7 +425,6 @@ class ScheduleBatch:
req_pool_indices
:
torch
.
Tensor
=
None
req_pool_indices
:
torch
.
Tensor
=
None
seq_lens
:
torch
.
Tensor
=
None
seq_lens
:
torch
.
Tensor
=
None
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_loc
:
torch
.
Tensor
=
None
output_ids
:
torch
.
Tensor
=
None
output_ids
:
torch
.
Tensor
=
None
# For processing logprobs
# For processing logprobs
...
@@ -442,27 +441,23 @@ class ScheduleBatch:
...
@@ -442,27 +441,23 @@ class ScheduleBatch:
# Stream
# Stream
has_stream
:
bool
=
False
has_stream
:
bool
=
False
# device
device
:
str
=
"cuda"
# Has regex
# Has regex
has_regex
:
bool
=
False
has_regex
:
bool
=
False
# device
device
:
str
=
"cuda"
@
classmethod
@
classmethod
def
init_new
(
cls
,
reqs
,
req_to_token_pool
,
token_to_kv_pool
,
tree_cache
):
def
init_new
(
cls
,
reqs
,
req_to_token_pool
,
token_to_kv_pool
,
tree_cache
):
return_logprob
=
any
(
req
.
return_logprob
for
req
in
reqs
)
has_stream
=
any
(
req
.
stream
for
req
in
reqs
)
has_regex
=
any
(
req
.
regex_fsm
for
req
in
reqs
)
return
cls
(
return
cls
(
reqs
=
reqs
,
reqs
=
reqs
,
req_to_token_pool
=
req_to_token_pool
,
req_to_token_pool
=
req_to_token_pool
,
token_to_kv_pool
=
token_to_kv_pool
,
token_to_kv_pool
=
token_to_kv_pool
,
tree_cache
=
tree_cache
,
tree_cache
=
tree_cache
,
return_logprob
=
return_logprob
,
return_logprob
=
any
(
req
.
return_logprob
for
req
in
reqs
),
has_stream
=
has_stream
,
has_stream
=
any
(
req
.
stream
for
req
in
reqs
),
has_regex
=
any
(
req
.
regex_fsm
for
req
in
reqs
),
device
=
req_to_token_pool
.
device
,
device
=
req_to_token_pool
.
device
,
has_regex
=
has_regex
,
)
)
def
batch_size
(
self
):
def
batch_size
(
self
):
...
@@ -754,7 +749,7 @@ class ScheduleBatch:
...
@@ -754,7 +749,7 @@ class ScheduleBatch:
return
jump_forward_reqs
return
jump_forward_reqs
def
prepare_for_decode
(
self
):
def
prepare_for_decode
(
self
,
enable_overlap
:
bool
=
False
):
self
.
forward_mode
=
ForwardMode
.
DECODE
self
.
forward_mode
=
ForwardMode
.
DECODE
self
.
input_ids
=
self
.
output_ids
self
.
input_ids
=
self
.
output_ids
...
@@ -767,10 +762,19 @@ class ScheduleBatch:
...
@@ -767,10 +762,19 @@ class ScheduleBatch:
# Alloc mem
# Alloc mem
bs
=
len
(
self
.
reqs
)
bs
=
len
(
self
.
reqs
)
self
.
out_cache_loc
=
self
.
alloc_token_slots
(
bs
)
self
.
out_cache_loc
=
self
.
alloc_token_slots
(
bs
)
self
.
req_to_token_pool
.
write
(
(
self
.
req_pool_indices
,
self
.
seq_lens
),
self
.
out_cache_loc
if
enable_overlap
:
)
# Do not use in-place operations in the overlap mode
self
.
seq_lens
.
add_
(
1
)
self
.
req_to_token_pool
.
write
(
(
self
.
req_pool_indices
,
self
.
seq_lens
),
self
.
out_cache_loc
)
self
.
seq_lens
=
self
.
seq_lens
+
1
else
:
# A faster in-place version
self
.
req_to_token_pool
.
write
(
(
self
.
req_pool_indices
,
self
.
seq_lens
),
self
.
out_cache_loc
)
self
.
seq_lens
.
add_
(
1
)
def
filter_batch
(
def
filter_batch
(
self
,
self
,
...
@@ -882,6 +886,7 @@ class ScheduleBatch:
...
@@ -882,6 +886,7 @@ class ScheduleBatch:
)
)
def
copy
(
self
):
def
copy
(
self
):
# Only contain fields that will be used by process_batch_result
return
ScheduleBatch
(
return
ScheduleBatch
(
reqs
=
self
.
reqs
,
reqs
=
self
.
reqs
,
forward_mode
=
self
.
forward_mode
,
forward_mode
=
self
.
forward_mode
,
...
@@ -940,9 +945,9 @@ class ModelWorkerBatch:
...
@@ -940,9 +945,9 @@ class ModelWorkerBatch:
return
ModelWorkerBatch
(
return
ModelWorkerBatch
(
bid
=
self
.
bid
,
bid
=
self
.
bid
,
forward_mode
=
self
.
forward_mode
,
forward_mode
=
self
.
forward_mode
,
input_ids
=
self
.
input_ids
.
clone
()
,
input_ids
=
self
.
input_ids
,
req_pool_indices
=
self
.
req_pool_indices
,
req_pool_indices
=
self
.
req_pool_indices
,
seq_lens
=
self
.
seq_lens
.
clone
()
,
seq_lens
=
self
.
seq_lens
,
out_cache_loc
=
self
.
out_cache_loc
,
out_cache_loc
=
self
.
out_cache_loc
,
req_to_token_pool_records
=
self
.
req_to_token_pool_records
,
req_to_token_pool_records
=
self
.
req_to_token_pool_records
,
return_logprob
=
self
.
return_logprob
,
return_logprob
=
self
.
return_logprob
,
...
...
python/sglang/srt/managers/scheduler.py
View file @
e12358dc
...
@@ -103,6 +103,7 @@ class Scheduler:
...
@@ -103,6 +103,7 @@ class Scheduler:
self
.
disable_regex_jump_forward
=
server_args
.
disable_regex_jump_forward
self
.
disable_regex_jump_forward
=
server_args
.
disable_regex_jump_forward
self
.
lora_paths
=
server_args
.
lora_paths
self
.
lora_paths
=
server_args
.
lora_paths
self
.
max_loras_per_batch
=
server_args
.
max_loras_per_batch
self
.
max_loras_per_batch
=
server_args
.
max_loras_per_batch
self
.
enable_overlap
=
server_args
.
enable_overlap_schedule
# Init inter-process communication
# Init inter-process communication
context
=
zmq
.
Context
(
2
)
context
=
zmq
.
Context
(
2
)
...
@@ -146,7 +147,7 @@ class Scheduler:
...
@@ -146,7 +147,7 @@ class Scheduler:
)
)
# Launch a tensor parallel worker
# Launch a tensor parallel worker
if
self
.
server_args
.
enable_overlap
_schedule
:
if
self
.
enable_overlap
:
TpWorkerClass
=
TpModelWorkerClient
TpWorkerClass
=
TpModelWorkerClient
self
.
resolve_next_token_ids
=
(
self
.
resolve_next_token_ids
=
(
lambda
bid
,
x
:
self
.
tp_worker
.
resolve_future_token_ids
(
bid
)
lambda
bid
,
x
:
self
.
tp_worker
.
resolve_future_token_ids
(
bid
)
...
@@ -670,7 +671,7 @@ class Scheduler:
...
@@ -670,7 +671,7 @@ class Scheduler:
# Mixed-style chunked prefill
# Mixed-style chunked prefill
if
self
.
is_mixed_chunk
and
self
.
running_batch
is
not
None
:
if
self
.
is_mixed_chunk
and
self
.
running_batch
is
not
None
:
self
.
running_batch
.
prepare_for_decode
()
self
.
running_batch
.
prepare_for_decode
(
self
.
enable_overlap
)
new_batch
.
mix_with_running
(
self
.
running_batch
)
new_batch
.
mix_with_running
(
self
.
running_batch
)
new_batch
.
decoding_reqs
=
self
.
running_batch
.
reqs
new_batch
.
decoding_reqs
=
self
.
running_batch
.
reqs
self
.
running_batch
=
None
self
.
running_batch
=
None
...
@@ -717,7 +718,7 @@ class Scheduler:
...
@@ -717,7 +718,7 @@ class Scheduler:
return
return
# Update batch tensors
# Update batch tensors
batch
.
prepare_for_decode
()
batch
.
prepare_for_decode
(
self
.
enable_overlap
)
def
run_batch
(
self
,
batch
:
ScheduleBatch
):
def
run_batch
(
self
,
batch
:
ScheduleBatch
):
"""Run a batch."""
"""Run a batch."""
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
e12358dc
...
@@ -51,7 +51,7 @@ class SamplingBatchInfo:
...
@@ -51,7 +51,7 @@ class SamplingBatchInfo:
disable_penalizer
:
bool
,
disable_penalizer
:
bool
,
):
):
reqs
=
batch
.
reqs
reqs
=
batch
.
reqs
device
=
batch
.
input_ids
.
device
device
=
batch
.
device
temperatures
=
(
temperatures
=
(
torch
.
tensor
(
torch
.
tensor
(
[
r
.
sampling_params
.
temperature
for
r
in
reqs
],
[
r
.
sampling_params
.
temperature
for
r
in
reqs
],
...
@@ -95,7 +95,7 @@ class SamplingBatchInfo:
...
@@ -95,7 +95,7 @@ class SamplingBatchInfo:
ret
.
penalizer_orchestrator
=
penaltylib
.
BatchedPenalizerOrchestrator
(
ret
.
penalizer_orchestrator
=
penaltylib
.
BatchedPenalizerOrchestrator
(
vocab_size
=
vocab_size
,
vocab_size
=
vocab_size
,
batch
=
batch
,
batch
=
batch
,
device
=
batch
.
input_ids
.
device
,
device
=
batch
.
device
,
Penalizers
=
{
Penalizers
=
{
penaltylib
.
BatchedFrequencyPenalizer
,
penaltylib
.
BatchedFrequencyPenalizer
,
penaltylib
.
BatchedMinNewTokensPenalizer
,
penaltylib
.
BatchedMinNewTokensPenalizer
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment