Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
23196d52
Unverified
Commit
23196d52
authored
Jan 18, 2025
by
Lianmin Zheng
Committed by
GitHub
Jan 18, 2025
Browse files
Simplify logits processor (#2974)
Co-authored-by:
SangBin Cho
<
rkooo567@gmail.com
>
parent
93b77c8e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
27 deletions
+44
-27
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+44
-27
No files found.
python/sglang/srt/layers/logits_processor.py
View file @
23196d52
...
...
@@ -14,6 +14,7 @@
"""Logits processing."""
import
dataclasses
import
logging
from
typing
import
List
,
Optional
,
Union
import
torch
...
...
@@ -32,6 +33,8 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardMode
,
)
logger
=
logging
.
getLogger
(
__name__
)
@
dataclasses
.
dataclass
class
LogitsProcessorOutput
:
...
...
@@ -136,50 +139,61 @@ class LogitsProcessor(nn.Module):
logits_metadata
.
forward_mode
.
is_decode_or_idle
()
or
logits_metadata
.
forward_mode
.
is_target_verify
()
):
last_index
=
None
last_hidden
=
hidden_states
else
:
pruned_states
=
hidden_states
sample_indices
=
None
elif
(
logits_metadata
.
forward_mode
.
is_extend
()
and
not
logits_metadata
.
extend_return_logprob
):
# Prefill without input logprobs.
last_index
=
torch
.
cumsum
(
logits_metadata
.
extend_seq_lens
,
dim
=
0
)
-
1
last_hidden
=
hidden_states
[
last_index
]
pruned_states
=
hidden_states
[
last_index
]
sample_indices
=
None
else
:
# Slice the requested tokens to compute logprob
sample_index_pt
=
-
1
sample_indices
=
[]
pt
,
pruned_states
,
pruned_input_ids
=
0
,
[],
[]
for
start_len
,
extend_len
in
zip
(
logits_metadata
.
extend_logprob_start_lens_cpu
,
logits_metadata
.
extend_seq_lens_cpu
,
):
pruned_states
.
append
(
hidden_states
[
pt
+
start_len
:
pt
+
extend_len
])
sample_index_pt
+=
extend_len
-
start_len
sample_indices
.
append
(
sample_index_pt
)
pruned_input_ids
.
append
(
input_ids
[
pt
+
start_len
:
pt
+
extend_len
])
pt
+=
extend_len
pruned_states
=
torch
.
cat
(
pruned_states
)
# Compute logits for both input and sampled tokens.
logits
=
self
.
_get_logits
(
pruned_states
,
lm_head
,
logits_metadata
)
sampled_logits
=
(
logits
[
sample_indices
]
if
sample_indices
is
not
None
else
logits
)
# Compute logits
last_logits
=
self
.
_get_logits
(
last_hidden
,
lm_head
)
if
(
not
logits_metadata
.
extend_return_logprob
or
logits_metadata
.
capture_hidden_mode
.
need_capture
()
):
# Decode mode or extend mode without return_logprob.
return
LogitsProcessorOutput
(
next_token_logits
=
last
_logits
,
next_token_logits
=
sampled
_logits
,
hidden_states
=
(
hidden_states
if
logits_metadata
.
capture_hidden_mode
.
is_full
()
else
(
last_hidden
pruned_states
if
logits_metadata
.
capture_hidden_mode
.
is_last
()
else
None
)
),
)
else
:
# Slice the requested tokens to compute logprob
pt
,
pruned_states
,
pruned_input_ids
=
0
,
[],
[]
for
start_len
,
extend_len
in
zip
(
logits_metadata
.
extend_logprob_start_lens_cpu
,
logits_metadata
.
extend_seq_lens_cpu
,
):
pruned_states
.
append
(
hidden_states
[
pt
+
start_len
:
pt
+
extend_len
])
pruned_input_ids
.
append
(
input_ids
[
pt
+
start_len
:
pt
+
extend_len
])
pt
+=
extend_len
# Compute the logits of all required tokens
pruned_states
=
torch
.
cat
(
pruned_states
)
del
hidden_states
input_token_logits
=
self
.
_get_logits
(
pruned_states
,
lm_head
)
del
pruned_states
input_logprobs
=
logits
del
hidden_states
,
logits
# Normalize the logprob w/o temperature, top-p
input_logprobs
=
input_token_logits
input_logprobs
=
self
.
compute_temp_top_p_normalized_logprobs
(
input_logprobs
,
logits_metadata
)
...
...
@@ -194,17 +208,17 @@ class LogitsProcessor(nn.Module):
input_top_logprobs_val
=
input_top_logprobs_idx
=
None
input_token_logprobs
=
input_logprobs
[
torch
.
arange
(
input_logprobs
.
shape
[
0
],
device
=
"cuda"
),
torch
.
arange
(
input_logprobs
.
shape
[
0
],
device
=
input_logprobs
.
device
),
torch
.
cat
(
[
torch
.
cat
(
pruned_input_ids
)[
1
:],
torch
.
tensor
([
0
],
device
=
"cuda"
),
torch
.
tensor
([
0
],
device
=
input_logprobs
.
device
),
]
),
]
return
LogitsProcessorOutput
(
next_token_logits
=
last
_logits
,
next_token_logits
=
sampled
_logits
,
input_token_logprobs
=
input_token_logprobs
,
input_top_logprobs_val
=
input_top_logprobs_val
,
input_top_logprobs_idx
=
input_top_logprobs_idx
,
...
...
@@ -214,8 +228,11 @@ class LogitsProcessor(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
lm_head
:
VocabParallelEmbedding
,
logits_metadata
:
LogitsMetadata
,
embedding_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Get logits from hidden_states."""
if
hasattr
(
lm_head
,
"weight"
):
logits
=
torch
.
matmul
(
hidden_states
,
lm_head
.
weight
.
T
)
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment