Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2854a5ea
Unverified
Commit
2854a5ea
authored
Sep 23, 2024
by
Lianmin Zheng
Committed by
GitHub
Sep 23, 2024
Browse files
Fix the overhead due to penalizer in bench_latency (#1496)
parent
42a2d82b
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
9 additions
and
16 deletions
+9
-16
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+2
-2
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+3
-5
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+1
-1
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+2
-4
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+0
-3
scripts/playground/reference_hf.py
scripts/playground/reference_hf.py
+1
-1
No files found.
python/sglang/bench_latency.py
View file @
2854a5ea
...
@@ -260,7 +260,7 @@ def correctness_test(
...
@@ -260,7 +260,7 @@ def correctness_test(
# Decode
# Decode
output_ids
=
[
input_ids
[
i
]
+
[
next_token_ids
[
i
]]
for
i
in
range
(
len
(
input_ids
))]
output_ids
=
[
input_ids
[
i
]
+
[
next_token_ids
[
i
]]
for
i
in
range
(
len
(
input_ids
))]
for
_
in
range
(
bench_args
.
output_len
[
0
]):
for
_
in
range
(
bench_args
.
output_len
[
0
]
-
1
):
next_token_ids
,
_
=
decode
(
next_token_ids
,
batch
,
model_runner
)
next_token_ids
,
_
=
decode
(
next_token_ids
,
batch
,
model_runner
)
for
i
in
range
(
len
(
reqs
)):
for
i
in
range
(
len
(
reqs
)):
output_ids
[
i
].
append
(
next_token_ids
[
i
])
output_ids
[
i
].
append
(
next_token_ids
[
i
])
...
@@ -311,7 +311,7 @@ def latency_test_run_once(
...
@@ -311,7 +311,7 @@ def latency_test_run_once(
# Decode
# Decode
decode_latencies
=
[]
decode_latencies
=
[]
for
i
in
range
(
output_len
):
for
i
in
range
(
output_len
-
1
):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
tic
=
time
.
time
()
tic
=
time
.
time
()
next_token_ids
,
_
=
decode
(
next_token_ids
,
batch
,
model_runner
)
next_token_ids
,
_
=
decode
(
next_token_ids
,
batch
,
model_runner
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
2854a5ea
...
@@ -429,7 +429,7 @@ class ScheduleBatch:
...
@@ -429,7 +429,7 @@ class ScheduleBatch:
def
prepare_for_extend
(
self
,
vocab_size
:
int
):
def
prepare_for_extend
(
self
,
vocab_size
:
int
):
self
.
forward_mode
=
ForwardMode
.
EXTEND
self
.
forward_mode
=
ForwardMode
.
EXTEND
bs
=
self
.
batch_size
(
)
bs
=
len
(
self
.
reqs
)
reqs
=
self
.
reqs
reqs
=
self
.
reqs
input_ids
=
[
r
.
fill_ids
[
len
(
r
.
prefix_indices
)
:]
for
r
in
reqs
]
input_ids
=
[
r
.
fill_ids
[
len
(
r
.
prefix_indices
)
:]
for
r
in
reqs
]
extend_num_tokens
=
sum
(
len
(
ids
)
for
ids
in
input_ids
)
extend_num_tokens
=
sum
(
len
(
ids
)
for
ids
in
input_ids
)
...
@@ -509,7 +509,7 @@ class ScheduleBatch:
...
@@ -509,7 +509,7 @@ class ScheduleBatch:
self
.
extend_logprob_start_lens_cpu
.
extend
([
0
]
*
running_bs
)
self
.
extend_logprob_start_lens_cpu
.
extend
([
0
]
*
running_bs
)
def
check_decode_mem
(
self
):
def
check_decode_mem
(
self
):
bs
=
self
.
batch_size
(
)
bs
=
len
(
self
.
reqs
)
if
self
.
token_to_kv_pool
.
available_size
()
>=
bs
:
if
self
.
token_to_kv_pool
.
available_size
()
>=
bs
:
return
True
return
True
...
@@ -680,14 +680,12 @@ class ScheduleBatch:
...
@@ -680,14 +680,12 @@ class ScheduleBatch:
r
.
output_ids
[
-
1
]
if
r
.
output_ids
else
r
.
origin_input_ids
[
-
1
]
r
.
output_ids
[
-
1
]
if
r
.
output_ids
else
r
.
origin_input_ids
[
-
1
]
for
r
in
self
.
reqs
for
r
in
self
.
reqs
]
]
else
:
self
.
sampling_info
.
penalizer_orchestrator
.
cumulate_input_tokens
(
input_ids
)
self
.
input_ids
=
torch
.
tensor
(
input_ids
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
input_ids
=
torch
.
tensor
(
input_ids
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
seq_lens
.
add_
(
1
)
self
.
seq_lens
.
add_
(
1
)
# Alloc mem
# Alloc mem
bs
=
self
.
batch_size
(
)
bs
=
len
(
self
.
reqs
)
self
.
out_cache_loc
=
self
.
alloc_token_slots
(
bs
)
self
.
out_cache_loc
=
self
.
alloc_token_slots
(
bs
)
self
.
req_to_token_pool
.
req_to_token
[
self
.
req_to_token_pool
.
req_to_token
[
...
...
python/sglang/srt/managers/tp_worker.py
View file @
2854a5ea
...
@@ -215,6 +215,7 @@ class ModelTpServer:
...
@@ -215,6 +215,7 @@ class ModelTpServer:
self
.
new_token_ratio_decay
=
global_config
.
new_token_ratio_decay
self
.
new_token_ratio_decay
=
global_config
.
new_token_ratio_decay
self
.
do_not_get_new_batch
=
False
self
.
do_not_get_new_batch
=
False
@
torch
.
inference_mode
()
def
exposed_step
(
self
,
recv_reqs
:
List
):
def
exposed_step
(
self
,
recv_reqs
:
List
):
try
:
try
:
# Recv requests
# Recv requests
...
@@ -246,7 +247,6 @@ class ModelTpServer:
...
@@ -246,7 +247,6 @@ class ModelTpServer:
self
.
out_pyobjs
=
[]
self
.
out_pyobjs
=
[]
return
ret
return
ret
@
torch
.
inference_mode
()
def
forward_step
(
self
):
def
forward_step
(
self
):
if
self
.
do_not_get_new_batch
and
self
.
current_inflight_req
is
None
:
if
self
.
do_not_get_new_batch
and
self
.
current_inflight_req
is
None
:
new_batch
=
None
new_batch
=
None
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
2854a5ea
...
@@ -97,14 +97,12 @@ class InputMetadata:
...
@@ -97,14 +97,12 @@ class InputMetadata:
self
.
modalities
=
[
r
.
modalities
for
r
in
reqs
]
self
.
modalities
=
[
r
.
modalities
for
r
in
reqs
]
def
compute_positions
(
self
,
batch
:
ScheduleBatch
):
def
compute_positions
(
self
,
batch
:
ScheduleBatch
):
position_ids_offsets
=
batch
.
position_ids_offsets
if
self
.
forward_mode
.
is_decode
():
if
self
.
forward_mode
.
is_decode
():
if
True
:
if
True
:
self
.
positions
=
self
.
seq_lens
-
1
self
.
positions
=
self
.
seq_lens
-
1
else
:
else
:
# Deprecated
# Deprecated
self
.
positions
=
(
self
.
seq_lens
-
1
)
+
position_ids_offsets
self
.
positions
=
(
self
.
seq_lens
-
1
)
+
batch
.
position_ids_offsets
else
:
else
:
if
True
:
if
True
:
self
.
positions
=
torch
.
tensor
(
self
.
positions
=
torch
.
tensor
(
...
@@ -119,7 +117,7 @@ class InputMetadata:
...
@@ -119,7 +117,7 @@ class InputMetadata:
)
)
else
:
else
:
# Deprecated
# Deprecated
position_ids_offsets_cpu
=
position_ids_offsets
.
cpu
().
numpy
()
position_ids_offsets_cpu
=
batch
.
position_ids_offsets
.
cpu
().
numpy
()
self
.
positions
=
torch
.
tensor
(
self
.
positions
=
torch
.
tensor
(
np
.
concatenate
(
np
.
concatenate
(
[
[
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
2854a5ea
...
@@ -467,7 +467,6 @@ class ModelRunner:
...
@@ -467,7 +467,6 @@ class ModelRunner:
logger
.
info
(
"Capture cuda graph begin. This can take up to several minutes."
)
logger
.
info
(
"Capture cuda graph begin. This can take up to several minutes."
)
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
)
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
)
@
torch
.
inference_mode
()
def
forward_decode
(
self
,
batch
:
ScheduleBatch
):
def
forward_decode
(
self
,
batch
:
ScheduleBatch
):
if
self
.
server_args
.
lora_paths
is
not
None
:
if
self
.
server_args
.
lora_paths
is
not
None
:
self
.
lora_manager
.
prepare_lora_batch
(
batch
)
self
.
lora_manager
.
prepare_lora_batch
(
batch
)
...
@@ -481,7 +480,6 @@ class ModelRunner:
...
@@ -481,7 +480,6 @@ class ModelRunner:
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
)
)
@
torch
.
inference_mode
()
def
forward_extend
(
self
,
batch
:
ScheduleBatch
):
def
forward_extend
(
self
,
batch
:
ScheduleBatch
):
input_metadata
=
InputMetadata
.
from_schedule_batch
(
self
,
batch
)
input_metadata
=
InputMetadata
.
from_schedule_batch
(
self
,
batch
)
if
self
.
server_args
.
lora_paths
is
not
None
:
if
self
.
server_args
.
lora_paths
is
not
None
:
...
@@ -500,7 +498,6 @@ class ModelRunner:
...
@@ -500,7 +498,6 @@ class ModelRunner:
get_embedding
=
True
,
get_embedding
=
True
,
)
)
@
torch
.
inference_mode
()
def
forward_extend_multi_modal
(
self
,
batch
:
ScheduleBatch
):
def
forward_extend_multi_modal
(
self
,
batch
:
ScheduleBatch
):
input_metadata
=
InputMetadata
.
from_schedule_batch
(
self
,
batch
)
input_metadata
=
InputMetadata
.
from_schedule_batch
(
self
,
batch
)
return
self
.
model
.
forward
(
return
self
.
model
.
forward
(
...
...
scripts/playground/reference_hf.py
View file @
2854a5ea
...
@@ -45,7 +45,7 @@ def normal_text(args):
...
@@ -45,7 +45,7 @@ def normal_text(args):
"The capital of the United Kindom is"
,
"The capital of the United Kindom is"
,
"Today is a sunny day and I like"
,
"Today is a sunny day and I like"
,
]
]
max_new_tokens
=
1
7
max_new_tokens
=
1
6
torch
.
cuda
.
set_device
(
0
)
torch
.
cuda
.
set_device
(
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment