Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a7334aee
Unverified
Commit
a7334aee
authored
Feb 06, 2024
by
Cody Yu
Committed by
GitHub
Feb 06, 2024
Browse files
Support decode token logprobs (#130)
parent
ee1df26a
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
228 additions
and
91 deletions
+228
-91
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+42
-34
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+4
-0
python/sglang/srt/managers/router/infer_batch.py
python/sglang/srt/managers/router/infer_batch.py
+1
-0
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+34
-14
python/sglang/srt/managers/router/model_runner.py
python/sglang/srt/managers/router/model_runner.py
+5
-6
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+5
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+59
-8
test/srt/test_httpserver_decode.py
test/srt/test_httpserver_decode.py
+14
-11
test/srt/test_httpserver_decode_stream.py
test/srt/test_httpserver_decode_stream.py
+28
-13
test/srt/test_openai_server.py
test/srt/test_openai_server.py
+36
-5
No files found.
python/sglang/srt/layers/logits_processor.py
View file @
a7334aee
...
@@ -14,18 +14,25 @@ class LogitsProcessor(nn.Module):
...
@@ -14,18 +14,25 @@ class LogitsProcessor(nn.Module):
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
def
forward
(
self
,
input_ids
,
hidden_states
,
weight
,
input_metadata
):
def
forward
(
self
,
input_ids
,
hidden_states
,
weight
,
input_metadata
):
last_index
=
None
# Compute the last index (the first decode token) of each requeast
# if we are in prefill or extend mode.
if
input_metadata
.
forward_mode
!=
ForwardMode
.
DECODE
:
last_index
=
(
torch
.
cumsum
(
input_metadata
.
seq_lens
-
input_metadata
.
prefix_lens
,
dim
=
0
,
dtype
=
torch
.
long
,
)
-
1
)
if
not
input_metadata
.
return_logprob
:
if
not
input_metadata
.
return_logprob
:
# When logprob is not requested, only compute the last logits.
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
last_hidden
=
hidden_states
last_hidden
=
hidden_states
else
:
else
:
last_index
=
(
torch
.
cumsum
(
input_metadata
.
seq_lens
-
input_metadata
.
prefix_lens
,
dim
=
0
,
dtype
=
torch
.
long
,
)
-
1
)
last_hidden
=
hidden_states
[
last_index
]
last_hidden
=
hidden_states
[
last_index
]
hidden_states
=
None
hidden_states
=
None
...
@@ -33,41 +40,42 @@ class LogitsProcessor(nn.Module):
...
@@ -33,41 +40,42 @@ class LogitsProcessor(nn.Module):
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
last_logits
=
tensor_model_parallel_all_gather
(
last_logits
)
last_logits
=
tensor_model_parallel_all_gather
(
last_logits
)
last_logits
=
last_logits
[:,
:
self
.
config
.
vocab_size
]
last_logits
=
last_logits
[:,
:
self
.
config
.
vocab_size
]
return
last_logits
,
(
None
,
None
)
return
last_logits
,
(
None
,
None
,
None
)
else
:
else
:
assert
input_metadata
.
forward_mode
!=
ForwardMode
.
DECODE
# When logprob is requested, compute the logits for all tokens.
last_index
=
(
torch
.
cumsum
(
input_metadata
.
seq_lens
-
input_metadata
.
prefix_lens
,
dim
=
0
,
dtype
=
torch
.
long
,
)
-
1
)
logits
=
torch
.
matmul
(
hidden_states
,
weight
.
T
)
logits
=
torch
.
matmul
(
hidden_states
,
weight
.
T
)
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
logits
=
tensor_model_parallel_all_gather
(
logits
)
logits
=
tensor_model_parallel_all_gather
(
logits
)
logits
=
logits
[:,
:
self
.
config
.
vocab_size
]
logits
=
logits
[:,
:
self
.
config
.
vocab_size
]
all_logprobs
=
torch
.
log
(
torch
.
softmax
(
logits
.
float
(),
dim
=-
1
)
+
1e-6
)
all_logprobs
=
torch
.
log
(
torch
.
softmax
(
logits
.
float
(),
dim
=-
1
)
+
1e-6
)
logprobs
=
all_logprobs
[
if
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
torch
.
arange
(
all_logprobs
.
shape
[
0
],
device
=
"cuda"
),
last_logits
=
logits
torch
.
cat
([
input_ids
[
1
:],
torch
.
tensor
([
0
],
device
=
"cuda"
)]),
last_logprobs
=
all_logprobs
]
prefill_logprobs
=
normalized_logprobs
=
None
logprobs_cumsum
=
torch
.
cumsum
(
logprobs
,
dim
=
0
,
dtype
=
torch
.
float32
)
else
:
# Compute the logprobs for the last token of each request.
last_logits
=
logits
[
last_index
]
last_logprobs
=
all_logprobs
[
last_index
]
# Compute the logprobs and normalized logprobs for the prefill tokens.
# Note that we pad a zero at the end of each sequence for easy computation.
prefill_logprobs
=
all_logprobs
[
torch
.
arange
(
all_logprobs
.
shape
[
0
],
device
=
"cuda"
),
torch
.
cat
([
input_ids
[
1
:],
torch
.
tensor
([
0
],
device
=
"cuda"
)]),
]
logprobs_cumsum
=
torch
.
cumsum
(
prefill_logprobs
,
dim
=
0
,
dtype
=
torch
.
float32
)
start
=
input_metadata
.
extend_start_loc
.
clone
()
start
=
input_metadata
.
extend_start_loc
.
clone
()
end
=
start
+
input_metadata
.
extend_seq_lens
-
2
end
=
start
+
input_metadata
.
extend_seq_lens
-
2
start
.
clamp_
(
min
=
0
,
max
=
logprobs
.
shape
[
0
]
-
1
)
start
.
clamp_
(
min
=
0
,
max
=
prefill_
logprobs
.
shape
[
0
]
-
1
)
end
.
clamp_
(
min
=
0
,
max
=
logprobs
.
shape
[
0
]
-
1
)
end
.
clamp_
(
min
=
0
,
max
=
prefill_
logprobs
.
shape
[
0
]
-
1
)
sum_logp
=
logprobs_cumsum
[
end
]
-
logprobs_cumsum
[
start
]
+
logprobs
[
start
]
sum_logp
=
logprobs_cumsum
[
end
]
-
logprobs_cumsum
[
start
]
+
prefill_
logprobs
[
start
]
normalized_logprobs
=
sum_logp
/
(
normalized_logprobs
=
sum_logp
/
(
(
input_metadata
.
extend_seq_lens
-
1
).
clamp
(
min
=
1
)
(
input_metadata
.
extend_seq_lens
-
1
).
clamp
(
min
=
1
)
)
)
last_logits
=
logits
[
last_index
]
return
last_logits
,
(
prefill_logprobs
,
normalized_logprobs
,
last_logprobs
)
return
last_logits
,
(
logprobs
,
normalized_logprobs
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
python/sglang/srt/managers/io_struct.py
View file @
a7334aee
...
@@ -99,3 +99,7 @@ class BatchStrOut:
...
@@ -99,3 +99,7 @@ class BatchStrOut:
@
dataclass
@
dataclass
class
FlushCacheReq
:
class
FlushCacheReq
:
pass
pass
@
dataclass
class
DetokenizeReqInput
:
input_ids
:
List
[
int
]
python/sglang/srt/managers/router/infer_batch.py
View file @
a7334aee
...
@@ -48,6 +48,7 @@ class Req:
...
@@ -48,6 +48,7 @@ class Req:
self
.
last_node
=
None
self
.
last_node
=
None
self
.
logprob
=
None
self
.
logprob
=
None
self
.
token_logprob
=
None
self
.
normalized_logprob
=
None
self
.
normalized_logprob
=
None
# For constrained decoding
# For constrained decoding
...
...
python/sglang/srt/managers/router/model_rpc.py
View file @
a7334aee
...
@@ -388,24 +388,28 @@ class ModelRpcServer(rpyc.Service):
...
@@ -388,24 +388,28 @@ class ModelRpcServer(rpyc.Service):
self
.
model_config
.
vocab_size
,
self
.
int_token_logit_bias
self
.
model_config
.
vocab_size
,
self
.
int_token_logit_bias
)
)
logprobs
=
None
if
batch
.
extend_num_tokens
!=
0
:
if
batch
.
extend_num_tokens
!=
0
:
# Forward
# Forward
logits
,
(
logprobs
,
normalized_logprobs
)
=
self
.
model_runner
.
forward
(
logits
,
(
prefill_
logprobs
,
normalized_logprobs
,
last_logprobs
)
=
(
batch
,
ForwardMode
.
EXTEND
,
batch
.
return_logprob
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
,
batch
.
return_logprob
)
)
)
# print("extend logits", logits)
if
prefill_logprobs
is
not
None
:
if
logprobs
is
not
None
:
logprobs
=
prefill_logprobs
.
cpu
().
tolist
()
logprobs
=
logprobs
.
cpu
().
tolist
()
normalized_logprobs
=
normalized_logprobs
.
cpu
().
tolist
()
normalized_logprobs
=
normalized_logprobs
.
cpu
().
tolist
()
next_token_ids
,
next_token_probs
=
batch
.
sample
(
logits
)
next_token_ids
,
_
=
batch
.
sample
(
logits
)
next_token_ids
=
next_token_ids
.
cpu
().
tolist
()
next_token_ids
=
next_token_ids
.
cpu
().
tolist
()
else
:
else
:
next_token_ids
=
[
self
.
tokenizer
.
eos_token_id
]
*
len
(
batch
.
reqs
)
next_token_ids
=
[
self
.
tokenizer
.
eos_token_id
]
*
len
(
batch
.
reqs
)
logprobs
=
normalized_logprobs
=
None
logits
=
logprobs
=
normalized_logprobs
=
last_logprobs
=
None
#
Check finish condition
#
Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
reqs
=
batch
.
reqs
reqs
=
batch
.
reqs
if
last_logprobs
is
not
None
:
last_logprobs
=
last_logprobs
[
torch
.
arange
(
len
(
reqs
)),
next_token_ids
].
cpu
().
tolist
()
# Check finish condition
pt
=
0
pt
=
0
for
i
,
req
in
enumerate
(
reqs
):
for
i
,
req
in
enumerate
(
reqs
):
req
.
output_ids
=
[
next_token_ids
[
i
]]
req
.
output_ids
=
[
next_token_ids
[
i
]]
...
@@ -414,6 +418,10 @@ class ModelRpcServer(rpyc.Service):
...
@@ -414,6 +418,10 @@ class ModelRpcServer(rpyc.Service):
if
logprobs
is
not
None
:
if
logprobs
is
not
None
:
req
.
logprob
=
logprobs
[
pt
:
pt
+
req
.
extend_input_len
-
1
]
req
.
logprob
=
logprobs
[
pt
:
pt
+
req
.
extend_input_len
-
1
]
req
.
normalized_logprob
=
normalized_logprobs
[
i
]
req
.
normalized_logprob
=
normalized_logprobs
[
i
]
token_ids
=
req
.
input_ids
+
[
next_token_ids
[
i
]]
token_logprobs
=
[
None
]
+
req
.
logprob
+
[
last_logprobs
[
i
]]
req
.
token_logprob
=
list
(
zip
(
token_ids
,
token_logprobs
))
pt
+=
req
.
extend_input_len
pt
+=
req
.
extend_input_len
self
.
handle_finished_requests
(
batch
)
self
.
handle_finished_requests
(
batch
)
...
@@ -463,15 +471,26 @@ class ModelRpcServer(rpyc.Service):
...
@@ -463,15 +471,26 @@ class ModelRpcServer(rpyc.Service):
batch
.
prepare_for_decode
()
batch
.
prepare_for_decode
()
# Forward
# Forward
logits
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
)
logits
,
(
_
,
_
,
last_logprobs
)
=
self
.
model_runner
.
forward
(
next_token_ids
,
next_token_probs
=
batch
.
sample
(
logits
)
batch
,
ForwardMode
.
DECODE
,
batch
.
return_logprob
,
)
next_token_ids
,
_
=
batch
.
sample
(
logits
)
next_token_ids
=
next_token_ids
.
cpu
().
tolist
()
next_token_ids
=
next_token_ids
.
cpu
().
tolist
()
#
Check finish condition
#
Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
reqs
=
batch
.
reqs
reqs
=
batch
.
reqs
for
i
in
range
(
len
(
reqs
)):
if
last_logprobs
is
not
None
:
reqs
[
i
].
output_ids
.
append
(
next_token_ids
[
i
])
last_logprobs
=
last_logprobs
[
torch
.
arange
(
len
(
reqs
)),
next_token_ids
].
tolist
()
reqs
[
i
].
check_finished
()
# Check finish condition
for
i
,
(
req
,
next_tok_id
)
in
enumerate
(
zip
(
reqs
,
next_token_ids
)):
req
.
output_ids
.
append
(
next_tok_id
)
req
.
check_finished
()
if
last_logprobs
is
not
None
:
req
.
token_logprob
.
append
((
next_tok_id
,
last_logprobs
[
i
]))
self
.
handle_finished_requests
(
batch
)
self
.
handle_finished_requests
(
batch
)
...
@@ -513,6 +532,7 @@ class ModelRpcServer(rpyc.Service):
...
@@ -513,6 +532,7 @@ class ModelRpcServer(rpyc.Service):
}
}
if
req
.
return_logprob
:
if
req
.
return_logprob
:
meta_info
[
"prompt_logprob"
]
=
req
.
logprob
meta_info
[
"prompt_logprob"
]
=
req
.
logprob
meta_info
[
"token_logprob"
]
=
req
.
token_logprob
meta_info
[
"normalized_prompt_logprob"
]
=
req
.
normalized_logprob
meta_info
[
"normalized_prompt_logprob"
]
=
req
.
normalized_logprob
output_meta_info
.
append
(
meta_info
)
output_meta_info
.
append
(
meta_info
)
output_finished
.
append
(
req
.
finished
)
output_finished
.
append
(
req
.
finished
)
...
...
python/sglang/srt/managers/router/model_runner.py
View file @
a7334aee
...
@@ -397,6 +397,7 @@ class ModelRunner:
...
@@ -397,6 +397,7 @@ class ModelRunner:
out_cache_loc
,
out_cache_loc
,
out_cache_cont_start
,
out_cache_cont_start
,
out_cache_cont_end
,
out_cache_cont_end
,
return_logprob
,
):
):
input_metadata
=
InputMetadata
.
create
(
input_metadata
=
InputMetadata
.
create
(
self
,
self
,
...
@@ -409,10 +410,9 @@ class ModelRunner:
...
@@ -409,10 +410,9 @@ class ModelRunner:
out_cache_loc
=
out_cache_loc
,
out_cache_loc
=
out_cache_loc
,
out_cache_cont_start
=
out_cache_cont_start
,
out_cache_cont_start
=
out_cache_cont_start
,
out_cache_cont_end
=
out_cache_cont_end
,
out_cache_cont_end
=
out_cache_cont_end
,
return_logprob
=
return_logprob
,
)
)
return
self
.
model
.
forward
(
input_ids
,
input_metadata
.
positions
,
input_metadata
)[
return
self
.
model
.
forward
(
input_ids
,
input_metadata
.
positions
,
input_metadata
)
0
]
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
forward_extend_multi_modal
(
def
forward_extend_multi_modal
(
...
@@ -460,8 +460,8 @@ class ModelRunner:
...
@@ -460,8 +460,8 @@ class ModelRunner:
"prefix_lens"
:
batch
.
prefix_lens
,
"prefix_lens"
:
batch
.
prefix_lens
,
"position_ids_offsets"
:
batch
.
position_ids_offsets
,
"position_ids_offsets"
:
batch
.
position_ids_offsets
,
"out_cache_loc"
:
batch
.
out_cache_loc
,
"out_cache_loc"
:
batch
.
out_cache_loc
,
"return_logprob"
:
return_logprob
,
}
}
kwargs
[
"return_logprob"
]
=
return_logprob
return
self
.
forward_extend_multi_modal
(
**
kwargs
)
return
self
.
forward_extend_multi_modal
(
**
kwargs
)
else
:
else
:
kwargs
=
{
kwargs
=
{
...
@@ -471,6 +471,7 @@ class ModelRunner:
...
@@ -471,6 +471,7 @@ class ModelRunner:
"prefix_lens"
:
batch
.
prefix_lens
,
"prefix_lens"
:
batch
.
prefix_lens
,
"position_ids_offsets"
:
batch
.
position_ids_offsets
,
"position_ids_offsets"
:
batch
.
position_ids_offsets
,
"out_cache_loc"
:
batch
.
out_cache_loc
,
"out_cache_loc"
:
batch
.
out_cache_loc
,
"return_logprob"
:
return_logprob
,
}
}
if
forward_mode
==
ForwardMode
.
DECODE
:
if
forward_mode
==
ForwardMode
.
DECODE
:
...
@@ -478,10 +479,8 @@ class ModelRunner:
...
@@ -478,10 +479,8 @@ class ModelRunner:
kwargs
[
"out_cache_cont_end"
]
=
batch
.
out_cache_cont_end
kwargs
[
"out_cache_cont_end"
]
=
batch
.
out_cache_cont_end
return
self
.
forward_decode
(
**
kwargs
)
return
self
.
forward_decode
(
**
kwargs
)
elif
forward_mode
==
ForwardMode
.
EXTEND
:
elif
forward_mode
==
ForwardMode
.
EXTEND
:
kwargs
[
"return_logprob"
]
=
return_logprob
return
self
.
forward_extend
(
**
kwargs
)
return
self
.
forward_extend
(
**
kwargs
)
elif
forward_mode
==
ForwardMode
.
PREFILL
:
elif
forward_mode
==
ForwardMode
.
PREFILL
:
kwargs
[
"return_logprob"
]
=
return_logprob
return
self
.
forward_prefill
(
**
kwargs
)
return
self
.
forward_prefill
(
**
kwargs
)
else
:
else
:
raise
ValueError
(
f
"Invaid forward mode:
{
forward_mode
}
"
)
raise
ValueError
(
f
"Invaid forward mode:
{
forward_mode
}
"
)
python/sglang/srt/managers/tokenizer_manager.py
View file @
a7334aee
...
@@ -18,6 +18,7 @@ from sglang.srt.hf_transformers_utils import (
...
@@ -18,6 +18,7 @@ from sglang.srt.hf_transformers_utils import (
)
)
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
BatchStrOut
,
BatchStrOut
,
DetokenizeReqInput
,
FlushCacheReq
,
FlushCacheReq
,
GenerateReqInput
,
GenerateReqInput
,
TokenizedGenerateReqInput
,
TokenizedGenerateReqInput
,
...
@@ -234,6 +235,10 @@ class TokenizerManager:
...
@@ -234,6 +235,10 @@ class TokenizerManager:
yield
output_list
yield
output_list
async
def
detokenize
(
self
,
obj
:
DetokenizeReqInput
):
token_texts
=
self
.
tokenizer
.
convert_ids_to_tokens
(
obj
.
input_ids
)
return
[
t
.
decode
()
if
isinstance
(
t
,
bytes
)
else
t
for
t
in
token_texts
]
async
def
flush_cache
(
self
):
async
def
flush_cache
(
self
):
flush_cache_req
=
FlushCacheReq
()
flush_cache_req
=
FlushCacheReq
()
self
.
send_to_router
.
send_pyobj
(
flush_cache_req
)
self
.
send_to_router
.
send_pyobj
(
flush_cache_req
)
...
...
python/sglang/srt/server.py
View file @
a7334aee
...
@@ -30,7 +30,7 @@ from sglang.srt.conversation import (
...
@@ -30,7 +30,7 @@ from sglang.srt.conversation import (
)
)
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.managers.detokenizer_manager
import
start_detokenizer_process
from
sglang.srt.managers.detokenizer_manager
import
start_detokenizer_process
from
sglang.srt.managers.io_struct
import
GenerateReqInput
from
sglang.srt.managers.io_struct
import
DetokenizeReqInput
,
GenerateReqInput
from
sglang.srt.managers.openai_protocol
import
(
from
sglang.srt.managers.openai_protocol
import
(
ChatCompletionRequest
,
ChatCompletionRequest
,
ChatCompletionResponse
,
ChatCompletionResponse
,
...
@@ -44,6 +44,7 @@ from sglang.srt.managers.openai_protocol import (
...
@@ -44,6 +44,7 @@ from sglang.srt.managers.openai_protocol import (
CompletionResponseStreamChoice
,
CompletionResponseStreamChoice
,
CompletionStreamResponse
,
CompletionStreamResponse
,
DeltaMessage
,
DeltaMessage
,
LogProbs
,
UsageInfo
,
UsageInfo
,
)
)
from
sglang.srt.managers.router.manager
import
start_router_process
from
sglang.srt.managers.router.manager
import
start_router_process
...
@@ -97,6 +98,23 @@ async def stream_generator(obj):
...
@@ -97,6 +98,23 @@ async def stream_generator(obj):
yield
out
yield
out
async
def
make_openai_style_logprobs
(
token_logprobs
):
ret_logprobs
=
LogProbs
()
# Detokenize
token_ids
=
[
tid
for
tid
,
_
in
token_logprobs
]
token_texts
=
await
tokenizer_manager
.
detokenize
(
DetokenizeReqInput
(
token_ids
))
for
token_text
,
(
_
,
token_logprob
)
in
zip
(
token_texts
,
token_logprobs
):
ret_logprobs
.
tokens
.
append
(
token_text
)
ret_logprobs
.
token_logprobs
.
append
(
token_logprob
)
# Not supported yet.
ret_logprobs
.
top_logprobs
.
append
({})
ret_logprobs
.
text_offset
.
append
(
-
1
)
return
ret_logprobs
@
app
.
post
(
"/generate"
)
@
app
.
post
(
"/generate"
)
async
def
generate_request
(
obj
:
GenerateReqInput
):
async
def
generate_request
(
obj
:
GenerateReqInput
):
obj
.
post_init
()
obj
.
post_init
()
...
@@ -132,6 +150,7 @@ async def v1_completions(raw_request: Request):
...
@@ -132,6 +150,7 @@ async def v1_completions(raw_request: Request):
"presence_penalty"
:
request
.
presence_penalty
,
"presence_penalty"
:
request
.
presence_penalty
,
"frequency_penalty"
:
request
.
frequency_penalty
,
"frequency_penalty"
:
request
.
frequency_penalty
,
},
},
return_logprob
=
request
.
logprobs
is
not
None
,
stream
=
request
.
stream
,
stream
=
request
.
stream
,
)
)
adapted_request
.
post_init
()
adapted_request
.
post_init
()
...
@@ -140,17 +159,34 @@ async def v1_completions(raw_request: Request):
...
@@ -140,17 +159,34 @@ async def v1_completions(raw_request: Request):
async
def
gnerate_stream_resp
():
async
def
gnerate_stream_resp
():
stream_buffer
=
""
stream_buffer
=
""
n_prev_token
=
0
async
for
content
in
stream_generator
(
adapted_request
):
async
for
content
in
stream_generator
(
adapted_request
):
text
=
content
[
"text"
]
text
=
content
[
"text"
]
prompt_tokens
=
content
[
"meta_info"
][
"prompt_tokens"
]
prompt_tokens
=
content
[
"meta_info"
][
"prompt_tokens"
]
completion_tokens
=
content
[
"meta_info"
][
"completion_tokens"
]
completion_tokens
=
content
[
"meta_info"
][
"completion_tokens"
]
if
not
stream_buffer
:
# The first chunk
if
request
.
echo
:
# Prepend prompt in response text.
text
=
request
.
prompt
+
text
else
:
# Skip prompt tokens if echo is disabled.
n_prev_token
=
prompt_tokens
if
request
.
logprobs
is
not
None
:
logprobs
=
await
make_openai_style_logprobs
(
content
[
"meta_info"
][
"token_logprob"
][
n_prev_token
:]
)
n_prev_token
=
len
(
content
[
"meta_info"
][
"token_logprob"
])
else
:
logprobs
=
None
delta
=
text
[
len
(
stream_buffer
)
:]
delta
=
text
[
len
(
stream_buffer
)
:]
stream_buffer
=
text
stream_buffer
=
content
[
"
text
"
]
choice_data
=
CompletionResponseStreamChoice
(
choice_data
=
CompletionResponseStreamChoice
(
index
=
0
,
index
=
0
,
text
=
delta
,
text
=
delta
,
logprobs
=
None
,
logprobs
=
logprobs
,
finish_reason
=
None
,
finish_reason
=
None
,
)
)
chunk
=
CompletionStreamResponse
(
chunk
=
CompletionStreamResponse
(
...
@@ -172,15 +208,28 @@ async def v1_completions(raw_request: Request):
...
@@ -172,15 +208,28 @@ async def v1_completions(raw_request: Request):
# Non-streaming response.
# Non-streaming response.
ret
=
await
generate_request
(
adapted_request
)
ret
=
await
generate_request
(
adapted_request
)
prompt_tokens
=
ret
[
"meta_info"
][
"prompt_tokens"
]
completion_tokens
=
ret
[
"meta_info"
][
"completion_tokens"
]
text
=
ret
[
"text"
]
token_logprob_pos
=
prompt_tokens
if
request
.
echo
:
token_logprob_pos
=
0
text
=
request
.
prompt
+
text
else
:
token_logprob_pos
=
prompt_tokens
logprobs
=
(
await
make_openai_style_logprobs
(
ret
[
"meta_info"
][
"token_logprob"
][
token_logprob_pos
:])
if
request
.
logprobs
is
not
None
else
None
)
choice_data
=
CompletionResponseChoice
(
choice_data
=
CompletionResponseChoice
(
index
=
0
,
index
=
0
,
text
=
ret
[
"
text
"
]
,
text
=
text
,
logprobs
=
None
,
logprobs
=
logprobs
,
finish_reason
=
None
,
# TODO(comaniac): Add finish reason.
finish_reason
=
None
,
# TODO(comaniac): Add finish reason.
)
)
prompt_tokens
=
ret
[
"meta_info"
][
"prompt_tokens"
]
completion_tokens
=
ret
[
"meta_info"
][
"completion_tokens"
]
response
=
CompletionResponse
(
response
=
CompletionResponse
(
id
=
ret
[
"meta_info"
][
"id"
],
id
=
ret
[
"meta_info"
][
"id"
],
model
=
request
.
model
,
model
=
request
.
model
,
...
@@ -216,7 +265,9 @@ async def v1_chat_completions(raw_request: Request):
...
@@ -216,7 +265,9 @@ async def v1_chat_completions(raw_request: Request):
if
not
isinstance
(
m
.
content
,
str
):
if
not
isinstance
(
m
.
content
,
str
):
raise
HTTPException
(
raise
HTTPException
(
status_code
=
503
,
status_code
=
503
,
detail
=
"Structured content requests not supported with HuggingFace Chat Templates. Make sure the server specifies a sglang chat template."
,
detail
=
"Structured content requests not supported with "
"HuggingFace Chat Templates. "
"Make sure the server specifies a sglang chat template."
,
)
)
prompt
=
tokenizer_manager
.
tokenizer
.
apply_chat_template
(
prompt
=
tokenizer_manager
.
tokenizer
.
apply_chat_template
(
request
.
messages
,
tokenize
=
False
,
add_generation_prompt
=
True
request
.
messages
,
tokenize
=
False
,
add_generation_prompt
=
True
...
...
test/srt/test_httpserver_decode.py
View file @
a7334aee
...
@@ -9,18 +9,10 @@ The capital of France is Paris.\nThe capital of the United States is Washington,
...
@@ -9,18 +9,10 @@ The capital of France is Paris.\nThe capital of the United States is Washington,
"""
"""
import
argparse
import
argparse
import
time
import
requests
import
requests
if
__name__
==
"__main__"
:
def
test_decode
(
url
,
return_logprob
):
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"http://127.0.0.1"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
30000
)
args
=
parser
.
parse_args
()
url
=
f
"
{
args
.
host
}
:
{
args
.
port
}
"
response
=
requests
.
post
(
response
=
requests
.
post
(
url
+
"/generate"
,
url
+
"/generate"
,
json
=
{
json
=
{
...
@@ -29,8 +21,19 @@ if __name__ == "__main__":
...
@@ -29,8 +21,19 @@ if __name__ == "__main__":
"temperature"
:
0
,
"temperature"
:
0
,
"max_new_tokens"
:
32
,
"max_new_tokens"
:
32
,
},
},
#
"return_logprob":
True
,
"return_logprob"
:
return_logprob
,
#
"logprob_start_len": 0,
"logprob_start_len"
:
0
,
},
},
)
)
print
(
response
.
json
())
print
(
response
.
json
())
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"http://127.0.0.1"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
30000
)
args
=
parser
.
parse_args
()
url
=
f
"
{
args
.
host
}
:
{
args
.
port
}
"
test_decode
(
url
,
False
)
test_decode
(
url
,
True
)
test/srt/test_httpserver_decode_stream.py
View file @
a7334aee
...
@@ -9,27 +9,20 @@ The capital of France is Paris.\nThe capital of the United States is Washington,
...
@@ -9,27 +9,20 @@ The capital of France is Paris.\nThe capital of the United States is Washington,
import
argparse
import
argparse
import
json
import
json
import
time
import
requests
import
requests
if
__name__
==
"__main__"
:
def
test_decode_stream
(
url
,
return_logprob
):
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"http://127.0.0.1"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
30000
)
args
=
parser
.
parse_args
()
url
=
f
"
{
args
.
host
}
:
{
args
.
port
}
"
response
=
requests
.
post
(
response
=
requests
.
post
(
url
+
"/generate"
,
url
+
"/generate"
,
json
=
{
json
=
{
"text"
:
"The capital of France is"
,
"text"
:
"The capital of France is"
,
"sampling_params"
:
{
"sampling_params"
:
{
"temperature"
:
0
,
"temperature"
:
0
,
"max_new_tokens"
:
5
12
,
"max_new_tokens"
:
12
8
,
},
},
"stream"
:
True
,
"stream"
:
True
,
"return_logprob"
:
return_logprob
,
},
},
stream
=
True
,
stream
=
True
,
)
)
...
@@ -41,7 +34,29 @@ if __name__ == "__main__":
...
@@ -41,7 +34,29 @@ if __name__ == "__main__":
if
chunk
==
"data: [DONE]"
:
if
chunk
==
"data: [DONE]"
:
break
break
data
=
json
.
loads
(
chunk
[
5
:].
strip
(
"
\n
"
))
data
=
json
.
loads
(
chunk
[
5
:].
strip
(
"
\n
"
))
output
=
data
[
"text"
].
strip
()
print
(
output
[
prev
:],
end
=
""
,
flush
=
True
)
if
return_logprob
:
prev
=
len
(
output
)
assert
data
[
"meta_info"
][
"prompt_logprob"
]
is
not
None
assert
data
[
"meta_info"
][
"token_logprob"
]
is
not
None
assert
data
[
"meta_info"
][
"normalized_prompt_logprob"
]
is
not
None
if
prev
==
0
:
# Skip prompt logprobs
prev
=
data
[
"meta_info"
][
"prompt_tokens"
]
for
token_txt
,
_
,
logprob
in
data
[
"meta_info"
][
"token_logprob"
][
prev
:]:
print
(
f
"
{
token_txt
}
\t
{
logprob
}
"
,
flush
=
True
)
prev
=
len
(
data
[
"meta_info"
][
"token_logprob"
])
else
:
output
=
data
[
"text"
].
strip
()
print
(
output
[
prev
:],
end
=
""
,
flush
=
True
)
prev
=
len
(
output
)
print
(
""
)
print
(
""
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"http://127.0.0.1"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
30000
)
args
=
parser
.
parse_args
()
url
=
f
"
{
args
.
host
}
:
{
args
.
port
}
"
test_decode_stream
(
url
,
False
)
test_decode_stream
(
url
,
True
)
test/srt/test_openai_server.py
View file @
a7334aee
...
@@ -18,15 +18,26 @@ import argparse
...
@@ -18,15 +18,26 @@ import argparse
import
openai
import
openai
def
test_completion
(
args
):
def
test_completion
(
args
,
echo
,
logprobs
):
client
=
openai
.
Client
(
api_key
=
"EMPTY"
,
base_url
=
args
.
base_url
)
client
=
openai
.
Client
(
api_key
=
"EMPTY"
,
base_url
=
args
.
base_url
)
response
=
client
.
completions
.
create
(
response
=
client
.
completions
.
create
(
model
=
"default"
,
model
=
"default"
,
prompt
=
"The capital of France is"
,
prompt
=
"The capital of France is"
,
temperature
=
0
,
temperature
=
0
,
max_tokens
=
32
,
max_tokens
=
32
,
echo
=
echo
,
logprobs
=
logprobs
,
)
)
text
=
response
.
choices
[
0
].
text
print
(
response
.
choices
[
0
].
text
)
print
(
response
.
choices
[
0
].
text
)
if
echo
:
assert
text
.
startswith
(
"The capital of France is"
)
if
logprobs
:
assert
response
.
choices
[
0
].
logprobs
if
echo
:
assert
response
.
choices
[
0
].
logprobs
.
token_logprobs
[
0
]
==
None
else
:
assert
response
.
choices
[
0
].
logprobs
.
token_logprobs
[
0
]
!=
None
assert
response
.
id
assert
response
.
id
assert
response
.
created
assert
response
.
created
assert
response
.
usage
.
prompt_tokens
>
0
assert
response
.
usage
.
prompt_tokens
>
0
...
@@ -34,7 +45,7 @@ def test_completion(args):
...
@@ -34,7 +45,7 @@ def test_completion(args):
assert
response
.
usage
.
total_tokens
>
0
assert
response
.
usage
.
total_tokens
>
0
def
test_completion_stream
(
args
):
def
test_completion_stream
(
args
,
echo
,
logprobs
):
client
=
openai
.
Client
(
api_key
=
"EMPTY"
,
base_url
=
args
.
base_url
)
client
=
openai
.
Client
(
api_key
=
"EMPTY"
,
base_url
=
args
.
base_url
)
response
=
client
.
completions
.
create
(
response
=
client
.
completions
.
create
(
model
=
"default"
,
model
=
"default"
,
...
@@ -42,9 +53,23 @@ def test_completion_stream(args):
...
@@ -42,9 +53,23 @@ def test_completion_stream(args):
temperature
=
0
,
temperature
=
0
,
max_tokens
=
32
,
max_tokens
=
32
,
stream
=
True
,
stream
=
True
,
echo
=
echo
,
logprobs
=
logprobs
,
)
)
first
=
True
for
r
in
response
:
for
r
in
response
:
print
(
r
.
choices
[
0
].
text
,
end
=
""
,
flush
=
True
)
if
first
:
if
echo
:
assert
r
.
choices
[
0
].
text
.
startswith
(
"The capital of France is"
)
first
=
False
if
logprobs
:
print
(
f
"
{
r
.
choices
[
0
].
text
:
12
s
}
\t
"
f
"
{
r
.
choices
[
0
].
logprobs
.
token_logprobs
}
"
,
flush
=
True
)
else
:
print
(
r
.
choices
[
0
].
text
,
end
=
""
,
flush
=
True
)
assert
r
.
id
assert
r
.
id
assert
r
.
usage
.
prompt_tokens
>
0
assert
r
.
usage
.
prompt_tokens
>
0
assert
r
.
usage
.
completion_tokens
>
0
assert
r
.
usage
.
completion_tokens
>
0
...
@@ -135,8 +160,14 @@ if __name__ == "__main__":
...
@@ -135,8 +160,14 @@ if __name__ == "__main__":
)
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
test_completion
(
args
)
test_completion
(
args
,
echo
=
False
,
logprobs
=
False
)
test_completion_stream
(
args
)
test_completion
(
args
,
echo
=
True
,
logprobs
=
False
)
test_completion
(
args
,
echo
=
False
,
logprobs
=
True
)
test_completion
(
args
,
echo
=
True
,
logprobs
=
True
)
test_completion_stream
(
args
,
echo
=
False
,
logprobs
=
False
)
test_completion_stream
(
args
,
echo
=
True
,
logprobs
=
False
)
test_completion_stream
(
args
,
echo
=
False
,
logprobs
=
True
)
test_completion_stream
(
args
,
echo
=
True
,
logprobs
=
True
)
test_chat_completion
(
args
)
test_chat_completion
(
args
)
test_chat_completion_stream
(
args
)
test_chat_completion_stream
(
args
)
if
args
.
test_image
:
if
args
.
test_image
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment