Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4a292f67
Unverified
Commit
4a292f67
authored
Oct 14, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 14, 2024
Browse files
[Minor] Add some utility functions (#1671)
parent
cd0be748
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
42 additions
and
2 deletions
+42
-2
python/sglang/bench_serving.py
python/sglang/bench_serving.py
+2
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+27
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+2
-2
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+11
-0
No files found.
python/sglang/bench_serving.py
View file @
4a292f67
...
@@ -587,6 +587,8 @@ async def benchmark(
...
@@ -587,6 +587,8 @@ async def benchmark(
else
:
else
:
print
(
"Initial test run completed. Starting main benchmark run..."
)
print
(
"Initial test run completed. Starting main benchmark run..."
)
time
.
sleep
(
1.5
)
pbar
=
None
if
disable_tqdm
else
tqdm
(
total
=
len
(
input_requests
))
pbar
=
None
if
disable_tqdm
else
tqdm
(
total
=
len
(
input_requests
))
benchmark_start_time
=
time
.
perf_counter
()
benchmark_start_time
=
time
.
perf_counter
()
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
4a292f67
...
@@ -392,6 +392,9 @@ class Req:
...
@@ -392,6 +392,9 @@ class Req:
return
f
"rid(n=
{
self
.
rid
}
, "
f
"input_ids=
{
self
.
origin_input_ids
}
, "
return
f
"rid(n=
{
self
.
rid
}
, "
f
"input_ids=
{
self
.
origin_input_ids
}
, "
bid
=
0
@
dataclass
@
dataclass
class
ScheduleBatch
:
class
ScheduleBatch
:
"""Store all inforamtion of a batch."""
"""Store all inforamtion of a batch."""
...
@@ -828,7 +831,11 @@ class ScheduleBatch:
...
@@ -828,7 +831,11 @@ class ScheduleBatch:
else
:
else
:
self
.
sampling_info
.
regex_fsms
=
None
self
.
sampling_info
.
regex_fsms
=
None
global
bid
bid
+=
1
return
ModelWorkerBatch
(
return
ModelWorkerBatch
(
bid
=
bid
,
forward_mode
=
self
.
forward_mode
,
forward_mode
=
self
.
forward_mode
,
input_ids
=
self
.
input_ids
,
input_ids
=
self
.
input_ids
,
req_pool_indices
=
self
.
req_pool_indices
,
req_pool_indices
=
self
.
req_pool_indices
,
...
@@ -865,6 +872,8 @@ class ScheduleBatch:
...
@@ -865,6 +872,8 @@ class ScheduleBatch:
@
dataclass
@
dataclass
class
ModelWorkerBatch
:
class
ModelWorkerBatch
:
# The batch id
bid
:
int
# The forward mode
# The forward mode
forward_mode
:
ForwardMode
forward_mode
:
ForwardMode
# The input ids
# The input ids
...
@@ -893,3 +902,21 @@ class ModelWorkerBatch:
...
@@ -893,3 +902,21 @@ class ModelWorkerBatch:
# Sampling info
# Sampling info
sampling_info
:
SamplingBatchInfo
sampling_info
:
SamplingBatchInfo
def
copy
(
self
):
return
ModelWorkerBatch
(
bid
=
self
.
bid
,
forward_mode
=
self
.
forward_mode
,
input_ids
=
self
.
input_ids
.
clone
(),
req_pool_indices
=
self
.
req_pool_indices
,
seq_lens
=
self
.
seq_lens
,
out_cache_loc
=
self
.
out_cache_loc
,
return_logprob
=
self
.
return_logprob
,
top_logprobs_nums
=
self
.
top_logprobs_nums
,
extend_seq_lens
=
self
.
extend_seq_lens
,
extend_prefix_lens
=
self
.
extend_prefix_lens
,
extend_logprob_start_lens
=
self
.
extend_logprob_start_lens
,
image_inputs
=
self
.
image_inputs
,
lora_paths
=
self
.
lora_paths
,
sampling_info
=
self
.
sampling_info
.
copy
(),
)
python/sglang/srt/managers/scheduler.py
View file @
4a292f67
...
@@ -710,7 +710,7 @@ class Scheduler:
...
@@ -710,7 +710,7 @@ class Scheduler:
next_token_ids
next_token_ids
)
)
if
logits_output
:
if
batch
.
return_logprob
:
# Move logprobs to cpu
# Move logprobs to cpu
if
logits_output
.
next_token_logprobs
is
not
None
:
if
logits_output
.
next_token_logprobs
is
not
None
:
logits_output
.
next_token_logprobs
=
(
logits_output
.
next_token_logprobs
=
(
...
@@ -786,7 +786,7 @@ class Scheduler:
...
@@ -786,7 +786,7 @@ class Scheduler:
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
# Move logprobs to cpu
# Move logprobs to cpu
if
logits_output
.
next_token_logprobs
is
not
None
:
if
batch
.
return_logprob
:
next_token_logprobs
=
logits_output
.
next_token_logprobs
[
next_token_logprobs
=
logits_output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
next_token_ids
.
device
),
torch
.
arange
(
len
(
next_token_ids
),
device
=
next_token_ids
.
device
),
next_token_ids
,
next_token_ids
,
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
4a292f67
...
@@ -202,3 +202,14 @@ class SamplingBatchInfo:
...
@@ -202,3 +202,14 @@ class SamplingBatchInfo:
self
.
logit_bias
=
SamplingBatchInfo
.
merge_bias_tensor
(
self
.
logit_bias
=
SamplingBatchInfo
.
merge_bias_tensor
(
self
.
logit_bias
,
other
.
logit_bias
,
len
(
self
),
len
(
other
),
self
.
device
self
.
logit_bias
,
other
.
logit_bias
,
len
(
self
),
len
(
other
),
self
.
device
)
)
def
copy
(
self
):
return
SamplingBatchInfo
(
temperatures
=
self
.
temperatures
,
top_ps
=
self
.
top_ps
,
top_ks
=
self
.
top_ks
,
min_ps
=
self
.
min_ps
,
need_min_p_sampling
=
self
.
need_min_p_sampling
,
vocab_size
=
self
.
vocab_size
,
device
=
self
.
device
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment