Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
61d4c939
Unverified
Commit
61d4c939
authored
Jan 18, 2024
by
Cody Yu
Committed by
GitHub
Jan 18, 2024
Browse files
Support stream=True in v1/completions (#49)
parent
98a3e8ef
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
233 additions
and
39 deletions
+233
-39
README.md
README.md
+17
-1
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/backend/runtime_endpoint.py
python/sglang/backend/runtime_endpoint.py
+6
-3
python/sglang/srt/managers/openai_protocol.py
python/sglang/srt/managers/openai_protocol.py
+63
-8
python/sglang/srt/server.py
python/sglang/srt/server.py
+85
-22
test/srt/test_httpserver_decode_stream.py
test/srt/test_httpserver_decode_stream.py
+7
-4
test/srt/test_openai_server.py
test/srt/test_openai_server.py
+54
-0
No files found.
README.md
View file @
61d4c939
...
...
@@ -238,9 +238,25 @@ curl http://localhost:30000/generate \
}
}'
```
Learn more about the argument format
[
here
](
docs/sampling_params.md
)
.
### OpenAI Compatible API
In addition, the server supports an experimental OpenAI-compatible API.
```
python
import
openai
client
=
openai
.
Client
(
base_url
=
"http://127.0.0.1:30000/v1"
,
api_key
=
"EMPTY"
)
response
=
client
.
completions
.
create
(
model
=
"default"
,
prompt
=
"The capital of France is"
,
temperature
=
0
,
max_tokens
=
32
,
)
print
(
response
)
```
### Additional Arguments
-
Add
`--tp 2`
to enable tensor parallelism.
```
...
...
python/pyproject.toml
View file @
61d4c939
...
...
@@ -19,7 +19,7 @@ dependencies = [
[project.optional-dependencies]
srt
=
[
"fastapi"
,
"psutil"
,
"rpyc"
,
"torch"
,
"uvloop"
,
"uvicorn"
,
"zmq"
,
"vllm>=0.2.5"
,
"interegular"
,
"lark"
,
"numba"
]
"interegular"
,
"lark"
,
"numba"
,
"pydantic"
]
openai
=
["openai>=1.0"]
anthropic
=
["anthropic"]
all
=
["sglang[srt]
", "
sglang
[openai]
", "
sglang
[anthropic]"]
...
...
python/sglang/backend/runtime_endpoint.py
View file @
61d4c939
...
...
@@ -116,9 +116,12 @@ class RuntimeEndpoint(BaseBackend):
pos
=
0
incomplete_text
=
""
for
chunk
in
response
.
iter_lines
(
decode_unicode
=
False
,
delimiter
=
b
"
\0
"
):
if
chunk
:
data
=
json
.
loads
(
chunk
.
decode
())
for
chunk
in
response
.
iter_lines
(
decode_unicode
=
False
):
chunk
=
chunk
.
decode
(
"utf-8"
)
if
chunk
and
chunk
.
startswith
(
"data:"
):
if
chunk
==
"data: [DONE]"
:
break
data
=
json
.
loads
(
chunk
[
5
:].
strip
(
"
\n
"
))
text
=
find_printable_text
(
data
[
"text"
][
pos
:])
meta_info
=
data
[
"meta_info"
]
pos
+=
len
(
text
)
...
...
python/sglang/srt/managers/openai_protocol.py
View file @
61d4c939
from
dataclasses
import
dataclass
from
typing
import
Any
,
List
,
Optional
,
Union
import
time
from
typing
import
Dict
,
List
,
Optional
,
Union
from
pydantic
import
BaseModel
,
Field
@
dataclass
class
CompletionRequest
:
prompt
:
Union
[
str
,
List
[
Any
]]
model
:
str
=
"default"
temperature
:
Optional
[
float
]
=
0.7
class
LogProbs
(
BaseModel
):
text_offset
:
List
[
int
]
=
Field
(
default_factory
=
list
)
token_logprobs
:
List
[
Optional
[
float
]]
=
Field
(
default_factory
=
list
)
tokens
:
List
[
str
]
=
Field
(
default_factory
=
list
)
top_logprobs
:
List
[
Optional
[
Dict
[
str
,
float
]]]
=
Field
(
default_factory
=
list
)
class
UsageInfo
(
BaseModel
):
prompt_tokens
:
int
=
0
total_tokens
:
int
=
0
completion_tokens
:
Optional
[
int
]
=
0
class
CompletionRequest
(
BaseModel
):
model
:
str
prompt
:
Union
[
str
,
List
[
str
]]
suffix
:
Optional
[
str
]
=
None
max_tokens
:
Optional
[
int
]
=
16
temperature
:
Optional
[
float
]
=
0.7
top_p
:
Optional
[
float
]
=
1.0
n
:
Optional
[
int
]
=
1
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
stream
:
Optional
[
bool
]
=
False
logprobs
:
Optional
[
int
]
=
None
echo
:
Optional
[
bool
]
=
False
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default_factory
=
list
)
presence_penalty
:
Optional
[
float
]
=
0.0
frequency_penalty
:
Optional
[
float
]
=
0.0
best_of
:
Optional
[
int
]
=
None
logit_bias
:
Optional
[
Dict
[
str
,
float
]]
=
None
user
:
Optional
[
str
]
=
None
class
CompletionResponseChoice
(
BaseModel
):
index
:
int
text
:
str
logprobs
:
Optional
[
LogProbs
]
=
None
finish_reason
:
Optional
[
str
]
=
None
class
CompletionResponse
(
BaseModel
):
id
:
str
object
:
str
=
"text_completion"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
model
:
str
choices
:
List
[
CompletionResponseChoice
]
usage
:
UsageInfo
class
CompletionResponseStreamChoice
(
BaseModel
):
index
:
int
text
:
str
logprobs
:
Optional
[
LogProbs
]
=
None
finish_reason
:
Optional
[
str
]
=
None
class
CompletionStreamResponse
(
BaseModel
):
id
:
str
object
:
str
=
"text_completion"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
model
:
str
choices
:
List
[
CompletionResponseStreamChoice
]
python/sglang/srt/server.py
View file @
61d4c939
"""SRT: SGLang Runtime"""
import
argparse
import
asyncio
import
dataclasses
import
json
import
multiprocessing
as
mp
import
sys
...
...
@@ -16,12 +14,19 @@ import psutil
import
requests
import
uvicorn
import
uvloop
from
fastapi
import
FastAPI
from
fastapi
import
FastAPI
,
Request
from
fastapi.responses
import
StreamingResponse
from
sglang.backend.runtime_endpoint
import
RuntimeEndpoint
from
sglang.srt.managers.detokenizer_manager
import
start_detokenizer_process
from
sglang.srt.managers.io_struct
import
GenerateReqInput
from
sglang.srt.managers.openai_protocol
import
CompletionRequest
from
sglang.srt.managers.openai_protocol
import
(
CompletionRequest
,
CompletionResponse
,
CompletionResponseChoice
,
CompletionResponseStreamChoice
,
CompletionStreamResponse
,
UsageInfo
)
from
sglang.srt.managers.router.manager
import
start_router_process
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
...
...
@@ -41,39 +46,97 @@ async def get_model_info():
}
return
result
async
def
stream_generator
(
obj
):
async
for
out
in
tokenizer_manager
.
generate_request
(
obj
):
yield
out
@
app
.
post
(
"/generate"
)
async
def
generate_request
(
obj
:
GenerateReqInput
):
obj
.
post_init
()
result_generator
=
tokenizer_manager
.
generate_request
(
obj
)
if
obj
.
stream
:
async
def
stream_results
():
async
for
out
in
result_generator
:
yield
(
json
.
dumps
(
out
)
+
"
\0
"
).
encode
(
"utf-8"
)
async
for
out
in
stream_generator
(
obj
):
yield
f
"data:
{
json
.
dumps
(
out
,
ensure_ascii
=
False
)
}
\n\n
"
yield
"data: [DONE]
\n\n
"
return
StreamingResponse
(
stream_results
(),
media_type
=
"text/event-stream"
)
else
:
ret
=
await
result_generator
.
__anext__
()
return
ret
ret
=
await
tokenizer_manager
.
generate_request
(
obj
)
.
__anext__
()
return
ret
@
app
.
post
(
"/v1/completions"
)
async
def
v1_completions
(
obj
:
CompletionRequest
):
assert
obj
.
n
==
1
obj
=
GenerateReqInput
(
text
=
obj
.
prompt
,
async
def
v1_completions
(
raw_request
:
Request
):
request_json
=
await
raw_request
.
json
()
request
=
CompletionRequest
(
**
request_json
)
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
assert
request
.
n
==
1
adapted_request
=
GenerateReqInput
(
text
=
request
.
prompt
,
sampling_params
=
{
"temperature"
:
obj
.
temperature
,
"max_new_tokens"
:
obj
.
max_tokens
,
"stop"
:
obj
.
stop
,
"temperature"
:
request
.
temperature
,
"max_new_tokens"
:
request
.
max_tokens
,
"stop"
:
request
.
stop
,
"top_p"
:
request
.
top_p
,
"presence_penalty"
:
request
.
presence_penalty
,
"frequency_penalty"
:
request
.
frequency_penalty
,
},
stream
=
request
.
stream
,
)
ret
=
await
generate_request
(
obj
)
return
{
"choices"
:
[{
"text"
:
ret
[
"text"
]}],
}
adapted_request
.
post_init
()
if
adapted_request
.
stream
:
async
def
gnerate_stream_resp
():
stream_buffer
=
""
async
for
content
in
stream_generator
(
adapted_request
):
text
=
content
[
"text"
]
delta
=
text
[
len
(
stream_buffer
):]
stream_buffer
=
text
choice_data
=
CompletionResponseStreamChoice
(
index
=
0
,
text
=
delta
,
logprobs
=
None
,
finish_reason
=
None
,
)
chunk
=
CompletionStreamResponse
(
id
=
content
[
"meta_info"
][
"id"
],
object
=
"text_completion"
,
choices
=
[
choice_data
],
model
=
request
.
model
,
)
yield
f
"data:
{
chunk
.
json
(
exclude_unset
=
True
,
ensure_ascii
=
False
)
}
\n\n
"
return
StreamingResponse
(
gnerate_stream_resp
(),
media_type
=
"text/event-stream"
)
# Non-streaming response.
ret
=
await
generate_request
(
adapted_request
)
choice_data
=
CompletionResponseChoice
(
index
=
0
,
text
=
ret
[
"text"
],
logprobs
=
None
,
finish_reason
=
None
,
# TODO(comaniac): Add finish reason.
)
prompt_tokens
=
ret
[
"meta_info"
][
"prompt_tokens"
]
completion_tokens
=
ret
[
"meta_info"
][
"completion_tokens"
]
response
=
CompletionResponse
(
id
=
ret
[
"meta_info"
][
"id"
],
model
=
request
.
model
,
choices
=
[
choice_data
],
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
completion_tokens
=
completion_tokens
,
total_tokens
=
prompt_tokens
+
completion_tokens
,
),
)
return
response
def
launch_server
(
server_args
,
pipe_finish_writer
):
...
...
test/srt/test_httpserver_decode_stream.py
View file @
61d4c939
...
...
@@ -25,7 +25,7 @@ if __name__ == "__main__":
"text"
:
"The capital of France is"
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
1024
,
"max_new_tokens"
:
512
,
},
"stream"
:
True
,
},
...
...
@@ -33,9 +33,12 @@ if __name__ == "__main__":
)
prev
=
0
for
chunk
in
response
.
iter_lines
(
decode_unicode
=
False
,
delimiter
=
b
"
\0
"
):
if
chunk
:
data
=
json
.
loads
(
chunk
.
decode
())
for
chunk
in
response
.
iter_lines
(
decode_unicode
=
False
):
chunk
=
chunk
.
decode
(
"utf-8"
)
if
chunk
and
chunk
.
startswith
(
"data:"
):
if
chunk
==
"data: [DONE]"
:
break
data
=
json
.
loads
(
chunk
[
5
:].
strip
(
"
\n
"
))
output
=
data
[
"text"
].
strip
()
print
(
output
[
prev
:],
end
=
""
,
flush
=
True
)
prev
=
len
(
output
)
...
...
test/srt/test_openai_server.py
0 → 100644
View file @
61d4c939
"""
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
Output:
The capital of France is Paris.
\n
The capital of the United States is Washington, D.C.
\n
The capital of Canada is Ottawa.
\n
The capital of Japan is Tokyo
"""
import
argparse
import
openai
def
test_completion
(
args
):
client
=
openai
.
Client
(
api_key
=
"EMPTY"
,
base_url
=
args
.
base_url
)
response
=
client
.
completions
.
create
(
model
=
"default"
,
prompt
=
"The capital of France is"
,
temperature
=
0
,
max_tokens
=
32
,
)
print
(
response
.
choices
[
0
].
text
)
assert
response
.
id
assert
response
.
created
assert
response
.
usage
.
prompt_tokens
>
0
assert
response
.
usage
.
completion_tokens
>
0
assert
response
.
usage
.
total_tokens
>
0
def
test_completion_stream
(
args
):
client
=
openai
.
Client
(
api_key
=
"EMPTY"
,
base_url
=
args
.
base_url
)
response
=
client
.
completions
.
create
(
model
=
"default"
,
prompt
=
"The capital of France is"
,
temperature
=
0
,
max_tokens
=
32
,
stream
=
True
,
)
for
r
in
response
:
print
(
r
.
choices
[
0
].
text
,
end
=
""
,
flush
=
True
)
assert
r
.
id
assert
r
.
created
assert
r
.
usage
.
prompt_tokens
>
0
assert
r
.
usage
.
completion_tokens
>
0
assert
r
.
usage
.
total_tokens
>
0
print
()
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--base-url"
,
type
=
str
,
default
=
"http://127.0.0.1:30000/v1"
)
args
=
parser
.
parse_args
()
test_completion
(
args
)
test_completion_stream
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment