Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
303ef888
"sgl-kernel/python/vscode:/vscode.git/clone" did not exist on "58f10679e1850fdc86046057c23bac5193156de9"
Unverified
Commit
303ef888
authored
Jun 22, 2024
by
Lianmin Zheng
Committed by
GitHub
Jun 22, 2024
Browse files
Clean up logits processor (#558)
parent
92cb93f3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
127 additions
and
113 deletions
+127
-113
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+51
-25
python/sglang/srt/managers/controller/tp_worker.py
python/sglang/srt/managers/controller/tp_worker.py
+73
-85
python/sglang/srt/server.py
python/sglang/srt/server.py
+3
-3
No files found.
python/sglang/srt/layers/logits_processor.py
View file @
303ef888
"""Logits processing."""
import
dataclasses
from
typing
import
List
import
torch
from
torch
import
nn
from
vllm.distributed
import
(
...
...
@@ -10,6 +13,24 @@ from vllm.distributed import (
from
sglang.srt.managers.controller.model_runner
import
ForwardMode
,
InputMetadata
@
dataclasses
.
dataclass
class
LogitProcessorOutput
:
# The logits of the next tokens. shape: [#seq, vocab_size]
next_token_logits
:
torch
.
Tensor
# The logprobs of the next tokens. shape: [#seq, vocab_size]
next_token_logprobs
:
torch
.
Tensor
# The normlaized logprobs of prompts. shape: [#seq]
normalized_prompt_logprobs
:
torch
.
Tensor
# The logprobs of prefill tokens. shape: [#token, vocab_size]
prefill_token_logprobs
:
torch
.
Tensor
# The logprob and id of the top-k tokens in prefill positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
prefill_top_logprobs
:
List
# The logprob and id of the top-k tokens in decode positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
decode_top_logprobs
:
List
class
LogitsProcessor
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
...
...
@@ -39,6 +60,7 @@ class LogitsProcessor(nn.Module):
return
normalized_prompt_logprobs
def
_get_top_logprobs
(
self
,
all_logprobs
,
input_metadata
:
InputMetadata
):
# TODO: vectorize the code below
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
decode_top_logprobs
=
[]
for
i
in
range
(
all_logprobs
.
shape
[
0
]):
...
...
@@ -51,7 +73,6 @@ class LogitsProcessor(nn.Module):
else
:
prefill_top_logprobs
,
decode_top_logprobs
=
[],
[]
pt
=
0
# NOTE: the GPU-CPU overhead can be reduced
extend_seq_lens_cpu
=
input_metadata
.
extend_seq_lens
.
tolist
()
for
i
,
extend_seq_len
in
enumerate
(
extend_seq_lens_cpu
):
if
extend_seq_len
==
0
:
...
...
@@ -71,18 +92,15 @@ class LogitsProcessor(nn.Module):
return
prefill_top_logprobs
,
decode_top_logprobs
def
forward
(
self
,
input_ids
,
hidden_states
,
weight
,
input_metadata
:
InputMetadata
):
# Get last index for next token prediction, except for DECODE mode.
last_index
=
None
if
input_metadata
.
forward_mode
!=
ForwardMode
.
DECODE
:
# Get the last hidden states and last logits for the next token prediction
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
last_index
=
None
last_hidden
=
hidden_states
else
:
last_index
=
(
torch
.
cumsum
(
input_metadata
.
extend_seq_lens
,
dim
=
0
,
dtype
=
torch
.
long
)
-
1
)
# Get the last hidden states and last logits
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
last_hidden
=
hidden_states
else
:
last_hidden
=
hidden_states
[
last_index
]
last_logits
=
torch
.
matmul
(
last_hidden
,
weight
.
T
)
...
...
@@ -92,8 +110,14 @@ class LogitsProcessor(nn.Module):
# Return only last_logits if logprob is not requested
if
not
input_metadata
.
return_logprob
:
hidden_states
=
None
return
last_logits
,
(
None
,
None
,
None
,
None
,
None
)
return
LogitProcessorOutput
(
next_token_logits
=
last_logits
,
next_token_logprobs
=
None
,
normalized_prompt_logprobs
=
None
,
prefill_token_logprobs
=
None
,
prefill_top_logprobs
=
None
,
decode_top_logprobs
=
None
,
)
else
:
# When logprob is requested, compute the logits for all tokens.
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
...
...
@@ -108,6 +132,7 @@ class LogitsProcessor(nn.Module):
del
all_logits
all_logprobs
[:]
=
torch
.
nn
.
functional
.
log_softmax
(
all_logprobs
,
dim
=-
1
)
# Get the logprob of top-k tokens
return_top_logprob
=
any
(
x
>
0
for
x
in
input_metadata
.
top_logprobs_nums
)
if
return_top_logprob
:
prefill_top_logprobs
,
decode_top_logprobs
=
self
.
_get_top_logprobs
(
...
...
@@ -117,16 +142,15 @@ class LogitsProcessor(nn.Module):
prefill_top_logprobs
=
decode_top_logprobs
=
None
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
last_logprobs
=
all_logprobs
return
last_logits
,
(
No
ne
,
None
,
None
,
decode
_top_logprobs
,
last
_logprobs
,
return
LogitProcessorOutput
(
next_token_logits
=
last_logits
,
ne
xt_token_logprobs
=
all_logprobs
,
normalized_prompt_logprobs
=
None
,
prefill_token_logprobs
=
None
,
prefill
_top_logprobs
=
None
,
decode_top_logprobs
=
decode_top
_logprobs
,
)
else
:
# Compute the logprobs for the last token of each request.
last_logprobs
=
all_logprobs
[
last_index
]
# Compute the logprobs and normalized logprobs for the prefill tokens.
...
...
@@ -139,12 +163,14 @@ class LogitsProcessor(nn.Module):
normalized_prompt_logprobs
=
self
.
_get_normalized_prompt_logprobs
(
prefill_token_logprobs
,
input_metadata
)
return
last_logits
,
(
prefill_token_logprobs
,
normalized_prompt_logprobs
,
prefill_top_logprobs
,
decode_top_logprobs
,
last_logprobs
,
return
LogitProcessorOutput
(
next_token_logits
=
last_logits
,
next_token_logprobs
=
last_logprobs
,
normalized_prompt_logprobs
=
normalized_prompt_logprobs
,
prefill_token_logprobs
=
prefill_token_logprobs
,
prefill_top_logprobs
=
prefill_top_logprobs
,
decode_top_logprobs
=
decode_top_logprobs
,
)
...
...
python/sglang/srt/managers/controller/tp_worker.py
View file @
303ef888
...
...
@@ -441,33 +441,25 @@ class ModelTpServer:
self
.
model_config
.
vocab_size
,
self
.
int_token_logit_bias
)
# Forward and sample the next tokens
if
batch
.
extend_num_tokens
!=
0
:
# Forward
logits
,
(
prefill_token_logprobs
,
normalized_prompt_logprobs
,
prefill_top_logprobs
,
decode_top_logprobs
,
last_logprobs
,
)
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
if
prefill_token_logprobs
is
not
None
:
prefill_token_logprobs
=
prefill_token_logprobs
.
tolist
()
normalized_prompt_logprobs
=
normalized_prompt_logprobs
.
tolist
()
next_token_ids
,
_
=
batch
.
sample
(
logits
)
# Only transfer the selected logprobs of the next token to CPU to reduce overhead.
if
last_logprobs
is
not
None
:
last_token_logprobs
=
last_logprobs
[
torch
.
arange
(
len
(
batch
.
reqs
),
device
=
next_token_ids
.
device
),
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
next_token_ids
,
_
=
batch
.
sample
(
output
.
next_token_logits
)
# Move logprobs to cpu
if
output
.
next_token_logprobs
is
not
None
:
output
.
next_token_logprobs
=
output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
next_token_ids
.
device
),
next_token_ids
,
].
tolist
()
output
.
prefill_token_logprobs
=
output
.
prefill_token_logprobs
.
tolist
()
output
.
normalized_prompt_logprobs
=
output
.
normalized_prompt_logprobs
.
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
else
:
next_token_ids
=
[
self
.
tokenizer
.
eos_token_id
]
*
len
(
batch
.
reqs
)
# Check finish condition
# Check finish condition
s
pt
=
0
for
i
,
req
in
enumerate
(
batch
.
reqs
):
req
.
completion_tokens_wo_jump_forward
+=
1
...
...
@@ -475,57 +467,59 @@ class ModelTpServer:
req
.
check_finished
()
if
req
.
return_logprob
:
if
req
.
normalized_prompt_logprob
is
None
:
req
.
normalized_prompt_logprob
=
normalized_prompt_logprobs
[
i
]
if
req
.
prefill_token_logprobs
is
None
:
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
req
.
prefill_token_logprobs
=
list
(
zip
(
prefill_token_logprobs
[
pt
:
pt
+
req
.
extend_input_len
-
1
],
req
.
input_ids
[
-
req
.
extend_input_len
+
1
:],
)
)
if
req
.
logprob_start_len
==
0
:
req
.
prefill_token_logprobs
=
[
(
None
,
req
.
input_ids
[
0
])
]
+
req
.
prefill_token_logprobs
if
req
.
last_update_decode_tokens
!=
0
:
req
.
decode_token_logprobs
.
extend
(
list
(
zip
(
prefill_token_logprobs
[
pt
+
req
.
extend_input_len
-
req
.
last_update_decode_tokens
:
pt
+
req
.
extend_input_len
-
1
],
req
.
input_ids
[
-
req
.
last_update_decode_tokens
+
1
:],
)
)
)
self
.
add_logprob_return_values
(
i
,
req
,
pt
,
next_token_ids
,
output
)
pt
+=
req
.
extend_input_len
req
.
decode_token_logprobs
.
append
(
(
last_token_logprobs
[
i
],
next_token_ids
[
i
])
)
self
.
handle_finished_requests
(
batch
)
if
req
.
top_logprobs_num
>
0
:
if
req
.
prefill_top_logprobs
is
None
:
req
.
prefill_top_logprobs
=
prefill_top_logprobs
[
i
]
if
req
.
logprob_start_len
==
0
:
req
.
prefill_top_logprobs
=
[
None
]
+
req
.
prefill_top_logprobs
def
add_logprob_return_values
(
self
,
i
,
req
,
pt
,
next_token_ids
,
output
):
if
req
.
normalized_prompt_logprob
is
None
:
req
.
normalized_prompt_logprob
=
output
.
normalized_prompt_logprobs
[
i
]
if
req
.
last_update_decode_tokens
!=
0
:
req
.
decode_top_logprobs
.
extend
(
prefill_top_logprobs
[
i
][
-
req
.
last_update_decode_tokens
+
1
:]
if
req
.
prefill_token_logprobs
is
None
:
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
req
.
prefill_token_logprobs
=
list
(
zip
(
output
.
prefill_token_logprobs
[
pt
:
pt
+
req
.
extend_input_len
-
1
],
req
.
input_ids
[
-
req
.
extend_input_len
+
1
:],
)
)
if
req
.
logprob_start_len
==
0
:
req
.
prefill_token_logprobs
=
[
(
None
,
req
.
input_ids
[
0
])
]
+
req
.
prefill_token_logprobs
if
req
.
last_update_decode_tokens
!=
0
:
req
.
decode_token_logprobs
.
extend
(
list
(
zip
(
output
.
prefill_token_logprobs
[
pt
+
req
.
extend_input_len
-
req
.
last_update_decode_tokens
:
pt
+
req
.
extend_input_len
-
1
],
req
.
input_ids
[
-
req
.
last_update_decode_tokens
+
1
:],
)
req
.
decode_top_logprobs
.
append
(
decode_top_logprobs
[
i
])
)
)
pt
+=
req
.
extend_input_len
req
.
decode_token_logprobs
.
append
(
(
output
.
next_token_logprobs
[
i
],
next_token_ids
[
i
])
)
self
.
handle_finished_requests
(
batch
)
if
req
.
top_logprobs_num
>
0
:
if
req
.
prefill_top_logprobs
is
None
:
req
.
prefill_top_logprobs
=
output
.
prefill_top_logprobs
[
i
]
if
req
.
logprob_start_len
==
0
:
req
.
prefill_top_logprobs
=
[
None
]
+
req
.
prefill_top_logprobs
if
req
.
last_update_decode_tokens
!=
0
:
req
.
decode_top_logprobs
.
extend
(
output
.
prefill_top_logprobs
[
i
][
-
req
.
last_update_decode_tokens
+
1
:]
)
req
.
decode_top_logprobs
.
append
(
output
.
decode_top_logprobs
[
i
])
def
cache_filled_batch
(
self
,
batch
:
Batch
):
req_pool_indices_cpu
=
batch
.
req_pool_indices
.
cpu
().
numpy
()
...
...
@@ -540,7 +534,7 @@ class ModelTpServer:
req
.
prefix_indices
,
req
.
last_node
=
new_prefix_indices
,
new_last_node
def
forward_decode_batch
(
self
,
batch
:
Batch
):
#
c
heck if decode out of memory
#
C
heck if decode out of memory
if
not
batch
.
check_decode_mem
():
old_ratio
=
self
.
new_token_ratio
self
.
new_token_ratio
=
min
(
old_ratio
+
self
.
new_token_ratio_recovery
,
1.0
)
...
...
@@ -559,9 +553,8 @@ class ModelTpServer:
)
if
not
self
.
disable_regex_jump_forward
:
#
c
heck for jump-forward
#
C
heck for jump-forward
jump_forward_reqs
=
batch
.
check_for_jump_forward
(
self
.
model_runner
)
self
.
forward_queue
.
extend
(
jump_forward_reqs
)
if
batch
.
is_empty
():
return
...
...
@@ -570,23 +563,19 @@ class ModelTpServer:
self
.
decode_forward_ct
=
(
self
.
decode_forward_ct
+
1
)
%
(
1
<<
30
)
batch
.
prepare_for_decode
()
# Forward
logits
,
(
_
,
_
,
_
,
decode_top_logprobs
,
last_logprobs
,
)
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
)
next_token_ids
,
_
=
batch
.
sample
(
logits
)
next_token_ids
=
next_token_ids
.
tolist
()
# Forward and sample the next tokens
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
)
next_token_ids
,
_
=
batch
.
sample
(
output
.
next_token_logits
)
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
if
last_logprobs
is
not
None
:
new_token_logprobs
=
last_logprobs
[
torch
.
arange
(
len
(
batch
.
reqs
)),
next_token_ids
# Move logprobs to cpu
if
output
.
next_token_logprobs
is
not
None
:
next_token_logprobs
=
output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
next_token_ids
.
device
),
next_token_ids
,
].
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
# Check finish condition
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
req
.
completion_tokens_wo_jump_forward
+=
1
...
...
@@ -594,10 +583,9 @@ class ModelTpServer:
req
.
check_finished
()
if
req
.
return_logprob
:
req
.
decode_token_logprobs
.
append
((
new_token_logprobs
[
i
],
next_token_id
))
if
req
.
top_logprobs_num
>
0
:
req
.
decode_top_logprobs
.
append
(
decode_top_logprobs
[
i
])
req
.
decode_token_logprobs
.
append
((
next_token_logprobs
[
i
],
next_token_id
))
if
req
.
top_logprobs_num
>
0
:
req
.
decode_top_logprobs
.
append
(
output
.
decode_top_logprobs
[
i
])
self
.
handle_finished_requests
(
batch
)
...
...
python/sglang/srt/server.py
View file @
303ef888
...
...
@@ -253,7 +253,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
try
:
requests
.
get
(
url
+
"/get_model_info"
,
timeout
=
5
,
headers
=
headers
)
break
except
requests
.
exceptions
.
RequestException
as
e
:
except
requests
.
exceptions
.
RequestException
:
pass
# Send a warmup request
...
...
@@ -265,14 +265,14 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
"text"
:
"The capital city of France is"
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
"max_new_tokens"
:
8
,
},
},
headers
=
headers
,
timeout
=
600
,
)
assert
res
.
status_code
==
200
except
Exception
:
except
Exception
as
e
:
if
pipe_finish_writer
is
not
None
:
pipe_finish_writer
.
send
(
get_exception_traceback
())
print
(
f
"Initialization failed. warmup error:
{
e
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment