Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9a16fea0
Unverified
Commit
9a16fea0
authored
Jan 23, 2024
by
Lianmin Zheng
Committed by
GitHub
Jan 23, 2024
Browse files
Return logprob for choices (#87)
parent
9e037c82
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
161 additions
and
112 deletions
+161
-112
README.md
README.md
+4
-0
docs/sampling_params.md
docs/sampling_params.md
+2
-2
examples/usage/async_io.py
examples/usage/async_io.py
+0
-0
examples/usage/choices_logprob.py
examples/usage/choices_logprob.py
+42
-0
python/sglang/backend/openai.py
python/sglang/backend/openai.py
+1
-1
python/sglang/backend/runtime_endpoint.py
python/sglang/backend/runtime_endpoint.py
+9
-5
python/sglang/lang/chat_template.py
python/sglang/lang/chat_template.py
+2
-2
python/sglang/lang/interpreter.py
python/sglang/lang/interpreter.py
+8
-2
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+16
-23
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+16
-18
python/sglang/srt/managers/router/infer_batch.py
python/sglang/srt/managers/router/infer_batch.py
+10
-13
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+32
-27
python/sglang/srt/managers/router/model_runner.py
python/sglang/srt/managers/router/model_runner.py
+13
-15
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+4
-4
test/srt/test_httpserver_decode.py
test/srt/test_httpserver_decode.py
+2
-0
No files found.
README.md
View file @
9a16fea0
...
@@ -69,6 +69,8 @@ state = multi_turn_question.run(
...
@@ -69,6 +69,8 @@ state = multi_turn_question.run(
for
m
in
state
.
messages
():
for
m
in
state
.
messages
():
print
(
m
[
"role"
],
":"
,
m
[
"content"
])
print
(
m
[
"role"
],
":"
,
m
[
"content"
])
print
(
state
[
"answer_1"
])
```
```
### Using Local Models
### Using Local Models
...
@@ -99,6 +101,8 @@ state = multi_turn_question.run(
...
@@ -99,6 +101,8 @@ state = multi_turn_question.run(
for
m
in
state
.
messages
():
for
m
in
state
.
messages
():
print
(
m
[
"role"
],
":"
,
m
[
"content"
])
print
(
m
[
"role"
],
":"
,
m
[
"content"
])
print
(
state
[
"answer_1"
])
```
```
### More Examples
### More Examples
...
...
docs/sampling_params.md
View file @
9a16fea0
...
@@ -9,8 +9,8 @@ class GenerateReqInput:
...
@@ -9,8 +9,8 @@ class GenerateReqInput:
image_data
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
image_data
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
return_
normalized_
logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
None
return_logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
None
normalized_
logprob_start_len
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
logprob_start_len
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
stream
:
bool
=
False
stream
:
bool
=
False
```
```
...
...
examples/usage/async.py
→
examples/usage/async
_io
.py
View file @
9a16fea0
File moved
examples/usage/choices_logprob.py
0 → 100644
View file @
9a16fea0
"""
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
"""
import
sglang
as
sgl
@
sgl
.
function
def
tool_use
(
s
,
question
):
s
+=
"To answer this question: "
+
question
+
", "
s
+=
"I need to use a "
+
sgl
.
gen
(
"tool"
,
choices
=
[
"calculator"
,
"search engine"
])
def
main
():
# Run one case
question
=
"What is 5 + 5?"
state
=
tool_use
.
run
(
question
)
print
(
"questions:"
,
question
)
print
(
"choice:"
,
state
[
"tool"
])
meta_info
=
state
.
get_meta_info
(
"tool"
)
print
(
"logprobs of choice 1"
,
meta_info
[
"prompt_logprob"
][
0
])
print
(
"logprobs of choice 2"
,
meta_info
[
"prompt_logprob"
][
1
])
print
(
'-'
*
50
)
# Run a batch
questions
=
[
"What is 5 + 6?"
,
"Who is Michael Jordan?"
,
]
states
=
tool_use
.
run_batch
([{
"question"
:
q
}
for
q
in
questions
])
for
question
,
state
in
zip
(
questions
,
states
):
print
(
"questions:"
,
question
)
print
(
"choice:"
,
state
[
"tool"
])
meta_info
=
state
.
get_meta_info
(
"tool"
)
print
(
"logprobs of choice 1"
,
meta_info
[
"prompt_logprob"
][
0
])
print
(
"logprobs of choice 2"
,
meta_info
[
"prompt_logprob"
][
1
])
print
(
'-'
*
50
)
if
__name__
==
"__main__"
:
sgl
.
set_default_backend
(
sgl
.
RuntimeEndpoint
(
"http://localhost:30000"
))
main
()
python/sglang/backend/openai.py
View file @
9a16fea0
...
@@ -209,7 +209,7 @@ class OpenAI(BaseBackend):
...
@@ -209,7 +209,7 @@ class OpenAI(BaseBackend):
prompt_tokens
.
append
(
ret_token
)
prompt_tokens
.
append
(
ret_token
)
decision
=
choices
[
np
.
argmax
(
scores
)]
decision
=
choices
[
np
.
argmax
(
scores
)]
return
decision
,
scores
return
decision
,
scores
,
scores
def
openai_completion
(
client
,
is_chat
=
None
,
prompt
=
None
,
**
kwargs
):
def
openai_completion
(
client
,
is_chat
=
None
,
prompt
=
None
,
**
kwargs
):
...
...
python/sglang/backend/runtime_endpoint.py
View file @
9a16fea0
...
@@ -150,16 +150,20 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -150,16 +150,20 @@ class RuntimeEndpoint(BaseBackend):
data
=
{
data
=
{
"text"
:
[
s
.
text_
+
c
for
c
in
choices
],
"text"
:
[
s
.
text_
+
c
for
c
in
choices
],
"sampling_params"
:
{
"max_new_tokens"
:
0
},
"sampling_params"
:
{
"max_new_tokens"
:
0
},
"return_
normalized_
logprob"
:
True
,
"return_logprob"
:
True
,
"
normalized_
logprob_start_len"
:
prompt_len
,
"logprob_start_len"
:
max
(
prompt_len
-
2
,
0
)
,
}
}
self
.
_add_images
(
s
,
data
)
self
.
_add_images
(
s
,
data
)
res
=
http_request
(
self
.
base_url
+
"/generate"
,
json
=
data
)
res
=
http_request
(
self
.
base_url
+
"/generate"
,
json
=
data
)
assert
res
.
status_code
==
200
assert
res
.
status_code
==
200
logps
=
[
r
[
"meta_info"
][
"normalized_logprob"
]
for
r
in
res
.
json
()]
obj
=
res
.
json
()
normalized_prompt_logprob
=
[
r
[
"meta_info"
][
"normalized_prompt_logprob"
]
for
r
in
obj
]
prompt_logprob
=
[
r
[
"meta_info"
][
"prompt_logprob"
]
for
r
in
obj
]
decision
=
choices
[
np
.
argmax
(
logp
s
)]
decision
=
choices
[
np
.
argmax
(
normalized_prompt_
logp
rob
)]
return
decision
,
logps
return
decision
,
normalized_prompt_logprob
,
prompt_logprob
def
concatenate_and_append
(
self
,
src_rids
:
List
[
str
],
dst_rid
:
str
):
def
concatenate_and_append
(
self
,
src_rids
:
List
[
str
],
dst_rid
:
str
):
res
=
http_request
(
res
=
http_request
(
...
...
python/sglang/lang/chat_template.py
View file @
9a16fea0
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
from
typing
import
Callable
,
Dict
,
List
,
Tuple
,
Optional
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Tuple
class
ChatTemplateStyle
(
Enum
):
class
ChatTemplateStyle
(
Enum
):
...
@@ -111,7 +111,7 @@ register_chat_template(
...
@@ -111,7 +111,7 @@ register_chat_template(
"assistant"
:
(
"<|im_start|>assistant
\n
"
,
"
\n
<|im_end|>
\n
"
),
"assistant"
:
(
"<|im_start|>assistant
\n
"
,
"
\n
<|im_end|>
\n
"
),
},
},
style
=
ChatTemplateStyle
.
PLAIN
,
style
=
ChatTemplateStyle
.
PLAIN
,
stop_str
=
(
'
<|im_end|>
'
,)
stop_str
=
(
"
<|im_end|>
"
,)
,
)
)
)
)
...
...
python/sglang/lang/interpreter.py
View file @
9a16fea0
...
@@ -80,7 +80,7 @@ def run_program_batch(
...
@@ -80,7 +80,7 @@ def run_program_batch(
# Run all programs
# Run all programs
if
num_threads
==
"auto"
:
if
num_threads
==
"auto"
:
num_threads
=
max
(
6
4
,
multiprocessing
.
cpu_count
()
*
8
)
num_threads
=
max
(
9
6
,
multiprocessing
.
cpu_count
()
*
16
)
num_threads
=
min
(
num_threads
,
len
(
batch_arguments
))
num_threads
=
min
(
num_threads
,
len
(
batch_arguments
))
if
num_threads
==
1
:
if
num_threads
==
1
:
...
@@ -364,10 +364,16 @@ class StreamExecutor:
...
@@ -364,10 +364,16 @@ class StreamExecutor:
self
.
stream_var_event
[
name
].
set
()
self
.
stream_var_event
[
name
].
set
()
def
_execute_select
(
self
,
expr
:
SglSelect
):
def
_execute_select
(
self
,
expr
:
SglSelect
):
decision
,
scores
=
self
.
backend
.
select
(
self
,
expr
.
choices
,
expr
.
temperature
)
decision
,
normalized_prompt_logprob
,
prompt_logprob
=
self
.
backend
.
select
(
self
,
expr
.
choices
,
expr
.
temperature
)
if
expr
.
name
is
not
None
:
if
expr
.
name
is
not
None
:
name
=
expr
.
name
name
=
expr
.
name
self
.
variables
[
name
]
=
decision
self
.
variables
[
name
]
=
decision
self
.
meta_info
[
name
]
=
{
"normalized_prompt_logprob"
:
normalized_prompt_logprob
,
"prompt_logprob"
:
prompt_logprob
,
}
self
.
variable_event
[
name
].
set
()
self
.
variable_event
[
name
].
set
()
self
.
text_
+=
decision
self
.
text_
+=
decision
...
...
python/sglang/srt/layers/logits_processor.py
View file @
9a16fea0
...
@@ -14,7 +14,7 @@ class LogitsProcessor(nn.Module):
...
@@ -14,7 +14,7 @@ class LogitsProcessor(nn.Module):
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
def
forward
(
self
,
input_ids
,
hidden_states
,
weight
,
input_metadata
):
def
forward
(
self
,
input_ids
,
hidden_states
,
weight
,
input_metadata
):
if
not
input_metadata
.
return_
normalized_
logprob
:
if
not
input_metadata
.
return_logprob
:
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
last_hidden
=
hidden_states
last_hidden
=
hidden_states
else
:
else
:
...
@@ -33,7 +33,7 @@ class LogitsProcessor(nn.Module):
...
@@ -33,7 +33,7 @@ class LogitsProcessor(nn.Module):
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
last_logits
=
tensor_model_parallel_all_gather
(
last_logits
)
last_logits
=
tensor_model_parallel_all_gather
(
last_logits
)
last_logits
=
last_logits
[:,
:
self
.
config
.
vocab_size
]
last_logits
=
last_logits
[:,
:
self
.
config
.
vocab_size
]
return
last_logits
,
None
return
last_logits
,
(
None
,
None
)
else
:
else
:
assert
input_metadata
.
forward_mode
!=
ForwardMode
.
DECODE
assert
input_metadata
.
forward_mode
!=
ForwardMode
.
DECODE
last_index
=
(
last_index
=
(
...
@@ -51,30 +51,23 @@ class LogitsProcessor(nn.Module):
...
@@ -51,30 +51,23 @@ class LogitsProcessor(nn.Module):
logits
=
logits
[:,
:
self
.
config
.
vocab_size
]
logits
=
logits
[:,
:
self
.
config
.
vocab_size
]
all_logprobs
=
torch
.
log
(
torch
.
softmax
(
logits
.
float
(),
dim
=-
1
)
+
1e-6
)
all_logprobs
=
torch
.
log
(
torch
.
softmax
(
logits
.
float
(),
dim
=-
1
)
+
1e-6
)
normalized_logprobs
=
compute_normalized_logprobs
(
logprobs
=
all_logprobs
[
all_logprobs
,
torch
.
arange
(
all_logprobs
.
shape
[
0
],
device
=
"cuda"
),
input_ids
,
torch
.
cat
([
input_ids
[
1
:],
torch
.
tensor
([
0
],
device
=
"cuda"
)]),
input_metadata
.
extend_seq_lens
,
]
input_metadata
.
extend_start_loc
,
logprobs_cumsum
=
torch
.
cumsum
(
logprobs
,
dim
=
0
,
dtype
=
torch
.
float32
)
start
=
input_metadata
.
extend_start_loc
.
clone
()
end
=
start
+
input_metadata
.
extend_seq_lens
-
2
start
.
clamp_
(
min
=
0
,
max
=
logprobs
.
shape
[
0
]
-
1
)
end
.
clamp_
(
min
=
0
,
max
=
logprobs
.
shape
[
0
]
-
1
)
sum_logp
=
logprobs_cumsum
[
end
]
-
logprobs_cumsum
[
start
]
+
logprobs
[
start
]
normalized_logprobs
=
sum_logp
/
(
(
input_metadata
.
extend_seq_lens
-
1
).
clamp
(
min
=
1
)
)
)
last_logits
=
logits
[
last_index
]
last_logits
=
logits
[
last_index
]
return
last_logits
,
normalized_logprobs
return
last_logits
,
(
logprobs
,
normalized_logprobs
)
def
compute_normalized_logprobs
(
all_logprobs
,
input_ids
,
seq_lens
,
start_loc
):
logprobs
=
all_logprobs
[
torch
.
arange
(
all_logprobs
.
shape
[
0
],
device
=
"cuda"
),
torch
.
cat
([
input_ids
[
1
:],
torch
.
tensor
([
0
],
device
=
"cuda"
)]),
]
logprobs_cumsum
=
torch
.
cumsum
(
logprobs
,
dim
=
0
,
dtype
=
torch
.
float32
)
start
=
start_loc
.
clone
()
end
=
start
+
seq_lens
-
2
start
.
clamp_
(
min
=
0
,
max
=
logprobs
.
shape
[
0
]
-
1
)
end
.
clamp_
(
min
=
0
,
max
=
logprobs
.
shape
[
0
]
-
1
)
sum_logp
=
logprobs_cumsum
[
end
]
-
logprobs_cumsum
[
start
]
+
logprobs
[
start
]
return
sum_logp
/
((
seq_lens
-
1
).
clamp
(
min
=
1
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
python/sglang/srt/managers/io_struct.py
View file @
9a16fea0
...
@@ -11,8 +11,8 @@ class GenerateReqInput:
...
@@ -11,8 +11,8 @@ class GenerateReqInput:
image_data
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
image_data
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
return_
normalized_
logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
None
return_logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
None
normalized_
logprob_start_len
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
logprob_start_len
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
stream
:
bool
=
False
stream
:
bool
=
False
def
post_init
(
self
):
def
post_init
(
self
):
...
@@ -23,10 +23,10 @@ class GenerateReqInput:
...
@@ -23,10 +23,10 @@ class GenerateReqInput:
self
.
sampling_params
=
{}
self
.
sampling_params
=
{}
if
self
.
rid
is
None
:
if
self
.
rid
is
None
:
self
.
rid
=
uuid
.
uuid4
().
hex
self
.
rid
=
uuid
.
uuid4
().
hex
if
self
.
return_
normalized_
logprob
is
None
:
if
self
.
return_logprob
is
None
:
self
.
return_
normalized_
logprob
=
False
self
.
return_logprob
=
False
if
self
.
normalized_
logprob_start_len
is
None
:
if
self
.
logprob_start_len
is
None
:
self
.
normalized_
logprob_start_len
=
0
self
.
logprob_start_len
=
0
else
:
else
:
num
=
len
(
self
.
text
)
num
=
len
(
self
.
text
)
...
@@ -45,17 +45,15 @@ class GenerateReqInput:
...
@@ -45,17 +45,15 @@ class GenerateReqInput:
else
:
else
:
assert
isinstance
(
self
.
rid
,
list
)
assert
isinstance
(
self
.
rid
,
list
)
if
self
.
return_
normalized_
logprob
is
None
:
if
self
.
return_logprob
is
None
:
self
.
return_
normalized_
logprob
=
[
False
]
*
num
self
.
return_logprob
=
[
False
]
*
num
elif
not
isinstance
(
self
.
return_
normalized_
logprob
,
list
):
elif
not
isinstance
(
self
.
return_logprob
,
list
):
self
.
return_
normalized_
logprob
=
[
self
.
return_
normalized_
logprob
]
*
num
self
.
return_logprob
=
[
self
.
return_logprob
]
*
num
if
self
.
normalized_logprob_start_len
is
None
:
if
self
.
logprob_start_len
is
None
:
self
.
normalized_logprob_start_len
=
[
0
]
*
num
self
.
logprob_start_len
=
[
0
]
*
num
elif
not
isinstance
(
self
.
normalized_logprob_start_len
,
list
):
elif
not
isinstance
(
self
.
logprob_start_len
,
list
):
self
.
normalized_logprob_start_len
=
[
self
.
logprob_start_len
=
[
self
.
logprob_start_len
]
*
num
self
.
normalized_logprob_start_len
]
*
num
@
dataclass
@
dataclass
...
@@ -65,8 +63,8 @@ class TokenizedGenerateReqInput:
...
@@ -65,8 +63,8 @@ class TokenizedGenerateReqInput:
pixel_values
:
List
[
float
]
pixel_values
:
List
[
float
]
image_hash
:
int
image_hash
:
int
sampling_params
:
SamplingParams
sampling_params
:
SamplingParams
return_
normalized_
logprob
:
bool
return_logprob
:
bool
normalized_
logprob_start_len
:
int
logprob_start_len
:
int
stream
:
bool
stream
:
bool
...
...
python/sglang/srt/managers/router/infer_batch.py
View file @
9a16fea0
...
@@ -28,8 +28,8 @@ class Req:
...
@@ -28,8 +28,8 @@ class Req:
self
.
pixel_values
=
None
self
.
pixel_values
=
None
self
.
image_offset
=
0
self
.
image_offset
=
0
self
.
sampling_params
=
None
self
.
sampling_params
=
None
self
.
return_
normalized_
logprob
=
False
self
.
return_logprob
=
False
self
.
normalized_
logprob_start_len
=
0
self
.
logprob_start_len
=
0
self
.
stream
=
False
self
.
stream
=
False
self
.
tokenizer
=
None
self
.
tokenizer
=
None
...
@@ -37,10 +37,11 @@ class Req:
...
@@ -37,10 +37,11 @@ class Req:
self
.
finish_reason
=
None
self
.
finish_reason
=
None
self
.
hit_stop_str
=
None
self
.
hit_stop_str
=
None
self
.
adjust
_input_len
=
0
self
.
extend
_input_len
=
0
self
.
prefix_indices
=
[]
self
.
prefix_indices
=
[]
self
.
last_node
=
None
self
.
last_node
=
None
self
.
logprob
=
None
self
.
normalized_logprob
=
None
self
.
normalized_logprob
=
None
# for constrained decoding
# for constrained decoding
...
@@ -99,7 +100,7 @@ class Batch:
...
@@ -99,7 +100,7 @@ class Batch:
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_cont_start
:
torch
.
Tensor
=
None
out_cache_cont_start
:
torch
.
Tensor
=
None
out_cache_cont_end
:
torch
.
Tensor
=
None
out_cache_cont_end
:
torch
.
Tensor
=
None
return_
normalized_
logprob
:
bool
=
False
return_logprob
:
bool
=
False
# for multimodal
# for multimodal
pixel_values
:
List
[
torch
.
Tensor
]
=
None
pixel_values
:
List
[
torch
.
Tensor
]
=
None
...
@@ -119,14 +120,14 @@ class Batch:
...
@@ -119,14 +120,14 @@ class Batch:
@
classmethod
@
classmethod
def
init_new
(
cls
,
reqs
,
req_to_token_pool
,
token_to_kv_pool
,
tree_cache
):
def
init_new
(
cls
,
reqs
,
req_to_token_pool
,
token_to_kv_pool
,
tree_cache
):
return_
normalized_
logprob
=
any
(
req
.
return_
normalized_
logprob
for
req
in
reqs
)
return_logprob
=
any
(
req
.
return_logprob
for
req
in
reqs
)
return
cls
(
return
cls
(
reqs
=
reqs
,
reqs
=
reqs
,
req_to_token_pool
=
req_to_token_pool
,
req_to_token_pool
=
req_to_token_pool
,
token_to_kv_pool
=
token_to_kv_pool
,
token_to_kv_pool
=
token_to_kv_pool
,
tree_cache
=
tree_cache
,
tree_cache
=
tree_cache
,
return_
normalized_
logprob
=
return_
normalized_
logprob
,
return_logprob
=
return_logprob
,
)
)
def
is_empty
(
self
):
def
is_empty
(
self
):
...
@@ -257,7 +258,7 @@ class Batch:
...
@@ -257,7 +258,7 @@ class Batch:
self
.
tree_cache
.
dec_ref_counter
(
req
.
last_node
)
self
.
tree_cache
.
dec_ref_counter
(
req
.
last_node
)
req
.
prefix_indices
=
None
req
.
prefix_indices
=
None
req
.
last_node
=
None
req
.
last_node
=
None
req
.
adjust
_input_len
=
0
req
.
extend
_input_len
=
0
req
.
output_ids
=
[]
req
.
output_ids
=
[]
# TODO: apply more fine-grained retraction
# TODO: apply more fine-grained retraction
...
@@ -310,9 +311,7 @@ class Batch:
...
@@ -310,9 +311,7 @@ class Batch:
self
.
prefix_lens
=
None
self
.
prefix_lens
=
None
self
.
position_ids_offsets
=
self
.
position_ids_offsets
[
new_indices
]
self
.
position_ids_offsets
=
self
.
position_ids_offsets
[
new_indices
]
self
.
out_cache_loc
=
self
.
out_cache_cont_start
=
self
.
out_cache_cont_end
=
None
self
.
out_cache_loc
=
self
.
out_cache_cont_start
=
self
.
out_cache_cont_end
=
None
self
.
return_normalized_logprob
=
any
(
self
.
return_logprob
=
any
(
req
.
return_logprob
for
req
in
self
.
reqs
)
req
.
return_normalized_logprob
for
req
in
self
.
reqs
)
for
item
in
[
for
item
in
[
"temperatures"
,
"temperatures"
,
...
@@ -336,9 +335,7 @@ class Batch:
...
@@ -336,9 +335,7 @@ class Batch:
[
self
.
position_ids_offsets
,
other
.
position_ids_offsets
]
[
self
.
position_ids_offsets
,
other
.
position_ids_offsets
]
)
)
self
.
out_cache_loc
=
self
.
out_cache_cont_start
=
self
.
out_cache_cont_end
=
None
self
.
out_cache_loc
=
self
.
out_cache_cont_start
=
self
.
out_cache_cont_end
=
None
self
.
return_normalized_logprob
=
any
(
self
.
return_logprob
=
any
(
req
.
return_logprob
for
req
in
self
.
reqs
)
req
.
return_normalized_logprob
for
req
in
self
.
reqs
)
for
item
in
[
for
item
in
[
"temperatures"
,
"temperatures"
,
...
...
python/sglang/srt/managers/router/model_rpc.py
View file @
9a16fea0
...
@@ -214,8 +214,8 @@ class ModelRpcServer(rpyc.Service):
...
@@ -214,8 +214,8 @@ class ModelRpcServer(rpyc.Service):
req
.
input_ids
,
pad_value
req
.
input_ids
,
pad_value
)
)
req
.
sampling_params
=
recv_req
.
sampling_params
req
.
sampling_params
=
recv_req
.
sampling_params
req
.
return_
normalized_
logprob
=
recv_req
.
return_
normalized_
logprob
req
.
return_logprob
=
recv_req
.
return_logprob
req
.
normalized_
logprob_start_len
=
recv_req
.
normalized_
logprob_start_len
req
.
logprob_start_len
=
recv_req
.
logprob_start_len
req
.
stream
=
recv_req
.
stream
req
.
stream
=
recv_req
.
stream
req
.
tokenizer
=
self
.
tokenizer
req
.
tokenizer
=
self
.
tokenizer
...
@@ -240,9 +240,9 @@ class ModelRpcServer(rpyc.Service):
...
@@ -240,9 +240,9 @@ class ModelRpcServer(rpyc.Service):
for
req
in
self
.
forward_queue
:
for
req
in
self
.
forward_queue
:
prefix_indices
,
last_node
=
self
.
tree_cache
.
match_prefix
(
req
.
input_ids
)
prefix_indices
,
last_node
=
self
.
tree_cache
.
match_prefix
(
req
.
input_ids
)
if
req
.
return_
normalized_
logprob
:
if
req
.
return_logprob
:
prefix_indices
=
prefix_indices
[:
req
.
normalized_
logprob_start_len
]
prefix_indices
=
prefix_indices
[:
req
.
logprob_start_len
]
req
.
adjust
_input_len
=
len
(
req
.
input_ids
)
-
len
(
prefix_indices
)
req
.
extend
_input_len
=
len
(
req
.
input_ids
)
-
len
(
prefix_indices
)
req
.
prefix_indices
=
prefix_indices
req
.
prefix_indices
=
prefix_indices
req
.
last_node
=
last_node
req
.
last_node
=
last_node
...
@@ -267,32 +267,32 @@ class ModelRpcServer(rpyc.Service):
...
@@ -267,32 +267,32 @@ class ModelRpcServer(rpyc.Service):
)
)
for
req
in
self
.
forward_queue
:
for
req
in
self
.
forward_queue
:
if
req
.
return_
normalized_
logprob
:
if
req
.
return_logprob
:
# Need at least two tokens to compute normalized logprob
# Need at least two tokens to compute normalized logprob
if
req
.
adjust
_input_len
<
2
:
if
req
.
extend
_input_len
<
2
:
delta
=
2
-
req
.
adjust
_input_len
delta
=
2
-
req
.
extend
_input_len
req
.
adjust
_input_len
+=
delta
req
.
extend
_input_len
+=
delta
req
.
prefix_indices
=
req
.
prefix_indices
[:
-
delta
]
req
.
prefix_indices
=
req
.
prefix_indices
[:
-
delta
]
if
req
.
image_offset
is
not
None
:
if
req
.
image_offset
is
not
None
:
req
.
image_offset
+=
delta
req
.
image_offset
+=
delta
if
req
.
adjust
_input_len
==
0
and
req
.
max_new_tokens
()
>
0
:
if
req
.
extend
_input_len
==
0
and
req
.
max_new_tokens
()
>
0
:
# Need at least one token to compute logits
# Need at least one token to compute logits
req
.
adjust
_input_len
=
1
req
.
extend
_input_len
=
1
req
.
prefix_indices
=
req
.
prefix_indices
[:
-
1
]
req
.
prefix_indices
=
req
.
prefix_indices
[:
-
1
]
if
req
.
image_offset
is
not
None
:
if
req
.
image_offset
is
not
None
:
req
.
image_offset
+=
1
req
.
image_offset
+=
1
if
(
if
(
req
.
adjust
_input_len
+
req
.
max_new_tokens
()
+
new_batch_total_tokens
req
.
extend
_input_len
+
req
.
max_new_tokens
()
+
new_batch_total_tokens
<
available_size
<
available_size
and
req
.
adjust
_input_len
+
new_batch_input_tokens
and
req
.
extend
_input_len
+
new_batch_input_tokens
<
self
.
max_prefill_num_token
<
self
.
max_prefill_num_token
):
):
delta
=
self
.
tree_cache
.
inc_ref_counter
(
req
.
last_node
)
delta
=
self
.
tree_cache
.
inc_ref_counter
(
req
.
last_node
)
available_size
+=
delta
available_size
+=
delta
if
not
(
if
not
(
req
.
adjust
_input_len
+
req
.
max_new_tokens
()
+
new_batch_total_tokens
req
.
extend
_input_len
+
req
.
max_new_tokens
()
+
new_batch_total_tokens
<
available_size
<
available_size
):
):
delta
=
self
.
tree_cache
.
dec_ref_counter
(
req
.
last_node
)
delta
=
self
.
tree_cache
.
dec_ref_counter
(
req
.
last_node
)
...
@@ -301,9 +301,9 @@ class ModelRpcServer(rpyc.Service):
...
@@ -301,9 +301,9 @@ class ModelRpcServer(rpyc.Service):
self
.
token_to_kv_pool
.
add_refs
(
req
.
prefix_indices
)
self
.
token_to_kv_pool
.
add_refs
(
req
.
prefix_indices
)
can_run_list
.
append
(
req
)
can_run_list
.
append
(
req
)
new_batch_total_tokens
+=
(
new_batch_total_tokens
+=
(
req
.
adjust
_input_len
+
req
.
max_new_tokens
()
req
.
extend
_input_len
+
req
.
max_new_tokens
()
)
)
new_batch_input_tokens
+=
req
.
adjust
_input_len
new_batch_input_tokens
+=
req
.
extend
_input_len
if
len
(
can_run_list
)
==
0
:
if
len
(
can_run_list
)
==
0
:
return
None
return
None
...
@@ -339,27 +339,31 @@ class ModelRpcServer(rpyc.Service):
...
@@ -339,27 +339,31 @@ class ModelRpcServer(rpyc.Service):
if
batch
.
extend_num_tokens
!=
0
:
if
batch
.
extend_num_tokens
!=
0
:
# Forward
# Forward
logits
,
normalized_logprobs
=
self
.
model_runner
.
forward
(
logits
,
(
logprobs
,
normalized_logprobs
)
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
,
batch
.
return_
normalized_
logprob
batch
,
ForwardMode
.
EXTEND
,
batch
.
return_logprob
)
)
# print("extend logits", logits)
# print("extend logits", logits)
if
normalized_logprobs
is
not
None
:
if
logprobs
is
not
None
:
logprobs
=
logprobs
.
cpu
().
tolist
()
normalized_logprobs
=
normalized_logprobs
.
cpu
().
tolist
()
normalized_logprobs
=
normalized_logprobs
.
cpu
().
tolist
()
next_token_ids
,
next_token_probs
=
batch
.
sample
(
logits
)
next_token_ids
,
next_token_probs
=
batch
.
sample
(
logits
)
next_token_ids
=
next_token_ids
.
cpu
().
tolist
()
next_token_ids
=
next_token_ids
.
cpu
().
tolist
()
else
:
else
:
next_token_ids
=
[
self
.
tokenizer
.
eos_token_id
]
*
len
(
batch
.
reqs
)
next_token_ids
=
[
self
.
tokenizer
.
eos_token_id
]
*
len
(
batch
.
reqs
)
normalized_logprobs
=
None
logprobs
=
normalized_logprobs
=
None
# Check finish condition
# Check finish condition
reqs
=
batch
.
reqs
reqs
=
batch
.
reqs
for
i
in
range
(
len
(
reqs
)):
pt
=
0
reqs
[
i
].
output_ids
=
[
next_token_ids
[
i
]]
for
i
,
req
in
enumerate
(
reqs
):
reqs
[
i
].
check_finished
()
req
.
output_ids
=
[
next_token_ids
[
i
]]
req
.
check_finished
()
if
normalized_logprobs
is
not
None
:
if
logprobs
is
not
None
:
reqs
[
i
].
normalized_logprob
=
normalized_logprobs
[
i
]
req
.
logprob
=
logprobs
[
pt
:
pt
+
req
.
extend_input_len
-
1
]
req
.
normalized_logprob
=
normalized_logprobs
[
i
]
pt
+=
req
.
extend_input_len
self
.
handle_finished_requests
(
batch
)
self
.
handle_finished_requests
(
batch
)
...
@@ -427,8 +431,9 @@ class ModelRpcServer(rpyc.Service):
...
@@ -427,8 +431,9 @@ class ModelRpcServer(rpyc.Service):
"prompt_tokens"
:
len
(
req
.
input_ids
),
"prompt_tokens"
:
len
(
req
.
input_ids
),
"completion_tokens"
:
len
(
req
.
output_ids
),
"completion_tokens"
:
len
(
req
.
output_ids
),
}
}
if
req
.
return_normalized_logprob
:
if
req
.
return_logprob
:
meta_info
[
"normalized_logprob"
]
=
req
.
normalized_logprob
meta_info
[
"prompt_logprob"
]
=
req
.
logprob
meta_info
[
"normalized_prompt_logprob"
]
=
req
.
normalized_logprob
output_meta_info
.
append
(
meta_info
)
output_meta_info
.
append
(
meta_info
)
output_finished
.
append
(
req
.
finished
)
output_finished
.
append
(
req
.
finished
)
...
...
python/sglang/srt/managers/router/model_runner.py
View file @
9a16fea0
...
@@ -45,7 +45,7 @@ class InputMetadata:
...
@@ -45,7 +45,7 @@ class InputMetadata:
out_cache_cont_end
:
torch
.
Tensor
=
None
out_cache_cont_end
:
torch
.
Tensor
=
None
other_kv_index
:
torch
.
Tensor
=
None
other_kv_index
:
torch
.
Tensor
=
None
return_
normalized_
logprob
:
bool
=
False
return_logprob
:
bool
=
False
# for flashinfer
# for flashinfer
use_flashinfer
:
bool
=
False
use_flashinfer
:
bool
=
False
...
@@ -127,7 +127,7 @@ class InputMetadata:
...
@@ -127,7 +127,7 @@ class InputMetadata:
out_cache_loc
,
out_cache_loc
,
out_cache_cont_start
=
None
,
out_cache_cont_start
=
None
,
out_cache_cont_end
=
None
,
out_cache_cont_end
=
None
,
return_
normalized_
logprob
=
False
,
return_logprob
=
False
,
):
):
batch_size
=
len
(
req_pool_indices
)
batch_size
=
len
(
req_pool_indices
)
start_loc
=
torch
.
zeros
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
start_loc
=
torch
.
zeros
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
...
@@ -175,7 +175,7 @@ class InputMetadata:
...
@@ -175,7 +175,7 @@ class InputMetadata:
out_cache_loc
=
out_cache_loc
,
out_cache_loc
=
out_cache_loc
,
out_cache_cont_start
=
out_cache_cont_start
,
out_cache_cont_start
=
out_cache_cont_start
,
out_cache_cont_end
=
out_cache_cont_end
,
out_cache_cont_end
=
out_cache_cont_end
,
return_
normalized_
logprob
=
return_
normalized_
logprob
,
return_logprob
=
return_logprob
,
other_kv_index
=
other_kv_index
,
other_kv_index
=
other_kv_index
,
)
)
...
@@ -337,7 +337,7 @@ class ModelRunner:
...
@@ -337,7 +337,7 @@ class ModelRunner:
prefix_lens
,
prefix_lens
,
position_ids_offsets
,
position_ids_offsets
,
out_cache_loc
,
out_cache_loc
,
return_
normalized_
logprob
,
return_logprob
,
):
):
input_metadata
=
InputMetadata
.
create
(
input_metadata
=
InputMetadata
.
create
(
self
,
self
,
...
@@ -348,7 +348,7 @@ class ModelRunner:
...
@@ -348,7 +348,7 @@ class ModelRunner:
prefix_lens
=
prefix_lens
,
prefix_lens
=
prefix_lens
,
position_ids_offsets
=
position_ids_offsets
,
position_ids_offsets
=
position_ids_offsets
,
out_cache_loc
=
out_cache_loc
,
out_cache_loc
=
out_cache_loc
,
return_
normalized_
logprob
=
return_
normalized_
logprob
,
return_logprob
=
return_logprob
,
)
)
return
self
.
model
.
forward
(
input_ids
,
input_metadata
.
positions
,
input_metadata
)
return
self
.
model
.
forward
(
input_ids
,
input_metadata
.
positions
,
input_metadata
)
...
@@ -361,7 +361,7 @@ class ModelRunner:
...
@@ -361,7 +361,7 @@ class ModelRunner:
prefix_lens
,
prefix_lens
,
position_ids_offsets
,
position_ids_offsets
,
out_cache_loc
,
out_cache_loc
,
return_
normalized_
logprob
,
return_logprob
,
):
):
input_metadata
=
InputMetadata
.
create
(
input_metadata
=
InputMetadata
.
create
(
self
,
self
,
...
@@ -372,7 +372,7 @@ class ModelRunner:
...
@@ -372,7 +372,7 @@ class ModelRunner:
prefix_lens
=
prefix_lens
,
prefix_lens
=
prefix_lens
,
position_ids_offsets
=
position_ids_offsets
,
position_ids_offsets
=
position_ids_offsets
,
out_cache_loc
=
out_cache_loc
,
out_cache_loc
=
out_cache_loc
,
return_
normalized_
logprob
=
return_
normalized_
logprob
,
return_logprob
=
return_logprob
,
)
)
return
self
.
model
.
forward
(
input_ids
,
input_metadata
.
positions
,
input_metadata
)
return
self
.
model
.
forward
(
input_ids
,
input_metadata
.
positions
,
input_metadata
)
...
@@ -415,7 +415,7 @@ class ModelRunner:
...
@@ -415,7 +415,7 @@ class ModelRunner:
prefix_lens
,
prefix_lens
,
position_ids_offsets
,
position_ids_offsets
,
out_cache_loc
,
out_cache_loc
,
return_
normalized_
logprob
,
return_logprob
,
):
):
input_metadata
=
InputMetadata
.
create
(
input_metadata
=
InputMetadata
.
create
(
self
,
self
,
...
@@ -426,7 +426,7 @@ class ModelRunner:
...
@@ -426,7 +426,7 @@ class ModelRunner:
prefix_lens
=
prefix_lens
,
prefix_lens
=
prefix_lens
,
position_ids_offsets
=
position_ids_offsets
,
position_ids_offsets
=
position_ids_offsets
,
out_cache_loc
=
out_cache_loc
,
out_cache_loc
=
out_cache_loc
,
return_
normalized_
logprob
=
return_
normalized_
logprob
,
return_logprob
=
return_logprob
,
)
)
return
self
.
model
.
forward
(
return
self
.
model
.
forward
(
input_ids
,
input_ids
,
...
@@ -436,9 +436,7 @@ class ModelRunner:
...
@@ -436,9 +436,7 @@ class ModelRunner:
image_offsets
,
image_offsets
,
)
)
def
forward
(
def
forward
(
self
,
batch
:
Batch
,
forward_mode
:
ForwardMode
,
return_logprob
=
False
):
self
,
batch
:
Batch
,
forward_mode
:
ForwardMode
,
return_normalized_logprob
=
False
):
if
self
.
is_multimodal_model
and
forward_mode
==
ForwardMode
.
EXTEND
:
if
self
.
is_multimodal_model
and
forward_mode
==
ForwardMode
.
EXTEND
:
kwargs
=
{
kwargs
=
{
"input_ids"
:
batch
.
input_ids
,
"input_ids"
:
batch
.
input_ids
,
...
@@ -450,7 +448,7 @@ class ModelRunner:
...
@@ -450,7 +448,7 @@ class ModelRunner:
"position_ids_offsets"
:
batch
.
position_ids_offsets
,
"position_ids_offsets"
:
batch
.
position_ids_offsets
,
"out_cache_loc"
:
batch
.
out_cache_loc
,
"out_cache_loc"
:
batch
.
out_cache_loc
,
}
}
kwargs
[
"return_
normalized_
logprob"
]
=
return_
normalized_
logprob
kwargs
[
"return_logprob"
]
=
return_logprob
return
self
.
forward_extend_multi_modal
(
**
kwargs
)
return
self
.
forward_extend_multi_modal
(
**
kwargs
)
else
:
else
:
kwargs
=
{
kwargs
=
{
...
@@ -467,10 +465,10 @@ class ModelRunner:
...
@@ -467,10 +465,10 @@ class ModelRunner:
kwargs
[
"out_cache_cont_end"
]
=
batch
.
out_cache_cont_end
kwargs
[
"out_cache_cont_end"
]
=
batch
.
out_cache_cont_end
return
self
.
forward_decode
(
**
kwargs
)
return
self
.
forward_decode
(
**
kwargs
)
elif
forward_mode
==
ForwardMode
.
EXTEND
:
elif
forward_mode
==
ForwardMode
.
EXTEND
:
kwargs
[
"return_
normalized_
logprob"
]
=
return_
normalized_
logprob
kwargs
[
"return_logprob"
]
=
return_logprob
return
self
.
forward_extend
(
**
kwargs
)
return
self
.
forward_extend
(
**
kwargs
)
elif
forward_mode
==
ForwardMode
.
PREFILL
:
elif
forward_mode
==
ForwardMode
.
PREFILL
:
kwargs
[
"return_
normalized_
logprob"
]
=
return_
normalized_
logprob
kwargs
[
"return_logprob"
]
=
return_logprob
return
self
.
forward_prefill
(
**
kwargs
)
return
self
.
forward_prefill
(
**
kwargs
)
else
:
else
:
raise
ValueError
(
f
"Invaid forward mode:
{
forward_mode
}
"
)
raise
ValueError
(
f
"Invaid forward mode:
{
forward_mode
}
"
)
python/sglang/srt/managers/tokenizer_manager.py
View file @
9a16fea0
...
@@ -132,8 +132,8 @@ class TokenizerManager:
...
@@ -132,8 +132,8 @@ class TokenizerManager:
pixel_values
=
pixel_values
,
pixel_values
=
pixel_values
,
image_hash
=
image_hash
,
image_hash
=
image_hash
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
return_
normalized_
logprob
=
obj
.
return_
normalized_
logprob
,
return_logprob
=
obj
.
return_logprob
,
normalized_
logprob_start_len
=
obj
.
normalized_
logprob_start_len
,
logprob_start_len
=
obj
.
logprob_start_len
,
stream
=
obj
.
stream
,
stream
=
obj
.
stream
,
)
)
self
.
send_to_router
.
send_pyobj
(
tokenized_obj
)
self
.
send_to_router
.
send_pyobj
(
tokenized_obj
)
...
@@ -173,8 +173,8 @@ class TokenizerManager:
...
@@ -173,8 +173,8 @@ class TokenizerManager:
pixel_values
=
pixel_values
,
pixel_values
=
pixel_values
,
image_hash
=
image_hash
,
image_hash
=
image_hash
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
return_
normalized_
logprob
=
obj
.
return_
normalized_
logprob
[
i
],
return_logprob
=
obj
.
return_logprob
[
i
],
normalized_
logprob_start_len
=
obj
.
normalized_
logprob_start_len
[
i
],
logprob_start_len
=
obj
.
logprob_start_len
[
i
],
stream
=
obj
.
stream
,
stream
=
obj
.
stream
,
)
)
self
.
send_to_router
.
send_pyobj
(
tokenized_obj
)
self
.
send_to_router
.
send_pyobj
(
tokenized_obj
)
...
...
test/srt/test_httpserver_decode.py
View file @
9a16fea0
...
@@ -26,6 +26,8 @@ if __name__ == "__main__":
...
@@ -26,6 +26,8 @@ if __name__ == "__main__":
"temperature"
:
0
,
"temperature"
:
0
,
"max_new_tokens"
:
32
,
"max_new_tokens"
:
32
,
},
},
# "return_logprob": True,
# "logprob_start_len": 0,
},
},
)
)
print
(
response
.
json
())
print
(
response
.
json
())
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment