Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
303ef888
Unverified
Commit
303ef888
authored
Jun 22, 2024
by
Lianmin Zheng
Committed by
GitHub
Jun 22, 2024
Browse files
Clean up logits processor (#558)
parent
92cb93f3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
127 additions
and
113 deletions
+127
-113
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+51
-25
python/sglang/srt/managers/controller/tp_worker.py
python/sglang/srt/managers/controller/tp_worker.py
+73
-85
python/sglang/srt/server.py
python/sglang/srt/server.py
+3
-3
No files found.
python/sglang/srt/layers/logits_processor.py
View file @
303ef888
"""Logits processing."""
import
dataclasses
from
typing
import
List
import
torch
from
torch
import
nn
from
vllm.distributed
import
(
...
...
@@ -10,6 +13,24 @@ from vllm.distributed import (
from
sglang.srt.managers.controller.model_runner
import
ForwardMode
,
InputMetadata
@
dataclasses
.
dataclass
class
LogitProcessorOutput
:
# The logits of the next tokens. shape: [#seq, vocab_size]
next_token_logits
:
torch
.
Tensor
# The logprobs of the next tokens. shape: [#seq, vocab_size]
next_token_logprobs
:
torch
.
Tensor
# The normlaized logprobs of prompts. shape: [#seq]
normalized_prompt_logprobs
:
torch
.
Tensor
# The logprobs of prefill tokens. shape: [#token, vocab_size]
prefill_token_logprobs
:
torch
.
Tensor
# The logprob and id of the top-k tokens in prefill positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
prefill_top_logprobs
:
List
# The logprob and id of the top-k tokens in decode positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
decode_top_logprobs
:
List
class
LogitsProcessor
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
...
...
@@ -39,6 +60,7 @@ class LogitsProcessor(nn.Module):
return
normalized_prompt_logprobs
def
_get_top_logprobs
(
self
,
all_logprobs
,
input_metadata
:
InputMetadata
):
# TODO: vectorize the code below
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
decode_top_logprobs
=
[]
for
i
in
range
(
all_logprobs
.
shape
[
0
]):
...
...
@@ -51,7 +73,6 @@ class LogitsProcessor(nn.Module):
else
:
prefill_top_logprobs
,
decode_top_logprobs
=
[],
[]
pt
=
0
# NOTE: the GPU-CPU overhead can be reduced
extend_seq_lens_cpu
=
input_metadata
.
extend_seq_lens
.
tolist
()
for
i
,
extend_seq_len
in
enumerate
(
extend_seq_lens_cpu
):
if
extend_seq_len
==
0
:
...
...
@@ -71,18 +92,15 @@ class LogitsProcessor(nn.Module):
return
prefill_top_logprobs
,
decode_top_logprobs
def
forward
(
self
,
input_ids
,
hidden_states
,
weight
,
input_metadata
:
InputMetadata
):
# Get last index for next token prediction, except for DECODE mode.
last_index
=
None
if
input_metadata
.
forward_mode
!=
ForwardMode
.
DECODE
:
# Get the last hidden states and last logits for the next token prediction
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
last_index
=
None
last_hidden
=
hidden_states
else
:
last_index
=
(
torch
.
cumsum
(
input_metadata
.
extend_seq_lens
,
dim
=
0
,
dtype
=
torch
.
long
)
-
1
)
# Get the last hidden states and last logits
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
last_hidden
=
hidden_states
else
:
last_hidden
=
hidden_states
[
last_index
]
last_logits
=
torch
.
matmul
(
last_hidden
,
weight
.
T
)
...
...
@@ -92,8 +110,14 @@ class LogitsProcessor(nn.Module):
# Return only last_logits if logprob is not requested
if
not
input_metadata
.
return_logprob
:
hidden_states
=
None
return
last_logits
,
(
None
,
None
,
None
,
None
,
None
)
return
LogitProcessorOutput
(
next_token_logits
=
last_logits
,
next_token_logprobs
=
None
,
normalized_prompt_logprobs
=
None
,
prefill_token_logprobs
=
None
,
prefill_top_logprobs
=
None
,
decode_top_logprobs
=
None
,
)
else
:
# When logprob is requested, compute the logits for all tokens.
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
...
...
@@ -108,6 +132,7 @@ class LogitsProcessor(nn.Module):
del
all_logits
all_logprobs
[:]
=
torch
.
nn
.
functional
.
log_softmax
(
all_logprobs
,
dim
=-
1
)
# Get the logprob of top-k tokens
return_top_logprob
=
any
(
x
>
0
for
x
in
input_metadata
.
top_logprobs_nums
)
if
return_top_logprob
:
prefill_top_logprobs
,
decode_top_logprobs
=
self
.
_get_top_logprobs
(
...
...
@@ -117,16 +142,15 @@ class LogitsProcessor(nn.Module):
prefill_top_logprobs
=
decode_top_logprobs
=
None
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
last_logprobs
=
all_logprobs
return
last_logits
,
(
No
ne
,
None
,
None
,
decode
_top_logprobs
,
last
_logprobs
,
return
LogitProcessorOutput
(
next_token_logits
=
last_logits
,
ne
xt_token_logprobs
=
all_logprobs
,
normalized_prompt_logprobs
=
None
,
prefill_token_logprobs
=
None
,
prefill
_top_logprobs
=
None
,
decode_top_logprobs
=
decode_top
_logprobs
,
)
else
:
# Compute the logprobs for the last token of each request.
last_logprobs
=
all_logprobs
[
last_index
]
# Compute the logprobs and normalized logprobs for the prefill tokens.
...
...
@@ -139,12 +163,14 @@ class LogitsProcessor(nn.Module):
normalized_prompt_logprobs
=
self
.
_get_normalized_prompt_logprobs
(
prefill_token_logprobs
,
input_metadata
)
return
last_logits
,
(
prefill_token_logprobs
,
normalized_prompt_logprobs
,
prefill_top_logprobs
,
decode_top_logprobs
,
last_logprobs
,
return
LogitProcessorOutput
(
next_token_logits
=
last_logits
,
next_token_logprobs
=
last_logprobs
,
normalized_prompt_logprobs
=
normalized_prompt_logprobs
,
prefill_token_logprobs
=
prefill_token_logprobs
,
prefill_top_logprobs
=
prefill_top_logprobs
,
decode_top_logprobs
=
decode_top_logprobs
,
)
...
...
python/sglang/srt/managers/controller/tp_worker.py
View file @
303ef888
...
...
@@ -441,33 +441,25 @@ class ModelTpServer:
self
.
model_config
.
vocab_size
,
self
.
int_token_logit_bias
)
# Forward and sample the next tokens
if
batch
.
extend_num_tokens
!=
0
:
# Forward
logits
,
(
prefill_token_logprobs
,
normalized_prompt_logprobs
,
prefill_top_logprobs
,
decode_top_logprobs
,
last_logprobs
,
)
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
if
prefill_token_logprobs
is
not
None
:
prefill_token_logprobs
=
prefill_token_logprobs
.
tolist
()
normalized_prompt_logprobs
=
normalized_prompt_logprobs
.
tolist
()
next_token_ids
,
_
=
batch
.
sample
(
logits
)
# Only transfer the selected logprobs of the next token to CPU to reduce overhead.
if
last_logprobs
is
not
None
:
last_token_logprobs
=
last_logprobs
[
torch
.
arange
(
len
(
batch
.
reqs
),
device
=
next_token_ids
.
device
),
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
next_token_ids
,
_
=
batch
.
sample
(
output
.
next_token_logits
)
# Move logprobs to cpu
if
output
.
next_token_logprobs
is
not
None
:
output
.
next_token_logprobs
=
output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
next_token_ids
.
device
),
next_token_ids
,
].
tolist
()
output
.
prefill_token_logprobs
=
output
.
prefill_token_logprobs
.
tolist
()
output
.
normalized_prompt_logprobs
=
output
.
normalized_prompt_logprobs
.
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
else
:
next_token_ids
=
[
self
.
tokenizer
.
eos_token_id
]
*
len
(
batch
.
reqs
)
# Check finish condition
# Check finish condition
s
pt
=
0
for
i
,
req
in
enumerate
(
batch
.
reqs
):
req
.
completion_tokens_wo_jump_forward
+=
1
...
...
@@ -475,57 +467,59 @@ class ModelTpServer:
req
.
check_finished
()
if
req
.
return_logprob
:
if
req
.
normalized_prompt_logprob
is
None
:
req
.
normalized_prompt_logprob
=
normalized_prompt_logprobs
[
i
]
if
req
.
prefill_token_logprobs
is
None
:
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
req
.
prefill_token_logprobs
=
list
(
zip
(
prefill_token_logprobs
[
pt
:
pt
+
req
.
extend_input_len
-
1
],
req
.
input_ids
[
-
req
.
extend_input_len
+
1
:],
)
)
if
req
.
logprob_start_len
==
0
:
req
.
prefill_token_logprobs
=
[
(
None
,
req
.
input_ids
[
0
])
]
+
req
.
prefill_token_logprobs
if
req
.
last_update_decode_tokens
!=
0
:
req
.
decode_token_logprobs
.
extend
(
list
(
zip
(
prefill_token_logprobs
[
pt
+
req
.
extend_input_len
-
req
.
last_update_decode_tokens
:
pt
+
req
.
extend_input_len
-
1
],
req
.
input_ids
[
-
req
.
last_update_decode_tokens
+
1
:],
)
)
)
self
.
add_logprob_return_values
(
i
,
req
,
pt
,
next_token_ids
,
output
)
pt
+=
req
.
extend_input_len
req
.
decode_token_logprobs
.
append
(
(
last_token_logprobs
[
i
],
next_token_ids
[
i
])
)
self
.
handle_finished_requests
(
batch
)
if
req
.
top_logprobs_num
>
0
:
if
req
.
prefill_top_logprobs
is
None
:
req
.
prefill_top_logprobs
=
prefill_top_logprobs
[
i
]
if
req
.
logprob_start_len
==
0
:
req
.
prefill_top_logprobs
=
[
None
]
+
req
.
prefill_top_logprobs
def
add_logprob_return_values
(
self
,
i
,
req
,
pt
,
next_token_ids
,
output
):
if
req
.
normalized_prompt_logprob
is
None
:
req
.
normalized_prompt_logprob
=
output
.
normalized_prompt_logprobs
[
i
]
if
req
.
last_update_decode_tokens
!=
0
:
req
.
decode_top_logprobs
.
extend
(
prefill_top_logprobs
[
i
][
-
req
.
last_update_decode_tokens
+
1
:]
if
req
.
prefill_token_logprobs
is
None
:
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
req
.
prefill_token_logprobs
=
list
(
zip
(
output
.
prefill_token_logprobs
[
pt
:
pt
+
req
.
extend_input_len
-
1
],
req
.
input_ids
[
-
req
.
extend_input_len
+
1
:],
)
)
if
req
.
logprob_start_len
==
0
:
req
.
prefill_token_logprobs
=
[
(
None
,
req
.
input_ids
[
0
])
]
+
req
.
prefill_token_logprobs
if
req
.
last_update_decode_tokens
!=
0
:
req
.
decode_token_logprobs
.
extend
(
list
(
zip
(
output
.
prefill_token_logprobs
[
pt
+
req
.
extend_input_len
-
req
.
last_update_decode_tokens
:
pt
+
req
.
extend_input_len
-
1
],
req
.
input_ids
[
-
req
.
last_update_decode_tokens
+
1
:],
)
req
.
decode_top_logprobs
.
append
(
decode_top_logprobs
[
i
])
)
)
pt
+=
req
.
extend_input_len
req
.
decode_token_logprobs
.
append
(
(
output
.
next_token_logprobs
[
i
],
next_token_ids
[
i
])
)
self
.
handle_finished_requests
(
batch
)
if
req
.
top_logprobs_num
>
0
:
if
req
.
prefill_top_logprobs
is
None
:
req
.
prefill_top_logprobs
=
output
.
prefill_top_logprobs
[
i
]
if
req
.
logprob_start_len
==
0
:
req
.
prefill_top_logprobs
=
[
None
]
+
req
.
prefill_top_logprobs
if
req
.
last_update_decode_tokens
!=
0
:
req
.
decode_top_logprobs
.
extend
(
output
.
prefill_top_logprobs
[
i
][
-
req
.
last_update_decode_tokens
+
1
:]
)
req
.
decode_top_logprobs
.
append
(
output
.
decode_top_logprobs
[
i
])
def
cache_filled_batch
(
self
,
batch
:
Batch
):
req_pool_indices_cpu
=
batch
.
req_pool_indices
.
cpu
().
numpy
()
...
...
@@ -540,7 +534,7 @@ class ModelTpServer:
req
.
prefix_indices
,
req
.
last_node
=
new_prefix_indices
,
new_last_node
def
forward_decode_batch
(
self
,
batch
:
Batch
):
#
c
heck if decode out of memory
#
C
heck if decode out of memory
if
not
batch
.
check_decode_mem
():
old_ratio
=
self
.
new_token_ratio
self
.
new_token_ratio
=
min
(
old_ratio
+
self
.
new_token_ratio_recovery
,
1.0
)
...
...
@@ -559,9 +553,8 @@ class ModelTpServer:
)
if
not
self
.
disable_regex_jump_forward
:
#
c
heck for jump-forward
#
C
heck for jump-forward
jump_forward_reqs
=
batch
.
check_for_jump_forward
(
self
.
model_runner
)
self
.
forward_queue
.
extend
(
jump_forward_reqs
)
if
batch
.
is_empty
():
return
...
...
@@ -570,23 +563,19 @@ class ModelTpServer:
self
.
decode_forward_ct
=
(
self
.
decode_forward_ct
+
1
)
%
(
1
<<
30
)
batch
.
prepare_for_decode
()
# Forward
logits
,
(
_
,
_
,
_
,
decode_top_logprobs
,
last_logprobs
,
)
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
)
next_token_ids
,
_
=
batch
.
sample
(
logits
)
next_token_ids
=
next_token_ids
.
tolist
()
# Forward and sample the next tokens
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
)
next_token_ids
,
_
=
batch
.
sample
(
output
.
next_token_logits
)
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
if
last_logprobs
is
not
None
:
new_token_logprobs
=
last_logprobs
[
torch
.
arange
(
len
(
batch
.
reqs
)),
next_token_ids
# Move logprobs to cpu
if
output
.
next_token_logprobs
is
not
None
:
next_token_logprobs
=
output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
next_token_ids
.
device
),
next_token_ids
,
].
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
# Check finish condition
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
req
.
completion_tokens_wo_jump_forward
+=
1
...
...
@@ -594,10 +583,9 @@ class ModelTpServer:
req
.
check_finished
()
if
req
.
return_logprob
:
req
.
decode_token_logprobs
.
append
((
new_token_logprobs
[
i
],
next_token_id
))
if
req
.
top_logprobs_num
>
0
:
req
.
decode_top_logprobs
.
append
(
decode_top_logprobs
[
i
])
req
.
decode_token_logprobs
.
append
((
next_token_logprobs
[
i
],
next_token_id
))
if
req
.
top_logprobs_num
>
0
:
req
.
decode_top_logprobs
.
append
(
output
.
decode_top_logprobs
[
i
])
self
.
handle_finished_requests
(
batch
)
...
...
python/sglang/srt/server.py
View file @
303ef888
...
...
@@ -253,7 +253,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
try
:
requests
.
get
(
url
+
"/get_model_info"
,
timeout
=
5
,
headers
=
headers
)
break
except
requests
.
exceptions
.
RequestException
as
e
:
except
requests
.
exceptions
.
RequestException
:
pass
# Send a warmup request
...
...
@@ -265,14 +265,14 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
"text"
:
"The capital city of France is"
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
"max_new_tokens"
:
8
,
},
},
headers
=
headers
,
timeout
=
600
,
)
assert
res
.
status_code
==
200
except
Exception
:
except
Exception
as
e
:
if
pipe_finish_writer
is
not
None
:
pipe_finish_writer
.
send
(
get_exception_traceback
())
print
(
f
"Initialization failed. warmup error:
{
e
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment