Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9a16fea0
Unverified
Commit
9a16fea0
authored
Jan 23, 2024
by
Lianmin Zheng
Committed by
GitHub
Jan 23, 2024
Browse files
Return logprob for choices (#87)
parent
9e037c82
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
161 additions
and
112 deletions
+161
-112
README.md
README.md
+4
-0
docs/sampling_params.md
docs/sampling_params.md
+2
-2
examples/usage/async_io.py
examples/usage/async_io.py
+0
-0
examples/usage/choices_logprob.py
examples/usage/choices_logprob.py
+42
-0
python/sglang/backend/openai.py
python/sglang/backend/openai.py
+1
-1
python/sglang/backend/runtime_endpoint.py
python/sglang/backend/runtime_endpoint.py
+9
-5
python/sglang/lang/chat_template.py
python/sglang/lang/chat_template.py
+2
-2
python/sglang/lang/interpreter.py
python/sglang/lang/interpreter.py
+8
-2
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+16
-23
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+16
-18
python/sglang/srt/managers/router/infer_batch.py
python/sglang/srt/managers/router/infer_batch.py
+10
-13
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+32
-27
python/sglang/srt/managers/router/model_runner.py
python/sglang/srt/managers/router/model_runner.py
+13
-15
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+4
-4
test/srt/test_httpserver_decode.py
test/srt/test_httpserver_decode.py
+2
-0
No files found.
README.md
View file @
9a16fea0
...
...
@@ -69,6 +69,8 @@ state = multi_turn_question.run(
for
m
in
state
.
messages
():
print
(
m
[
"role"
],
":"
,
m
[
"content"
])
print
(
state
[
"answer_1"
])
```
### Using Local Models
...
...
@@ -99,6 +101,8 @@ state = multi_turn_question.run(
for
m
in
state
.
messages
():
print
(
m
[
"role"
],
":"
,
m
[
"content"
])
print
(
state
[
"answer_1"
])
```
### More Examples
...
...
docs/sampling_params.md
View file @
9a16fea0
...
...
@@ -9,8 +9,8 @@ class GenerateReqInput:
image_data
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
return_
normalized_
logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
None
normalized_
logprob_start_len
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
return_logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
None
logprob_start_len
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
stream
:
bool
=
False
```
...
...
examples/usage/async.py
→
examples/usage/async
_io
.py
View file @
9a16fea0
File moved
examples/usage/choices_logprob.py
0 → 100644
View file @
9a16fea0
"""
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
"""
import
sglang
as
sgl
@
sgl
.
function
def
tool_use
(
s
,
question
):
s
+=
"To answer this question: "
+
question
+
", "
s
+=
"I need to use a "
+
sgl
.
gen
(
"tool"
,
choices
=
[
"calculator"
,
"search engine"
])
def
main
():
# Run one case
question
=
"What is 5 + 5?"
state
=
tool_use
.
run
(
question
)
print
(
"questions:"
,
question
)
print
(
"choice:"
,
state
[
"tool"
])
meta_info
=
state
.
get_meta_info
(
"tool"
)
print
(
"logprobs of choice 1"
,
meta_info
[
"prompt_logprob"
][
0
])
print
(
"logprobs of choice 2"
,
meta_info
[
"prompt_logprob"
][
1
])
print
(
'-'
*
50
)
# Run a batch
questions
=
[
"What is 5 + 6?"
,
"Who is Michael Jordan?"
,
]
states
=
tool_use
.
run_batch
([{
"question"
:
q
}
for
q
in
questions
])
for
question
,
state
in
zip
(
questions
,
states
):
print
(
"questions:"
,
question
)
print
(
"choice:"
,
state
[
"tool"
])
meta_info
=
state
.
get_meta_info
(
"tool"
)
print
(
"logprobs of choice 1"
,
meta_info
[
"prompt_logprob"
][
0
])
print
(
"logprobs of choice 2"
,
meta_info
[
"prompt_logprob"
][
1
])
print
(
'-'
*
50
)
if
__name__
==
"__main__"
:
sgl
.
set_default_backend
(
sgl
.
RuntimeEndpoint
(
"http://localhost:30000"
))
main
()
python/sglang/backend/openai.py
View file @
9a16fea0
...
...
@@ -209,7 +209,7 @@ class OpenAI(BaseBackend):
prompt_tokens
.
append
(
ret_token
)
decision
=
choices
[
np
.
argmax
(
scores
)]
return
decision
,
scores
return
decision
,
scores
,
scores
def
openai_completion
(
client
,
is_chat
=
None
,
prompt
=
None
,
**
kwargs
):
...
...
python/sglang/backend/runtime_endpoint.py
View file @
9a16fea0
...
...
@@ -150,16 +150,20 @@ class RuntimeEndpoint(BaseBackend):
data
=
{
"text"
:
[
s
.
text_
+
c
for
c
in
choices
],
"sampling_params"
:
{
"max_new_tokens"
:
0
},
"return_
normalized_
logprob"
:
True
,
"
normalized_
logprob_start_len"
:
prompt_len
,
"return_logprob"
:
True
,
"logprob_start_len"
:
max
(
prompt_len
-
2
,
0
)
,
}
self
.
_add_images
(
s
,
data
)
res
=
http_request
(
self
.
base_url
+
"/generate"
,
json
=
data
)
assert
res
.
status_code
==
200
logps
=
[
r
[
"meta_info"
][
"normalized_logprob"
]
for
r
in
res
.
json
()]
obj
=
res
.
json
()
normalized_prompt_logprob
=
[
r
[
"meta_info"
][
"normalized_prompt_logprob"
]
for
r
in
obj
]
prompt_logprob
=
[
r
[
"meta_info"
][
"prompt_logprob"
]
for
r
in
obj
]
decision
=
choices
[
np
.
argmax
(
logp
s
)]
return
decision
,
logps
decision
=
choices
[
np
.
argmax
(
normalized_prompt_
logp
rob
)]
return
decision
,
normalized_prompt_logprob
,
prompt_logprob
def
concatenate_and_append
(
self
,
src_rids
:
List
[
str
],
dst_rid
:
str
):
res
=
http_request
(
...
...
python/sglang/lang/chat_template.py
View file @
9a16fea0
from
dataclasses
import
dataclass
,
field
from
enum
import
Enum
,
auto
from
typing
import
Callable
,
Dict
,
List
,
Tuple
,
Optional
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Tuple
class
ChatTemplateStyle
(
Enum
):
...
...
@@ -111,7 +111,7 @@ register_chat_template(
"assistant"
:
(
"<|im_start|>assistant
\n
"
,
"
\n
<|im_end|>
\n
"
),
},
style
=
ChatTemplateStyle
.
PLAIN
,
stop_str
=
(
'
<|im_end|>
'
,)
stop_str
=
(
"
<|im_end|>
"
,)
,
)
)
...
...
python/sglang/lang/interpreter.py
View file @
9a16fea0
...
...
@@ -80,7 +80,7 @@ def run_program_batch(
# Run all programs
if
num_threads
==
"auto"
:
num_threads
=
max
(
6
4
,
multiprocessing
.
cpu_count
()
*
8
)
num_threads
=
max
(
9
6
,
multiprocessing
.
cpu_count
()
*
16
)
num_threads
=
min
(
num_threads
,
len
(
batch_arguments
))
if
num_threads
==
1
:
...
...
@@ -364,10 +364,16 @@ class StreamExecutor:
self
.
stream_var_event
[
name
].
set
()
def
_execute_select
(
self
,
expr
:
SglSelect
):
decision
,
scores
=
self
.
backend
.
select
(
self
,
expr
.
choices
,
expr
.
temperature
)
decision
,
normalized_prompt_logprob
,
prompt_logprob
=
self
.
backend
.
select
(
self
,
expr
.
choices
,
expr
.
temperature
)
if
expr
.
name
is
not
None
:
name
=
expr
.
name
self
.
variables
[
name
]
=
decision
self
.
meta_info
[
name
]
=
{
"normalized_prompt_logprob"
:
normalized_prompt_logprob
,
"prompt_logprob"
:
prompt_logprob
,
}
self
.
variable_event
[
name
].
set
()
self
.
text_
+=
decision
...
...
python/sglang/srt/layers/logits_processor.py
View file @
9a16fea0
...
...
@@ -14,7 +14,7 @@ class LogitsProcessor(nn.Module):
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
def
forward
(
self
,
input_ids
,
hidden_states
,
weight
,
input_metadata
):
if
not
input_metadata
.
return_
normalized_
logprob
:
if
not
input_metadata
.
return_logprob
:
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
last_hidden
=
hidden_states
else
:
...
...
@@ -33,7 +33,7 @@ class LogitsProcessor(nn.Module):
if
self
.
tp_size
>
1
:
last_logits
=
tensor_model_parallel_all_gather
(
last_logits
)
last_logits
=
last_logits
[:,
:
self
.
config
.
vocab_size
]
return
last_logits
,
None
return
last_logits
,
(
None
,
None
)
else
:
assert
input_metadata
.
forward_mode
!=
ForwardMode
.
DECODE
last_index
=
(
...
...
@@ -51,30 +51,23 @@ class LogitsProcessor(nn.Module):
logits
=
logits
[:,
:
self
.
config
.
vocab_size
]
all_logprobs
=
torch
.
log
(
torch
.
softmax
(
logits
.
float
(),
dim
=-
1
)
+
1e-6
)
normalized_logprobs
=
compute_normalized_logprobs
(
all_logprobs
,
input_ids
,
input_metadata
.
extend_seq_lens
,
input_metadata
.
extend_start_loc
,
logprobs
=
all_logprobs
[
torch
.
arange
(
all_logprobs
.
shape
[
0
],
device
=
"cuda"
),
torch
.
cat
([
input_ids
[
1
:],
torch
.
tensor
([
0
],
device
=
"cuda"
)]),
]
logprobs_cumsum
=
torch
.
cumsum
(
logprobs
,
dim
=
0
,
dtype
=
torch
.
float32
)
start
=
input_metadata
.
extend_start_loc
.
clone
()
end
=
start
+
input_metadata
.
extend_seq_lens
-
2
start
.
clamp_
(
min
=
0
,
max
=
logprobs
.
shape
[
0
]
-
1
)
end
.
clamp_
(
min
=
0
,
max
=
logprobs
.
shape
[
0
]
-
1
)
sum_logp
=
logprobs_cumsum
[
end
]
-
logprobs_cumsum
[
start
]
+
logprobs
[
start
]
normalized_logprobs
=
sum_logp
/
(
(
input_metadata
.
extend_seq_lens
-
1
).
clamp
(
min
=
1
)
)
last_logits
=
logits
[
last_index
]
return
last_logits
,
normalized_logprobs
def
compute_normalized_logprobs
(
all_logprobs
,
input_ids
,
seq_lens
,
start_loc
):
logprobs
=
all_logprobs
[
torch
.
arange
(
all_logprobs
.
shape
[
0
],
device
=
"cuda"
),
torch
.
cat
([
input_ids
[
1
:],
torch
.
tensor
([
0
],
device
=
"cuda"
)]),
]
logprobs_cumsum
=
torch
.
cumsum
(
logprobs
,
dim
=
0
,
dtype
=
torch
.
float32
)
start
=
start_loc
.
clone
()
end
=
start
+
seq_lens
-
2
start
.
clamp_
(
min
=
0
,
max
=
logprobs
.
shape
[
0
]
-
1
)
end
.
clamp_
(
min
=
0
,
max
=
logprobs
.
shape
[
0
]
-
1
)
sum_logp
=
logprobs_cumsum
[
end
]
-
logprobs_cumsum
[
start
]
+
logprobs
[
start
]
return
sum_logp
/
((
seq_lens
-
1
).
clamp
(
min
=
1
))
return
last_logits
,
(
logprobs
,
normalized_logprobs
)
if
__name__
==
"__main__"
:
...
...
python/sglang/srt/managers/io_struct.py
View file @
9a16fea0
...
...
@@ -11,8 +11,8 @@ class GenerateReqInput:
image_data
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
return_
normalized_
logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
None
normalized_
logprob_start_len
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
return_logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
None
logprob_start_len
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
stream
:
bool
=
False
def
post_init
(
self
):
...
...
@@ -23,10 +23,10 @@ class GenerateReqInput:
self
.
sampling_params
=
{}
if
self
.
rid
is
None
:
self
.
rid
=
uuid
.
uuid4
().
hex
if
self
.
return_
normalized_
logprob
is
None
:
self
.
return_
normalized_
logprob
=
False
if
self
.
normalized_
logprob_start_len
is
None
:
self
.
normalized_
logprob_start_len
=
0
if
self
.
return_logprob
is
None
:
self
.
return_logprob
=
False
if
self
.
logprob_start_len
is
None
:
self
.
logprob_start_len
=
0
else
:
num
=
len
(
self
.
text
)
...
...
@@ -45,17 +45,15 @@ class GenerateReqInput:
else
:
assert
isinstance
(
self
.
rid
,
list
)
if
self
.
return_
normalized_
logprob
is
None
:
self
.
return_
normalized_
logprob
=
[
False
]
*
num
elif
not
isinstance
(
self
.
return_
normalized_
logprob
,
list
):
self
.
return_
normalized_
logprob
=
[
self
.
return_
normalized_
logprob
]
*
num
if
self
.
return_logprob
is
None
:
self
.
return_logprob
=
[
False
]
*
num
elif
not
isinstance
(
self
.
return_logprob
,
list
):
self
.
return_logprob
=
[
self
.
return_logprob
]
*
num
if
self
.
normalized_logprob_start_len
is
None
:
self
.
normalized_logprob_start_len
=
[
0
]
*
num
elif
not
isinstance
(
self
.
normalized_logprob_start_len
,
list
):
self
.
normalized_logprob_start_len
=
[
self
.
normalized_logprob_start_len
]
*
num
if
self
.
logprob_start_len
is
None
:
self
.
logprob_start_len
=
[
0
]
*
num
elif
not
isinstance
(
self
.
logprob_start_len
,
list
):
self
.
logprob_start_len
=
[
self
.
logprob_start_len
]
*
num
@
dataclass
...
...
@@ -65,8 +63,8 @@ class TokenizedGenerateReqInput:
pixel_values
:
List
[
float
]
image_hash
:
int
sampling_params
:
SamplingParams
return_
normalized_
logprob
:
bool
normalized_
logprob_start_len
:
int
return_logprob
:
bool
logprob_start_len
:
int
stream
:
bool
...
...
python/sglang/srt/managers/router/infer_batch.py
View file @
9a16fea0
...
...
@@ -28,8 +28,8 @@ class Req:
self
.
pixel_values
=
None
self
.
image_offset
=
0
self
.
sampling_params
=
None
self
.
return_
normalized_
logprob
=
False
self
.
normalized_
logprob_start_len
=
0
self
.
return_logprob
=
False
self
.
logprob_start_len
=
0
self
.
stream
=
False
self
.
tokenizer
=
None
...
...
@@ -37,10 +37,11 @@ class Req:
self
.
finish_reason
=
None
self
.
hit_stop_str
=
None
self
.
adjust
_input_len
=
0
self
.
extend
_input_len
=
0
self
.
prefix_indices
=
[]
self
.
last_node
=
None
self
.
logprob
=
None
self
.
normalized_logprob
=
None
# for constrained decoding
...
...
@@ -99,7 +100,7 @@ class Batch:
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_cont_start
:
torch
.
Tensor
=
None
out_cache_cont_end
:
torch
.
Tensor
=
None
return_
normalized_
logprob
:
bool
=
False
return_logprob
:
bool
=
False
# for multimodal
pixel_values
:
List
[
torch
.
Tensor
]
=
None
...
...
@@ -119,14 +120,14 @@ class Batch:
@
classmethod
def
init_new
(
cls
,
reqs
,
req_to_token_pool
,
token_to_kv_pool
,
tree_cache
):
return_
normalized_
logprob
=
any
(
req
.
return_
normalized_
logprob
for
req
in
reqs
)
return_logprob
=
any
(
req
.
return_logprob
for
req
in
reqs
)
return
cls
(
reqs
=
reqs
,
req_to_token_pool
=
req_to_token_pool
,
token_to_kv_pool
=
token_to_kv_pool
,
tree_cache
=
tree_cache
,
return_
normalized_
logprob
=
return_
normalized_
logprob
,
return_logprob
=
return_logprob
,
)
def
is_empty
(
self
):
...
...
@@ -257,7 +258,7 @@ class Batch:
self
.
tree_cache
.
dec_ref_counter
(
req
.
last_node
)
req
.
prefix_indices
=
None
req
.
last_node
=
None
req
.
adjust
_input_len
=
0
req
.
extend
_input_len
=
0
req
.
output_ids
=
[]
# TODO: apply more fine-grained retraction
...
...
@@ -310,9 +311,7 @@ class Batch:
self
.
prefix_lens
=
None
self
.
position_ids_offsets
=
self
.
position_ids_offsets
[
new_indices
]
self
.
out_cache_loc
=
self
.
out_cache_cont_start
=
self
.
out_cache_cont_end
=
None
self
.
return_normalized_logprob
=
any
(
req
.
return_normalized_logprob
for
req
in
self
.
reqs
)
self
.
return_logprob
=
any
(
req
.
return_logprob
for
req
in
self
.
reqs
)
for
item
in
[
"temperatures"
,
...
...
@@ -336,9 +335,7 @@ class Batch:
[
self
.
position_ids_offsets
,
other
.
position_ids_offsets
]
)
self
.
out_cache_loc
=
self
.
out_cache_cont_start
=
self
.
out_cache_cont_end
=
None
self
.
return_normalized_logprob
=
any
(
req
.
return_normalized_logprob
for
req
in
self
.
reqs
)
self
.
return_logprob
=
any
(
req
.
return_logprob
for
req
in
self
.
reqs
)
for
item
in
[
"temperatures"
,
...
...
python/sglang/srt/managers/router/model_rpc.py
View file @
9a16fea0
...
...
@@ -214,8 +214,8 @@ class ModelRpcServer(rpyc.Service):
req
.
input_ids
,
pad_value
)
req
.
sampling_params
=
recv_req
.
sampling_params
req
.
return_
normalized_
logprob
=
recv_req
.
return_
normalized_
logprob
req
.
normalized_
logprob_start_len
=
recv_req
.
normalized_
logprob_start_len
req
.
return_logprob
=
recv_req
.
return_logprob
req
.
logprob_start_len
=
recv_req
.
logprob_start_len
req
.
stream
=
recv_req
.
stream
req
.
tokenizer
=
self
.
tokenizer
...
...
@@ -240,9 +240,9 @@ class ModelRpcServer(rpyc.Service):
for
req
in
self
.
forward_queue
:
prefix_indices
,
last_node
=
self
.
tree_cache
.
match_prefix
(
req
.
input_ids
)
if
req
.
return_
normalized_
logprob
:
prefix_indices
=
prefix_indices
[:
req
.
normalized_
logprob_start_len
]
req
.
adjust
_input_len
=
len
(
req
.
input_ids
)
-
len
(
prefix_indices
)
if
req
.
return_logprob
:
prefix_indices
=
prefix_indices
[:
req
.
logprob_start_len
]
req
.
extend
_input_len
=
len
(
req
.
input_ids
)
-
len
(
prefix_indices
)
req
.
prefix_indices
=
prefix_indices
req
.
last_node
=
last_node
...
...
@@ -267,32 +267,32 @@ class ModelRpcServer(rpyc.Service):
)
for
req
in
self
.
forward_queue
:
if
req
.
return_
normalized_
logprob
:
if
req
.
return_logprob
:
# Need at least two tokens to compute normalized logprob
if
req
.
adjust
_input_len
<
2
:
delta
=
2
-
req
.
adjust
_input_len
req
.
adjust
_input_len
+=
delta
if
req
.
extend
_input_len
<
2
:
delta
=
2
-
req
.
extend
_input_len
req
.
extend
_input_len
+=
delta
req
.
prefix_indices
=
req
.
prefix_indices
[:
-
delta
]
if
req
.
image_offset
is
not
None
:
req
.
image_offset
+=
delta
if
req
.
adjust
_input_len
==
0
and
req
.
max_new_tokens
()
>
0
:
if
req
.
extend
_input_len
==
0
and
req
.
max_new_tokens
()
>
0
:
# Need at least one token to compute logits
req
.
adjust
_input_len
=
1
req
.
extend
_input_len
=
1
req
.
prefix_indices
=
req
.
prefix_indices
[:
-
1
]
if
req
.
image_offset
is
not
None
:
req
.
image_offset
+=
1
if
(
req
.
adjust
_input_len
+
req
.
max_new_tokens
()
+
new_batch_total_tokens
req
.
extend
_input_len
+
req
.
max_new_tokens
()
+
new_batch_total_tokens
<
available_size
and
req
.
adjust
_input_len
+
new_batch_input_tokens
and
req
.
extend
_input_len
+
new_batch_input_tokens
<
self
.
max_prefill_num_token
):
delta
=
self
.
tree_cache
.
inc_ref_counter
(
req
.
last_node
)
available_size
+=
delta
if
not
(
req
.
adjust
_input_len
+
req
.
max_new_tokens
()
+
new_batch_total_tokens
req
.
extend
_input_len
+
req
.
max_new_tokens
()
+
new_batch_total_tokens
<
available_size
):
delta
=
self
.
tree_cache
.
dec_ref_counter
(
req
.
last_node
)
...
...
@@ -301,9 +301,9 @@ class ModelRpcServer(rpyc.Service):
self
.
token_to_kv_pool
.
add_refs
(
req
.
prefix_indices
)
can_run_list
.
append
(
req
)
new_batch_total_tokens
+=
(
req
.
adjust
_input_len
+
req
.
max_new_tokens
()
req
.
extend
_input_len
+
req
.
max_new_tokens
()
)
new_batch_input_tokens
+=
req
.
adjust
_input_len
new_batch_input_tokens
+=
req
.
extend
_input_len
if
len
(
can_run_list
)
==
0
:
return
None
...
...
@@ -339,27 +339,31 @@ class ModelRpcServer(rpyc.Service):
if
batch
.
extend_num_tokens
!=
0
:
# Forward
logits
,
normalized_logprobs
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
,
batch
.
return_
normalized_
logprob
logits
,
(
logprobs
,
normalized_logprobs
)
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
,
batch
.
return_logprob
)
# print("extend logits", logits)
if
normalized_logprobs
is
not
None
:
if
logprobs
is
not
None
:
logprobs
=
logprobs
.
cpu
().
tolist
()
normalized_logprobs
=
normalized_logprobs
.
cpu
().
tolist
()
next_token_ids
,
next_token_probs
=
batch
.
sample
(
logits
)
next_token_ids
=
next_token_ids
.
cpu
().
tolist
()
else
:
next_token_ids
=
[
self
.
tokenizer
.
eos_token_id
]
*
len
(
batch
.
reqs
)
normalized_logprobs
=
None
logprobs
=
normalized_logprobs
=
None
# Check finish condition
reqs
=
batch
.
reqs
for
i
in
range
(
len
(
reqs
)):
reqs
[
i
].
output_ids
=
[
next_token_ids
[
i
]]
reqs
[
i
].
check_finished
()
pt
=
0
for
i
,
req
in
enumerate
(
reqs
):
req
.
output_ids
=
[
next_token_ids
[
i
]]
req
.
check_finished
()
if
normalized_logprobs
is
not
None
:
reqs
[
i
].
normalized_logprob
=
normalized_logprobs
[
i
]
if
logprobs
is
not
None
:
req
.
logprob
=
logprobs
[
pt
:
pt
+
req
.
extend_input_len
-
1
]
req
.
normalized_logprob
=
normalized_logprobs
[
i
]
pt
+=
req
.
extend_input_len
self
.
handle_finished_requests
(
batch
)
...
...
@@ -427,8 +431,9 @@ class ModelRpcServer(rpyc.Service):
"prompt_tokens"
:
len
(
req
.
input_ids
),
"completion_tokens"
:
len
(
req
.
output_ids
),
}
if
req
.
return_normalized_logprob
:
meta_info
[
"normalized_logprob"
]
=
req
.
normalized_logprob
if
req
.
return_logprob
:
meta_info
[
"prompt_logprob"
]
=
req
.
logprob
meta_info
[
"normalized_prompt_logprob"
]
=
req
.
normalized_logprob
output_meta_info
.
append
(
meta_info
)
output_finished
.
append
(
req
.
finished
)
...
...
python/sglang/srt/managers/router/model_runner.py
View file @
9a16fea0
...
...
@@ -45,7 +45,7 @@ class InputMetadata:
out_cache_cont_end
:
torch
.
Tensor
=
None
other_kv_index
:
torch
.
Tensor
=
None
return_
normalized_
logprob
:
bool
=
False
return_logprob
:
bool
=
False
# for flashinfer
use_flashinfer
:
bool
=
False
...
...
@@ -127,7 +127,7 @@ class InputMetadata:
out_cache_loc
,
out_cache_cont_start
=
None
,
out_cache_cont_end
=
None
,
return_
normalized_
logprob
=
False
,
return_logprob
=
False
,
):
batch_size
=
len
(
req_pool_indices
)
start_loc
=
torch
.
zeros
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
...
...
@@ -175,7 +175,7 @@ class InputMetadata:
out_cache_loc
=
out_cache_loc
,
out_cache_cont_start
=
out_cache_cont_start
,
out_cache_cont_end
=
out_cache_cont_end
,
return_
normalized_
logprob
=
return_
normalized_
logprob
,
return_logprob
=
return_logprob
,
other_kv_index
=
other_kv_index
,
)
...
...
@@ -337,7 +337,7 @@ class ModelRunner:
prefix_lens
,
position_ids_offsets
,
out_cache_loc
,
return_
normalized_
logprob
,
return_logprob
,
):
input_metadata
=
InputMetadata
.
create
(
self
,
...
...
@@ -348,7 +348,7 @@ class ModelRunner:
prefix_lens
=
prefix_lens
,
position_ids_offsets
=
position_ids_offsets
,
out_cache_loc
=
out_cache_loc
,
return_
normalized_
logprob
=
return_
normalized_
logprob
,
return_logprob
=
return_logprob
,
)
return
self
.
model
.
forward
(
input_ids
,
input_metadata
.
positions
,
input_metadata
)
...
...
@@ -361,7 +361,7 @@ class ModelRunner:
prefix_lens
,
position_ids_offsets
,
out_cache_loc
,
return_
normalized_
logprob
,
return_logprob
,
):
input_metadata
=
InputMetadata
.
create
(
self
,
...
...
@@ -372,7 +372,7 @@ class ModelRunner:
prefix_lens
=
prefix_lens
,
position_ids_offsets
=
position_ids_offsets
,
out_cache_loc
=
out_cache_loc
,
return_
normalized_
logprob
=
return_
normalized_
logprob
,
return_logprob
=
return_logprob
,
)
return
self
.
model
.
forward
(
input_ids
,
input_metadata
.
positions
,
input_metadata
)
...
...
@@ -415,7 +415,7 @@ class ModelRunner:
prefix_lens
,
position_ids_offsets
,
out_cache_loc
,
return_
normalized_
logprob
,
return_logprob
,
):
input_metadata
=
InputMetadata
.
create
(
self
,
...
...
@@ -426,7 +426,7 @@ class ModelRunner:
prefix_lens
=
prefix_lens
,
position_ids_offsets
=
position_ids_offsets
,
out_cache_loc
=
out_cache_loc
,
return_
normalized_
logprob
=
return_
normalized_
logprob
,
return_logprob
=
return_logprob
,
)
return
self
.
model
.
forward
(
input_ids
,
...
...
@@ -436,9 +436,7 @@ class ModelRunner:
image_offsets
,
)
def
forward
(
self
,
batch
:
Batch
,
forward_mode
:
ForwardMode
,
return_normalized_logprob
=
False
):
def
forward
(
self
,
batch
:
Batch
,
forward_mode
:
ForwardMode
,
return_logprob
=
False
):
if
self
.
is_multimodal_model
and
forward_mode
==
ForwardMode
.
EXTEND
:
kwargs
=
{
"input_ids"
:
batch
.
input_ids
,
...
...
@@ -450,7 +448,7 @@ class ModelRunner:
"position_ids_offsets"
:
batch
.
position_ids_offsets
,
"out_cache_loc"
:
batch
.
out_cache_loc
,
}
kwargs
[
"return_
normalized_
logprob"
]
=
return_
normalized_
logprob
kwargs
[
"return_logprob"
]
=
return_logprob
return
self
.
forward_extend_multi_modal
(
**
kwargs
)
else
:
kwargs
=
{
...
...
@@ -467,10 +465,10 @@ class ModelRunner:
kwargs
[
"out_cache_cont_end"
]
=
batch
.
out_cache_cont_end
return
self
.
forward_decode
(
**
kwargs
)
elif
forward_mode
==
ForwardMode
.
EXTEND
:
kwargs
[
"return_
normalized_
logprob"
]
=
return_
normalized_
logprob
kwargs
[
"return_logprob"
]
=
return_logprob
return
self
.
forward_extend
(
**
kwargs
)
elif
forward_mode
==
ForwardMode
.
PREFILL
:
kwargs
[
"return_
normalized_
logprob"
]
=
return_
normalized_
logprob
kwargs
[
"return_logprob"
]
=
return_logprob
return
self
.
forward_prefill
(
**
kwargs
)
else
:
raise
ValueError
(
f
"Invaid forward mode:
{
forward_mode
}
"
)
python/sglang/srt/managers/tokenizer_manager.py
View file @
9a16fea0
...
...
@@ -132,8 +132,8 @@ class TokenizerManager:
pixel_values
=
pixel_values
,
image_hash
=
image_hash
,
sampling_params
=
sampling_params
,
return_
normalized_
logprob
=
obj
.
return_
normalized_
logprob
,
normalized_
logprob_start_len
=
obj
.
normalized_
logprob_start_len
,
return_logprob
=
obj
.
return_logprob
,
logprob_start_len
=
obj
.
logprob_start_len
,
stream
=
obj
.
stream
,
)
self
.
send_to_router
.
send_pyobj
(
tokenized_obj
)
...
...
@@ -173,8 +173,8 @@ class TokenizerManager:
pixel_values
=
pixel_values
,
image_hash
=
image_hash
,
sampling_params
=
sampling_params
,
return_
normalized_
logprob
=
obj
.
return_
normalized_
logprob
[
i
],
normalized_
logprob_start_len
=
obj
.
normalized_
logprob_start_len
[
i
],
return_logprob
=
obj
.
return_logprob
[
i
],
logprob_start_len
=
obj
.
logprob_start_len
[
i
],
stream
=
obj
.
stream
,
)
self
.
send_to_router
.
send_pyobj
(
tokenized_obj
)
...
...
test/srt/test_httpserver_decode.py
View file @
9a16fea0
...
...
@@ -26,6 +26,8 @@ if __name__ == "__main__":
"temperature"
:
0
,
"max_new_tokens"
:
32
,
},
# "return_logprob": True,
# "logprob_start_len": 0,
},
)
print
(
response
.
json
())
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment