Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f65c13b5
Commit
f65c13b5
authored
Jan 15, 2025
by
Lianmin Zheng
Browse files
Remove normalized_prompt_logprobs from the engine to make code easier to maintain (#2902)
parent
b803b395
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
11 additions
and
153 deletions
+11
-153
python/sglang/lang/backend/runtime_endpoint.py
python/sglang/lang/backend/runtime_endpoint.py
+9
-3
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+0
-32
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+0
-1
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+0
-2
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+0
-4
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+0
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+1
-12
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+0
-3
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+0
-8
python/sglang/test/test_programs.py
python/sglang/test/test_programs.py
+1
-1
scripts/deprecated/test_httpserver_classify.py
scripts/deprecated/test_httpserver_classify.py
+0
-85
scripts/deprecated/test_httpserver_decode_stream.py
scripts/deprecated/test_httpserver_decode_stream.py
+0
-1
No files found.
python/sglang/lang/backend/runtime_endpoint.py
View file @
f65c13b5
...
...
@@ -251,11 +251,12 @@ class RuntimeEndpoint(BaseBackend):
}
obj
=
self
.
_generate_http_request
(
s
,
data
)
normalized_prompt_logprobs
=
[
r
[
"meta_info"
][
"normalized_prompt_logprob"
]
for
r
in
obj
]
input_token_logprobs
=
[
r
[
"meta_info"
][
"input_token_logprobs"
]
for
r
in
obj
]
output_token_logprobs
=
[
r
[
"meta_info"
][
"output_token_logprobs"
]
for
r
in
obj
]
normalized_prompt_logprobs
=
[
compute_normalized_prompt_logprobs
(
r
[
"meta_info"
][
"input_token_logprobs"
])
for
r
in
obj
]
# Remove extra token if no token healing occurred
for
i
in
range
(
len
(
input_token_logprobs
)):
...
...
@@ -319,3 +320,8 @@ class RuntimeEndpoint(BaseBackend):
def
_assert_success
(
self
,
res
):
if
res
.
status_code
!=
200
:
raise
RuntimeError
(
res
.
json
())
def
compute_normalized_prompt_logprobs
(
input_logprobs
):
values
=
[
x
[
0
]
for
x
in
input_logprobs
if
x
[
0
]]
return
sum
(
values
)
/
len
(
values
)
python/sglang/srt/layers/logits_processor.py
View file @
f65c13b5
...
...
@@ -50,8 +50,6 @@ class LogitsProcessorOutput:
next_token_top_logprobs_idx
:
Optional
[
List
]
=
None
## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
# The normlaized logprobs of prompts. shape: [#seq]
normalized_prompt_logprobs
:
torch
.
Tensor
=
None
# The logprobs of input tokens. shape: [#token]
input_token_logprobs
:
torch
.
Tensor
=
None
# The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
...
...
@@ -195,8 +193,6 @@ class LogitsProcessor(nn.Module):
else
:
input_top_logprobs_val
=
input_top_logprobs_idx
=
None
# Compute the normalized logprobs for the requested tokens.
# Note that we pad a zero at the end for easy batching.
input_token_logprobs
=
input_logprobs
[
torch
.
arange
(
input_logprobs
.
shape
[
0
],
device
=
"cuda"
),
torch
.
cat
(
...
...
@@ -206,14 +202,9 @@ class LogitsProcessor(nn.Module):
]
),
]
normalized_prompt_logprobs
=
self
.
_get_normalized_prompt_logprobs
(
input_token_logprobs
,
logits_metadata
,
)
return
LogitsProcessorOutput
(
next_token_logits
=
last_logits
,
normalized_prompt_logprobs
=
normalized_prompt_logprobs
,
input_token_logprobs
=
input_token_logprobs
,
input_top_logprobs_val
=
input_top_logprobs_val
,
input_top_logprobs_idx
=
input_top_logprobs_idx
,
...
...
@@ -237,8 +228,6 @@ class LogitsProcessor(nn.Module):
if
self
.
do_tensor_parallel_all_gather
:
logits
=
tensor_model_parallel_all_gather
(
logits
)
# Compute the normalized logprobs for the requested tokens.
# Note that we pad a zero at the end for easy batching.
logits
=
logits
[:,
:
self
.
config
.
vocab_size
].
float
()
if
self
.
final_logit_softcapping
:
...
...
@@ -246,27 +235,6 @@ class LogitsProcessor(nn.Module):
return
logits
@
staticmethod
def
_get_normalized_prompt_logprobs
(
input_token_logprobs
:
torch
.
Tensor
,
logits_metadata
:
LogitsMetadata
,
):
logprobs_cumsum
=
torch
.
cumsum
(
input_token_logprobs
,
dim
=
0
,
dtype
=
torch
.
float32
)
pruned_lens
=
torch
.
tensor
(
logits_metadata
.
extend_logprob_pruned_lens_cpu
,
device
=
"cuda"
)
start
=
torch
.
zeros_like
(
pruned_lens
)
start
[
1
:]
=
torch
.
cumsum
(
pruned_lens
[:
-
1
],
dim
=
0
)
end
=
torch
.
clamp
(
start
+
pruned_lens
-
2
,
min
=
0
,
max
=
logprobs_cumsum
.
shape
[
0
]
-
1
)
sum_logp
=
(
logprobs_cumsum
[
end
]
-
logprobs_cumsum
[
start
]
+
input_token_logprobs
[
start
]
)
normalized_prompt_logprobs
=
sum_logp
/
(
pruned_lens
-
1
).
clamp
(
min
=
1
)
return
normalized_prompt_logprobs
@
staticmethod
def
get_top_logprobs
(
all_logprobs
:
torch
.
Tensor
,
logits_metadata
:
LogitsMetadata
):
max_k
=
max
(
logits_metadata
.
top_logprobs_nums
)
...
...
python/sglang/srt/managers/detokenizer_manager.py
View file @
f65c13b5
...
...
@@ -191,7 +191,6 @@ class DetokenizerManager:
input_top_logprobs_idx
=
recv_obj
.
input_top_logprobs_idx
,
output_top_logprobs_val
=
recv_obj
.
output_top_logprobs_val
,
output_top_logprobs_idx
=
recv_obj
.
output_top_logprobs_idx
,
normalized_prompt_logprob
=
recv_obj
.
normalized_prompt_logprob
,
)
)
...
...
python/sglang/srt/managers/io_struct.py
View file @
f65c13b5
...
...
@@ -340,7 +340,6 @@ class BatchTokenIDOut:
input_top_logprobs_idx
:
List
[
List
]
output_top_logprobs_val
:
List
[
List
]
output_top_logprobs_idx
:
List
[
List
]
normalized_prompt_logprob
:
List
[
float
]
@
dataclass
...
...
@@ -366,7 +365,6 @@ class BatchStrOut:
input_top_logprobs_idx
:
List
[
List
]
output_top_logprobs_val
:
List
[
List
]
output_top_logprobs_idx
:
List
[
List
]
normalized_prompt_logprob
:
List
[
float
]
@
dataclass
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
f65c13b5
...
...
@@ -280,7 +280,6 @@ class Req:
self
.
top_logprobs_num
=
top_logprobs_num
# Logprobs (return value)
self
.
normalized_prompt_logprob
=
None
self
.
input_token_logprobs_val
=
None
self
.
input_token_logprobs_idx
=
None
self
.
input_top_logprobs_val
=
None
...
...
@@ -344,9 +343,6 @@ class Req:
max_prefix_len
=
min
(
max_prefix_len
,
input_len
-
1
)
if
self
.
return_logprob
:
if
self
.
normalized_prompt_logprob
is
None
:
# Need at least two tokens to compute normalized logprob
max_prefix_len
=
min
(
max_prefix_len
,
input_len
-
2
)
max_prefix_len
=
min
(
max_prefix_len
,
self
.
logprob_start_len
)
max_prefix_len
=
max
(
max_prefix_len
,
0
)
...
...
python/sglang/srt/managers/schedule_policy.py
View file @
f65c13b5
...
...
@@ -433,7 +433,6 @@ class PrefillAdder:
or
input_tokens
<=
self
.
rem_chunk_tokens
or
(
req
.
return_logprob
and
req
.
normalized_prompt_logprob
is
None
and
req
.
logprob_start_len
!=
len
(
req
.
origin_input_ids
)
-
1
)
):
...
...
python/sglang/srt/managers/scheduler.py
View file @
f65c13b5
...
...
@@ -1038,9 +1038,6 @@ class Scheduler:
logits_output
.
input_token_logprobs
=
(
logits_output
.
input_token_logprobs
.
tolist
()
)
logits_output
.
normalized_prompt_logprobs
=
(
logits_output
.
normalized_prompt_logprobs
.
tolist
()
)
# Check finish conditions
logprob_pt
=
0
...
...
@@ -1188,9 +1185,6 @@ class Scheduler:
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
num_input_logprobs
=
req
.
extend_input_len
-
req
.
extend_logprob_start_len
if
req
.
normalized_prompt_logprob
is
None
:
req
.
normalized_prompt_logprob
=
output
.
normalized_prompt_logprobs
[
i
]
if
req
.
input_token_logprobs_val
is
None
:
input_token_logprobs_val
=
output
.
input_token_logprobs
[
pt
:
pt
+
num_input_logprobs
-
1
-
req
.
last_update_decode_tokens
...
...
@@ -1288,15 +1282,12 @@ class Scheduler:
input_top_logprobs_idx
=
[]
output_top_logprobs_val
=
[]
output_top_logprobs_idx
=
[]
normalized_prompt_logprob
=
[]
else
:
input_token_logprobs_val
=
input_token_logprobs_idx
=
(
output_token_logprobs_val
)
=
output_token_logprobs_idx
=
input_top_logprobs_val
=
(
input_top_logprobs_idx
)
=
output_top_logprobs_val
=
output_top_logprobs_idx
=
(
normalized_prompt_logprob
)
=
None
)
=
output_top_logprobs_val
=
output_top_logprobs_idx
=
None
for
req
in
reqs
:
if
req
is
skip_req
:
...
...
@@ -1343,7 +1334,6 @@ class Scheduler:
input_top_logprobs_idx
.
append
(
req
.
input_top_logprobs_idx
)
output_top_logprobs_val
.
append
(
req
.
output_top_logprobs_val
)
output_top_logprobs_idx
.
append
(
req
.
output_top_logprobs_idx
)
normalized_prompt_logprob
.
append
(
req
.
normalized_prompt_logprob
)
# Send to detokenizer
if
rids
:
...
...
@@ -1370,7 +1360,6 @@ class Scheduler:
input_top_logprobs_idx
,
output_top_logprobs_val
,
output_top_logprobs_idx
,
normalized_prompt_logprob
,
)
)
else
:
# embedding or reward model
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
f65c13b5
...
...
@@ -796,9 +796,6 @@ class TokenizerManager:
recv_obj
.
output_token_logprobs_idx
[
recv_obj_index
],
return_text_in_logprobs
,
)
meta_info
[
"normalized_prompt_logprob"
]
=
recv_obj
.
normalized_prompt_logprob
[
recv_obj_index
]
if
top_logprobs_num
>
0
:
meta_info
[
"input_top_logprobs"
]
=
self
.
detokenize_top_logprobs_tokens
(
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
f65c13b5
...
...
@@ -151,11 +151,6 @@ class TpModelWorkerClient:
logits_output
.
input_token_logprobs
=
(
logits_output
.
input_token_logprobs
.
to
(
"cpu"
,
non_blocking
=
True
)
)
logits_output
.
normalized_prompt_logprobs
=
(
logits_output
.
normalized_prompt_logprobs
.
to
(
"cpu"
,
non_blocking
=
True
)
)
next_token_ids
=
next_token_ids
.
to
(
"cpu"
,
non_blocking
=
True
)
copy_done
.
record
()
...
...
@@ -174,9 +169,6 @@ class TpModelWorkerClient:
logits_output
.
input_token_logprobs
=
(
logits_output
.
input_token_logprobs
.
tolist
()
)
logits_output
.
normalized_prompt_logprobs
=
(
logits_output
.
normalized_prompt_logprobs
.
tolist
()
)
next_token_ids
=
next_token_ids
.
tolist
()
return
logits_output
,
next_token_ids
...
...
python/sglang/test/test_programs.py
View file @
f65c13b5
...
...
@@ -535,7 +535,7 @@ def test_hellaswag_select():
# Compute accuracy
accuracy_gen
=
np
.
mean
(
np
.
array
(
preds_gen
)
==
np
.
array
(
labels
))
assert
np
.
abs
(
accuracy_gen
-
accuracy
)
<
0.0
1
assert
np
.
abs
(
accuracy_gen
-
accuracy
)
<
0.0
5
assert
np
.
abs
(
latency_gen
-
latency
)
<
1
return
accuracy
,
latency
...
...
scripts/deprecated/test_httpserver_classify.py
deleted
100644 → 0
View file @
b803b395
"""
Usage:
python3 -m sglang.launch_server --model-path /model/llama-classification --is-embedding --disable-radix-cache
python3 test_httpserver_classify.py
"""
import
argparse
import
numpy
as
np
import
requests
def
get_logits_deprecated
(
url
:
str
,
prompt
:
str
):
response
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
"text"
:
prompt
,
"sampling_params"
:
{
"max_new_tokens"
:
0
,
},
"return_logprob"
:
True
,
},
)
return
response
.
json
()[
"meta_info"
][
"normalized_prompt_logprob"
]
def
get_logits_batch_deprecated
(
url
:
str
,
prompts
:
list
[
str
]):
response
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
"text"
:
prompts
,
"sampling_params"
:
{
"max_new_tokens"
:
0
,
},
"return_logprob"
:
True
,
},
)
ret
=
response
.
json
()
logits
=
np
.
array
(
list
(
ret
[
i
][
"meta_info"
][
"normalized_prompt_logprob"
]
for
i
in
range
(
len
(
prompts
))
)
)
return
logits
def
get_logits
(
url
:
str
,
prompt
:
str
):
response
=
requests
.
post
(
url
+
"/classify"
,
json
=
{
"text"
:
prompt
},
)
return
response
.
json
()[
"embedding"
]
def
get_logits_batch
(
url
:
str
,
prompts
:
list
[
str
]):
response
=
requests
.
post
(
url
+
"/classify"
,
json
=
{
"text"
:
prompts
},
)
return
np
.
array
([
x
[
"embedding"
]
for
x
in
response
.
json
()])
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"http://127.0.0.1"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
30000
)
args
=
parser
.
parse_args
()
url
=
f
"
{
args
.
host
}
:
{
args
.
port
}
"
# A single request
prompt
=
"This is a test prompt.<|eot_id|>"
logits
=
get_logits
(
url
,
prompt
)
print
(
f
"
{
logits
=
}
"
)
# A batch of requests
prompts
=
[
"This is a test prompt.<|eot_id|>"
,
"This is another test prompt.<|eot_id|>"
,
"This is a long long long long test prompt.<|eot_id|>"
,
]
logits
=
get_logits_batch
(
url
,
prompts
)
print
(
f
"
{
logits
=
}
"
)
scripts/deprecated/test_httpserver_decode_stream.py
View file @
f65c13b5
...
...
@@ -42,7 +42,6 @@ def test_decode_stream(url, return_logprob, top_logprobs_num):
if
return_logprob
:
assert
data
[
"meta_info"
][
"input_token_logprobs"
]
is
not
None
assert
data
[
"meta_info"
][
"output_token_logprobs"
]
is
not
None
assert
data
[
"meta_info"
][
"normalized_prompt_logprob"
]
is
not
None
for
logprob
,
token_id
,
token_text
in
data
[
"meta_info"
][
"output_token_logprobs"
][
prev
:]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment