Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
057daef7
Unverified
Commit
057daef7
authored
May 23, 2023
by
Zhuohan Li
Committed by
GitHub
May 23, 2023
Browse files
OpenAI Compatible Frontend (#116)
parent
e8671783
Changes
20
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
645 additions
and
170 deletions
+645
-170
cacheflow/core/block_manager.py
cacheflow/core/block_manager.py
+3
-3
cacheflow/core/scheduler.py
cacheflow/core/scheduler.py
+4
-2
cacheflow/entrypoints/openai/openai_frontend.py
cacheflow/entrypoints/openai/openai_frontend.py
+300
-0
cacheflow/entrypoints/openai/protocol.py
cacheflow/entrypoints/openai/protocol.py
+126
-0
cacheflow/entrypoints/simple_fastapi_frontend.py
cacheflow/entrypoints/simple_fastapi_frontend.py
+51
-0
cacheflow/model_executor/layers/sampler.py
cacheflow/model_executor/layers/sampler.py
+3
-3
cacheflow/outputs.py
cacheflow/outputs.py
+18
-11
cacheflow/sampling_params.py
cacheflow/sampling_params.py
+2
-2
cacheflow/sequence.py
cacheflow/sequence.py
+21
-5
cacheflow/server/async_llm_server.py
cacheflow/server/async_llm_server.py
+21
-46
cacheflow/server/llm_server.py
cacheflow/server/llm_server.py
+8
-5
cacheflow/utils.py
cacheflow/utils.py
+5
-0
examples/gradio_webserver.py
examples/gradio_webserver.py
+7
-6
examples/openai_client.py
examples/openai_client.py
+22
-0
examples/simple_fastapi_client.py
examples/simple_fastapi_client.py
+48
-0
examples/simple_server.py
examples/simple_server.py
+5
-4
playground/http_client.py
playground/http_client.py
+0
-20
playground/streaming_fastapi_worker.py
playground/streaming_fastapi_worker.py
+0
-40
requirements.txt
requirements.txt
+1
-0
test_cli_client.py
test_cli_client.py
+0
-23
No files found.
cacheflow/core/block_manager.py
View file @
057daef7
...
@@ -148,7 +148,7 @@ class BlockSpaceManager:
...
@@ -148,7 +148,7 @@ class BlockSpaceManager:
# the sequences in the same group.
# the sequences in the same group.
blocks
:
Set
[
PhysicalTokenBlock
]
=
set
()
blocks
:
Set
[
PhysicalTokenBlock
]
=
set
()
for
seq
in
seq_group
.
get_seqs
():
for
seq
in
seq_group
.
get_seqs
():
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
if
SequenceStatus
.
is_finished
(
seq
.
status
)
:
continue
continue
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
for
block
in
block_table
:
for
block
in
block_table
:
...
@@ -169,7 +169,7 @@ class BlockSpaceManager:
...
@@ -169,7 +169,7 @@ class BlockSpaceManager:
# CPU block -> GPU block.
# CPU block -> GPU block.
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
for
seq
in
seq_group
.
get_seqs
():
for
seq
in
seq_group
.
get_seqs
():
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
if
SequenceStatus
.
is_finished
(
seq
.
status
)
:
continue
continue
new_block_table
:
BlockTable
=
[]
new_block_table
:
BlockTable
=
[]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
...
@@ -200,7 +200,7 @@ class BlockSpaceManager:
...
@@ -200,7 +200,7 @@ class BlockSpaceManager:
# GPU block -> CPU block.
# GPU block -> CPU block.
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
for
seq
in
seq_group
.
get_seqs
():
for
seq
in
seq_group
.
get_seqs
():
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
if
SequenceStatus
.
is_finished
(
seq
.
status
)
:
continue
continue
new_block_table
:
BlockTable
=
[]
new_block_table
:
BlockTable
=
[]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
...
...
cacheflow/core/scheduler.py
View file @
057daef7
...
@@ -292,10 +292,12 @@ class Scheduler:
...
@@ -292,10 +292,12 @@ class Scheduler:
# Append a new token to the sequence.
# Append a new token to the sequence.
output
=
seq_outputs
[
seq
.
seq_id
]
output
=
seq_outputs
[
seq
.
seq_id
]
seq
.
append_token_id
(
output
.
output_token
,
output
.
logprobs
)
seq
.
append_token_id
(
output
.
output_token
,
output
.
logprobs
)
# Return a shallow copy of the running queue to prevent the queue
# from being modified by the caller.
return
self
.
running
.
copy
()
return
self
.
running
.
copy
()
def
free_seq
(
self
,
seq
:
Sequence
)
->
None
:
def
free_seq
(
self
,
seq
:
Sequence
,
finish_status
:
SequenceStatus
)
->
None
:
seq
.
status
=
SequenceStatus
.
FINISHED
seq
.
status
=
finish_status
self
.
block_manager
.
free
(
seq
)
self
.
block_manager
.
free
(
seq
)
def
free_finished_seq_groups
(
self
)
->
None
:
def
free_finished_seq_groups
(
self
)
->
None
:
...
...
cacheflow/entrypoints/openai/openai_frontend.py
0 → 100644
View file @
057daef7
# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
import
argparse
from
http
import
HTTPStatus
import
json
import
time
from
typing
import
AsyncGenerator
,
Dict
,
List
,
Optional
import
fastapi
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
StreamingResponse
,
JSONResponse
import
uvicorn
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.server.arg_utils
import
ServerArgs
from
cacheflow.server.async_llm_server
import
AsyncLLMServer
from
cacheflow.server.tokenizer_utils
import
get_tokenizer
from
cacheflow.logger
import
init_logger
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.utils
import
random_uuid
from
cacheflow.entrypoints.openai.protocol
import
(
CompletionRequest
,
CompletionResponse
,
CompletionResponseChoice
,
CompletionResponseStreamChoice
,
CompletionStreamResponse
,
ErrorResponse
,
LogProbs
,
ModelCard
,
ModelList
,
ModelPermission
,
UsageInfo
,
)
logger
=
init_logger
(
__name__
)
served_model
=
None
app
=
fastapi
.
FastAPI
()
def
create_error_response
(
status_code
:
HTTPStatus
,
message
:
str
)
->
JSONResponse
:
return
JSONResponse
(
ErrorResponse
(
message
=
message
,
type
=
"invalid_request_error"
).
dict
(),
status_code
=
status_code
.
value
)
@
app
.
exception_handler
(
RequestValidationError
)
async
def
validation_exception_handler
(
request
,
exc
):
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
exc
))
async
def
check_model
(
request
)
->
Optional
[
JSONResponse
]:
if
request
.
model
==
served_model
:
return
ret
=
create_error_response
(
HTTPStatus
.
NOT_FOUND
,
f
"The model `
{
request
.
model
}
` does not exist."
,
)
return
ret
@
app
.
get
(
"/v1/models"
)
async
def
show_available_models
():
"""Show available models. Right now we only have one model."""
model_cards
=
[
ModelCard
(
id
=
served_model
,
root
=
served_model
,
permission
=
[
ModelPermission
()])]
return
ModelList
(
data
=
model_cards
)
def
create_logprobs
(
token_ids
:
List
[
int
],
id_logprobs
:
List
[
Dict
[
int
,
float
]],
initial_text_offset
:
int
=
0
)
->
LogProbs
:
"""Create OpenAI-style logprobs."""
logprobs
=
LogProbs
()
last_token_len
=
0
for
token_id
,
id_logprob
in
zip
(
token_ids
,
id_logprobs
):
token
=
tokenizer
.
convert_ids_to_tokens
(
token_id
)
logprobs
.
tokens
.
append
(
token
)
logprobs
.
token_logprobs
.
append
(
id_logprob
[
token_id
])
if
len
(
logprobs
.
text_offset
)
==
0
:
logprobs
.
text_offset
.
append
(
initial_text_offset
)
else
:
logprobs
.
text_offset
.
append
(
logprobs
.
text_offset
[
-
1
]
+
last_token_len
)
last_token_len
=
len
(
token
)
logprobs
.
top_logprobs
.
append
(
{
tokenizer
.
convert_ids_to_tokens
(
i
):
p
for
i
,
p
in
id_logprob
.
items
()})
return
logprobs
@
app
.
post
(
"/v1/completions"
)
async
def
create_completion
(
request
:
CompletionRequest
):
logger
.
info
(
f
"Received completion request:
{
request
}
"
)
error_check_ret
=
await
check_model
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
if
request
.
echo
:
# We do not support echo since the cacheflow server does not
# currently support getting the logprobs of prompt tokens.
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
"echo is not currently supported"
)
if
request
.
suffix
is
not
None
:
# The language models we currently support do not support suffix.
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
"suffix is not currently supported"
)
if
request
.
logit_bias
is
not
None
:
# TODO: support logit_bias in cacheflow server.
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
"logit_bias is not currently supported"
)
model_name
=
request
.
model
request_id
=
f
"cmpl-
{
random_uuid
()
}
"
prompt
=
request
.
prompt
created_time
=
int
(
time
.
time
())
try
:
sampling_params
=
SamplingParams
(
n
=
request
.
n
,
best_of
=
request
.
best_of
,
presence_penalty
=
request
.
presence_penalty
,
frequency_penalty
=
request
.
frequency_penalty
,
temperature
=
request
.
temperature
,
top_p
=
request
.
top_p
,
top_k
=
request
.
top_k
,
stop
=
request
.
stop
,
ignore_eos
=
request
.
ignore_eos
,
max_tokens
=
request
.
max_tokens
,
logprobs
=
request
.
logprobs
,
use_beam_search
=
request
.
use_beam_search
,
)
except
ValueError
as
e
:
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
e
))
result_generator
=
server
.
generate
(
prompt
,
sampling_params
,
request_id
=
request_id
)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
stream
=
(
request
.
stream
and
(
request
.
best_of
is
None
or
request
.
n
==
request
.
best_of
)
and
not
request
.
use_beam_search
)
def
create_stream_response_json
(
index
:
int
,
text
:
str
,
logprobs
:
Optional
[
LogProbs
]
=
None
,
finish_reason
:
Optional
[
str
]
=
None
)
->
str
:
choice_data
=
CompletionResponseStreamChoice
(
index
=
index
,
text
=
text
,
logprobs
=
logprobs
,
finish_reason
=
finish_reason
,
)
response
=
CompletionStreamResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
choices
=
[
choice_data
],
)
response_json
=
response
.
json
(
ensure_ascii
=
False
)
return
response_json
async
def
completion_stream_generator
()
->
AsyncGenerator
[
str
,
None
]:
previous_texts
=
[
""
]
*
request
.
n
previous_num_tokens
=
[
0
]
*
request
.
n
async
for
res
in
result_generator
:
res
:
RequestOutput
for
output
in
res
.
outputs
:
i
=
output
.
index
delta_text
=
output
.
text
[
len
(
previous_texts
[
i
]):]
if
request
.
logprobs
is
not
None
:
logprobs
=
create_logprobs
(
output
.
token_ids
[
previous_num_tokens
[
i
]:],
output
.
logprobs
[
previous_num_tokens
[
i
]:],
len
(
previous_texts
[
i
]))
else
:
logprobs
=
None
previous_texts
[
i
]
=
output
.
text
previous_num_tokens
[
i
]
=
len
(
output
.
token_ids
)
response_json
=
create_stream_response_json
(
index
=
i
,
text
=
delta_text
,
logprobs
=
logprobs
,
)
yield
f
"data:
{
response_json
}
\n\n
"
if
output
.
finish_reason
is
not
None
:
logprobs
=
LogProbs
()
if
request
.
logprobs
is
not
None
else
None
response_json
=
create_stream_response_json
(
index
=
i
,
text
=
""
,
logprobs
=
logprobs
,
finish_reason
=
output
.
finish_reason
,
)
yield
f
"data:
{
response_json
}
\n\n
"
yield
"data: [DONE]
\n\n
"
# Streaming response
if
stream
:
return
StreamingResponse
(
completion_stream_generator
(),
media_type
=
"text/event-stream"
)
# Non-streaming response
final_res
:
RequestOutput
=
None
async
for
res
in
result_generator
:
final_res
=
res
assert
final_res
is
not
None
choices
=
[]
for
output
in
final_res
.
outputs
:
if
request
.
logprobs
is
not
None
:
logprobs
=
create_logprobs
(
output
.
token_ids
,
output
.
logprobs
)
else
:
logprobs
=
None
choice_data
=
CompletionResponseChoice
(
index
=
output
.
index
,
text
=
output
.
text
,
logprobs
=
logprobs
,
finish_reason
=
output
.
finish_reason
,
)
choices
.
append
(
choice_data
)
num_prompt_tokens
=
len
(
final_res
.
prompt_token_ids
)
num_generated_tokens
=
sum
(
len
(
output
.
token_ids
)
for
output
in
final_res
.
outputs
)
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
completion_tokens
=
num_generated_tokens
,
total_tokens
=
num_prompt_tokens
+
num_generated_tokens
,
)
response
=
CompletionResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
choices
=
choices
,
usage
=
usage
,
)
if
request
.
stream
:
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
response_json
=
response
.
json
(
ensure_ascii
=
False
)
async
def
fake_stream_generator
()
->
AsyncGenerator
[
str
,
None
]:
yield
f
"data:
{
response_json
}
\n\n
"
yield
"data: [DONE]
\n\n
"
return
StreamingResponse
(
fake_stream_generator
(),
media_type
=
"text/event-stream"
)
return
response
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"CacheFlow OpenAI-Compatible RESTful API server."
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
,
help
=
"host name"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
,
help
=
"port number"
)
parser
.
add_argument
(
"--allow-credentials"
,
action
=
"store_true"
,
help
=
"allow credentials"
)
parser
.
add_argument
(
"--allowed-origins"
,
type
=
json
.
loads
,
default
=
[
"*"
],
help
=
"allowed origins"
)
parser
.
add_argument
(
"--allowed-methods"
,
type
=
json
.
loads
,
default
=
[
"*"
],
help
=
"allowed methods"
)
parser
.
add_argument
(
"--allowed-headers"
,
type
=
json
.
loads
,
default
=
[
"*"
],
help
=
"allowed headers"
)
parser
.
add_argument
(
"--served-model-name"
,
type
=
str
,
default
=
None
,
help
=
"The model name used in the API. If not specified, "
"the model name will be the same as the "
"huggingface name."
)
parser
=
ServerArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
app
.
add_middleware
(
CORSMiddleware
,
allow_origins
=
args
.
allowed_origins
,
allow_credentials
=
args
.
allow_credentials
,
allow_methods
=
args
.
allowed_methods
,
allow_headers
=
args
.
allowed_headers
,
)
logger
.
info
(
f
"args:
{
args
}
"
)
served_model
=
args
.
served_model_name
or
args
.
model
server_args
=
ServerArgs
.
from_cli_args
(
args
)
server
=
AsyncLLMServer
.
from_server_args
(
server_args
)
# A separate tokenizer to map token IDs to strings.
tokenizer
=
get_tokenizer
(
args
.
model
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"info"
)
cacheflow/entrypoints/openai/protocol.py
0 → 100644
View file @
057daef7
# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import
time
from
typing
import
Dict
,
List
,
Literal
,
Optional
,
Union
from
pydantic
import
BaseModel
,
Field
from
cacheflow.utils
import
random_uuid
class
ErrorResponse
(
BaseModel
):
object
:
str
=
"error"
message
:
str
type
:
str
param
:
Optional
[
str
]
=
None
code
:
Optional
[
str
]
=
None
class
ModelPermission
(
BaseModel
):
id
:
str
=
Field
(
default_factory
=
lambda
:
f
"modelperm-
{
random_uuid
()
}
"
)
object
:
str
=
"model_permission"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
allow_create_engine
:
bool
=
False
allow_sampling
:
bool
=
True
allow_logprobs
:
bool
=
True
allow_search_indices
:
bool
=
False
allow_view
:
bool
=
True
allow_fine_tuning
:
bool
=
False
organization
:
str
=
"*"
group
:
Optional
[
str
]
=
None
is_blocking
:
str
=
False
class
ModelCard
(
BaseModel
):
id
:
str
object
:
str
=
"model"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
owned_by
:
str
=
"cacheflow"
root
:
Optional
[
str
]
=
None
parent
:
Optional
[
str
]
=
None
permission
:
List
[
ModelPermission
]
=
Field
(
default_factory
=
list
)
class
ModelList
(
BaseModel
):
object
:
str
=
"list"
data
:
List
[
ModelCard
]
=
Field
(
default_factory
=
list
)
class
UsageInfo
(
BaseModel
):
prompt_tokens
:
int
=
0
total_tokens
:
int
=
0
completion_tokens
:
Optional
[
int
]
=
0
class
ChatCompletionRequest
(
BaseModel
):
model
:
str
messages
:
List
[
Dict
[
str
,
str
]]
temperature
:
Optional
[
float
]
=
0.7
top_p
:
Optional
[
float
]
=
1.0
n
:
Optional
[
int
]
=
1
max_tokens
:
Optional
[
int
]
=
None
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
stream
:
Optional
[
bool
]
=
False
presence_penalty
:
Optional
[
float
]
=
0.0
frequency_penalty
:
Optional
[
float
]
=
0.0
user
:
Optional
[
str
]
=
None
class
CompletionRequest
(
BaseModel
):
model
:
str
prompt
:
str
suffix
:
Optional
[
str
]
=
None
max_tokens
:
Optional
[
int
]
=
16
temperature
:
Optional
[
float
]
=
1.0
top_p
:
Optional
[
float
]
=
1.0
n
:
Optional
[
int
]
=
1
stream
:
Optional
[
bool
]
=
False
logprobs
:
Optional
[
int
]
=
None
echo
:
Optional
[
bool
]
=
False
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default_factory
=
list
)
presence_penalty
:
Optional
[
float
]
=
0.0
frequency_penalty
:
Optional
[
float
]
=
0.0
best_of
:
Optional
[
int
]
=
None
logit_bias
:
Optional
[
Dict
[
str
,
float
]]
=
None
user
:
Optional
[
str
]
=
None
# Additional parameters supported by cacheflow
top_k
:
Optional
[
int
]
=
-
1
ignore_eos
:
Optional
[
bool
]
=
False
use_beam_search
:
Optional
[
bool
]
=
False
class
LogProbs
(
BaseModel
):
text_offset
:
List
[
int
]
=
Field
(
default_factory
=
list
)
token_logprobs
:
List
[
Optional
[
float
]]
=
Field
(
default_factory
=
list
)
tokens
:
List
[
str
]
=
Field
(
default_factory
=
list
)
top_logprobs
:
List
[
Optional
[
Dict
[
str
,
float
]]]
=
Field
(
default_factory
=
list
)
class
CompletionResponseChoice
(
BaseModel
):
index
:
int
text
:
str
logprobs
:
Optional
[
LogProbs
]
=
None
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
]]
=
None
class
CompletionResponse
(
BaseModel
):
id
:
str
=
Field
(
default_factory
=
lambda
:
f
"cmpl-
{
random_uuid
()
}
"
)
object
:
str
=
"text_completion"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
model
:
str
choices
:
List
[
CompletionResponseChoice
]
usage
:
UsageInfo
class
CompletionResponseStreamChoice
(
BaseModel
):
index
:
int
text
:
str
logprobs
:
Optional
[
LogProbs
]
=
None
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
]]
=
None
class
CompletionStreamResponse
(
BaseModel
):
id
:
str
=
Field
(
default_factory
=
lambda
:
f
"cmpl-
{
random_uuid
()
}
"
)
object
:
str
=
"text_completion"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
model
:
str
choices
:
List
[
CompletionResponseStreamChoice
]
cacheflow/entrypoints/simple_fastapi_frontend.py
0 → 100644
View file @
057daef7
import
argparse
import
json
from
typing
import
AsyncGenerator
from
fastapi
import
FastAPI
,
Request
from
fastapi.responses
import
StreamingResponse
import
uvicorn
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.server.arg_utils
import
ServerArgs
from
cacheflow.server.async_llm_server
import
AsyncLLMServer
from
cacheflow.server.ray_utils
import
initialize_cluster
TIMEOUT_TO_PREVENT_DEADLOCK
=
1
# seconds
app
=
FastAPI
()
@
app
.
post
(
"/generate"
)
async
def
generate_stream
(
request
:
Request
)
->
StreamingResponse
:
request_dict
=
await
request
.
json
()
prompt
=
request_dict
.
pop
(
"prompt"
)
sampling_params
=
SamplingParams
(
**
request_dict
)
results_generator
=
server
.
generate
(
prompt
,
sampling_params
)
async
def
stream_results
()
->
AsyncGenerator
[
bytes
,
None
]:
async
for
request_output
in
results_generator
:
prompt
=
request_output
.
prompt
text_outputs
=
[
prompt
+
output
.
text
for
output
in
request_output
.
outputs
]
ret
=
{
"text"
:
text_outputs
,
"error"
:
0
,
}
yield
(
json
.
dumps
(
ret
)
+
"
\0
"
).
encode
(
"utf-8"
)
return
StreamingResponse
(
stream_results
())
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8001
)
parser
=
ServerArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
server_args
=
ServerArgs
.
from_cli_args
(
args
)
server
=
AsyncLLMServer
.
from_server_args
(
server_args
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"info"
)
cacheflow/model_executor/layers/sampler.py
View file @
057daef7
"""A layer that samples the next tokens from the model's outputs."""
"""A layer that samples the next tokens from the model's outputs."""
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Tuple
,
Optional
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -258,9 +258,9 @@ def _apply_top_p_top_k(
...
@@ -258,9 +258,9 @@ def _apply_top_p_top_k(
def
_get_topk_logprobs
(
def
_get_topk_logprobs
(
logprobs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
num_logprobs
:
int
,
num_logprobs
:
Optional
[
int
]
,
)
->
Dict
[
int
,
float
]:
)
->
Dict
[
int
,
float
]:
if
num_logprobs
==
0
:
if
num_logprobs
is
None
or
num_logprobs
==
0
:
return
{}
return
{}
topk_logprobs
,
topk_ids
=
torch
.
topk
(
logprobs
,
num_logprobs
)
topk_logprobs
,
topk_ids
=
torch
.
topk
(
logprobs
,
num_logprobs
)
...
...
cacheflow/outputs.py
View file @
057daef7
from
typing
import
Dict
,
List
from
typing
import
Dict
,
List
,
Optional
from
cacheflow.sequence
import
SequenceGroup
from
cacheflow.sequence
import
SequenceGroup
,
SequenceStatus
class
CompletionOutput
:
class
CompletionOutput
:
...
@@ -12,19 +12,25 @@ class CompletionOutput:
...
@@ -12,19 +12,25 @@ class CompletionOutput:
token_ids
:
List
[
int
],
token_ids
:
List
[
int
],
cumulative_logprob
:
float
,
cumulative_logprob
:
float
,
logprobs
:
List
[
Dict
[
int
,
float
]],
logprobs
:
List
[
Dict
[
int
,
float
]],
finish_reason
:
Optional
[
str
]
=
None
,
)
->
None
:
)
->
None
:
self
.
index
=
index
self
.
index
=
index
self
.
text
=
text
self
.
text
=
text
self
.
token_ids
=
token_ids
self
.
token_ids
=
token_ids
self
.
cumulative_logprob
=
cumulative_logprob
self
.
cumulative_logprob
=
cumulative_logprob
self
.
logprobs
=
logprobs
self
.
logprobs
=
logprobs
self
.
finish_reason
=
finish_reason
def
finished
(
self
)
->
bool
:
return
self
.
finish_reason
is
not
None
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"CompletionOutput(index=
{
self
.
index
}
, "
return
(
f
"CompletionOutput(index=
{
self
.
index
}
, "
f
"text=
{
self
.
text
!
r
}
, "
f
"text=
{
self
.
text
!
r
}
, "
f
"token_ids=
{
self
.
token_ids
}
, "
f
"token_ids=
{
self
.
token_ids
}
, "
f
"cumulative_logprob=
{
self
.
cumulative_logprob
}
, "
f
"cumulative_logprob=
{
self
.
cumulative_logprob
}
, "
f
"logprobs=
{
self
.
logprobs
}
)"
)
f
"logprobs=
{
self
.
logprobs
}
,"
f
"finish_reason=
{
self
.
finish_reason
}
)"
)
class
RequestOutput
:
class
RequestOutput
:
...
@@ -35,13 +41,11 @@ class RequestOutput:
...
@@ -35,13 +41,11 @@ class RequestOutput:
prompt
:
str
,
prompt
:
str
,
prompt_token_ids
:
List
[
int
],
prompt_token_ids
:
List
[
int
],
outputs
:
List
[
CompletionOutput
],
outputs
:
List
[
CompletionOutput
],
done
:
bool
,
)
->
None
:
)
->
None
:
self
.
request_id
=
request_id
self
.
request_id
=
request_id
self
.
prompt
=
prompt
self
.
prompt
=
prompt
self
.
prompt_token_ids
=
prompt_token_ids
self
.
prompt_token_ids
=
prompt_token_ids
self
.
outputs
=
outputs
self
.
outputs
=
outputs
self
.
done
=
done
@
classmethod
@
classmethod
def
from_seq_group
(
cls
,
seq_group
:
SequenceGroup
)
->
"RequestOutput"
:
def
from_seq_group
(
cls
,
seq_group
:
SequenceGroup
)
->
"RequestOutput"
:
...
@@ -57,25 +61,28 @@ class RequestOutput:
...
@@ -57,25 +61,28 @@ class RequestOutput:
outputs
:
List
[
CompletionOutput
]
=
[]
outputs
:
List
[
CompletionOutput
]
=
[]
for
seq
in
top_n_seqs
:
for
seq
in
top_n_seqs
:
logprobs
=
seq
.
output_logprobs
logprobs
=
seq
.
output_logprobs
if
seq_group
.
sampling_params
.
logprobs
==
0
:
if
seq_group
.
sampling_params
.
logprobs
is
None
:
# NOTE: We need to take care of this case because the sequence
# NOTE: We need to take care of this case because the sequence
# always has the logprobs of the sampled tokens even if the
# always has the logprobs of the sampled tokens even if the
# logprobs are not requested.
# logprobs are not requested.
logprobs
=
{}
logprobs
=
{}
finshed_reason
=
SequenceStatus
.
get_finished_reason
(
seq
.
status
)
output
=
CompletionOutput
(
seqs
.
index
(
seq
),
seq
.
output_text
,
output
=
CompletionOutput
(
seqs
.
index
(
seq
),
seq
.
output_text
,
seq
.
get_output_token_ids
(),
seq
.
get_output_token_ids
(),
seq
.
get_cumulative_logprob
(),
logprobs
)
seq
.
get_cumulative_logprob
(),
logprobs
,
finshed_reason
)
outputs
.
append
(
output
)
outputs
.
append
(
output
)
# Every sequence in the sequence group should have the same prompt.
# Every sequence in the sequence group should have the same prompt.
prompt
=
top_n_seqs
[
0
].
prompt
prompt
=
top_n_seqs
[
0
].
prompt
prompt_token_ids
=
top_n_seqs
[
0
].
data
.
prompt_token_ids
prompt_token_ids
=
top_n_seqs
[
0
].
data
.
prompt_token_ids
return
cls
(
seq_group
.
request_id
,
prompt
,
prompt_token_ids
,
outputs
,
return
cls
(
seq_group
.
request_id
,
prompt
,
prompt_token_ids
,
outputs
)
seq_group
.
is_finished
())
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"RequestOutput(request_id=
{
self
.
request_id
}
, "
return
(
f
"RequestOutput(request_id=
{
self
.
request_id
}
, "
f
"prompt=
{
self
.
prompt
!
r
}
, "
f
"prompt=
{
self
.
prompt
!
r
}
, "
f
"prompt_token_ids=
{
self
.
prompt_token_ids
}
, "
f
"prompt_token_ids=
{
self
.
prompt_token_ids
}
, "
f
"outputs=
{
self
.
outputs
}
, "
f
"outputs=
{
self
.
outputs
}
)"
)
f
"done=
{
self
.
done
}
)"
)
def
finished
(
self
)
->
bool
:
return
all
(
output
.
finished
()
for
output
in
self
.
outputs
)
cacheflow/sampling_params.py
View file @
057daef7
...
@@ -53,7 +53,7 @@ class SamplingParams:
...
@@ -53,7 +53,7 @@ class SamplingParams:
stop
:
Union
[
str
,
List
[
str
]]
=
[],
stop
:
Union
[
str
,
List
[
str
]]
=
[],
ignore_eos
:
bool
=
False
,
ignore_eos
:
bool
=
False
,
max_tokens
:
int
=
16
,
max_tokens
:
int
=
16
,
logprobs
:
int
=
0
,
logprobs
:
Optional
[
int
]
=
None
,
)
->
None
:
)
->
None
:
self
.
n
=
n
self
.
n
=
n
self
.
best_of
=
best_of
if
best_of
is
not
None
else
n
self
.
best_of
=
best_of
if
best_of
is
not
None
else
n
...
@@ -98,7 +98,7 @@ class SamplingParams:
...
@@ -98,7 +98,7 @@ class SamplingParams:
if
self
.
max_tokens
<
1
:
if
self
.
max_tokens
<
1
:
raise
ValueError
(
raise
ValueError
(
f
"max_tokens must be at least 1, got
{
self
.
max_tokens
}
."
)
f
"max_tokens must be at least 1, got
{
self
.
max_tokens
}
."
)
if
self
.
logprobs
<
0
:
if
self
.
logprobs
is
not
None
and
self
.
logprobs
<
0
:
raise
ValueError
(
raise
ValueError
(
f
"logprobs must be non-negative, got
{
self
.
logprobs
}
."
)
f
"logprobs must be non-negative, got
{
self
.
logprobs
}
."
)
...
...
cacheflow/sequence.py
View file @
057daef7
import
copy
import
copy
import
enum
import
enum
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
,
Union
from
cacheflow.block
import
LogicalTokenBlock
from
cacheflow.block
import
LogicalTokenBlock
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
...
@@ -10,8 +10,25 @@ class SequenceStatus(enum.Enum):
...
@@ -10,8 +10,25 @@ class SequenceStatus(enum.Enum):
WAITING
=
enum
.
auto
()
WAITING
=
enum
.
auto
()
RUNNING
=
enum
.
auto
()
RUNNING
=
enum
.
auto
()
SWAPPED
=
enum
.
auto
()
SWAPPED
=
enum
.
auto
()
FINISHED
=
enum
.
auto
()
FINISHED_STOPPED
=
enum
.
auto
()
FINISHED_LENGTH_CAPPED
=
enum
.
auto
()
@
staticmethod
def
is_finished
(
status
:
"SequenceStatus"
)
->
bool
:
return
status
in
[
SequenceStatus
.
FINISHED_STOPPED
,
SequenceStatus
.
FINISHED_LENGTH_CAPPED
,
]
@
staticmethod
def
get_finished_reason
(
status
:
"SequenceStatus"
)
->
Union
[
str
,
None
]:
if
status
==
SequenceStatus
.
FINISHED_STOPPED
:
finish_reason
=
"stop"
elif
status
==
SequenceStatus
.
FINISHED_LENGTH_CAPPED
:
finish_reason
=
"length"
else
:
finish_reason
=
None
return
finish_reason
class
SequenceData
:
class
SequenceData
:
...
@@ -20,7 +37,6 @@ class SequenceData:
...
@@ -20,7 +37,6 @@ class SequenceData:
prompt_token_ids
:
List
[
int
],
prompt_token_ids
:
List
[
int
],
)
->
None
:
)
->
None
:
self
.
prompt_token_ids
=
prompt_token_ids
self
.
prompt_token_ids
=
prompt_token_ids
self
.
output_token_ids
:
List
[
int
]
=
[]
self
.
output_token_ids
:
List
[
int
]
=
[]
self
.
cumulative_logprob
=
0.0
self
.
cumulative_logprob
=
0.0
...
@@ -166,7 +182,7 @@ class SequenceGroup:
...
@@ -166,7 +182,7 @@ class SequenceGroup:
raise
ValueError
(
f
'Sequence
{
seq_id
}
not found.'
)
raise
ValueError
(
f
'Sequence
{
seq_id
}
not found.'
)
def
is_finished
(
self
)
->
bool
:
def
is_finished
(
self
)
->
bool
:
return
all
(
seq
.
status
==
SequenceStatus
.
FINISHED
for
seq
in
self
.
seqs
)
return
all
(
SequenceStatus
.
is_finished
(
seq
.
status
)
for
seq
in
self
.
seqs
)
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"SequenceGroup(request_id=
{
self
.
request_id
}
, "
return
(
f
"SequenceGroup(request_id=
{
self
.
request_id
}
, "
...
...
cacheflow/
entrypoints/fastapi
_server.py
→
cacheflow/
server/async_llm
_server.py
View file @
057daef7
import
argparse
import
asyncio
import
asyncio
import
json
import
time
import
time
from
typing
import
Any
,
Dict
from
typing
import
Dict
,
Optional
import
uuid
from
fastapi
import
FastAPI
,
Request
from
fastapi.responses
import
StreamingResponse
import
ray
import
ray
import
uvicorn
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.server.arg_utils
import
ServerArgs
from
cacheflow.server.arg_utils
import
ServerArgs
from
cacheflow.server.llm_server
import
LLMServer
from
cacheflow.server.llm_server
import
LLMServer
from
cacheflow.server.ray_utils
import
initialize_cluster
from
cacheflow.server.ray_utils
import
initialize_cluster
from
cacheflow.utils
import
random_uuid
TIMEOUT_TO_PREVENT_DEADLOCK
=
1
# seconds
TIMEOUT_TO_PREVENT_DEADLOCK
=
1
# seconds
app
=
FastAPI
()
class
FastAPI
Server
:
class
AsyncLLM
Server
:
def
__init__
(
self
,
server_use_ray
:
bool
,
*
args
,
**
kwargs
)
->
None
:
def
__init__
(
self
,
server_use_ray
:
bool
,
*
args
,
**
kwargs
)
->
None
:
if
server_use_ray
:
if
server_use_ray
:
...
@@ -45,15 +39,15 @@ class FastAPIServer:
...
@@ -45,15 +39,15 @@ class FastAPIServer:
self
.
request_outputs
[
request_id
]
=
request_output
self
.
request_outputs
[
request_id
]
=
request_output
self
.
request_events
[
request_id
].
set
()
self
.
request_events
[
request_id
].
set
()
async
def
generate
(
self
,
request_dict
:
Dict
[
str
,
Any
]):
async
def
generate
(
self
,
prompt
:
str
,
sampling_params
:
SamplingParams
,
request_id
:
Optional
[
str
]
=
None
)
->
RequestOutput
:
# Preprocess the request.
# Preprocess the request.
arrival_time
=
time
.
time
()
arrival_time
=
time
.
time
()
prompt
=
request_dict
.
pop
(
"prompt"
)
sampling_params
=
SamplingParams
(
**
request_dict
)
# Create an event to notify us that there is new output from the
# Create an event to notify us that there is new output from the
# cacheflow server.
# cacheflow server.
request_id
=
str
(
uuid
.
uuid4
().
hex
[:
8
])
if
request_id
is
None
:
request_id
=
random_uuid
()
request_event
=
asyncio
.
Event
()
request_event
=
asyncio
.
Event
()
self
.
request_events
[
request_id
]
=
request_event
self
.
request_events
[
request_id
]
=
request_event
...
@@ -82,19 +76,10 @@ class FastAPIServer:
...
@@ -82,19 +76,10 @@ class FastAPIServer:
# Decode and return new outputs.
# Decode and return new outputs.
request_output
=
self
.
request_outputs
[
request_id
]
request_output
=
self
.
request_outputs
[
request_id
]
prompt
=
request_output
.
prompt
yield
request_output
text_outputs
=
[
prompt
+
output
.
text
for
output
in
request_output
.
outputs
]
ret
=
{
"text"
:
text_outputs
,
"error"
:
0
,
}
yield
(
json
.
dumps
(
ret
)
+
"
\0
"
).
encode
(
"utf-8"
)
# Once finished, release the resources of the sequence group.
# Once finished, release the resources of the sequence group.
if
request_output
.
done
:
if
request_output
.
finished
()
:
del
self
.
request_outputs
[
request_id
]
del
self
.
request_outputs
[
request_id
]
del
self
.
request_events
[
request_id
]
del
self
.
request_events
[
request_id
]
# Kick the server if the server is not running. This is to
# Kick the server if the server is not running. This is to
...
@@ -104,25 +89,15 @@ class FastAPIServer:
...
@@ -104,25 +89,15 @@ class FastAPIServer:
await
self
.
server_step
()
await
self
.
server_step
()
break
break
@
classmethod
@
app
.
post
(
"/generate"
)
def
from_server_args
(
cls
,
server_args
:
ServerArgs
)
->
"AsyncLLMServer"
:
async
def
generate_stream
(
request
:
Request
):
# Create the server configs.
request_dict
=
await
request
.
json
()
server_configs
=
server_args
.
create_server_configs
()
return
StreamingResponse
(
server
.
generate
(
request_dict
))
parallel_config
=
server_configs
[
2
]
# Initialize the cluster.
distributed_init_method
,
devices
=
initialize_cluster
(
parallel_config
)
if
__name__
==
"__main__"
:
# Create the LLM server.
parser
=
argparse
.
ArgumentParser
()
server
=
cls
(
server_args
.
use_ray
,
*
server_configs
,
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
)
distributed_init_method
,
devices
,
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
10002
)
log_stats
=
not
server_args
.
disable_log_stats
)
parser
=
ServerArgs
.
add_cli_args
(
parser
)
return
server
args
=
parser
.
parse_args
()
server_configs
=
ServerArgs
.
from_cli_args
(
args
).
create_server_configs
()
parallel_config
=
server_configs
[
2
]
distributed_init_method
,
stage_devices
=
initialize_cluster
(
parallel_config
)
server
=
FastAPIServer
(
args
.
use_ray
,
*
server_configs
,
distributed_init_method
,
stage_devices
,
log_stats
=
not
args
.
disable_log_stats
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"info"
)
cacheflow/server/llm_server.py
View file @
057daef7
...
@@ -210,7 +210,8 @@ class LLMServer:
...
@@ -210,7 +210,8 @@ class LLMServer:
# Truncate the output text so that the stop string is
# Truncate the output text so that the stop string is
# not included in the output.
# not included in the output.
seq
.
output_text
=
seq
.
output_text
[:
-
len
(
stop_str
)]
seq
.
output_text
=
seq
.
output_text
[:
-
len
(
stop_str
)]
self
.
scheduler
.
free_seq
(
seq
)
self
.
scheduler
.
free_seq
(
seq
,
SequenceStatus
.
FINISHED_STOPPED
)
stopped
=
True
stopped
=
True
break
break
if
stopped
:
if
stopped
:
...
@@ -218,12 +219,14 @@ class LLMServer:
...
@@ -218,12 +219,14 @@ class LLMServer:
# Check if the sequence has reached max_tokens.
# Check if the sequence has reached max_tokens.
if
seq
.
get_output_len
()
==
sampling_params
.
max_tokens
:
if
seq
.
get_output_len
()
==
sampling_params
.
max_tokens
:
self
.
scheduler
.
free_seq
(
seq
)
self
.
scheduler
.
free_seq
(
seq
,
SequenceStatus
.
FINISHED_LENGTH_CAPPED
)
continue
continue
# Check if the sequence has generated the EOS token.
# Check if the sequence has generated the EOS token.
if
not
sampling_params
.
ignore_eos
:
if
not
sampling_params
.
ignore_eos
:
if
seq
.
get_last_token_id
()
==
self
.
tokenizer
.
eos_token_id
:
if
seq
.
get_last_token_id
()
==
self
.
tokenizer
.
eos_token_id
:
self
.
scheduler
.
free_seq
(
seq
)
self
.
scheduler
.
free_seq
(
seq
,
SequenceStatus
.
FINISHED_STOPPED
)
continue
continue
def
_run_workers
(
def
_run_workers
(
...
@@ -238,10 +241,10 @@ class LLMServer:
...
@@ -238,10 +241,10 @@ class LLMServer:
executor
=
getattr
(
worker
,
method
)
executor
=
getattr
(
worker
,
method
)
if
self
.
parallel_config
.
use_ray
:
if
self
.
parallel_config
.
use_ray
:
executor
=
executor
.
remote
executor
=
executor
.
remote
output
=
executor
(
*
args
,
**
kwargs
)
output
=
executor
(
*
args
,
**
kwargs
)
all_outputs
.
append
(
output
)
all_outputs
.
append
(
output
)
if
self
.
parallel_config
.
use_ray
:
if
self
.
parallel_config
.
use_ray
:
all_outputs
=
ray
.
get
(
all_outputs
)
all_outputs
=
ray
.
get
(
all_outputs
)
...
...
cacheflow/utils.py
View file @
057daef7
import
enum
import
enum
import
uuid
import
psutil
import
psutil
import
torch
import
torch
...
@@ -31,3 +32,7 @@ def get_gpu_memory(gpu: int = 0) -> int:
...
@@ -31,3 +32,7 @@ def get_gpu_memory(gpu: int = 0) -> int:
def
get_cpu_memory
()
->
int
:
def
get_cpu_memory
()
->
int
:
"""Returns the total CPU memory of the node in bytes."""
"""Returns the total CPU memory of the node in bytes."""
return
psutil
.
virtual_memory
().
total
return
psutil
.
virtual_memory
().
total
def
random_uuid
()
->
str
:
return
str
(
uuid
.
uuid4
().
hex
)
gradio_webserver.py
→
examples/
gradio_webserver.py
View file @
057daef7
import
argparse
import
argparse
import
json
import
json
import
time
import
gradio
as
gr
import
gradio
as
gr
import
requests
import
requests
...
@@ -24,9 +23,9 @@ def http_bot(prompt):
...
@@ -24,9 +23,9 @@ def http_bot(prompt):
def
build_demo
():
def
build_demo
():
with
gr
.
Blocks
()
as
demo
:
with
gr
.
Blocks
()
as
demo
:
gr
.
Markdown
(
gr
.
Markdown
(
"# Cacheflow demo
\n
"
"# Cacheflow
text completion
demo
\n
"
)
)
inputbox
=
gr
.
Textbox
(
label
=
"Input"
,
placeholder
=
"Enter text and press ENTER"
)
# .style(container=False)
inputbox
=
gr
.
Textbox
(
label
=
"Input"
,
placeholder
=
"Enter text and press ENTER"
)
outputbox
=
gr
.
Textbox
(
label
=
"Output"
,
placeholder
=
"Generated result from the model"
)
outputbox
=
gr
.
Textbox
(
label
=
"Output"
,
placeholder
=
"Generated result from the model"
)
inputbox
.
submit
(
http_bot
,
[
inputbox
],
[
outputbox
])
inputbox
.
submit
(
http_bot
,
[
inputbox
],
[
outputbox
])
return
demo
return
demo
...
@@ -35,9 +34,11 @@ def build_demo():
...
@@ -35,9 +34,11 @@ def build_demo():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
1
00
03
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8
00
2
)
parser
.
add_argument
(
"--model-url"
,
type
=
str
,
default
=
"http://localhost:
1
00
02
/generate"
)
parser
.
add_argument
(
"--model-url"
,
type
=
str
,
default
=
"http://localhost:
8
00
1
/generate"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
demo
=
build_demo
()
demo
=
build_demo
()
demo
.
queue
(
concurrency_count
=
100
).
launch
(
server_name
=
args
.
host
,
server_port
=
args
.
port
)
demo
.
queue
(
concurrency_count
=
100
).
launch
(
server_name
=
args
.
host
,
\ No newline at end of file
server_port
=
args
.
port
,
share
=
True
)
examples/openai_client.py
0 → 100644
View file @
057daef7
import
openai
openai
.
api_key
=
"EMPTY"
openai
.
api_base
=
"http://localhost:8000/v1"
model
=
"facebook/opt-125m"
# list models
models
=
openai
.
Model
.
list
()
print
(
models
)
# create a completion
stream
=
True
completion
=
openai
.
Completion
.
create
(
model
=
model
,
prompt
=
"A robot may not injure a human being"
,
echo
=
False
,
n
=
2
,
best_of
=
3
,
stream
=
stream
,
logprobs
=
3
)
# print the completion
if
stream
:
for
c
in
completion
:
print
(
c
)
else
:
print
(
"completion:"
,
completion
)
examples/simple_fastapi_client.py
0 → 100644
View file @
057daef7
import
argparse
import
requests
import
json
def
clear_line
(
n
=
1
):
LINE_UP
=
'
\033
[1A'
LINE_CLEAR
=
'
\x1b
[2K'
for
i
in
range
(
n
):
print
(
LINE_UP
,
end
=
LINE_CLEAR
,
flush
=
True
)
def
http_request
(
prompt
:
str
,
api_url
:
str
,
n
:
int
=
1
):
headers
=
{
"User-Agent"
:
"Test Client"
}
pload
=
{
"prompt"
:
prompt
,
"n"
:
n
,
"use_beam_search"
:
True
,
"temperature"
:
0.0
,
"max_tokens"
:
16
,
}
response
=
requests
.
post
(
api_url
,
headers
=
headers
,
json
=
pload
,
stream
=
True
)
for
chunk
in
response
.
iter_lines
(
chunk_size
=
8192
,
decode_unicode
=
False
,
delimiter
=
b
"
\0
"
):
if
chunk
:
data
=
json
.
loads
(
chunk
.
decode
(
"utf-8"
))
output
=
data
[
"text"
]
yield
output
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8001
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
default
=
"San Francisco is a"
)
args
=
parser
.
parse_args
()
prompt
=
args
.
prompt
api_url
=
f
"http://
{
args
.
host
}
:
{
args
.
port
}
/generate"
n
=
args
.
n
print
(
f
"Prompt:
{
prompt
}
\n
"
,
flush
=
True
)
num_printed_lines
=
0
for
h
in
http_request
(
prompt
,
api_url
,
n
):
clear_line
(
num_printed_lines
)
num_printed_lines
=
0
for
i
,
line
in
enumerate
(
h
):
num_printed_lines
+=
1
print
(
f
"Beam candidate
{
i
}
:
{
line
}
"
,
flush
=
True
)
examples/simple_server.py
View file @
057daef7
import
argparse
import
argparse
import
uuid
from
cacheflow
import
ServerArgs
,
LLMServer
,
SamplingParams
from
cacheflow
import
ServerArgs
,
LLMServer
,
SamplingParams
...
@@ -20,17 +19,19 @@ def main(args: argparse.Namespace):
...
@@ -20,17 +19,19 @@ def main(args: argparse.Namespace):
SamplingParams
(
n
=
3
,
best_of
=
3
,
use_beam_search
=
True
,
temperature
=
0.0
)),
SamplingParams
(
n
=
3
,
best_of
=
3
,
use_beam_search
=
True
,
temperature
=
0.0
)),
]
]
request_id
=
0
# Run the server.
# Run the server.
while
True
:
while
True
:
# To test iteration-level scheduling, we add one request at each step.
# To test iteration-level scheduling, we add one request at each step.
if
test_prompts
:
if
test_prompts
:
prompt
,
sampling_params
=
test_prompts
.
pop
(
0
)
prompt
,
sampling_params
=
test_prompts
.
pop
(
0
)
request_id
=
str
(
uuid
.
uuid4
().
hex
[:
8
]
)
server
.
add_request
(
str
(
request_id
),
prompt
,
sampling_params
)
server
.
add_request
(
request_id
,
prompt
,
sampling_params
)
request_id
+=
1
request_outputs
=
server
.
step
()
request_outputs
=
server
.
step
()
for
request_output
in
request_outputs
:
for
request_output
in
request_outputs
:
if
request_output
.
done
:
if
request_output
.
finished
()
:
print
(
request_output
)
print
(
request_output
)
if
not
(
server
.
has_unfinished_requests
()
or
test_prompts
):
if
not
(
server
.
has_unfinished_requests
()
or
test_prompts
):
...
...
playground/http_client.py
deleted
100644 → 0
View file @
e8671783
import
requests
import
json
def
http_bot
():
prompt
=
"How are you? I'm fine."
headers
=
{
"User-Agent"
:
"Test Client"
}
pload
=
{
"prompt"
:
prompt
,
}
response
=
requests
.
post
(
"http://localhost:10002"
,
headers
=
headers
,
json
=
pload
,
stream
=
True
)
for
chunk
in
response
.
iter_lines
(
chunk_size
=
8192
,
decode_unicode
=
False
,
delimiter
=
b
"
\0
"
):
if
chunk
:
data
=
json
.
loads
(
chunk
.
decode
(
"utf-8"
))
output
=
data
[
"text"
]
yield
output
for
h
in
http_bot
():
print
(
h
,
end
=
""
,
flush
=
True
)
\ No newline at end of file
playground/streaming_fastapi_worker.py
deleted
100644 → 0
View file @
e8671783
import
argparse
import
asyncio
import
time
from
typing
import
Union
import
json
from
fastapi
import
FastAPI
,
Request
from
fastapi.responses
import
StreamingResponse
import
uvicorn
app
=
FastAPI
()
async
def
text_streamer
(
args
):
context
=
args
[
"prompt"
]
words
=
context
.
split
(
" "
)
for
word
in
words
:
await
asyncio
.
sleep
(
1
)
print
(
"word:"
,
word
)
ret
=
{
"text"
:
word
+
" "
,
"error"
:
0
,
}
yield
(
json
.
dumps
(
ret
)
+
"
\0
"
).
encode
(
"utf-8"
)
@
app
.
post
(
"/"
)
async
def
read_root
(
request
:
Request
):
args
=
await
request
.
json
()
return
StreamingResponse
(
text_streamer
(
args
))
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
10002
)
args
=
parser
.
parse_args
()
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"info"
)
requirements.txt
View file @
057daef7
...
@@ -8,3 +8,4 @@ transformers >= 4.28.0 # Required for LLaMA.
...
@@ -8,3 +8,4 @@ transformers >= 4.28.0 # Required for LLaMA.
xformers
>= 0.0.19
xformers
>= 0.0.19
fastapi
fastapi
uvicorn
uvicorn
pydantic
# Required for OpenAI server.
test_cli_client.py
deleted
100644 → 0
View file @
e8671783
import
requests
import
json
def
http_request
():
prompt
=
"Ion Stoica is a"
headers
=
{
"User-Agent"
:
"Test Client"
}
pload
=
{
"prompt"
:
prompt
,
"n"
:
4
,
"use_beam_search"
:
True
,
"temperature"
:
0.0
,
}
response
=
requests
.
post
(
"http://localhost:10002/generate"
,
headers
=
headers
,
json
=
pload
,
stream
=
True
)
for
chunk
in
response
.
iter_lines
(
chunk_size
=
8192
,
decode_unicode
=
False
,
delimiter
=
b
"
\0
"
):
if
chunk
:
data
=
json
.
loads
(
chunk
.
decode
(
"utf-8"
))
output
=
data
[
"text"
]
yield
output
for
h
in
http_request
():
print
(
h
,
flush
=
True
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment