Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
303ef888
Unverified
Commit
303ef888
authored
Jun 22, 2024
by
Lianmin Zheng
Committed by
GitHub
Jun 22, 2024
Browse files
Clean up logits processor (#558)
parent
92cb93f3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
127 additions
and
113 deletions
+127
-113
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+51
-25
python/sglang/srt/managers/controller/tp_worker.py
python/sglang/srt/managers/controller/tp_worker.py
+73
-85
python/sglang/srt/server.py
python/sglang/srt/server.py
+3
-3
No files found.
python/sglang/srt/layers/logits_processor.py
View file @
303ef888
"""Logits processing."""
"""Logits processing."""
import
dataclasses
from
typing
import
List
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
vllm.distributed
import
(
from
vllm.distributed
import
(
...
@@ -10,6 +13,24 @@ from vllm.distributed import (
...
@@ -10,6 +13,24 @@ from vllm.distributed import (
from
sglang.srt.managers.controller.model_runner
import
ForwardMode
,
InputMetadata
from
sglang.srt.managers.controller.model_runner
import
ForwardMode
,
InputMetadata
@
dataclasses
.
dataclass
class
LogitProcessorOutput
:
# The logits of the next tokens. shape: [#seq, vocab_size]
next_token_logits
:
torch
.
Tensor
# The logprobs of the next tokens. shape: [#seq, vocab_size]
next_token_logprobs
:
torch
.
Tensor
# The normlaized logprobs of prompts. shape: [#seq]
normalized_prompt_logprobs
:
torch
.
Tensor
# The logprobs of prefill tokens. shape: [#token, vocab_size]
prefill_token_logprobs
:
torch
.
Tensor
# The logprob and id of the top-k tokens in prefill positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
prefill_top_logprobs
:
List
# The logprob and id of the top-k tokens in decode positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
decode_top_logprobs
:
List
class
LogitsProcessor
(
nn
.
Module
):
class
LogitsProcessor
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
()
super
().
__init__
()
...
@@ -39,6 +60,7 @@ class LogitsProcessor(nn.Module):
...
@@ -39,6 +60,7 @@ class LogitsProcessor(nn.Module):
return
normalized_prompt_logprobs
return
normalized_prompt_logprobs
def
_get_top_logprobs
(
self
,
all_logprobs
,
input_metadata
:
InputMetadata
):
def
_get_top_logprobs
(
self
,
all_logprobs
,
input_metadata
:
InputMetadata
):
# TODO: vectorize the code below
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
decode_top_logprobs
=
[]
decode_top_logprobs
=
[]
for
i
in
range
(
all_logprobs
.
shape
[
0
]):
for
i
in
range
(
all_logprobs
.
shape
[
0
]):
...
@@ -51,7 +73,6 @@ class LogitsProcessor(nn.Module):
...
@@ -51,7 +73,6 @@ class LogitsProcessor(nn.Module):
else
:
else
:
prefill_top_logprobs
,
decode_top_logprobs
=
[],
[]
prefill_top_logprobs
,
decode_top_logprobs
=
[],
[]
pt
=
0
pt
=
0
# NOTE: the GPU-CPU overhead can be reduced
extend_seq_lens_cpu
=
input_metadata
.
extend_seq_lens
.
tolist
()
extend_seq_lens_cpu
=
input_metadata
.
extend_seq_lens
.
tolist
()
for
i
,
extend_seq_len
in
enumerate
(
extend_seq_lens_cpu
):
for
i
,
extend_seq_len
in
enumerate
(
extend_seq_lens_cpu
):
if
extend_seq_len
==
0
:
if
extend_seq_len
==
0
:
...
@@ -71,18 +92,15 @@ class LogitsProcessor(nn.Module):
...
@@ -71,18 +92,15 @@ class LogitsProcessor(nn.Module):
return
prefill_top_logprobs
,
decode_top_logprobs
return
prefill_top_logprobs
,
decode_top_logprobs
def
forward
(
self
,
input_ids
,
hidden_states
,
weight
,
input_metadata
:
InputMetadata
):
def
forward
(
self
,
input_ids
,
hidden_states
,
weight
,
input_metadata
:
InputMetadata
):
# Get last index for next token prediction, except for DECODE mode.
# Get the last hidden states and last logits for the next token prediction
last_index
=
None
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
if
input_metadata
.
forward_mode
!=
ForwardMode
.
DECODE
:
last_index
=
None
last_hidden
=
hidden_states
else
:
last_index
=
(
last_index
=
(
torch
.
cumsum
(
input_metadata
.
extend_seq_lens
,
dim
=
0
,
dtype
=
torch
.
long
)
torch
.
cumsum
(
input_metadata
.
extend_seq_lens
,
dim
=
0
,
dtype
=
torch
.
long
)
-
1
-
1
)
)
# Get the last hidden states and last logits
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
last_hidden
=
hidden_states
else
:
last_hidden
=
hidden_states
[
last_index
]
last_hidden
=
hidden_states
[
last_index
]
last_logits
=
torch
.
matmul
(
last_hidden
,
weight
.
T
)
last_logits
=
torch
.
matmul
(
last_hidden
,
weight
.
T
)
...
@@ -92,8 +110,14 @@ class LogitsProcessor(nn.Module):
...
@@ -92,8 +110,14 @@ class LogitsProcessor(nn.Module):
# Return only last_logits if logprob is not requested
# Return only last_logits if logprob is not requested
if
not
input_metadata
.
return_logprob
:
if
not
input_metadata
.
return_logprob
:
hidden_states
=
None
return
LogitProcessorOutput
(
return
last_logits
,
(
None
,
None
,
None
,
None
,
None
)
next_token_logits
=
last_logits
,
next_token_logprobs
=
None
,
normalized_prompt_logprobs
=
None
,
prefill_token_logprobs
=
None
,
prefill_top_logprobs
=
None
,
decode_top_logprobs
=
None
,
)
else
:
else
:
# When logprob is requested, compute the logits for all tokens.
# When logprob is requested, compute the logits for all tokens.
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
...
@@ -108,6 +132,7 @@ class LogitsProcessor(nn.Module):
...
@@ -108,6 +132,7 @@ class LogitsProcessor(nn.Module):
del
all_logits
del
all_logits
all_logprobs
[:]
=
torch
.
nn
.
functional
.
log_softmax
(
all_logprobs
,
dim
=-
1
)
all_logprobs
[:]
=
torch
.
nn
.
functional
.
log_softmax
(
all_logprobs
,
dim
=-
1
)
# Get the logprob of top-k tokens
return_top_logprob
=
any
(
x
>
0
for
x
in
input_metadata
.
top_logprobs_nums
)
return_top_logprob
=
any
(
x
>
0
for
x
in
input_metadata
.
top_logprobs_nums
)
if
return_top_logprob
:
if
return_top_logprob
:
prefill_top_logprobs
,
decode_top_logprobs
=
self
.
_get_top_logprobs
(
prefill_top_logprobs
,
decode_top_logprobs
=
self
.
_get_top_logprobs
(
...
@@ -117,16 +142,15 @@ class LogitsProcessor(nn.Module):
...
@@ -117,16 +142,15 @@ class LogitsProcessor(nn.Module):
prefill_top_logprobs
=
decode_top_logprobs
=
None
prefill_top_logprobs
=
decode_top_logprobs
=
None
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
last_logprobs
=
all_logprobs
return
LogitProcessorOutput
(
return
last_logits
,
(
next_token_logits
=
last_logits
,
No
ne
,
ne
xt_token_logprobs
=
all_logprobs
,
None
,
normalized_prompt_logprobs
=
None
,
None
,
prefill_token_logprobs
=
None
,
decode
_top_logprobs
,
prefill
_top_logprobs
=
None
,
last
_logprobs
,
decode_top_logprobs
=
decode_top
_logprobs
,
)
)
else
:
else
:
# Compute the logprobs for the last token of each request.
last_logprobs
=
all_logprobs
[
last_index
]
last_logprobs
=
all_logprobs
[
last_index
]
# Compute the logprobs and normalized logprobs for the prefill tokens.
# Compute the logprobs and normalized logprobs for the prefill tokens.
...
@@ -139,12 +163,14 @@ class LogitsProcessor(nn.Module):
...
@@ -139,12 +163,14 @@ class LogitsProcessor(nn.Module):
normalized_prompt_logprobs
=
self
.
_get_normalized_prompt_logprobs
(
normalized_prompt_logprobs
=
self
.
_get_normalized_prompt_logprobs
(
prefill_token_logprobs
,
input_metadata
prefill_token_logprobs
,
input_metadata
)
)
return
last_logits
,
(
prefill_token_logprobs
,
return
LogitProcessorOutput
(
normalized_prompt_logprobs
,
next_token_logits
=
last_logits
,
prefill_top_logprobs
,
next_token_logprobs
=
last_logprobs
,
decode_top_logprobs
,
normalized_prompt_logprobs
=
normalized_prompt_logprobs
,
last_logprobs
,
prefill_token_logprobs
=
prefill_token_logprobs
,
prefill_top_logprobs
=
prefill_top_logprobs
,
decode_top_logprobs
=
decode_top_logprobs
,
)
)
...
...
python/sglang/srt/managers/controller/tp_worker.py
View file @
303ef888
...
@@ -441,33 +441,25 @@ class ModelTpServer:
...
@@ -441,33 +441,25 @@ class ModelTpServer:
self
.
model_config
.
vocab_size
,
self
.
int_token_logit_bias
self
.
model_config
.
vocab_size
,
self
.
int_token_logit_bias
)
)
# Forward and sample the next tokens
if
batch
.
extend_num_tokens
!=
0
:
if
batch
.
extend_num_tokens
!=
0
:
# Forward
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
logits
,
(
next_token_ids
,
_
=
batch
.
sample
(
output
.
next_token_logits
)
prefill_token_logprobs
,
normalized_prompt_logprobs
,
# Move logprobs to cpu
prefill_top_logprobs
,
if
output
.
next_token_logprobs
is
not
None
:
decode_top_logprobs
,
output
.
next_token_logprobs
=
output
.
next_token_logprobs
[
last_logprobs
,
torch
.
arange
(
len
(
next_token_ids
),
device
=
next_token_ids
.
device
),
)
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
if
prefill_token_logprobs
is
not
None
:
prefill_token_logprobs
=
prefill_token_logprobs
.
tolist
()
normalized_prompt_logprobs
=
normalized_prompt_logprobs
.
tolist
()
next_token_ids
,
_
=
batch
.
sample
(
logits
)
# Only transfer the selected logprobs of the next token to CPU to reduce overhead.
if
last_logprobs
is
not
None
:
last_token_logprobs
=
last_logprobs
[
torch
.
arange
(
len
(
batch
.
reqs
),
device
=
next_token_ids
.
device
),
next_token_ids
,
next_token_ids
,
].
tolist
()
].
tolist
()
output
.
prefill_token_logprobs
=
output
.
prefill_token_logprobs
.
tolist
()
output
.
normalized_prompt_logprobs
=
output
.
normalized_prompt_logprobs
.
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
else
:
else
:
next_token_ids
=
[
self
.
tokenizer
.
eos_token_id
]
*
len
(
batch
.
reqs
)
next_token_ids
=
[
self
.
tokenizer
.
eos_token_id
]
*
len
(
batch
.
reqs
)
# Check finish condition
# Check finish condition
s
pt
=
0
pt
=
0
for
i
,
req
in
enumerate
(
batch
.
reqs
):
for
i
,
req
in
enumerate
(
batch
.
reqs
):
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
completion_tokens_wo_jump_forward
+=
1
...
@@ -475,57 +467,59 @@ class ModelTpServer:
...
@@ -475,57 +467,59 @@ class ModelTpServer:
req
.
check_finished
()
req
.
check_finished
()
if
req
.
return_logprob
:
if
req
.
return_logprob
:
if
req
.
normalized_prompt_logprob
is
None
:
self
.
add_logprob_return_values
(
i
,
req
,
pt
,
next_token_ids
,
output
)
req
.
normalized_prompt_logprob
=
normalized_prompt_logprobs
[
i
]
pt
+=
req
.
extend_input_len
if
req
.
prefill_token_logprobs
is
None
:
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
req
.
prefill_token_logprobs
=
list
(
zip
(
prefill_token_logprobs
[
pt
:
pt
+
req
.
extend_input_len
-
1
],
req
.
input_ids
[
-
req
.
extend_input_len
+
1
:],
)
)
if
req
.
logprob_start_len
==
0
:
req
.
prefill_token_logprobs
=
[
(
None
,
req
.
input_ids
[
0
])
]
+
req
.
prefill_token_logprobs
if
req
.
last_update_decode_tokens
!=
0
:
req
.
decode_token_logprobs
.
extend
(
list
(
zip
(
prefill_token_logprobs
[
pt
+
req
.
extend_input_len
-
req
.
last_update_decode_tokens
:
pt
+
req
.
extend_input_len
-
1
],
req
.
input_ids
[
-
req
.
last_update_decode_tokens
+
1
:],
)
)
)
req
.
decode_token_logprobs
.
append
(
self
.
handle_finished_requests
(
batch
)
(
last_token_logprobs
[
i
],
next_token_ids
[
i
])
)
if
req
.
top_logprobs_num
>
0
:
def
add_logprob_return_values
(
self
,
i
,
req
,
pt
,
next_token_ids
,
output
):
if
req
.
prefill_top_logprobs
is
None
:
if
req
.
normalized_prompt_logprob
is
None
:
req
.
prefill_top_logprobs
=
prefill_top_logprobs
[
i
]
req
.
normalized_prompt_logprob
=
output
.
normalized_prompt_logprobs
[
i
]
if
req
.
logprob_start_len
==
0
:
req
.
prefill_top_logprobs
=
[
None
]
+
req
.
prefill_top_logprobs
if
req
.
last_update_decode_tokens
!=
0
:
if
req
.
prefill_token_logprobs
is
None
:
req
.
decode_top_logprobs
.
extend
(
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
prefill_top_logprobs
[
i
][
-
req
.
last_update_decode_tokens
+
1
:]
req
.
prefill_token_logprobs
=
list
(
zip
(
output
.
prefill_token_logprobs
[
pt
:
pt
+
req
.
extend_input_len
-
1
],
req
.
input_ids
[
-
req
.
extend_input_len
+
1
:],
)
)
if
req
.
logprob_start_len
==
0
:
req
.
prefill_token_logprobs
=
[
(
None
,
req
.
input_ids
[
0
])
]
+
req
.
prefill_token_logprobs
if
req
.
last_update_decode_tokens
!=
0
:
req
.
decode_token_logprobs
.
extend
(
list
(
zip
(
output
.
prefill_token_logprobs
[
pt
+
req
.
extend_input_len
-
req
.
last_update_decode_tokens
:
pt
+
req
.
extend_input_len
-
1
],
req
.
input_ids
[
-
req
.
last_update_decode_tokens
+
1
:],
)
)
req
.
decode_top_logprobs
.
append
(
decode_top_logprobs
[
i
])
)
)
pt
+=
req
.
extend_input_len
req
.
decode_token_logprobs
.
append
(
(
output
.
next_token_logprobs
[
i
],
next_token_ids
[
i
])
)
self
.
handle_finished_requests
(
batch
)
if
req
.
top_logprobs_num
>
0
:
if
req
.
prefill_top_logprobs
is
None
:
req
.
prefill_top_logprobs
=
output
.
prefill_top_logprobs
[
i
]
if
req
.
logprob_start_len
==
0
:
req
.
prefill_top_logprobs
=
[
None
]
+
req
.
prefill_top_logprobs
if
req
.
last_update_decode_tokens
!=
0
:
req
.
decode_top_logprobs
.
extend
(
output
.
prefill_top_logprobs
[
i
][
-
req
.
last_update_decode_tokens
+
1
:]
)
req
.
decode_top_logprobs
.
append
(
output
.
decode_top_logprobs
[
i
])
def
cache_filled_batch
(
self
,
batch
:
Batch
):
def
cache_filled_batch
(
self
,
batch
:
Batch
):
req_pool_indices_cpu
=
batch
.
req_pool_indices
.
cpu
().
numpy
()
req_pool_indices_cpu
=
batch
.
req_pool_indices
.
cpu
().
numpy
()
...
@@ -540,7 +534,7 @@ class ModelTpServer:
...
@@ -540,7 +534,7 @@ class ModelTpServer:
req
.
prefix_indices
,
req
.
last_node
=
new_prefix_indices
,
new_last_node
req
.
prefix_indices
,
req
.
last_node
=
new_prefix_indices
,
new_last_node
def
forward_decode_batch
(
self
,
batch
:
Batch
):
def
forward_decode_batch
(
self
,
batch
:
Batch
):
#
c
heck if decode out of memory
#
C
heck if decode out of memory
if
not
batch
.
check_decode_mem
():
if
not
batch
.
check_decode_mem
():
old_ratio
=
self
.
new_token_ratio
old_ratio
=
self
.
new_token_ratio
self
.
new_token_ratio
=
min
(
old_ratio
+
self
.
new_token_ratio_recovery
,
1.0
)
self
.
new_token_ratio
=
min
(
old_ratio
+
self
.
new_token_ratio_recovery
,
1.0
)
...
@@ -559,9 +553,8 @@ class ModelTpServer:
...
@@ -559,9 +553,8 @@ class ModelTpServer:
)
)
if
not
self
.
disable_regex_jump_forward
:
if
not
self
.
disable_regex_jump_forward
:
#
c
heck for jump-forward
#
C
heck for jump-forward
jump_forward_reqs
=
batch
.
check_for_jump_forward
(
self
.
model_runner
)
jump_forward_reqs
=
batch
.
check_for_jump_forward
(
self
.
model_runner
)
self
.
forward_queue
.
extend
(
jump_forward_reqs
)
self
.
forward_queue
.
extend
(
jump_forward_reqs
)
if
batch
.
is_empty
():
if
batch
.
is_empty
():
return
return
...
@@ -570,23 +563,19 @@ class ModelTpServer:
...
@@ -570,23 +563,19 @@ class ModelTpServer:
self
.
decode_forward_ct
=
(
self
.
decode_forward_ct
+
1
)
%
(
1
<<
30
)
self
.
decode_forward_ct
=
(
self
.
decode_forward_ct
+
1
)
%
(
1
<<
30
)
batch
.
prepare_for_decode
()
batch
.
prepare_for_decode
()
# Forward
# Forward and sample the next tokens
logits
,
(
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
)
_
,
next_token_ids
,
_
=
batch
.
sample
(
output
.
next_token_logits
)
_
,
_
,
decode_top_logprobs
,
last_logprobs
,
)
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
)
next_token_ids
,
_
=
batch
.
sample
(
logits
)
next_token_ids
=
next_token_ids
.
tolist
()
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
# Move logprobs to cpu
if
last_logprobs
is
not
None
:
if
output
.
next_token_logprobs
is
not
None
:
new_token_logprobs
=
last_logprobs
[
next_token_logprobs
=
output
.
next_token_logprobs
[
torch
.
arange
(
len
(
batch
.
reqs
)),
next_token_ids
torch
.
arange
(
len
(
next_token_ids
),
device
=
next_token_ids
.
device
),
next_token_ids
,
].
tolist
()
].
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
# Check finish condition
# Check finish condition
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
completion_tokens_wo_jump_forward
+=
1
...
@@ -594,10 +583,9 @@ class ModelTpServer:
...
@@ -594,10 +583,9 @@ class ModelTpServer:
req
.
check_finished
()
req
.
check_finished
()
if
req
.
return_logprob
:
if
req
.
return_logprob
:
req
.
decode_token_logprobs
.
append
((
new_token_logprobs
[
i
],
next_token_id
))
req
.
decode_token_logprobs
.
append
((
next_token_logprobs
[
i
],
next_token_id
))
if
req
.
top_logprobs_num
>
0
:
if
req
.
top_logprobs_num
>
0
:
req
.
decode_top_logprobs
.
append
(
output
.
decode_top_logprobs
[
i
])
req
.
decode_top_logprobs
.
append
(
decode_top_logprobs
[
i
])
self
.
handle_finished_requests
(
batch
)
self
.
handle_finished_requests
(
batch
)
...
...
python/sglang/srt/server.py
View file @
303ef888
...
@@ -253,7 +253,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
...
@@ -253,7 +253,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
try
:
try
:
requests
.
get
(
url
+
"/get_model_info"
,
timeout
=
5
,
headers
=
headers
)
requests
.
get
(
url
+
"/get_model_info"
,
timeout
=
5
,
headers
=
headers
)
break
break
except
requests
.
exceptions
.
RequestException
as
e
:
except
requests
.
exceptions
.
RequestException
:
pass
pass
# Send a warmup request
# Send a warmup request
...
@@ -265,14 +265,14 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
...
@@ -265,14 +265,14 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
"text"
:
"The capital city of France is"
,
"text"
:
"The capital city of France is"
,
"sampling_params"
:
{
"sampling_params"
:
{
"temperature"
:
0
,
"temperature"
:
0
,
"max_new_tokens"
:
16
,
"max_new_tokens"
:
8
,
},
},
},
},
headers
=
headers
,
headers
=
headers
,
timeout
=
600
,
timeout
=
600
,
)
)
assert
res
.
status_code
==
200
assert
res
.
status_code
==
200
except
Exception
:
except
Exception
as
e
:
if
pipe_finish_writer
is
not
None
:
if
pipe_finish_writer
is
not
None
:
pipe_finish_writer
.
send
(
get_exception_traceback
())
pipe_finish_writer
.
send
(
get_exception_traceback
())
print
(
f
"Initialization failed. warmup error:
{
e
}
"
)
print
(
f
"Initialization failed. warmup error:
{
e
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment