Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3842eba5
Unverified
Commit
3842eba5
authored
Mar 28, 2024
by
Liangsheng Yin
Committed by
GitHub
Mar 28, 2024
Browse files
Logprobs Refractor (#331)
parent
24e59f53
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
385 additions
and
152 deletions
+385
-152
docs/sampling_params.md
docs/sampling_params.md
+5
-1
examples/usage/choices_logprob.py
examples/usage/choices_logprob.py
+7
-6
python/sglang/backend/runtime_endpoint.py
python/sglang/backend/runtime_endpoint.py
+12
-5
python/sglang/lang/interpreter.py
python/sglang/lang/interpreter.py
+9
-5
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+102
-52
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+11
-0
python/sglang/srt/managers/router/infer_batch.py
python/sglang/srt/managers/router/infer_batch.py
+12
-3
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+58
-29
python/sglang/srt/managers/router/model_runner.py
python/sglang/srt/managers/router/model_runner.py
+8
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+2
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+121
-34
test/srt/test_httpserver_decode.py
test/srt/test_httpserver_decode.py
+11
-4
test/srt/test_httpserver_decode_stream.py
test/srt/test_httpserver_decode_stream.py
+15
-11
test/srt/test_openai_server.py
test/srt/test_openai_server.py
+12
-2
No files found.
docs/sampling_params.md
View file @
3842eba5
...
...
@@ -14,10 +14,14 @@ class GenerateReqInput:
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
# The request id
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# Whether return logprobs
of the prompts
# Whether
to
return logprobs
return_logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
None
# The start location of the prompt for return_logprob
logprob_start_len
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
# The number of top logprobs to return
top_logprobs_num
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
# Whether to detokenize tokens in logprobs
return_text_in_logprobs
:
bool
=
False
# Whether to stream output
stream
:
bool
=
False
```
...
...
examples/usage/choices_logprob.py
View file @
3842eba5
...
...
@@ -3,6 +3,7 @@ Usage:
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
python choices_logprob.py
"""
import
sglang
as
sgl
...
...
@@ -19,9 +20,9 @@ def main():
print
(
"questions:"
,
question
)
print
(
"choice:"
,
state
[
"tool"
])
meta_info
=
state
.
get_meta_info
(
"tool"
)
print
(
"logprobs of choice 1"
,
meta_info
[
"pr
ompt
_logprob"
][
0
])
print
(
"logprobs of choice 2"
,
meta_info
[
"pr
ompt
_logprob"
][
1
])
print
(
'-'
*
50
)
print
(
"logprobs of choice 1"
,
meta_info
[
"pr
efill_token
_logprob
s
"
][
0
])
print
(
"logprobs of choice 2"
,
meta_info
[
"pr
efill_token
_logprob
s
"
][
1
])
print
(
"-"
*
50
)
# Run a batch
questions
=
[
...
...
@@ -33,9 +34,9 @@ def main():
print
(
"questions:"
,
question
)
print
(
"choice:"
,
state
[
"tool"
])
meta_info
=
state
.
get_meta_info
(
"tool"
)
print
(
"logprobs of choice 1"
,
meta_info
[
"pr
ompt
_logprob"
][
0
])
print
(
"logprobs of choice 2"
,
meta_info
[
"pr
ompt
_logprob"
][
1
])
print
(
'-'
*
50
)
print
(
"logprobs of choice 1"
,
meta_info
[
"pr
efill_token
_logprob
s
"
][
0
])
print
(
"logprobs of choice 2"
,
meta_info
[
"pr
efill_token
_logprob
s
"
][
1
])
print
(
"-"
*
50
)
if
__name__
==
"__main__"
:
...
...
python/sglang/backend/runtime_endpoint.py
View file @
3842eba5
...
...
@@ -213,6 +213,7 @@ class RuntimeEndpoint(BaseBackend):
"sampling_params"
:
{
"max_new_tokens"
:
0
},
"return_logprob"
:
True
,
"logprob_start_len"
:
max
(
prompt_len
-
2
,
0
),
"return_text_in_logprobs"
:
True
,
}
self
.
_add_images
(
s
,
data
)
res
=
http_request
(
...
...
@@ -224,13 +225,19 @@ class RuntimeEndpoint(BaseBackend):
)
assert
res
.
status_code
==
200
obj
=
res
.
json
()
normalized_prompt_logprob
=
[
normalized_prompt_logprob
s
=
[
r
[
"meta_info"
][
"normalized_prompt_logprob"
]
for
r
in
obj
]
prompt_logprob
=
[
r
[
"meta_info"
][
"prompt_logprob"
]
for
r
in
obj
]
decision
=
choices
[
np
.
argmax
(
normalized_prompt_logprob
)]
return
decision
,
normalized_prompt_logprob
,
prompt_logprob
decision
=
choices
[
np
.
argmax
(
normalized_prompt_logprobs
)]
prefill_token_logprobs
=
[
r
[
"meta_info"
][
"prefill_token_logprobs"
]
for
r
in
obj
]
decode_token_logprobs
=
[
r
[
"meta_info"
][
"decode_token_logprobs"
]
for
r
in
obj
]
return
(
decision
,
normalized_prompt_logprobs
,
prefill_token_logprobs
,
decode_token_logprobs
,
)
def
concatenate_and_append
(
self
,
src_rids
:
List
[
str
],
dst_rid
:
str
):
res
=
http_request
(
...
...
python/sglang/lang/interpreter.py
View file @
3842eba5
...
...
@@ -454,15 +454,19 @@ class StreamExecutor:
self
.
stream_var_event
[
name
].
set
()
def
_execute_select
(
self
,
expr
:
SglSelect
):
decision
,
normalized_prompt_logprob
,
prompt_logprob
=
self
.
backend
.
select
(
self
,
expr
.
choices
,
expr
.
temperature
)
(
decision
,
normalized_prompt_logprobs
,
prefill_token_logprobs
,
decode_token_logprobs
,
)
=
self
.
backend
.
select
(
self
,
expr
.
choices
,
expr
.
temperature
)
if
expr
.
name
is
not
None
:
name
=
expr
.
name
self
.
variables
[
name
]
=
decision
self
.
meta_info
[
name
]
=
{
"normalized_prompt_logprob"
:
normalized_prompt_logprob
,
"prompt_logprob"
:
prompt_logprob
,
"normalized_prompt_logprobs"
:
normalized_prompt_logprobs
,
"prefill_token_logprobs"
:
prefill_token_logprobs
,
"decode_token_logprobs"
:
decode_token_logprobs
,
}
self
.
variable_event
[
name
].
set
()
self
.
text_
+=
decision
...
...
python/sglang/srt/layers/logits_processor.py
View file @
3842eba5
...
...
@@ -13,76 +13,127 @@ class LogitsProcessor(nn.Module):
self
.
config
=
config
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
def
forward
(
self
,
input_ids
,
hidden_states
,
weight
,
input_metadata
):
last_index
=
None
def
_get_normalized_prompt_logprobs
(
self
,
prefill_token_logprobs
,
input_metadata
:
InputMetadata
):
logprobs_cumsum
=
torch
.
cumsum
(
prefill_token_logprobs
,
dim
=
0
,
dtype
=
torch
.
float32
)
start
=
input_metadata
.
extend_start_loc
.
clone
()
end
=
start
+
input_metadata
.
extend_seq_lens
-
2
start
.
clamp_
(
min
=
0
,
max
=
prefill_token_logprobs
.
shape
[
0
]
-
1
)
end
.
clamp_
(
min
=
0
,
max
=
prefill_token_logprobs
.
shape
[
0
]
-
1
)
sum_logp
=
(
logprobs_cumsum
[
end
]
-
logprobs_cumsum
[
start
]
+
prefill_token_logprobs
[
start
]
)
normalized_prompt_logprobs
=
sum_logp
/
(
(
input_metadata
.
extend_seq_lens
-
1
).
clamp
(
min
=
1
)
)
return
normalized_prompt_logprobs
def
_get_top_logprobs
(
self
,
all_logprobs
,
input_metadata
:
InputMetadata
):
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
decode_top_logprobs
=
[]
for
i
in
range
(
all_logprobs
.
shape
[
0
]):
k
=
input_metadata
.
top_logprobs_nums
[
i
]
t
=
all_logprobs
[
i
].
topk
(
k
)
v_cpu
=
t
.
values
.
cpu
().
tolist
()
p_cpu
=
t
.
indices
.
cpu
().
tolist
()
decode_top_logprobs
.
append
(
list
(
zip
(
v_cpu
,
p_cpu
)))
return
None
,
decode_top_logprobs
else
:
prefill_top_logprobs
,
decode_top_logprobs
=
[],
[]
pt
=
0
# NOTE: the GPU-CPU overhead can be reduced
extend_seq_lens_cpu
=
input_metadata
.
extend_seq_lens
for
i
in
range
(
len
(
input_metadata
.
extend_seq_lens
)):
if
extend_seq_lens_cpu
[
i
]
==
0
:
continue
k
=
input_metadata
.
top_logprobs_nums
[
i
]
t
=
all_logprobs
[
pt
:
pt
+
extend_seq_lens_cpu
[
i
]].
topk
(
k
)
vs_cpu
=
t
.
values
.
cpu
().
tolist
()
ps_cpu
=
t
.
indices
.
cpu
().
tolist
()
prefill_top_logprobs
.
append
(
[
list
(
zip
(
vs_cpu
[
j
],
ps_cpu
[
j
]))
for
j
in
range
(
len
(
vs_cpu
)
-
1
)]
)
decode_top_logprobs
.
append
(
list
(
zip
(
vs_cpu
[
-
1
],
ps_cpu
[
-
1
])))
return
prefill_top_logprobs
,
decode_top_logprobs
# Compute the last index (the first decode token) of each requeast
# if we are in prefill or extend mode.
def
forward
(
self
,
input_ids
,
hidden_states
,
weight
,
input_metadata
:
InputMetadata
):
# Get last index for next token prediction, except for DECODE mode.
last_index
=
None
if
input_metadata
.
forward_mode
!=
ForwardMode
.
DECODE
:
last_index
=
(
torch
.
cumsum
(
input_metadata
.
seq_lens
-
input_metadata
.
prefix_lens
,
dim
=
0
,
dtype
=
torch
.
long
,
)
torch
.
cumsum
(
input_metadata
.
extend_seq_lens
,
dim
=
0
,
dtype
=
torch
.
long
)
-
1
)
# Get the last hidden states and last logits
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
last_hidden
=
hidden_states
else
:
last_hidden
=
hidden_states
[
last_index
]
last_logits
=
torch
.
matmul
(
last_hidden
,
weight
.
T
)
if
self
.
tp_size
>
1
:
last_logits
=
tensor_model_parallel_all_gather
(
last_logits
)
last_logits
=
last_logits
[:,
:
self
.
config
.
vocab_size
]
# Return only last_logits if logprob is not requested
if
not
input_metadata
.
return_logprob
:
# When logprob is not requested, only compute the last logits.
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
last_hidden
=
hidden_states
else
:
last_hidden
=
hidden_states
[
last_index
]
hidden_states
=
None
last_logits
=
torch
.
matmul
(
last_hidden
,
weight
.
T
)
if
self
.
tp_size
>
1
:
last_logits
=
tensor_model_parallel_all_gather
(
last_logits
)
last_logits
=
last_logits
[:,
:
self
.
config
.
vocab_size
]
return
last_logits
,
(
None
,
None
,
None
)
hidden_states
=
None
return
last_logits
,
(
None
,
None
,
None
,
None
,
None
)
else
:
# When logprob is requested, compute the logits for all tokens.
logits
=
torch
.
matmul
(
hidden_states
,
weight
.
T
)
if
self
.
tp_size
>
1
:
logits
=
tensor_model_parallel_all_gather
(
logits
)
logits
=
logits
[:,
:
self
.
config
.
vocab_size
]
all_logprobs
=
torch
.
log
(
torch
.
softmax
(
logits
.
float
(),
dim
=-
1
)
+
1e-6
)
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
all_logits
=
last_logits
else
:
all_logits
=
torch
.
matmul
(
hidden_states
,
weight
.
T
)
if
self
.
tp_size
>
1
:
all_logits
=
tensor_model_parallel_all_gather
(
all_logits
)
all_logits
=
all_logits
[:,
:
self
.
config
.
vocab_size
]
all_logprobs
=
torch
.
log
(
torch
.
softmax
(
all_logits
.
float
(),
dim
=-
1
)
+
1e-6
)
prefill_top_logprobs
,
decode_top_logprobs
=
self
.
_get_top_logprobs
(
all_logprobs
,
input_metadata
)
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
last_logits
=
logits
last_logprobs
=
all_logprobs
prefill_logprobs
=
normalized_logprobs
=
None
return
last_logits
,
(
None
,
None
,
decode_top_logprobs
,
None
,
last_logprobs
,
)
else
:
# Compute the logprobs for the last token of each request.
last_logits
=
logits
[
last_index
]
last_logprobs
=
all_logprobs
[
last_index
]
# Compute the logprobs and normalized logprobs for the prefill tokens.
# Note that we pad a zero at the end of each sequence for easy computation.
prefill_logprobs
=
all_logprobs
[
prefill_
token_
logprobs
=
all_logprobs
[
torch
.
arange
(
all_logprobs
.
shape
[
0
],
device
=
"cuda"
),
torch
.
cat
([
input_ids
[
1
:],
torch
.
tensor
([
0
],
device
=
"cuda"
)]),
]
logprobs_cumsum
=
torch
.
cumsum
(
prefill_logprobs
,
dim
=
0
,
dtype
=
torch
.
float32
)
start
=
input_metadata
.
extend_start_loc
.
clone
()
end
=
start
+
input_metadata
.
extend_seq_lens
-
2
start
.
clamp_
(
min
=
0
,
max
=
prefill_logprobs
.
shape
[
0
]
-
1
)
end
.
clamp_
(
min
=
0
,
max
=
prefill_logprobs
.
shape
[
0
]
-
1
)
sum_logp
=
(
logprobs_cumsum
[
end
]
-
logprobs_cumsum
[
start
]
+
prefill_logprobs
[
start
]
normalized_prompt_logprobs
=
self
.
_get_normalized_prompt_logprobs
(
prefill_token_logprobs
,
input_metadata
)
normalized_logprobs
=
sum_logp
/
(
(
input_metadata
.
extend_seq_lens
-
1
).
clamp
(
min
=
1
)
return
last_logits
,
(
prefill_token_logprobs
,
prefill_top_logprobs
,
decode_top_logprobs
,
normalized_prompt_logprobs
,
last_logprobs
,
)
return
last_logits
,
(
prefill_logprobs
,
normalized_logprobs
,
last_logprobs
)
if
__name__
==
"__main__"
:
all_logprobs
=
torch
.
tensor
(
...
...
@@ -93,23 +144,22 @@ if __name__ == "__main__":
)
seq_lens
=
torch
.
tensor
([
2
,
0
,
3
,
0
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
input_ids
=
torch
.
tensor
([
1
,
2
,
3
,
0
,
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
logprobs
=
torch
.
zeros
(
5
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
logprobs
=
all_logprobs
[
token_
logprobs
=
all_logprobs
[
torch
.
arange
(
all_logprobs
.
shape
[
0
],
device
=
"cuda"
),
torch
.
cat
([
input_ids
[
1
:],
torch
.
tensor
([
0
],
device
=
"cuda"
)]),
]
logprobs_cumsum
=
torch
.
cumsum
(
logprobs
,
dim
=
0
,
dtype
=
torch
.
float32
)
logprobs_cumsum
=
torch
.
cumsum
(
token_
logprobs
,
dim
=
0
,
dtype
=
torch
.
float32
)
len_cumsum
=
torch
.
cumsum
(
seq_lens
,
dim
=
0
)
start
=
torch
.
cat
((
torch
.
tensor
([
0
],
device
=
"cuda"
),
len_cumsum
[:
-
1
]),
0
)
end
=
start
+
seq_lens
-
2
start
.
clamp_
(
min
=
0
,
max
=
logprobs
.
shape
[
0
]
-
1
)
end
.
clamp_
(
min
=
0
,
max
=
logprobs
.
shape
[
0
]
-
1
)
sum_logp
=
logprobs_cumsum
[
end
]
-
logprobs_cumsum
[
start
]
+
logprobs
[
start
]
start
.
clamp_
(
min
=
0
,
max
=
token_
logprobs
.
shape
[
0
]
-
1
)
end
.
clamp_
(
min
=
0
,
max
=
token_
logprobs
.
shape
[
0
]
-
1
)
sum_logp
=
logprobs_cumsum
[
end
]
-
logprobs_cumsum
[
start
]
+
token_
logprobs
[
start
]
# assert logprobs == [2, _, 2, 4, _]
print
(
"logprobs"
,
logprobs
)
print
(
"
token
logprobs"
,
token_
logprobs
)
print
(
"start"
,
start
)
print
(
"end"
,
end
)
print
(
"sum_logp"
,
sum_logp
)
python/sglang/srt/managers/io_struct.py
View file @
3842eba5
...
...
@@ -19,10 +19,13 @@ class GenerateReqInput:
return_logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
None
# The start location of the prompt for return_logprob
logprob_start_len
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
# The number of top logprobs to return
top_logprobs_num
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
# Whether to detokenize tokens in logprobs
return_text_in_logprobs
:
bool
=
False
# Whether to stream output
stream
:
bool
=
False
# TODO: make all parameters a Union[List[T], T] to allow for batched requests
def
post_init
(
self
):
is_single
=
isinstance
(
self
.
text
,
str
)
...
...
@@ -36,6 +39,8 @@ class GenerateReqInput:
self
.
return_logprob
=
False
if
self
.
logprob_start_len
is
None
:
self
.
logprob_start_len
=
0
if
self
.
top_logprobs_num
is
None
:
self
.
top_logprobs_num
=
0
else
:
num
=
len
(
self
.
text
)
...
...
@@ -64,6 +69,11 @@ class GenerateReqInput:
elif
not
isinstance
(
self
.
logprob_start_len
,
list
):
self
.
logprob_start_len
=
[
self
.
logprob_start_len
]
*
num
if
self
.
top_logprobs_num
is
None
:
self
.
top_logprobs_num
=
[
0
]
*
num
elif
not
isinstance
(
self
.
top_logprobs_num
,
list
):
self
.
top_logprobs_num
=
[
self
.
top_logprobs_num
]
*
num
@
dataclass
class
TokenizedGenerateReqInput
:
...
...
@@ -76,6 +86,7 @@ class TokenizedGenerateReqInput:
sampling_params
:
SamplingParams
return_logprob
:
bool
logprob_start_len
:
int
top_logprobs_num
:
int
stream
:
bool
...
...
python/sglang/srt/managers/router/infer_batch.py
View file @
3842eba5
...
...
@@ -43,6 +43,7 @@ class Req:
self
.
sampling_params
=
None
self
.
return_logprob
=
False
self
.
logprob_start_len
=
0
self
.
top_logprobs_num
=
0
self
.
stream
=
False
self
.
tokenizer
=
None
...
...
@@ -54,9 +55,11 @@ class Req:
self
.
prefix_indices
=
[]
self
.
last_node
=
None
self
.
logprob
=
None
self
.
token_logprob
=
None
self
.
normalized_logprob
=
None
self
.
prefill_token_logprobs
=
None
self
.
decode_token_logprobs
=
None
self
.
normalized_prompt_logprob
=
None
self
.
prefill_top_logprobs
=
None
self
.
decode_top_logprobs
=
None
# For constrained decoding
self
.
regex_fsm
=
None
...
...
@@ -159,6 +162,9 @@ class Batch:
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_cont_start
:
torch
.
Tensor
=
None
out_cache_cont_end
:
torch
.
Tensor
=
None
# for processing logprobs
top_logprobs_nums
:
List
[
int
]
=
None
return_logprob
:
bool
=
False
# for multimodal
...
...
@@ -266,6 +272,7 @@ class Batch:
self
.
position_ids_offsets
=
position_ids_offsets
self
.
extend_num_tokens
=
extend_num_tokens
self
.
out_cache_loc
=
out_cache_loc
self
.
top_logprobs_nums
=
[
r
.
top_logprobs_num
for
r
in
reqs
]
self
.
temperatures
=
torch
.
tensor
(
[
r
.
sampling_params
.
temperature
for
r
in
reqs
],
...
...
@@ -415,6 +422,7 @@ class Batch:
self
.
prefix_lens
=
None
self
.
position_ids_offsets
=
self
.
position_ids_offsets
[
new_indices
]
self
.
out_cache_loc
=
self
.
out_cache_cont_start
=
self
.
out_cache_cont_end
=
None
self
.
top_logprobs_nums
=
[
self
.
top_logprobs_nums
[
i
]
for
i
in
unfinished_indices
]
self
.
return_logprob
=
any
(
req
.
return_logprob
for
req
in
self
.
reqs
)
for
item
in
[
...
...
@@ -439,6 +447,7 @@ class Batch:
[
self
.
position_ids_offsets
,
other
.
position_ids_offsets
]
)
self
.
out_cache_loc
=
self
.
out_cache_cont_start
=
self
.
out_cache_cont_end
=
None
self
.
top_logprobs_nums
.
extend
(
other
.
top_logprobs_nums
)
self
.
return_logprob
=
any
(
req
.
return_logprob
for
req
in
self
.
reqs
)
for
item
in
[
...
...
python/sglang/srt/managers/router/model_rpc.py
View file @
3842eba5
...
...
@@ -260,6 +260,7 @@ class ModelRpcServer:
req
.
sampling_params
=
recv_req
.
sampling_params
req
.
return_logprob
=
recv_req
.
return_logprob
req
.
logprob_start_len
=
recv_req
.
logprob_start_len
req
.
top_logprobs_num
=
recv_req
.
top_logprobs_num
req
.
stream
=
recv_req
.
stream
req
.
tokenizer
=
self
.
tokenizer
...
...
@@ -400,28 +401,36 @@ class ModelRpcServer:
self
.
model_config
.
vocab_size
,
self
.
int_token_logit_bias
)
logprobs
=
None
prefill_token_
logprobs
=
None
if
batch
.
extend_num_tokens
!=
0
:
# Forward
logits
,
(
prefill_logprobs
,
normalized_logprobs
,
prefill_token_logprobs
,
prefill_top_logprobs
,
decode_top_logprobs
,
normalized_prompt_logprobs
,
last_logprobs
,
)
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
if
prefill_logprobs
is
not
None
:
logprobs
=
prefill_logprobs
.
cpu
().
tolist
()
normalized_logprobs
=
normalized_logprobs
.
cpu
().
tolist
()
if
prefill_
token_
logprobs
is
not
None
:
prefill_token_
logprobs
=
prefill_
token_
logprobs
.
cpu
().
tolist
()
normalized_
prompt_
logprobs
=
normalized_
prompt_
logprobs
.
cpu
().
tolist
()
next_token_ids
,
_
=
batch
.
sample
(
logits
)
next_token_ids
=
next_token_ids
.
cpu
().
tolist
()
else
:
next_token_ids
=
[
self
.
tokenizer
.
eos_token_id
]
*
len
(
batch
.
reqs
)
logits
=
logprobs
=
normalized_logprobs
=
last_logprobs
=
None
(
logits
,
prefill_token_logprobs
,
normalized_prompt_logprobs
,
last_logprobs
,
)
=
(
None
,)
*
4
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
reqs
=
batch
.
reqs
last_token_logprobs
=
None
if
last_logprobs
is
not
None
:
last_logprobs
=
(
last_
token_
logprobs
=
(
last_logprobs
[
torch
.
arange
(
len
(
reqs
)),
next_token_ids
].
cpu
().
tolist
()
)
...
...
@@ -432,18 +441,26 @@ class ModelRpcServer:
req
.
output_ids
=
[
next_token_ids
[
i
]]
req
.
check_finished
()
if
logprobs
is
not
None
:
req
.
logprob
=
logprobs
[
pt
:
pt
+
req
.
extend_input_len
-
1
]
req
.
normalized_logprob
=
normalized_logprobs
[
i
]
# If logprob_start_len > 0, then first logprob_start_len prompt tokens
# will be ignored.
prompt_token_len
=
len
(
req
.
logprob
)
token_ids
=
req
.
input_ids
[
-
prompt_token_len
:]
+
[
next_token_ids
[
i
]]
token_logprobs
=
req
.
logprob
+
[
last_logprobs
[
i
]]
req
.
token_logprob
=
list
(
zip
(
token_ids
,
token_logprobs
))
if
prefill_token_logprobs
is
not
None
:
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
req
.
prefill_token_logprobs
=
list
(
zip
(
prefill_token_logprobs
[
pt
:
pt
+
req
.
extend_input_len
-
1
],
req
.
input_ids
[
-
req
.
extend_input_len
+
1
:],
)
)
if
req
.
logprob_start_len
==
0
:
req
.
token_logprob
=
[(
req
.
input_ids
[
0
],
None
)]
+
req
.
token_logprob
req
.
prefill_token_logprobs
=
[
(
None
,
req
.
input_ids
[
0
])
]
+
req
.
prefill_token_logprobs
req
.
decode_token_logprobs
=
[
(
last_token_logprobs
[
i
],
next_token_ids
[
i
])
]
req
.
prefill_top_logprobs
=
prefill_top_logprobs
[
i
]
if
req
.
logprob_start_len
==
0
:
req
.
prefill_top_logprobs
=
[
None
]
+
req
.
prefill_top_logprobs
req
.
decode_top_logprobs
=
[
decode_top_logprobs
[
i
]]
req
.
normalized_prompt_logprob
=
normalized_prompt_logprobs
[
i
]
pt
+=
req
.
extend_input_len
self
.
handle_finished_requests
(
batch
)
...
...
@@ -493,27 +510,29 @@ class ModelRpcServer:
batch
.
prepare_for_decode
()
# Forward
logits
,
(
_
,
_
,
last_logprobs
)
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
logits
,
(
_
,
_
,
decode_top_logprobs
,
_
,
last_logprobs
)
=
(
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
)
)
next_token_ids
,
_
=
batch
.
sample
(
logits
)
next_token_ids
=
next_token_ids
.
cpu
().
tolist
()
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
reqs
=
batch
.
reqs
new_token_logprobs
=
None
if
last_logprobs
is
not
None
:
last
_logprobs
=
last_logprobs
[
new_token
_logprobs
=
last_logprobs
[
torch
.
arange
(
len
(
reqs
)),
next_token_ids
].
tolist
()
# Check finish condition
for
i
,
(
req
,
next_tok_id
)
in
enumerate
(
zip
(
reqs
,
next_token_ids
)):
for
i
,
(
req
,
next_tok
en
_id
)
in
enumerate
(
zip
(
reqs
,
next_token_ids
)):
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
.
append
(
next_tok_id
)
req
.
output_ids
.
append
(
next_tok
en
_id
)
req
.
check_finished
()
if
last_logprobs
is
not
None
:
req
.
token_logprob
.
append
((
next_tok_id
,
last_logprobs
[
i
]))
if
new_token_logprobs
is
not
None
:
req
.
decode_token_logprobs
.
append
((
new_token_logprobs
[
i
],
next_token_id
))
req
.
decode_top_logprobs
.
append
(
decode_top_logprobs
[
i
])
self
.
handle_finished_requests
(
batch
)
...
...
@@ -558,9 +577,19 @@ class ModelRpcServer:
"completion_tokens_wo_jump_forward"
:
req
.
completion_tokens_wo_jump_forward
,
}
if
req
.
return_logprob
:
meta_info
[
"prompt_logprob"
]
=
req
.
logprob
meta_info
[
"token_logprob"
]
=
req
.
token_logprob
meta_info
[
"normalized_prompt_logprob"
]
=
req
.
normalized_logprob
(
meta_info
[
"prefill_token_logprobs"
],
meta_info
[
"decode_token_logprobs"
],
meta_info
[
"prefill_top_logprobs"
],
meta_info
[
"decode_top_logprobs"
],
meta_info
[
"normalized_prompt_logprob"
],
)
=
(
req
.
prefill_token_logprobs
,
req
.
decode_token_logprobs
,
req
.
prefill_top_logprobs
,
req
.
decode_top_logprobs
,
req
.
normalized_prompt_logprob
,
)
output_meta_info
.
append
(
meta_info
)
output_finished
.
append
(
req
.
finished
)
...
...
python/sglang/srt/managers/router/model_runner.py
View file @
3842eba5
...
...
@@ -5,6 +5,7 @@ import logging
import
pkgutil
from
dataclasses
import
dataclass
from
functools
import
lru_cache
from
typing
import
List
import
numpy
as
np
import
torch
...
...
@@ -81,6 +82,7 @@ class InputMetadata:
out_cache_cont_end
:
torch
.
Tensor
=
None
other_kv_index
:
torch
.
Tensor
=
None
top_logprobs_nums
:
List
[
int
]
=
None
return_logprob
:
bool
=
False
# for flashinfer
...
...
@@ -181,6 +183,7 @@ class InputMetadata:
out_cache_loc
,
out_cache_cont_start
=
None
,
out_cache_cont_end
=
None
,
top_logprobs_nums
=
None
,
return_logprob
=
False
,
):
batch_size
=
len
(
req_pool_indices
)
...
...
@@ -229,6 +232,7 @@ class InputMetadata:
out_cache_loc
=
out_cache_loc
,
out_cache_cont_start
=
out_cache_cont_start
,
out_cache_cont_end
=
out_cache_cont_end
,
top_logprobs_nums
=
top_logprobs_nums
,
return_logprob
=
return_logprob
,
other_kv_index
=
other_kv_index
,
)
...
...
@@ -377,6 +381,7 @@ class ModelRunner:
prefix_lens
=
batch
.
prefix_lens
,
position_ids_offsets
=
batch
.
position_ids_offsets
,
out_cache_loc
=
batch
.
out_cache_loc
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
return_logprob
=
batch
.
return_logprob
,
)
return
self
.
model
.
forward
(
...
...
@@ -394,6 +399,7 @@ class ModelRunner:
prefix_lens
=
batch
.
prefix_lens
,
position_ids_offsets
=
batch
.
position_ids_offsets
,
out_cache_loc
=
batch
.
out_cache_loc
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
return_logprob
=
batch
.
return_logprob
,
)
return
self
.
model
.
forward
(
...
...
@@ -413,6 +419,7 @@ class ModelRunner:
out_cache_loc
=
batch
.
out_cache_loc
,
out_cache_cont_start
=
batch
.
out_cache_cont_start
,
out_cache_cont_end
=
batch
.
out_cache_cont_end
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
return_logprob
=
batch
.
return_logprob
,
)
return
self
.
model
.
forward
(
...
...
@@ -430,6 +437,7 @@ class ModelRunner:
prefix_lens
=
batch
.
prefix_lens
,
position_ids_offsets
=
batch
.
position_ids_offsets
,
out_cache_loc
=
batch
.
out_cache_loc
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
return_logprob
=
batch
.
return_logprob
,
)
return
self
.
model
.
forward
(
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
3842eba5
...
...
@@ -173,6 +173,7 @@ class TokenizerManager:
sampling_params
=
sampling_params
,
return_logprob
=
obj
.
return_logprob
,
logprob_start_len
=
obj
.
logprob_start_len
,
top_logprobs_num
=
obj
.
top_logprobs_num
,
stream
=
obj
.
stream
,
)
self
.
send_to_router
.
send_pyobj
(
tokenized_obj
)
...
...
@@ -215,6 +216,7 @@ class TokenizerManager:
sampling_params
=
sampling_params
,
return_logprob
=
obj
.
return_logprob
[
i
],
logprob_start_len
=
obj
.
logprob_start_len
[
i
],
top_logprobs_num
=
obj
.
top_logprobs_num
[
i
],
stream
=
obj
.
stream
,
)
self
.
send_to_router
.
send_pyobj
(
tokenized_obj
)
...
...
python/sglang/srt/server.py
View file @
3842eba5
...
...
@@ -123,31 +123,97 @@ async def flush_cache():
)
async
def
detokenize_logprob_tokens
(
token_logprobs
):
token_ids
=
[
tid
for
tid
,
_
in
token_logprobs
]
async
def
detokenize_logprob_tokens
(
token_logprobs
,
decode_to_text
):
if
not
decode_to_text
:
return
[(
logprob
,
token_id
,
None
)
for
logprob
,
token_id
in
token_logprobs
]
token_ids
=
[
tid
for
_
,
tid
in
token_logprobs
]
token_texts
=
await
tokenizer_manager
.
detokenize
(
DetokenizeReqInput
(
token_ids
))
return
[(
text
,
logprob
)
for
text
,
(
_
,
logprob
)
in
zip
(
token_texts
,
token_logprobs
)]
return
[
(
logprob
,
token_id
,
token_text
)
for
(
logprob
,
token_id
),
token_text
,
in
zip
(
token_logprobs
,
token_texts
)
]
async
def
detokenize_top_logprobs_tokens
(
top_logprobs
,
decode_to_text
):
for
i
,
t
in
enumerate
(
top_logprobs
):
if
top_logprobs
[
i
]
is
not
None
:
top_logprobs
[
i
]
=
await
detokenize_logprob_tokens
(
t
,
decode_to_text
)
return
top_logprobs
async
def
handle_token_logprobs_results
(
obj
:
GenerateReqInput
,
ret
):
"""Handle the token logprobs results, convert token ids to text if needed.
Args:
obj (GenerateReqInput): The request object.
ret (Union[Dict, List[Dict]]): The response object.
"""
# NOTE: This is because the multiple requests in one http request.
async
def
convert_style
(
r
,
return_text
):
r
[
"meta_info"
][
"prefill_token_logprobs"
]
=
await
detokenize_logprob_tokens
(
r
[
"meta_info"
][
"prefill_token_logprobs"
],
return_text
)
r
[
"meta_info"
][
"decode_token_logprobs"
]
=
await
detokenize_logprob_tokens
(
r
[
"meta_info"
][
"decode_token_logprobs"
],
return_text
)
r
[
"meta_info"
][
"prefill_top_logprobs"
]
=
await
detokenize_top_logprobs_tokens
(
r
[
"meta_info"
][
"prefill_top_logprobs"
],
return_text
)
r
[
"meta_info"
][
"decode_top_logprobs"
]
=
await
detokenize_top_logprobs_tokens
(
r
[
"meta_info"
][
"decode_top_logprobs"
],
return_text
)
if
isinstance
(
obj
.
text
,
str
):
if
obj
.
return_logprob
:
await
convert_style
(
ret
,
obj
.
return_text_in_logprobs
)
else
:
for
i
,
r
in
enumerate
(
ret
):
if
obj
.
return_logprob
[
i
]:
await
convert_style
(
r
,
obj
.
return_text_in_logprobs
)
async
def
stream_generator
(
obj
:
GenerateReqInput
):
async
for
out
in
tokenizer_manager
.
generate_request
(
obj
):
if
obj
.
return_logprob
and
obj
.
return_text_in_logprobs
:
out
[
"meta_info"
][
"token_logprob"
]
=
await
detokenize_logprob_tokens
(
out
[
"meta_info"
][
"token_logprob"
]
)
await
handle_token_logprobs_results
(
obj
,
out
)
yield
out
async
def
make_openai_style_logprobs
(
token_logprobs
):
async
def
make_openai_style_logprobs
(
prefill_token_logprobs
=
None
,
decode_token_logprobs
=
None
,
prefill_top_logprobs
=
None
,
decode_top_logprobs
=
None
,
):
ret_logprobs
=
LogProbs
()
for
token_text
,
token_logprob
in
token_logprobs
:
ret_logprobs
.
tokens
.
append
(
token_text
)
ret_logprobs
.
token_logprobs
.
append
(
token_logprob
)
def
append_token_logprobs
(
token_logprobs
):
for
logprob
,
_
,
token_text
in
token_logprobs
:
ret_logprobs
.
tokens
.
append
(
token_text
)
ret_logprobs
.
token_logprobs
.
append
(
logprob
)
# Not Supported yet
ret_logprobs
.
text_offset
.
append
(
-
1
)
def
append_top_logprobs
(
top_logprobs
):
for
tokens
in
top_logprobs
:
if
tokens
is
not
None
:
ret_logprobs
.
top_logprobs
.
append
(
{
token
[
2
]:
token
[
0
]
for
token
in
tokens
}
)
else
:
ret_logprobs
.
top_logprobs
.
append
(
None
)
if
prefill_token_logprobs
is
not
None
:
append_token_logprobs
(
prefill_token_logprobs
)
if
decode_token_logprobs
is
not
None
:
append_token_logprobs
(
decode_token_logprobs
)
if
prefill_top_logprobs
is
not
None
:
append_top_logprobs
(
prefill_top_logprobs
)
if
decode_top_logprobs
is
not
None
:
append_top_logprobs
(
decode_top_logprobs
)
# Not supported yet.
ret_logprobs
.
top_logprobs
.
append
({})
ret_logprobs
.
text_offset
.
append
(
-
1
)
return
ret_logprobs
...
...
@@ -165,10 +231,7 @@ async def generate_request(obj: GenerateReqInput):
return
StreamingResponse
(
stream_results
(),
media_type
=
"text/event-stream"
)
ret
=
await
tokenizer_manager
.
generate_request
(
obj
).
__anext__
()
if
obj
.
return_logprob
and
obj
.
return_text_in_logprobs
:
ret
[
"meta_info"
][
"token_logprob"
]
=
await
detokenize_logprob_tokens
(
ret
[
"meta_info"
][
"token_logprob"
]
)
await
handle_token_logprobs_results
(
obj
,
ret
)
return
ret
...
...
@@ -192,7 +255,8 @@ async def v1_completions(raw_request: Request):
"frequency_penalty"
:
request
.
frequency_penalty
,
"regex"
:
request
.
regex
,
},
return_logprob
=
request
.
logprobs
is
not
None
,
return_logprob
=
request
.
logprobs
is
not
None
and
request
.
logprobs
>
0
,
top_logprobs_num
=
request
.
logprobs
if
request
.
logprobs
is
not
None
else
0
,
return_text_in_logprobs
=
True
,
stream
=
request
.
stream
,
)
...
...
@@ -212,15 +276,32 @@ async def v1_completions(raw_request: Request):
if
request
.
echo
:
# Prepend prompt in response text.
text
=
request
.
prompt
+
text
if
request
.
logprobs
:
# The first chunk and echo is enabled.
if
not
stream_buffer
and
request
.
echo
:
prefill_token_logprobs
=
content
[
"meta_info"
][
"prefill_token_logprobs"
]
prefill_top_logprobs
=
content
[
"meta_info"
][
"prefill_top_logprobs"
]
else
:
# Skip prompt tokens if echo is disabled.
n_
pre
v_token
=
prompt_tokens
prefill_token_logprobs
=
None
pre
fill_top_logprobs
=
None
if
request
.
logprobs
is
not
None
:
logprobs
=
await
make_openai_style_logprobs
(
content
[
"meta_info"
][
"token_logprob"
][
n_prev_token
:]
prefill_token_logprobs
=
prefill_token_logprobs
,
prefill_top_logprobs
=
prefill_top_logprobs
,
decode_token_logprobs
=
content
[
"meta_info"
][
"decode_token_logprobs"
][
n_prev_token
:],
decode_top_logprobs
=
content
[
"meta_info"
][
"decode_top_logprobs"
][
n_prev_token
:
],
)
n_prev_token
=
len
(
content
[
"meta_info"
][
"token_logprob"
])
n_prev_token
=
len
(
content
[
"meta_info"
][
"decode_token_logprobs"
])
else
:
logprobs
=
None
...
...
@@ -255,20 +336,26 @@ async def v1_completions(raw_request: Request):
prompt_tokens
=
ret
[
"meta_info"
][
"prompt_tokens"
]
completion_tokens
=
ret
[
"meta_info"
][
"completion_tokens"
]
text
=
ret
[
"text"
]
token_logprob_pos
=
prompt_tokens
if
request
.
echo
:
token_logprob_pos
=
0
text
=
request
.
prompt
+
text
else
:
token_logprob_pos
=
prompt_tokens
logprobs
=
(
await
make_openai_style_logprobs
(
ret
[
"meta_info"
][
"token_logprob"
][
token_logprob_pos
:]
if
request
.
logprobs
:
if
request
.
echo
:
prefill_token_logprobs
=
ret
[
"meta_info"
][
"prefill_token_logprobs"
]
prefill_top_logprobs
=
ret
[
"meta_info"
][
"prefill_top_logprobs"
]
else
:
prefill_token_logprobs
=
None
prefill_top_logprobs
=
None
logprobs
=
await
make_openai_style_logprobs
(
prefill_token_logprobs
=
prefill_token_logprobs
,
prefill_top_logprobs
=
prefill_top_logprobs
,
decode_token_logprobs
=
ret
[
"meta_info"
][
"decode_token_logprobs"
],
decode_top_logprobs
=
ret
[
"meta_info"
][
"decode_top_logprobs"
],
)
if
request
.
logprobs
is
not
None
else
None
)
else
:
logprobs
=
None
choice_data
=
CompletionResponseChoice
(
index
=
0
,
text
=
text
,
...
...
test/srt/test_httpserver_decode.py
View file @
3842eba5
...
...
@@ -9,11 +9,12 @@ The capital of France is Paris.\nThe capital of the United States is Washington,
"""
import
argparse
import
json
import
requests
def
test_decode
(
url
,
return_logprob
):
def
test_decode
(
url
,
return_logprob
,
top_logprobs_num
,
return_text
):
response
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
...
...
@@ -23,10 +24,13 @@ def test_decode(url, return_logprob):
"max_new_tokens"
:
32
,
},
"return_logprob"
:
return_logprob
,
"top_logprobs_num"
:
top_logprobs_num
,
"return_text_in_logprobs"
:
return_text
,
"logprob_start_len"
:
0
,
},
)
print
(
response
.
json
())
print
(
json
.
dumps
(
response
.
json
()))
print
(
"="
*
100
)
if
__name__
==
"__main__"
:
...
...
@@ -37,5 +41,8 @@ if __name__ == "__main__":
url
=
f
"
{
args
.
host
}
:
{
args
.
port
}
"
test_decode
(
url
,
False
)
test_decode
(
url
,
True
)
test_decode
(
url
,
False
,
0
,
False
)
test_decode
(
url
,
True
,
0
,
False
)
test_decode
(
url
,
True
,
0
,
True
)
test_decode
(
url
,
True
,
3
,
False
)
test_decode
(
url
,
True
,
3
,
True
)
test/srt/test_httpserver_decode_stream.py
View file @
3842eba5
...
...
@@ -13,7 +13,7 @@ import json
import
requests
def
test_decode_stream
(
url
,
return_logprob
):
def
test_decode_stream
(
url
,
return_logprob
,
top_logprobs_num
):
response
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
...
...
@@ -24,6 +24,8 @@ def test_decode_stream(url, return_logprob):
},
"stream"
:
True
,
"return_logprob"
:
return_logprob
,
"top_logprobs_num"
:
top_logprobs_num
,
"return_text_in_logprobs"
:
True
,
},
stream
=
True
,
)
...
...
@@ -37,19 +39,20 @@ def test_decode_stream(url, return_logprob):
data
=
json
.
loads
(
chunk
[
5
:].
strip
(
"
\n
"
))
if
return_logprob
:
assert
data
[
"meta_info"
][
"pr
ompt
_logprob"
]
is
not
None
assert
data
[
"meta_info"
][
"token_logprob"
]
is
not
None
assert
data
[
"meta_info"
][
"pr
efill_token
_logprob
s
"
]
is
not
None
assert
data
[
"meta_info"
][
"
decode_
token_logprob
s
"
]
is
not
None
assert
data
[
"meta_info"
][
"normalized_prompt_logprob"
]
is
not
None
if
prev
==
0
:
# Skip prompt logprobs
prev
=
data
[
"meta_info"
][
"prompt_token
s"
]
for
token_txt
,
_
,
logprob
in
data
[
"meta_info"
][
"token_logprob"
][
prev
:]:
print
(
f
"
{
token_txt
}
\t
{
logprob
}
"
,
flush
=
True
)
prev
=
len
(
data
[
"meta_info"
][
"token_logprob"
])
for
logprob
,
token_id
,
token_text
in
data
[
"meta_info"
][
"decode_token_logprob
s"
][
prev
:]:
print
(
f
"
{
token_t
e
xt
:
12
s
}
\t
{
logprob
}
\t
{
token_id
}
"
,
flush
=
True
)
prev
=
len
(
data
[
"meta_info"
][
"
decode_
token_logprob
s
"
])
else
:
output
=
data
[
"text"
].
strip
()
print
(
output
[
prev
:],
end
=
""
,
flush
=
True
)
prev
=
len
(
output
)
print
(
""
)
print
(
"="
*
100
)
if
__name__
==
"__main__"
:
...
...
@@ -60,5 +63,6 @@ if __name__ == "__main__":
url
=
f
"
{
args
.
host
}
:
{
args
.
port
}
"
test_decode_stream
(
url
,
False
)
test_decode_stream
(
url
,
True
)
test_decode_stream
(
url
,
False
,
0
)
test_decode_stream
(
url
,
True
,
0
)
test_decode_stream
(
url
,
True
,
3
)
test/srt/test_openai_server.py
View file @
3842eba5
...
...
@@ -34,6 +34,7 @@ def test_completion(args, echo, logprobs):
if
echo
:
assert
text
.
startswith
(
"The capital of France is"
)
if
logprobs
:
print
(
response
.
choices
[
0
].
logprobs
.
top_logprobs
)
assert
response
.
choices
[
0
].
logprobs
if
echo
:
assert
response
.
choices
[
0
].
logprobs
.
token_logprobs
[
0
]
==
None
...
...
@@ -44,6 +45,7 @@ def test_completion(args, echo, logprobs):
assert
response
.
usage
.
prompt_tokens
>
0
assert
response
.
usage
.
completion_tokens
>
0
assert
response
.
usage
.
total_tokens
>
0
print
(
"="
*
100
)
def
test_completion_stream
(
args
,
echo
,
logprobs
):
...
...
@@ -68,13 +70,14 @@ def test_completion_stream(args, echo, logprobs):
f
"
{
r
.
choices
[
0
].
text
:
12
s
}
\t
"
f
"
{
r
.
choices
[
0
].
logprobs
.
token_logprobs
}
"
,
flush
=
True
,
)
print
(
r
.
choices
[
0
].
logprobs
.
top_logprobs
)
else
:
print
(
r
.
choices
[
0
].
text
,
end
=
""
,
flush
=
True
)
assert
r
.
id
assert
r
.
usage
.
prompt_tokens
>
0
assert
r
.
usage
.
completion_tokens
>
0
assert
r
.
usage
.
total_tokens
>
0
print
()
print
(
"="
*
100
)
def
test_chat_completion
(
args
):
...
...
@@ -94,6 +97,7 @@ def test_chat_completion(args):
assert
response
.
usage
.
prompt_tokens
>
0
assert
response
.
usage
.
completion_tokens
>
0
assert
response
.
usage
.
total_tokens
>
0
print
(
"="
*
100
)
def
test_chat_completion_image
(
args
):
...
...
@@ -124,6 +128,7 @@ def test_chat_completion_image(args):
assert
response
.
usage
.
prompt_tokens
>
0
assert
response
.
usage
.
completion_tokens
>
0
assert
response
.
usage
.
total_tokens
>
0
print
(
"="
*
100
)
def
test_chat_completion_stream
(
args
):
...
...
@@ -149,7 +154,7 @@ def test_chat_completion_stream(args):
if
not
data
.
content
:
continue
print
(
data
.
content
,
end
=
""
,
flush
=
True
)
print
()
print
(
"="
*
100
)
def
test_regex
(
args
):
...
...
@@ -174,6 +179,7 @@ def test_regex(args):
)
text
=
response
.
choices
[
0
].
message
.
content
print
(
json
.
loads
(
text
))
print
(
"="
*
100
)
if
__name__
==
"__main__"
:
...
...
@@ -188,10 +194,14 @@ if __name__ == "__main__":
test_completion
(
args
,
echo
=
True
,
logprobs
=
False
)
test_completion
(
args
,
echo
=
False
,
logprobs
=
True
)
test_completion
(
args
,
echo
=
True
,
logprobs
=
True
)
test_completion
(
args
,
echo
=
False
,
logprobs
=
3
)
test_completion
(
args
,
echo
=
True
,
logprobs
=
3
)
test_completion_stream
(
args
,
echo
=
False
,
logprobs
=
False
)
test_completion_stream
(
args
,
echo
=
True
,
logprobs
=
False
)
test_completion_stream
(
args
,
echo
=
False
,
logprobs
=
True
)
test_completion_stream
(
args
,
echo
=
True
,
logprobs
=
True
)
test_completion_stream
(
args
,
echo
=
False
,
logprobs
=
3
)
test_completion_stream
(
args
,
echo
=
True
,
logprobs
=
3
)
test_chat_completion
(
args
)
test_chat_completion_stream
(
args
)
test_regex
(
args
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment