Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9a16fea0
"vscode:/vscode.git/clone" did not exist on "f56ba6b04fdd080e45654ecbe8fbd4c97499b3f6"
Unverified
Commit
9a16fea0
authored
Jan 23, 2024
by
Lianmin Zheng
Committed by
GitHub
Jan 23, 2024
Browse files
Return logprob for choices (#87)
parent
9e037c82
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
161 additions
and
112 deletions
+161
-112
README.md
README.md
+4
-0
docs/sampling_params.md
docs/sampling_params.md
+2
-2
examples/usage/async_io.py
examples/usage/async_io.py
+0
-0
examples/usage/choices_logprob.py
examples/usage/choices_logprob.py
+42
-0
python/sglang/backend/openai.py
python/sglang/backend/openai.py
+1
-1
python/sglang/backend/runtime_endpoint.py
python/sglang/backend/runtime_endpoint.py
+9
-5
python/sglang/lang/chat_template.py
python/sglang/lang/chat_template.py
+2
-2
python/sglang/lang/interpreter.py
python/sglang/lang/interpreter.py
+8
-2
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+16
-23
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+16
-18
python/sglang/srt/managers/router/infer_batch.py
python/sglang/srt/managers/router/infer_batch.py
+10
-13
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+32
-27
python/sglang/srt/managers/router/model_runner.py
python/sglang/srt/managers/router/model_runner.py
+13
-15
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+4
-4
test/srt/test_httpserver_decode.py
test/srt/test_httpserver_decode.py
+2
-0
No files found.
README.md
View file @
9a16fea0
...
@@ -69,6 +69,8 @@ state = multi_turn_question.run(
...
@@ -69,6 +69,8 @@ state = multi_turn_question.run(
for
m
in
state
.
messages
():
for
m
in
state
.
messages
():
print
(
m
[
"role"
],
":"
,
m
[
"content"
])
print
(
m
[
"role"
],
":"
,
m
[
"content"
])
print
(
state
[
"answer_1"
])
```
```
### Using Local Models
### Using Local Models
...
@@ -99,6 +101,8 @@ state = multi_turn_question.run(
...
@@ -99,6 +101,8 @@ state = multi_turn_question.run(
for
m
in
state
.
messages
():
for
m
in
state
.
messages
():
print
(
m
[
"role"
],
":"
,
m
[
"content"
])
print
(
m
[
"role"
],
":"
,
m
[
"content"
])
print
(
state
[
"answer_1"
])
```
```
### More Examples
### More Examples
...
...
docs/sampling_params.md
View file @
9a16fea0
...
@@ -9,8 +9,8 @@ class GenerateReqInput:
...
@@ -9,8 +9,8 @@ class GenerateReqInput:
image_data
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
image_data
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
return_
normalized_
logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
None
return_logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
None
normalized_
logprob_start_len
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
logprob_start_len
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
stream
:
bool
=
False
stream
:
bool
=
False
```
```
...
...
examples/usage/async.py
→
examples/usage/async
_io
.py
View file @
9a16fea0
File moved
examples/usage/choices_logprob.py
0 → 100644
View file @
9a16fea0
"""
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
"""
import
sglang
as
sgl
@
sgl
.
function
def
tool_use
(
s
,
question
):
s
+=
"To answer this question: "
+
question
+
", "
s
+=
"I need to use a "
+
sgl
.
gen
(
"tool"
,
choices
=
[
"calculator"
,
"search engine"
])
def
main
():
# Run one case
question
=
"What is 5 + 5?"
state
=
tool_use
.
run
(
question
)
print
(
"questions:"
,
question
)
print
(
"choice:"
,
state
[
"tool"
])
meta_info
=
state
.
get_meta_info
(
"tool"
)
print
(
"logprobs of choice 1"
,
meta_info
[
"prompt_logprob"
][
0
])
print
(
"logprobs of choice 2"
,
meta_info
[
"prompt_logprob"
][
1
])
print
(
'-'
*
50
)
# Run a batch
questions
=
[
"What is 5 + 6?"
,
"Who is Michael Jordan?"
,
]
states
=
tool_use
.
run_batch
([{
"question"
:
q
}
for
q
in
questions
])
for
question
,
state
in
zip
(
questions
,
states
):
print
(
"questions:"
,
question
)
print
(
"choice:"
,
state
[
"tool"
])
meta_info
=
state
.
get_meta_info
(
"tool"
)
print
(
"logprobs of choice 1"
,
meta_info
[
"prompt_logprob"
][
0
])
print
(
"logprobs of choice 2"
,
meta_info
[
"prompt_logprob"
][
1
])
print
(
'-'
*
50
)
if
__name__
==
"__main__"
:
sgl
.
set_default_backend
(
sgl
.
RuntimeEndpoint
(
"http://localhost:30000"
))
main
()
python/sglang/backend/openai.py
View file @
9a16fea0
...
@@ -209,7 +209,7 @@ class OpenAI(BaseBackend):
...
@@ -209,7 +209,7 @@ class OpenAI(BaseBackend):
prompt_tokens
.
append
(
ret_token
)
prompt_tokens
.
append
(
ret_token
)
decision
=
choices
[
np
.
argmax
(
scores
)]
decision
=
choices
[
np
.
argmax
(
scores
)]
return
decision
,
scores
return
decision
,
scores
,
scores
def
openai_completion
(
client
,
is_chat
=
None
,
prompt
=
None
,
**
kwargs
):
def
openai_completion
(
client
,
is_chat
=
None
,
prompt
=
None
,
**
kwargs
):
...
...
python/sglang/backend/runtime_endpoint.py
View file @
9a16fea0
...
@@ -150,16 +150,20 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -150,16 +150,20 @@ class RuntimeEndpoint(BaseBackend):
data
=
{
data
=
{
"text"
:
[
s
.
text_
+
c
for
c
in
choices
],
"text"
:
[
s
.
text_
+
c
for
c
in
choices
],
"sampling_params"
:
{
"max_new_tokens"
:
0
},
"sampling_params"
:
{
"max_new_tokens"
:
0
},
"return_
normalized_
logprob"
:
True
,
"return_logprob"
:
True
,
"
normalized_
logprob_start_len"
:
prompt_len
,
"logprob_start_len"
:
max
(
prompt_len
-
2
,
0
)
,
}
}
self
.
_add_images
(
s
,
data
)
self
.
_add_images
(
s
,
data
)
res
=
http_request
(
self
.
base_url
+
"/generate"
,
json
=
data
)
res
=
http_request
(
self
.
base_url
+
"/generate"
,
json
=
data
)
assert
res
.
status_code
==
200
assert
res
.
status_code
==
200
logps
=
[
r
[
"meta_info"
][
"normalized_logprob"
]
for
r
in
res
.
json
()]
obj
=
res
.
json
()
normalized_prompt_logprob
=
[
r
[
"meta_info"
][
"normalized_prompt_logprob"
]
for
r
in
obj
]
prompt_logprob
=
[
r
[
"meta_info"
][
"prompt_logprob"
]
for
r
in
obj
]
decision
=
choices
[
np
.
argmax
(
logp
s
)]
decision
=
choices
[
np
.
argmax
(
normalized_prompt_
logp
rob
)]
return
decision
,
logps
return
decision
,
normalized_prompt_logprob
,
prompt_logprob
def
concatenate_and_append
(
self
,
src_rids
:
List
[
str
],
dst_rid
:
str
):
def
concatenate_and_append
(
self
,
src_rids
:
List
[
str
],
dst_rid
:
str
):
res
=
http_request
(
res
=
http_request
(
...
...
python/sglang/lang/chat_template.py
View file @
9a16fea0
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
from
typing
import
Callable
,
Dict
,
List
,
Tuple
,
Optional
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Tuple
class
ChatTemplateStyle
(
Enum
):
class
ChatTemplateStyle
(
Enum
):
...
@@ -111,7 +111,7 @@ register_chat_template(
...
@@ -111,7 +111,7 @@ register_chat_template(
"assistant"
:
(
"<|im_start|>assistant
\n
"
,
"
\n
<|im_end|>
\n
"
),
"assistant"
:
(
"<|im_start|>assistant
\n
"
,
"
\n
<|im_end|>
\n
"
),
},
},
style
=
ChatTemplateStyle
.
PLAIN
,
style
=
ChatTemplateStyle
.
PLAIN
,
stop_str
=
(
'
<|im_end|>
'
,)
stop_str
=
(
"
<|im_end|>
"
,)
,
)
)
)
)
...
...
python/sglang/lang/interpreter.py
View file @
9a16fea0
...
@@ -80,7 +80,7 @@ def run_program_batch(
...
@@ -80,7 +80,7 @@ def run_program_batch(
# Run all programs
# Run all programs
if
num_threads
==
"auto"
:
if
num_threads
==
"auto"
:
num_threads
=
max
(
6
4
,
multiprocessing
.
cpu_count
()
*
8
)
num_threads
=
max
(
9
6
,
multiprocessing
.
cpu_count
()
*
16
)
num_threads
=
min
(
num_threads
,
len
(
batch_arguments
))
num_threads
=
min
(
num_threads
,
len
(
batch_arguments
))
if
num_threads
==
1
:
if
num_threads
==
1
:
...
@@ -364,10 +364,16 @@ class StreamExecutor:
...
@@ -364,10 +364,16 @@ class StreamExecutor:
self
.
stream_var_event
[
name
].
set
()
self
.
stream_var_event
[
name
].
set
()
def
_execute_select
(
self
,
expr
:
SglSelect
):
def
_execute_select
(
self
,
expr
:
SglSelect
):
decision
,
scores
=
self
.
backend
.
select
(
self
,
expr
.
choices
,
expr
.
temperature
)
decision
,
normalized_prompt_logprob
,
prompt_logprob
=
self
.
backend
.
select
(
self
,
expr
.
choices
,
expr
.
temperature
)
if
expr
.
name
is
not
None
:
if
expr
.
name
is
not
None
:
name
=
expr
.
name
name
=
expr
.
name
self
.
variables
[
name
]
=
decision
self
.
variables
[
name
]
=
decision
self
.
meta_info
[
name
]
=
{
"normalized_prompt_logprob"
:
normalized_prompt_logprob
,
"prompt_logprob"
:
prompt_logprob
,
}
self
.
variable_event
[
name
].
set
()
self
.
variable_event
[
name
].
set
()
self
.
text_
+=
decision
self
.
text_
+=
decision
...
...
python/sglang/srt/layers/logits_processor.py
View file @
9a16fea0
...
@@ -14,7 +14,7 @@ class LogitsProcessor(nn.Module):
...
@@ -14,7 +14,7 @@ class LogitsProcessor(nn.Module):
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
def
forward
(
self
,
input_ids
,
hidden_states
,
weight
,
input_metadata
):
def
forward
(
self
,
input_ids
,
hidden_states
,
weight
,
input_metadata
):
if
not
input_metadata
.
return_
normalized_
logprob
:
if
not
input_metadata
.
return_logprob
:
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
last_hidden
=
hidden_states
last_hidden
=
hidden_states
else
:
else
:
...
@@ -33,7 +33,7 @@ class LogitsProcessor(nn.Module):
...
@@ -33,7 +33,7 @@ class LogitsProcessor(nn.Module):
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
last_logits
=
tensor_model_parallel_all_gather
(
last_logits
)
last_logits
=
tensor_model_parallel_all_gather
(
last_logits
)
last_logits
=
last_logits
[:,
:
self
.
config
.
vocab_size
]
last_logits
=
last_logits
[:,
:
self
.
config
.
vocab_size
]
return
last_logits
,
None
return
last_logits
,
(
None
,
None
)
else
:
else
:
assert
input_metadata
.
forward_mode
!=
ForwardMode
.
DECODE
assert
input_metadata
.
forward_mode
!=
ForwardMode
.
DECODE
last_index
=
(
last_index
=
(
...
@@ -51,30 +51,23 @@ class LogitsProcessor(nn.Module):
...
@@ -51,30 +51,23 @@ class LogitsProcessor(nn.Module):
logits
=
logits
[:,
:
self
.
config
.
vocab_size
]
logits
=
logits
[:,
:
self
.
config
.
vocab_size
]
all_logprobs
=
torch
.
log
(
torch
.
softmax
(
logits
.
float
(),
dim
=-
1
)
+
1e-6
)
all_logprobs
=
torch
.
log
(
torch
.
softmax
(
logits
.
float
(),
dim
=-
1
)
+
1e-6
)
normalized_logprobs
=
compute_normalized_logprobs
(
logprobs
=
all_logprobs
[
all_logprobs
,
torch
.
arange
(
all_logprobs
.
shape
[
0
],
device
=
"cuda"
),
input_ids
,
torch
.
cat
([
input_ids
[
1
:],
torch
.
tensor
([
0
],
device
=
"cuda"
)]),
input_metadata
.
extend_seq_lens
,
]
input_metadata
.
extend_start_loc
,
logprobs_cumsum
=
torch
.
cumsum
(
logprobs
,
dim
=
0
,
dtype
=
torch
.
float32
)
start
=
input_metadata
.
extend_start_loc
.
clone
()
end
=
start
+
input_metadata
.
extend_seq_lens
-
2
start
.
clamp_
(
min
=
0
,
max
=
logprobs
.
shape
[
0
]
-
1
)
end
.
clamp_
(
min
=
0
,
max
=
logprobs
.
shape
[
0
]
-
1
)
sum_logp
=
logprobs_cumsum
[
end
]
-
logprobs_cumsum
[
start
]
+
logprobs
[
start
]
normalized_logprobs
=
sum_logp
/
(
(
input_metadata
.
extend_seq_lens
-
1
).
clamp
(
min
=
1
)
)
)
last_logits
=
logits
[
last_index
]
last_logits
=
logits
[
last_index
]
return
last_logits
,
normalized_logprobs
return
last_logits
,
(
logprobs
,
normalized_logprobs
)
def
compute_normalized_logprobs
(
all_logprobs
,
input_ids
,
seq_lens
,
start_loc
):
logprobs
=
all_logprobs
[
torch
.
arange
(
all_logprobs
.
shape
[
0
],
device
=
"cuda"
),
torch
.
cat
([
input_ids
[
1
:],
torch
.
tensor
([
0
],
device
=
"cuda"
)]),
]
logprobs_cumsum
=
torch
.
cumsum
(
logprobs
,
dim
=
0
,
dtype
=
torch
.
float32
)
start
=
start_loc
.
clone
()
end
=
start
+
seq_lens
-
2
start
.
clamp_
(
min
=
0
,
max
=
logprobs
.
shape
[
0
]
-
1
)
end
.
clamp_
(
min
=
0
,
max
=
logprobs
.
shape
[
0
]
-
1
)
sum_logp
=
logprobs_cumsum
[
end
]
-
logprobs_cumsum
[
start
]
+
logprobs
[
start
]
return
sum_logp
/
((
seq_lens
-
1
).
clamp
(
min
=
1
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
python/sglang/srt/managers/io_struct.py
View file @
9a16fea0
...
@@ -11,8 +11,8 @@ class GenerateReqInput:
...
@@ -11,8 +11,8 @@ class GenerateReqInput:
image_data
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
image_data
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
return_
normalized_
logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
None
return_logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
None
normalized_
logprob_start_len
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
logprob_start_len
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
stream
:
bool
=
False
stream
:
bool
=
False
def
post_init
(
self
):
def
post_init
(
self
):
...
@@ -23,10 +23,10 @@ class GenerateReqInput:
...
@@ -23,10 +23,10 @@ class GenerateReqInput:
self
.
sampling_params
=
{}
self
.
sampling_params
=
{}
if
self
.
rid
is
None
:
if
self
.
rid
is
None
:
self
.
rid
=
uuid
.
uuid4
().
hex
self
.
rid
=
uuid
.
uuid4
().
hex
if
self
.
return_
normalized_
logprob
is
None
:
if
self
.
return_logprob
is
None
:
self
.
return_
normalized_
logprob
=
False
self
.
return_logprob
=
False
if
self
.
normalized_
logprob_start_len
is
None
:
if
self
.
logprob_start_len
is
None
:
self
.
normalized_
logprob_start_len
=
0
self
.
logprob_start_len
=
0
else
:
else
:
num
=
len
(
self
.
text
)
num
=
len
(
self
.
text
)
...
@@ -45,17 +45,15 @@ class GenerateReqInput:
...
@@ -45,17 +45,15 @@ class GenerateReqInput:
else
:
else
:
assert
isinstance
(
self
.
rid
,
list
)
assert
isinstance
(
self
.
rid
,
list
)
if
self
.
return_
normalized_
logprob
is
None
:
if
self
.
return_logprob
is
None
:
self
.
return_
normalized_
logprob
=
[
False
]
*
num
self
.
return_logprob
=
[
False
]
*
num
elif
not
isinstance
(
self
.
return_
normalized_
logprob
,
list
):
elif
not
isinstance
(
self
.
return_logprob
,
list
):
self
.
return_
normalized_
logprob
=
[
self
.
return_
normalized_
logprob
]
*
num
self
.
return_logprob
=
[
self
.
return_logprob
]
*
num
if
self
.
normalized_logprob_start_len
is
None
:
if
self
.
logprob_start_len
is
None
:
self
.
normalized_logprob_start_len
=
[
0
]
*
num
self
.
logprob_start_len
=
[
0
]
*
num
elif
not
isinstance
(
self
.
normalized_logprob_start_len
,
list
):
elif
not
isinstance
(
self
.
logprob_start_len
,
list
):
self
.
normalized_logprob_start_len
=
[
self
.
logprob_start_len
=
[
self
.
logprob_start_len
]
*
num
self
.
normalized_logprob_start_len
]
*
num
@
dataclass
@
dataclass
...
@@ -65,8 +63,8 @@ class TokenizedGenerateReqInput:
...
@@ -65,8 +63,8 @@ class TokenizedGenerateReqInput:
pixel_values
:
List
[
float
]
pixel_values
:
List
[
float
]
image_hash
:
int
image_hash
:
int
sampling_params
:
SamplingParams
sampling_params
:
SamplingParams
return_
normalized_
logprob
:
bool
return_logprob
:
bool
normalized_
logprob_start_len
:
int
logprob_start_len
:
int
stream
:
bool
stream
:
bool
...
...
python/sglang/srt/managers/router/infer_batch.py
View file @
9a16fea0
...
@@ -28,8 +28,8 @@ class Req:
...
@@ -28,8 +28,8 @@ class Req:
self
.
pixel_values
=
None
self
.
pixel_values
=
None
self
.
image_offset
=
0
self
.
image_offset
=
0
self
.
sampling_params
=
None
self
.
sampling_params
=
None
self
.
return_
normalized_
logprob
=
False
self
.
return_logprob
=
False
self
.
normalized_
logprob_start_len
=
0
self
.
logprob_start_len
=
0
self
.
stream
=
False
self
.
stream
=
False
self
.
tokenizer
=
None
self
.
tokenizer
=
None
...
@@ -37,10 +37,11 @@ class Req:
...
@@ -37,10 +37,11 @@ class Req:
self
.
finish_reason
=
None
self
.
finish_reason
=
None
self
.
hit_stop_str
=
None
self
.
hit_stop_str
=
None
self
.
adjust
_input_len
=
0
self
.
extend
_input_len
=
0
self
.
prefix_indices
=
[]
self
.
prefix_indices
=
[]
self
.
last_node
=
None
self
.
last_node
=
None
self
.
logprob
=
None
self
.
normalized_logprob
=
None
self
.
normalized_logprob
=
None
# for constrained decoding
# for constrained decoding
...
@@ -99,7 +100,7 @@ class Batch:
...
@@ -99,7 +100,7 @@ class Batch:
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_cont_start
:
torch
.
Tensor
=
None
out_cache_cont_start
:
torch
.
Tensor
=
None
out_cache_cont_end
:
torch
.
Tensor
=
None
out_cache_cont_end
:
torch
.
Tensor
=
None
return_
normalized_
logprob
:
bool
=
False
return_logprob
:
bool
=
False
# for multimodal
# for multimodal
pixel_values
:
List
[
torch
.
Tensor
]
=
None
pixel_values
:
List
[
torch
.
Tensor
]
=
None
...
@@ -119,14 +120,14 @@ class Batch:
...
@@ -119,14 +120,14 @@ class Batch:
@
classmethod
@
classmethod
def
init_new
(
cls
,
reqs
,
req_to_token_pool
,
token_to_kv_pool
,
tree_cache
):
def
init_new
(
cls
,
reqs
,
req_to_token_pool
,
token_to_kv_pool
,
tree_cache
):
return_
normalized_
logprob
=
any
(
req
.
return_
normalized_
logprob
for
req
in
reqs
)
return_logprob
=
any
(
req
.
return_logprob
for
req
in
reqs
)
return
cls
(
return
cls
(
reqs
=
reqs
,
reqs
=
reqs
,
req_to_token_pool
=
req_to_token_pool
,
req_to_token_pool
=
req_to_token_pool
,
token_to_kv_pool
=
token_to_kv_pool
,
token_to_kv_pool
=
token_to_kv_pool
,
tree_cache
=
tree_cache
,
tree_cache
=
tree_cache
,
return_
normalized_
logprob
=
return_
normalized_
logprob
,
return_logprob
=
return_logprob
,
)
)
def
is_empty
(
self
):
def
is_empty
(
self
):
...
@@ -257,7 +258,7 @@ class Batch:
...
@@ -257,7 +258,7 @@ class Batch:
self
.
tree_cache
.
dec_ref_counter
(
req
.
last_node
)
self
.
tree_cache
.
dec_ref_counter
(
req
.
last_node
)
req
.
prefix_indices
=
None
req
.
prefix_indices
=
None
req
.
last_node
=
None
req
.
last_node
=
None
req
.
adjust
_input_len
=
0
req
.
extend
_input_len
=
0
req
.
output_ids
=
[]
req
.
output_ids
=
[]
# TODO: apply more fine-grained retraction
# TODO: apply more fine-grained retraction
...
@@ -310,9 +311,7 @@ class Batch:
...
@@ -310,9 +311,7 @@ class Batch:
self
.
prefix_lens
=
None
self
.
prefix_lens
=
None
self
.
position_ids_offsets
=
self
.
position_ids_offsets
[
new_indices
]
self
.
position_ids_offsets
=
self
.
position_ids_offsets
[
new_indices
]
self
.
out_cache_loc
=
self
.
out_cache_cont_start
=
self
.
out_cache_cont_end
=
None
self
.
out_cache_loc
=
self
.
out_cache_cont_start
=
self
.
out_cache_cont_end
=
None
self
.
return_normalized_logprob
=
any
(
self
.
return_logprob
=
any
(
req
.
return_logprob
for
req
in
self
.
reqs
)
req
.
return_normalized_logprob
for
req
in
self
.
reqs
)
for
item
in
[
for
item
in
[
"temperatures"
,
"temperatures"
,
...
@@ -336,9 +335,7 @@ class Batch:
...
@@ -336,9 +335,7 @@ class Batch:
[
self
.
position_ids_offsets
,
other
.
position_ids_offsets
]
[
self
.
position_ids_offsets
,
other
.
position_ids_offsets
]
)
)
self
.
out_cache_loc
=
self
.
out_cache_cont_start
=
self
.
out_cache_cont_end
=
None
self
.
out_cache_loc
=
self
.
out_cache_cont_start
=
self
.
out_cache_cont_end
=
None
self
.
return_normalized_logprob
=
any
(
self
.
return_logprob
=
any
(
req
.
return_logprob
for
req
in
self
.
reqs
)
req
.
return_normalized_logprob
for
req
in
self
.
reqs
)
for
item
in
[
for
item
in
[
"temperatures"
,
"temperatures"
,
...
...
python/sglang/srt/managers/router/model_rpc.py
View file @
9a16fea0
...
@@ -214,8 +214,8 @@ class ModelRpcServer(rpyc.Service):
...
@@ -214,8 +214,8 @@ class ModelRpcServer(rpyc.Service):
req
.
input_ids
,
pad_value
req
.
input_ids
,
pad_value
)
)
req
.
sampling_params
=
recv_req
.
sampling_params
req
.
sampling_params
=
recv_req
.
sampling_params
req
.
return_
normalized_
logprob
=
recv_req
.
return_
normalized_
logprob
req
.
return_logprob
=
recv_req
.
return_logprob
req
.
normalized_
logprob_start_len
=
recv_req
.
normalized_
logprob_start_len
req
.
logprob_start_len
=
recv_req
.
logprob_start_len
req
.
stream
=
recv_req
.
stream
req
.
stream
=
recv_req
.
stream
req
.
tokenizer
=
self
.
tokenizer
req
.
tokenizer
=
self
.
tokenizer
...
@@ -240,9 +240,9 @@ class ModelRpcServer(rpyc.Service):
...
@@ -240,9 +240,9 @@ class ModelRpcServer(rpyc.Service):
for
req
in
self
.
forward_queue
:
for
req
in
self
.
forward_queue
:
prefix_indices
,
last_node
=
self
.
tree_cache
.
match_prefix
(
req
.
input_ids
)
prefix_indices
,
last_node
=
self
.
tree_cache
.
match_prefix
(
req
.
input_ids
)
if
req
.
return_
normalized_
logprob
:
if
req
.
return_logprob
:
prefix_indices
=
prefix_indices
[:
req
.
normalized_
logprob_start_len
]
prefix_indices
=
prefix_indices
[:
req
.
logprob_start_len
]
req
.
adjust
_input_len
=
len
(
req
.
input_ids
)
-
len
(
prefix_indices
)
req
.
extend
_input_len
=
len
(
req
.
input_ids
)
-
len
(
prefix_indices
)
req
.
prefix_indices
=
prefix_indices
req
.
prefix_indices
=
prefix_indices
req
.
last_node
=
last_node
req
.
last_node
=
last_node
...
@@ -267,32 +267,32 @@ class ModelRpcServer(rpyc.Service):
...
@@ -267,32 +267,32 @@ class ModelRpcServer(rpyc.Service):
)
)
for
req
in
self
.
forward_queue
:
for
req
in
self
.
forward_queue
:
if
req
.
return_
normalized_
logprob
:
if
req
.
return_logprob
:
# Need at least two tokens to compute normalized logprob
# Need at least two tokens to compute normalized logprob
if
req
.
adjust
_input_len
<
2
:
if
req
.
extend
_input_len
<
2
:
delta
=
2
-
req
.
adjust
_input_len
delta
=
2
-
req
.
extend
_input_len
req
.
adjust
_input_len
+=
delta
req
.
extend
_input_len
+=
delta
req
.
prefix_indices
=
req
.
prefix_indices
[:
-
delta
]
req
.
prefix_indices
=
req
.
prefix_indices
[:
-
delta
]
if
req
.
image_offset
is
not
None
:
if
req
.
image_offset
is
not
None
:
req
.
image_offset
+=
delta
req
.
image_offset
+=
delta
if
req
.
adjust
_input_len
==
0
and
req
.
max_new_tokens
()
>
0
:
if
req
.
extend
_input_len
==
0
and
req
.
max_new_tokens
()
>
0
:
# Need at least one token to compute logits
# Need at least one token to compute logits
req
.
adjust
_input_len
=
1
req
.
extend
_input_len
=
1
req
.
prefix_indices
=
req
.
prefix_indices
[:
-
1
]
req
.
prefix_indices
=
req
.
prefix_indices
[:
-
1
]
if
req
.
image_offset
is
not
None
:
if
req
.
image_offset
is
not
None
:
req
.
image_offset
+=
1
req
.
image_offset
+=
1
if
(
if
(
req
.
adjust
_input_len
+
req
.
max_new_tokens
()
+
new_batch_total_tokens
req
.
extend
_input_len
+
req
.
max_new_tokens
()
+
new_batch_total_tokens
<
available_size
<
available_size
and
req
.
adjust
_input_len
+
new_batch_input_tokens
and
req
.
extend
_input_len
+
new_batch_input_tokens
<
self
.
max_prefill_num_token
<
self
.
max_prefill_num_token
):
):
delta
=
self
.
tree_cache
.
inc_ref_counter
(
req
.
last_node
)
delta
=
self
.
tree_cache
.
inc_ref_counter
(
req
.
last_node
)
available_size
+=
delta
available_size
+=
delta
if
not
(
if
not
(
req
.
adjust
_input_len
+
req
.
max_new_tokens
()
+
new_batch_total_tokens
req
.
extend
_input_len
+
req
.
max_new_tokens
()
+
new_batch_total_tokens
<
available_size
<
available_size
):
):
delta
=
self
.
tree_cache
.
dec_ref_counter
(
req
.
last_node
)
delta
=
self
.
tree_cache
.
dec_ref_counter
(
req
.
last_node
)
...
@@ -301,9 +301,9 @@ class ModelRpcServer(rpyc.Service):
...
@@ -301,9 +301,9 @@ class ModelRpcServer(rpyc.Service):
self
.
token_to_kv_pool
.
add_refs
(
req
.
prefix_indices
)
self
.
token_to_kv_pool
.
add_refs
(
req
.
prefix_indices
)
can_run_list
.
append
(
req
)
can_run_list
.
append
(
req
)
new_batch_total_tokens
+=
(
new_batch_total_tokens
+=
(
req
.
adjust
_input_len
+
req
.
max_new_tokens
()
req
.
extend
_input_len
+
req
.
max_new_tokens
()
)
)
new_batch_input_tokens
+=
req
.
adjust
_input_len
new_batch_input_tokens
+=
req
.
extend
_input_len
if
len
(
can_run_list
)
==
0
:
if
len
(
can_run_list
)
==
0
:
return
None
return
None
...
@@ -339,27 +339,31 @@ class ModelRpcServer(rpyc.Service):
...
@@ -339,27 +339,31 @@ class ModelRpcServer(rpyc.Service):
if
batch
.
extend_num_tokens
!=
0
:
if
batch
.
extend_num_tokens
!=
0
:
# Forward
# Forward
logits
,
normalized_logprobs
=
self
.
model_runner
.
forward
(
logits
,
(
logprobs
,
normalized_logprobs
)
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
,
batch
.
return_
normalized_
logprob
batch
,
ForwardMode
.
EXTEND
,
batch
.
return_logprob
)
)
# print("extend logits", logits)
# print("extend logits", logits)
if
normalized_logprobs
is
not
None
:
if
logprobs
is
not
None
:
logprobs
=
logprobs
.
cpu
().
tolist
()
normalized_logprobs
=
normalized_logprobs
.
cpu
().
tolist
()
normalized_logprobs
=
normalized_logprobs
.
cpu
().
tolist
()
next_token_ids
,
next_token_probs
=
batch
.
sample
(
logits
)
next_token_ids
,
next_token_probs
=
batch
.
sample
(
logits
)
next_token_ids
=
next_token_ids
.
cpu
().
tolist
()
next_token_ids
=
next_token_ids
.
cpu
().
tolist
()
else
:
else
:
next_token_ids
=
[
self
.
tokenizer
.
eos_token_id
]
*
len
(
batch
.
reqs
)
next_token_ids
=
[
self
.
tokenizer
.
eos_token_id
]
*
len
(
batch
.
reqs
)
normalized_logprobs
=
None
logprobs
=
normalized_logprobs
=
None
# Check finish condition
# Check finish condition
reqs
=
batch
.
reqs
reqs
=
batch
.
reqs
for
i
in
range
(
len
(
reqs
)):
pt
=
0
reqs
[
i
].
output_ids
=
[
next_token_ids
[
i
]]
for
i
,
req
in
enumerate
(
reqs
):
reqs
[
i
].
check_finished
()
req
.
output_ids
=
[
next_token_ids
[
i
]]
req
.
check_finished
()
if
normalized_logprobs
is
not
None
:
if
logprobs
is
not
None
:
reqs
[
i
].
normalized_logprob
=
normalized_logprobs
[
i
]
req
.
logprob
=
logprobs
[
pt
:
pt
+
req
.
extend_input_len
-
1
]
req
.
normalized_logprob
=
normalized_logprobs
[
i
]
pt
+=
req
.
extend_input_len
self
.
handle_finished_requests
(
batch
)
self
.
handle_finished_requests
(
batch
)
...
@@ -427,8 +431,9 @@ class ModelRpcServer(rpyc.Service):
...
@@ -427,8 +431,9 @@ class ModelRpcServer(rpyc.Service):
"prompt_tokens"
:
len
(
req
.
input_ids
),
"prompt_tokens"
:
len
(
req
.
input_ids
),
"completion_tokens"
:
len
(
req
.
output_ids
),
"completion_tokens"
:
len
(
req
.
output_ids
),
}
}
if
req
.
return_normalized_logprob
:
if
req
.
return_logprob
:
meta_info
[
"normalized_logprob"
]
=
req
.
normalized_logprob
meta_info
[
"prompt_logprob"
]
=
req
.
logprob
meta_info
[
"normalized_prompt_logprob"
]
=
req
.
normalized_logprob
output_meta_info
.
append
(
meta_info
)
output_meta_info
.
append
(
meta_info
)
output_finished
.
append
(
req
.
finished
)
output_finished
.
append
(
req
.
finished
)
...
...
python/sglang/srt/managers/router/model_runner.py
View file @
9a16fea0
...
@@ -45,7 +45,7 @@ class InputMetadata:
...
@@ -45,7 +45,7 @@ class InputMetadata:
out_cache_cont_end
:
torch
.
Tensor
=
None
out_cache_cont_end
:
torch
.
Tensor
=
None
other_kv_index
:
torch
.
Tensor
=
None
other_kv_index
:
torch
.
Tensor
=
None
return_
normalized_
logprob
:
bool
=
False
return_logprob
:
bool
=
False
# for flashinfer
# for flashinfer
use_flashinfer
:
bool
=
False
use_flashinfer
:
bool
=
False
...
@@ -127,7 +127,7 @@ class InputMetadata:
...
@@ -127,7 +127,7 @@ class InputMetadata:
out_cache_loc
,
out_cache_loc
,
out_cache_cont_start
=
None
,
out_cache_cont_start
=
None
,
out_cache_cont_end
=
None
,
out_cache_cont_end
=
None
,
return_
normalized_
logprob
=
False
,
return_logprob
=
False
,
):
):
batch_size
=
len
(
req_pool_indices
)
batch_size
=
len
(
req_pool_indices
)
start_loc
=
torch
.
zeros
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
start_loc
=
torch
.
zeros
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
...
@@ -175,7 +175,7 @@ class InputMetadata:
...
@@ -175,7 +175,7 @@ class InputMetadata:
out_cache_loc
=
out_cache_loc
,
out_cache_loc
=
out_cache_loc
,
out_cache_cont_start
=
out_cache_cont_start
,
out_cache_cont_start
=
out_cache_cont_start
,
out_cache_cont_end
=
out_cache_cont_end
,
out_cache_cont_end
=
out_cache_cont_end
,
return_
normalized_
logprob
=
return_
normalized_
logprob
,
return_logprob
=
return_logprob
,
other_kv_index
=
other_kv_index
,
other_kv_index
=
other_kv_index
,
)
)
...
@@ -337,7 +337,7 @@ class ModelRunner:
...
@@ -337,7 +337,7 @@ class ModelRunner:
prefix_lens
,
prefix_lens
,
position_ids_offsets
,
position_ids_offsets
,
out_cache_loc
,
out_cache_loc
,
return_
normalized_
logprob
,
return_logprob
,
):
):
input_metadata
=
InputMetadata
.
create
(
input_metadata
=
InputMetadata
.
create
(
self
,
self
,
...
@@ -348,7 +348,7 @@ class ModelRunner:
...
@@ -348,7 +348,7 @@ class ModelRunner:
prefix_lens
=
prefix_lens
,
prefix_lens
=
prefix_lens
,
position_ids_offsets
=
position_ids_offsets
,
position_ids_offsets
=
position_ids_offsets
,
out_cache_loc
=
out_cache_loc
,
out_cache_loc
=
out_cache_loc
,
return_
normalized_
logprob
=
return_
normalized_
logprob
,
return_logprob
=
return_logprob
,
)
)
return
self
.
model
.
forward
(
input_ids
,
input_metadata
.
positions
,
input_metadata
)
return
self
.
model
.
forward
(
input_ids
,
input_metadata
.
positions
,
input_metadata
)
...
@@ -361,7 +361,7 @@ class ModelRunner:
...
@@ -361,7 +361,7 @@ class ModelRunner:
prefix_lens
,
prefix_lens
,
position_ids_offsets
,
position_ids_offsets
,
out_cache_loc
,
out_cache_loc
,
return_
normalized_
logprob
,
return_logprob
,
):
):
input_metadata
=
InputMetadata
.
create
(
input_metadata
=
InputMetadata
.
create
(
self
,
self
,
...
@@ -372,7 +372,7 @@ class ModelRunner:
...
@@ -372,7 +372,7 @@ class ModelRunner:
prefix_lens
=
prefix_lens
,
prefix_lens
=
prefix_lens
,
position_ids_offsets
=
position_ids_offsets
,
position_ids_offsets
=
position_ids_offsets
,
out_cache_loc
=
out_cache_loc
,
out_cache_loc
=
out_cache_loc
,
return_
normalized_
logprob
=
return_
normalized_
logprob
,
return_logprob
=
return_logprob
,
)
)
return
self
.
model
.
forward
(
input_ids
,
input_metadata
.
positions
,
input_metadata
)
return
self
.
model
.
forward
(
input_ids
,
input_metadata
.
positions
,
input_metadata
)
...
@@ -415,7 +415,7 @@ class ModelRunner:
...
@@ -415,7 +415,7 @@ class ModelRunner:
prefix_lens
,
prefix_lens
,
position_ids_offsets
,
position_ids_offsets
,
out_cache_loc
,
out_cache_loc
,
return_
normalized_
logprob
,
return_logprob
,
):
):
input_metadata
=
InputMetadata
.
create
(
input_metadata
=
InputMetadata
.
create
(
self
,
self
,
...
@@ -426,7 +426,7 @@ class ModelRunner:
...
@@ -426,7 +426,7 @@ class ModelRunner:
prefix_lens
=
prefix_lens
,
prefix_lens
=
prefix_lens
,
position_ids_offsets
=
position_ids_offsets
,
position_ids_offsets
=
position_ids_offsets
,
out_cache_loc
=
out_cache_loc
,
out_cache_loc
=
out_cache_loc
,
return_
normalized_
logprob
=
return_
normalized_
logprob
,
return_logprob
=
return_logprob
,
)
)
return
self
.
model
.
forward
(
return
self
.
model
.
forward
(
input_ids
,
input_ids
,
...
@@ -436,9 +436,7 @@ class ModelRunner:
...
@@ -436,9 +436,7 @@ class ModelRunner:
image_offsets
,
image_offsets
,
)
)
def
forward
(
def
forward
(
self
,
batch
:
Batch
,
forward_mode
:
ForwardMode
,
return_logprob
=
False
):
self
,
batch
:
Batch
,
forward_mode
:
ForwardMode
,
return_normalized_logprob
=
False
):
if
self
.
is_multimodal_model
and
forward_mode
==
ForwardMode
.
EXTEND
:
if
self
.
is_multimodal_model
and
forward_mode
==
ForwardMode
.
EXTEND
:
kwargs
=
{
kwargs
=
{
"input_ids"
:
batch
.
input_ids
,
"input_ids"
:
batch
.
input_ids
,
...
@@ -450,7 +448,7 @@ class ModelRunner:
...
@@ -450,7 +448,7 @@ class ModelRunner:
"position_ids_offsets"
:
batch
.
position_ids_offsets
,
"position_ids_offsets"
:
batch
.
position_ids_offsets
,
"out_cache_loc"
:
batch
.
out_cache_loc
,
"out_cache_loc"
:
batch
.
out_cache_loc
,
}
}
kwargs
[
"return_
normalized_
logprob"
]
=
return_
normalized_
logprob
kwargs
[
"return_logprob"
]
=
return_logprob
return
self
.
forward_extend_multi_modal
(
**
kwargs
)
return
self
.
forward_extend_multi_modal
(
**
kwargs
)
else
:
else
:
kwargs
=
{
kwargs
=
{
...
@@ -467,10 +465,10 @@ class ModelRunner:
...
@@ -467,10 +465,10 @@ class ModelRunner:
kwargs
[
"out_cache_cont_end"
]
=
batch
.
out_cache_cont_end
kwargs
[
"out_cache_cont_end"
]
=
batch
.
out_cache_cont_end
return
self
.
forward_decode
(
**
kwargs
)
return
self
.
forward_decode
(
**
kwargs
)
elif
forward_mode
==
ForwardMode
.
EXTEND
:
elif
forward_mode
==
ForwardMode
.
EXTEND
:
kwargs
[
"return_
normalized_
logprob"
]
=
return_
normalized_
logprob
kwargs
[
"return_logprob"
]
=
return_logprob
return
self
.
forward_extend
(
**
kwargs
)
return
self
.
forward_extend
(
**
kwargs
)
elif
forward_mode
==
ForwardMode
.
PREFILL
:
elif
forward_mode
==
ForwardMode
.
PREFILL
:
kwargs
[
"return_
normalized_
logprob"
]
=
return_
normalized_
logprob
kwargs
[
"return_logprob"
]
=
return_logprob
return
self
.
forward_prefill
(
**
kwargs
)
return
self
.
forward_prefill
(
**
kwargs
)
else
:
else
:
raise
ValueError
(
f
"Invaid forward mode:
{
forward_mode
}
"
)
raise
ValueError
(
f
"Invaid forward mode:
{
forward_mode
}
"
)
python/sglang/srt/managers/tokenizer_manager.py
View file @
9a16fea0
...
@@ -132,8 +132,8 @@ class TokenizerManager:
...
@@ -132,8 +132,8 @@ class TokenizerManager:
pixel_values
=
pixel_values
,
pixel_values
=
pixel_values
,
image_hash
=
image_hash
,
image_hash
=
image_hash
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
return_
normalized_
logprob
=
obj
.
return_
normalized_
logprob
,
return_logprob
=
obj
.
return_logprob
,
normalized_
logprob_start_len
=
obj
.
normalized_
logprob_start_len
,
logprob_start_len
=
obj
.
logprob_start_len
,
stream
=
obj
.
stream
,
stream
=
obj
.
stream
,
)
)
self
.
send_to_router
.
send_pyobj
(
tokenized_obj
)
self
.
send_to_router
.
send_pyobj
(
tokenized_obj
)
...
@@ -173,8 +173,8 @@ class TokenizerManager:
...
@@ -173,8 +173,8 @@ class TokenizerManager:
pixel_values
=
pixel_values
,
pixel_values
=
pixel_values
,
image_hash
=
image_hash
,
image_hash
=
image_hash
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
return_
normalized_
logprob
=
obj
.
return_
normalized_
logprob
[
i
],
return_logprob
=
obj
.
return_logprob
[
i
],
normalized_
logprob_start_len
=
obj
.
normalized_
logprob_start_len
[
i
],
logprob_start_len
=
obj
.
logprob_start_len
[
i
],
stream
=
obj
.
stream
,
stream
=
obj
.
stream
,
)
)
self
.
send_to_router
.
send_pyobj
(
tokenized_obj
)
self
.
send_to_router
.
send_pyobj
(
tokenized_obj
)
...
...
test/srt/test_httpserver_decode.py
View file @
9a16fea0
...
@@ -26,6 +26,8 @@ if __name__ == "__main__":
...
@@ -26,6 +26,8 @@ if __name__ == "__main__":
"temperature"
:
0
,
"temperature"
:
0
,
"max_new_tokens"
:
32
,
"max_new_tokens"
:
32
,
},
},
# "return_logprob": True,
# "logprob_start_len": 0,
},
},
)
)
print
(
response
.
json
())
print
(
response
.
json
())
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment