Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a7334aee
"mmdet3d/vscode:/vscode.git/clone" did not exist on "3c57cc4117ead37bfc571c25fbadb52410deac1b"
Unverified
Commit
a7334aee
authored
Feb 06, 2024
by
Cody Yu
Committed by
GitHub
Feb 06, 2024
Browse files
Support decode token logprobs (#130)
parent
ee1df26a
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
228 additions
and
91 deletions
+228
-91
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+42
-34
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+4
-0
python/sglang/srt/managers/router/infer_batch.py
python/sglang/srt/managers/router/infer_batch.py
+1
-0
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+34
-14
python/sglang/srt/managers/router/model_runner.py
python/sglang/srt/managers/router/model_runner.py
+5
-6
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+5
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+59
-8
test/srt/test_httpserver_decode.py
test/srt/test_httpserver_decode.py
+14
-11
test/srt/test_httpserver_decode_stream.py
test/srt/test_httpserver_decode_stream.py
+28
-13
test/srt/test_openai_server.py
test/srt/test_openai_server.py
+36
-5
No files found.
python/sglang/srt/layers/logits_processor.py
View file @
a7334aee
...
...
@@ -14,10 +14,11 @@ class LogitsProcessor(nn.Module):
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
def
forward
(
self
,
input_ids
,
hidden_states
,
weight
,
input_metadata
):
if
not
input_metadata
.
return_logprob
:
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
last_hidden
=
hidden_states
else
:
last_index
=
None
# Compute the last index (the first decode token) of each requeast
# if we are in prefill or extend mode.
if
input_metadata
.
forward_mode
!=
ForwardMode
.
DECODE
:
last_index
=
(
torch
.
cumsum
(
input_metadata
.
seq_lens
-
input_metadata
.
prefix_lens
,
...
...
@@ -26,6 +27,12 @@ class LogitsProcessor(nn.Module):
)
-
1
)
if
not
input_metadata
.
return_logprob
:
# When logprob is not requested, only compute the last logits.
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
last_hidden
=
hidden_states
else
:
last_hidden
=
hidden_states
[
last_index
]
hidden_states
=
None
...
...
@@ -33,41 +40,42 @@ class LogitsProcessor(nn.Module):
if
self
.
tp_size
>
1
:
last_logits
=
tensor_model_parallel_all_gather
(
last_logits
)
last_logits
=
last_logits
[:,
:
self
.
config
.
vocab_size
]
return
last_logits
,
(
None
,
None
)
return
last_logits
,
(
None
,
None
,
None
)
else
:
assert
input_metadata
.
forward_mode
!=
ForwardMode
.
DECODE
last_index
=
(
torch
.
cumsum
(
input_metadata
.
seq_lens
-
input_metadata
.
prefix_lens
,
dim
=
0
,
dtype
=
torch
.
long
,
)
-
1
)
# When logprob is requested, compute the logits for all tokens.
logits
=
torch
.
matmul
(
hidden_states
,
weight
.
T
)
if
self
.
tp_size
>
1
:
logits
=
tensor_model_parallel_all_gather
(
logits
)
logits
=
logits
[:,
:
self
.
config
.
vocab_size
]
all_logprobs
=
torch
.
log
(
torch
.
softmax
(
logits
.
float
(),
dim
=-
1
)
+
1e-6
)
logprobs
=
all_logprobs
[
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
last_logits
=
logits
last_logprobs
=
all_logprobs
prefill_logprobs
=
normalized_logprobs
=
None
else
:
# Compute the logprobs for the last token of each request.
last_logits
=
logits
[
last_index
]
last_logprobs
=
all_logprobs
[
last_index
]
# Compute the logprobs and normalized logprobs for the prefill tokens.
# Note that we pad a zero at the end of each sequence for easy computation.
prefill_logprobs
=
all_logprobs
[
torch
.
arange
(
all_logprobs
.
shape
[
0
],
device
=
"cuda"
),
torch
.
cat
([
input_ids
[
1
:],
torch
.
tensor
([
0
],
device
=
"cuda"
)]),
]
logprobs_cumsum
=
torch
.
cumsum
(
logprobs
,
dim
=
0
,
dtype
=
torch
.
float32
)
logprobs_cumsum
=
torch
.
cumsum
(
prefill_
logprobs
,
dim
=
0
,
dtype
=
torch
.
float32
)
start
=
input_metadata
.
extend_start_loc
.
clone
()
end
=
start
+
input_metadata
.
extend_seq_lens
-
2
start
.
clamp_
(
min
=
0
,
max
=
logprobs
.
shape
[
0
]
-
1
)
end
.
clamp_
(
min
=
0
,
max
=
logprobs
.
shape
[
0
]
-
1
)
sum_logp
=
logprobs_cumsum
[
end
]
-
logprobs_cumsum
[
start
]
+
logprobs
[
start
]
start
.
clamp_
(
min
=
0
,
max
=
prefill_
logprobs
.
shape
[
0
]
-
1
)
end
.
clamp_
(
min
=
0
,
max
=
prefill_
logprobs
.
shape
[
0
]
-
1
)
sum_logp
=
logprobs_cumsum
[
end
]
-
logprobs_cumsum
[
start
]
+
prefill_
logprobs
[
start
]
normalized_logprobs
=
sum_logp
/
(
(
input_metadata
.
extend_seq_lens
-
1
).
clamp
(
min
=
1
)
)
last_logits
=
logits
[
last_index
]
return
last_logits
,
(
logprobs
,
normalized_logprobs
)
return
last_logits
,
(
prefill_logprobs
,
normalized_logprobs
,
last_logprobs
)
if
__name__
==
"__main__"
:
...
...
python/sglang/srt/managers/io_struct.py
View file @
a7334aee
...
...
@@ -99,3 +99,7 @@ class BatchStrOut:
@
dataclass
class
FlushCacheReq
:
pass
@
dataclass
class
DetokenizeReqInput
:
input_ids
:
List
[
int
]
python/sglang/srt/managers/router/infer_batch.py
View file @
a7334aee
...
...
@@ -48,6 +48,7 @@ class Req:
self
.
last_node
=
None
self
.
logprob
=
None
self
.
token_logprob
=
None
self
.
normalized_logprob
=
None
# For constrained decoding
...
...
python/sglang/srt/managers/router/model_rpc.py
View file @
a7334aee
...
...
@@ -388,24 +388,28 @@ class ModelRpcServer(rpyc.Service):
self
.
model_config
.
vocab_size
,
self
.
int_token_logit_bias
)
logprobs
=
None
if
batch
.
extend_num_tokens
!=
0
:
# Forward
logits
,
(
logprobs
,
normalized_logprobs
)
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
,
batch
.
return_logprob
logits
,
(
prefill_
logprobs
,
normalized_logprobs
,
last_logprobs
)
=
(
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
,
batch
.
return_logprob
)
)
# print("extend logits", logits)
if
logprobs
is
not
None
:
logprobs
=
logprobs
.
cpu
().
tolist
()
if
prefill_logprobs
is
not
None
:
logprobs
=
prefill_logprobs
.
cpu
().
tolist
()
normalized_logprobs
=
normalized_logprobs
.
cpu
().
tolist
()
next_token_ids
,
next_token_probs
=
batch
.
sample
(
logits
)
next_token_ids
,
_
=
batch
.
sample
(
logits
)
next_token_ids
=
next_token_ids
.
cpu
().
tolist
()
else
:
next_token_ids
=
[
self
.
tokenizer
.
eos_token_id
]
*
len
(
batch
.
reqs
)
logprobs
=
normalized_logprobs
=
None
logits
=
logprobs
=
normalized_logprobs
=
last_logprobs
=
None
#
Check finish condition
#
Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
reqs
=
batch
.
reqs
if
last_logprobs
is
not
None
:
last_logprobs
=
last_logprobs
[
torch
.
arange
(
len
(
reqs
)),
next_token_ids
].
cpu
().
tolist
()
# Check finish condition
pt
=
0
for
i
,
req
in
enumerate
(
reqs
):
req
.
output_ids
=
[
next_token_ids
[
i
]]
...
...
@@ -414,6 +418,10 @@ class ModelRpcServer(rpyc.Service):
if
logprobs
is
not
None
:
req
.
logprob
=
logprobs
[
pt
:
pt
+
req
.
extend_input_len
-
1
]
req
.
normalized_logprob
=
normalized_logprobs
[
i
]
token_ids
=
req
.
input_ids
+
[
next_token_ids
[
i
]]
token_logprobs
=
[
None
]
+
req
.
logprob
+
[
last_logprobs
[
i
]]
req
.
token_logprob
=
list
(
zip
(
token_ids
,
token_logprobs
))
pt
+=
req
.
extend_input_len
self
.
handle_finished_requests
(
batch
)
...
...
@@ -463,15 +471,26 @@ class ModelRpcServer(rpyc.Service):
batch
.
prepare_for_decode
()
# Forward
logits
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
)
next_token_ids
,
next_token_probs
=
batch
.
sample
(
logits
)
logits
,
(
_
,
_
,
last_logprobs
)
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
,
batch
.
return_logprob
,
)
next_token_ids
,
_
=
batch
.
sample
(
logits
)
next_token_ids
=
next_token_ids
.
cpu
().
tolist
()
#
Check finish condition
#
Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
reqs
=
batch
.
reqs
for
i
in
range
(
len
(
reqs
)):
reqs
[
i
].
output_ids
.
append
(
next_token_ids
[
i
])
reqs
[
i
].
check_finished
()
if
last_logprobs
is
not
None
:
last_logprobs
=
last_logprobs
[
torch
.
arange
(
len
(
reqs
)),
next_token_ids
].
tolist
()
# Check finish condition
for
i
,
(
req
,
next_tok_id
)
in
enumerate
(
zip
(
reqs
,
next_token_ids
)):
req
.
output_ids
.
append
(
next_tok_id
)
req
.
check_finished
()
if
last_logprobs
is
not
None
:
req
.
token_logprob
.
append
((
next_tok_id
,
last_logprobs
[
i
]))
self
.
handle_finished_requests
(
batch
)
...
...
@@ -513,6 +532,7 @@ class ModelRpcServer(rpyc.Service):
}
if
req
.
return_logprob
:
meta_info
[
"prompt_logprob"
]
=
req
.
logprob
meta_info
[
"token_logprob"
]
=
req
.
token_logprob
meta_info
[
"normalized_prompt_logprob"
]
=
req
.
normalized_logprob
output_meta_info
.
append
(
meta_info
)
output_finished
.
append
(
req
.
finished
)
...
...
python/sglang/srt/managers/router/model_runner.py
View file @
a7334aee
...
...
@@ -397,6 +397,7 @@ class ModelRunner:
out_cache_loc
,
out_cache_cont_start
,
out_cache_cont_end
,
return_logprob
,
):
input_metadata
=
InputMetadata
.
create
(
self
,
...
...
@@ -409,10 +410,9 @@ class ModelRunner:
out_cache_loc
=
out_cache_loc
,
out_cache_cont_start
=
out_cache_cont_start
,
out_cache_cont_end
=
out_cache_cont_end
,
return_logprob
=
return_logprob
,
)
return
self
.
model
.
forward
(
input_ids
,
input_metadata
.
positions
,
input_metadata
)[
0
]
return
self
.
model
.
forward
(
input_ids
,
input_metadata
.
positions
,
input_metadata
)
@
torch
.
inference_mode
()
def
forward_extend_multi_modal
(
...
...
@@ -460,8 +460,8 @@ class ModelRunner:
"prefix_lens"
:
batch
.
prefix_lens
,
"position_ids_offsets"
:
batch
.
position_ids_offsets
,
"out_cache_loc"
:
batch
.
out_cache_loc
,
"return_logprob"
:
return_logprob
,
}
kwargs
[
"return_logprob"
]
=
return_logprob
return
self
.
forward_extend_multi_modal
(
**
kwargs
)
else
:
kwargs
=
{
...
...
@@ -471,6 +471,7 @@ class ModelRunner:
"prefix_lens"
:
batch
.
prefix_lens
,
"position_ids_offsets"
:
batch
.
position_ids_offsets
,
"out_cache_loc"
:
batch
.
out_cache_loc
,
"return_logprob"
:
return_logprob
,
}
if
forward_mode
==
ForwardMode
.
DECODE
:
...
...
@@ -478,10 +479,8 @@ class ModelRunner:
kwargs
[
"out_cache_cont_end"
]
=
batch
.
out_cache_cont_end
return
self
.
forward_decode
(
**
kwargs
)
elif
forward_mode
==
ForwardMode
.
EXTEND
:
kwargs
[
"return_logprob"
]
=
return_logprob
return
self
.
forward_extend
(
**
kwargs
)
elif
forward_mode
==
ForwardMode
.
PREFILL
:
kwargs
[
"return_logprob"
]
=
return_logprob
return
self
.
forward_prefill
(
**
kwargs
)
else
:
raise
ValueError
(
f
"Invaid forward mode:
{
forward_mode
}
"
)
python/sglang/srt/managers/tokenizer_manager.py
View file @
a7334aee
...
...
@@ -18,6 +18,7 @@ from sglang.srt.hf_transformers_utils import (
)
from
sglang.srt.managers.io_struct
import
(
BatchStrOut
,
DetokenizeReqInput
,
FlushCacheReq
,
GenerateReqInput
,
TokenizedGenerateReqInput
,
...
...
@@ -234,6 +235,10 @@ class TokenizerManager:
yield
output_list
async
def
detokenize
(
self
,
obj
:
DetokenizeReqInput
):
token_texts
=
self
.
tokenizer
.
convert_ids_to_tokens
(
obj
.
input_ids
)
return
[
t
.
decode
()
if
isinstance
(
t
,
bytes
)
else
t
for
t
in
token_texts
]
async
def
flush_cache
(
self
):
flush_cache_req
=
FlushCacheReq
()
self
.
send_to_router
.
send_pyobj
(
flush_cache_req
)
...
...
python/sglang/srt/server.py
View file @
a7334aee
...
...
@@ -30,7 +30,7 @@ from sglang.srt.conversation import (
)
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.managers.detokenizer_manager
import
start_detokenizer_process
from
sglang.srt.managers.io_struct
import
GenerateReqInput
from
sglang.srt.managers.io_struct
import
DetokenizeReqInput
,
GenerateReqInput
from
sglang.srt.managers.openai_protocol
import
(
ChatCompletionRequest
,
ChatCompletionResponse
,
...
...
@@ -44,6 +44,7 @@ from sglang.srt.managers.openai_protocol import (
CompletionResponseStreamChoice
,
CompletionStreamResponse
,
DeltaMessage
,
LogProbs
,
UsageInfo
,
)
from
sglang.srt.managers.router.manager
import
start_router_process
...
...
@@ -97,6 +98,23 @@ async def stream_generator(obj):
yield
out
async
def
make_openai_style_logprobs
(
token_logprobs
):
ret_logprobs
=
LogProbs
()
# Detokenize
token_ids
=
[
tid
for
tid
,
_
in
token_logprobs
]
token_texts
=
await
tokenizer_manager
.
detokenize
(
DetokenizeReqInput
(
token_ids
))
for
token_text
,
(
_
,
token_logprob
)
in
zip
(
token_texts
,
token_logprobs
):
ret_logprobs
.
tokens
.
append
(
token_text
)
ret_logprobs
.
token_logprobs
.
append
(
token_logprob
)
# Not supported yet.
ret_logprobs
.
top_logprobs
.
append
({})
ret_logprobs
.
text_offset
.
append
(
-
1
)
return
ret_logprobs
@
app
.
post
(
"/generate"
)
async
def
generate_request
(
obj
:
GenerateReqInput
):
obj
.
post_init
()
...
...
@@ -132,6 +150,7 @@ async def v1_completions(raw_request: Request):
"presence_penalty"
:
request
.
presence_penalty
,
"frequency_penalty"
:
request
.
frequency_penalty
,
},
return_logprob
=
request
.
logprobs
is
not
None
,
stream
=
request
.
stream
,
)
adapted_request
.
post_init
()
...
...
@@ -140,17 +159,34 @@ async def v1_completions(raw_request: Request):
async
def
gnerate_stream_resp
():
stream_buffer
=
""
n_prev_token
=
0
async
for
content
in
stream_generator
(
adapted_request
):
text
=
content
[
"text"
]
prompt_tokens
=
content
[
"meta_info"
][
"prompt_tokens"
]
completion_tokens
=
content
[
"meta_info"
][
"completion_tokens"
]
if
not
stream_buffer
:
# The first chunk
if
request
.
echo
:
# Prepend prompt in response text.
text
=
request
.
prompt
+
text
else
:
# Skip prompt tokens if echo is disabled.
n_prev_token
=
prompt_tokens
if
request
.
logprobs
is
not
None
:
logprobs
=
await
make_openai_style_logprobs
(
content
[
"meta_info"
][
"token_logprob"
][
n_prev_token
:]
)
n_prev_token
=
len
(
content
[
"meta_info"
][
"token_logprob"
])
else
:
logprobs
=
None
delta
=
text
[
len
(
stream_buffer
)
:]
stream_buffer
=
text
stream_buffer
=
content
[
"
text
"
]
choice_data
=
CompletionResponseStreamChoice
(
index
=
0
,
text
=
delta
,
logprobs
=
None
,
logprobs
=
logprobs
,
finish_reason
=
None
,
)
chunk
=
CompletionStreamResponse
(
...
...
@@ -172,15 +208,28 @@ async def v1_completions(raw_request: Request):
# Non-streaming response.
ret
=
await
generate_request
(
adapted_request
)
prompt_tokens
=
ret
[
"meta_info"
][
"prompt_tokens"
]
completion_tokens
=
ret
[
"meta_info"
][
"completion_tokens"
]
text
=
ret
[
"text"
]
token_logprob_pos
=
prompt_tokens
if
request
.
echo
:
token_logprob_pos
=
0
text
=
request
.
prompt
+
text
else
:
token_logprob_pos
=
prompt_tokens
logprobs
=
(
await
make_openai_style_logprobs
(
ret
[
"meta_info"
][
"token_logprob"
][
token_logprob_pos
:])
if
request
.
logprobs
is
not
None
else
None
)
choice_data
=
CompletionResponseChoice
(
index
=
0
,
text
=
ret
[
"
text
"
]
,
logprobs
=
None
,
text
=
text
,
logprobs
=
logprobs
,
finish_reason
=
None
,
# TODO(comaniac): Add finish reason.
)
prompt_tokens
=
ret
[
"meta_info"
][
"prompt_tokens"
]
completion_tokens
=
ret
[
"meta_info"
][
"completion_tokens"
]
response
=
CompletionResponse
(
id
=
ret
[
"meta_info"
][
"id"
],
model
=
request
.
model
,
...
...
@@ -216,7 +265,9 @@ async def v1_chat_completions(raw_request: Request):
if
not
isinstance
(
m
.
content
,
str
):
raise
HTTPException
(
status_code
=
503
,
detail
=
"Structured content requests not supported with HuggingFace Chat Templates. Make sure the server specifies a sglang chat template."
,
detail
=
"Structured content requests not supported with "
"HuggingFace Chat Templates. "
"Make sure the server specifies a sglang chat template."
,
)
prompt
=
tokenizer_manager
.
tokenizer
.
apply_chat_template
(
request
.
messages
,
tokenize
=
False
,
add_generation_prompt
=
True
...
...
test/srt/test_httpserver_decode.py
View file @
a7334aee
...
...
@@ -9,18 +9,10 @@ The capital of France is Paris.\nThe capital of the United States is Washington,
"""
import
argparse
import
time
import
requests
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"http://127.0.0.1"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
30000
)
args
=
parser
.
parse_args
()
url
=
f
"
{
args
.
host
}
:
{
args
.
port
}
"
def
test_decode
(
url
,
return_logprob
):
response
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
...
...
@@ -29,8 +21,19 @@ if __name__ == "__main__":
"temperature"
:
0
,
"max_new_tokens"
:
32
,
},
#
"return_logprob":
True
,
#
"logprob_start_len": 0,
"return_logprob"
:
return_logprob
,
"logprob_start_len"
:
0
,
},
)
print
(
response
.
json
())
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"http://127.0.0.1"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
30000
)
args
=
parser
.
parse_args
()
url
=
f
"
{
args
.
host
}
:
{
args
.
port
}
"
test_decode
(
url
,
False
)
test_decode
(
url
,
True
)
test/srt/test_httpserver_decode_stream.py
View file @
a7334aee
...
...
@@ -9,27 +9,20 @@ The capital of France is Paris.\nThe capital of the United States is Washington,
import
argparse
import
json
import
time
import
requests
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"http://127.0.0.1"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
30000
)
args
=
parser
.
parse_args
()
url
=
f
"
{
args
.
host
}
:
{
args
.
port
}
"
def
test_decode_stream
(
url
,
return_logprob
):
response
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
"text"
:
"The capital of France is"
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
5
12
,
"max_new_tokens"
:
12
8
,
},
"stream"
:
True
,
"return_logprob"
:
return_logprob
,
},
stream
=
True
,
)
...
...
@@ -41,7 +34,29 @@ if __name__ == "__main__":
if
chunk
==
"data: [DONE]"
:
break
data
=
json
.
loads
(
chunk
[
5
:].
strip
(
"
\n
"
))
if
return_logprob
:
assert
data
[
"meta_info"
][
"prompt_logprob"
]
is
not
None
assert
data
[
"meta_info"
][
"token_logprob"
]
is
not
None
assert
data
[
"meta_info"
][
"normalized_prompt_logprob"
]
is
not
None
if
prev
==
0
:
# Skip prompt logprobs
prev
=
data
[
"meta_info"
][
"prompt_tokens"
]
for
token_txt
,
_
,
logprob
in
data
[
"meta_info"
][
"token_logprob"
][
prev
:]:
print
(
f
"
{
token_txt
}
\t
{
logprob
}
"
,
flush
=
True
)
prev
=
len
(
data
[
"meta_info"
][
"token_logprob"
])
else
:
output
=
data
[
"text"
].
strip
()
print
(
output
[
prev
:],
end
=
""
,
flush
=
True
)
prev
=
len
(
output
)
print
(
""
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"http://127.0.0.1"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
30000
)
args
=
parser
.
parse_args
()
url
=
f
"
{
args
.
host
}
:
{
args
.
port
}
"
test_decode_stream
(
url
,
False
)
test_decode_stream
(
url
,
True
)
test/srt/test_openai_server.py
View file @
a7334aee
...
...
@@ -18,15 +18,26 @@ import argparse
import
openai
def
test_completion
(
args
):
def
test_completion
(
args
,
echo
,
logprobs
):
client
=
openai
.
Client
(
api_key
=
"EMPTY"
,
base_url
=
args
.
base_url
)
response
=
client
.
completions
.
create
(
model
=
"default"
,
prompt
=
"The capital of France is"
,
temperature
=
0
,
max_tokens
=
32
,
echo
=
echo
,
logprobs
=
logprobs
,
)
text
=
response
.
choices
[
0
].
text
print
(
response
.
choices
[
0
].
text
)
if
echo
:
assert
text
.
startswith
(
"The capital of France is"
)
if
logprobs
:
assert
response
.
choices
[
0
].
logprobs
if
echo
:
assert
response
.
choices
[
0
].
logprobs
.
token_logprobs
[
0
]
==
None
else
:
assert
response
.
choices
[
0
].
logprobs
.
token_logprobs
[
0
]
!=
None
assert
response
.
id
assert
response
.
created
assert
response
.
usage
.
prompt_tokens
>
0
...
...
@@ -34,7 +45,7 @@ def test_completion(args):
assert
response
.
usage
.
total_tokens
>
0
def
test_completion_stream
(
args
):
def
test_completion_stream
(
args
,
echo
,
logprobs
):
client
=
openai
.
Client
(
api_key
=
"EMPTY"
,
base_url
=
args
.
base_url
)
response
=
client
.
completions
.
create
(
model
=
"default"
,
...
...
@@ -42,8 +53,22 @@ def test_completion_stream(args):
temperature
=
0
,
max_tokens
=
32
,
stream
=
True
,
echo
=
echo
,
logprobs
=
logprobs
,
)
first
=
True
for
r
in
response
:
if
first
:
if
echo
:
assert
r
.
choices
[
0
].
text
.
startswith
(
"The capital of France is"
)
first
=
False
if
logprobs
:
print
(
f
"
{
r
.
choices
[
0
].
text
:
12
s
}
\t
"
f
"
{
r
.
choices
[
0
].
logprobs
.
token_logprobs
}
"
,
flush
=
True
)
else
:
print
(
r
.
choices
[
0
].
text
,
end
=
""
,
flush
=
True
)
assert
r
.
id
assert
r
.
usage
.
prompt_tokens
>
0
...
...
@@ -135,8 +160,14 @@ if __name__ == "__main__":
)
args
=
parser
.
parse_args
()
test_completion
(
args
)
test_completion_stream
(
args
)
test_completion
(
args
,
echo
=
False
,
logprobs
=
False
)
test_completion
(
args
,
echo
=
True
,
logprobs
=
False
)
test_completion
(
args
,
echo
=
False
,
logprobs
=
True
)
test_completion
(
args
,
echo
=
True
,
logprobs
=
True
)
test_completion_stream
(
args
,
echo
=
False
,
logprobs
=
False
)
test_completion_stream
(
args
,
echo
=
True
,
logprobs
=
False
)
test_completion_stream
(
args
,
echo
=
False
,
logprobs
=
True
)
test_completion_stream
(
args
,
echo
=
True
,
logprobs
=
True
)
test_chat_completion
(
args
)
test_chat_completion_stream
(
args
)
if
args
.
test_image
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment