Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
057daef7
Unverified
Commit
057daef7
authored
May 23, 2023
by
Zhuohan Li
Committed by
GitHub
May 23, 2023
Browse files
OpenAI Compatible Frontend (#116)
parent
e8671783
Changes
20
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
645 additions
and
170 deletions
+645
-170
cacheflow/core/block_manager.py
cacheflow/core/block_manager.py
+3
-3
cacheflow/core/scheduler.py
cacheflow/core/scheduler.py
+4
-2
cacheflow/entrypoints/openai/openai_frontend.py
cacheflow/entrypoints/openai/openai_frontend.py
+300
-0
cacheflow/entrypoints/openai/protocol.py
cacheflow/entrypoints/openai/protocol.py
+126
-0
cacheflow/entrypoints/simple_fastapi_frontend.py
cacheflow/entrypoints/simple_fastapi_frontend.py
+51
-0
cacheflow/model_executor/layers/sampler.py
cacheflow/model_executor/layers/sampler.py
+3
-3
cacheflow/outputs.py
cacheflow/outputs.py
+18
-11
cacheflow/sampling_params.py
cacheflow/sampling_params.py
+2
-2
cacheflow/sequence.py
cacheflow/sequence.py
+21
-5
cacheflow/server/async_llm_server.py
cacheflow/server/async_llm_server.py
+21
-46
cacheflow/server/llm_server.py
cacheflow/server/llm_server.py
+8
-5
cacheflow/utils.py
cacheflow/utils.py
+5
-0
examples/gradio_webserver.py
examples/gradio_webserver.py
+7
-6
examples/openai_client.py
examples/openai_client.py
+22
-0
examples/simple_fastapi_client.py
examples/simple_fastapi_client.py
+48
-0
examples/simple_server.py
examples/simple_server.py
+5
-4
playground/http_client.py
playground/http_client.py
+0
-20
playground/streaming_fastapi_worker.py
playground/streaming_fastapi_worker.py
+0
-40
requirements.txt
requirements.txt
+1
-0
test_cli_client.py
test_cli_client.py
+0
-23
No files found.
cacheflow/core/block_manager.py
View file @
057daef7
...
...
@@ -148,7 +148,7 @@ class BlockSpaceManager:
# the sequences in the same group.
blocks
:
Set
[
PhysicalTokenBlock
]
=
set
()
for
seq
in
seq_group
.
get_seqs
():
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
if
SequenceStatus
.
is_finished
(
seq
.
status
)
:
continue
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
for
block
in
block_table
:
...
...
@@ -169,7 +169,7 @@ class BlockSpaceManager:
# CPU block -> GPU block.
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
for
seq
in
seq_group
.
get_seqs
():
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
if
SequenceStatus
.
is_finished
(
seq
.
status
)
:
continue
new_block_table
:
BlockTable
=
[]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
...
...
@@ -200,7 +200,7 @@ class BlockSpaceManager:
# GPU block -> CPU block.
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
for
seq
in
seq_group
.
get_seqs
():
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
if
SequenceStatus
.
is_finished
(
seq
.
status
)
:
continue
new_block_table
:
BlockTable
=
[]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
...
...
cacheflow/core/scheduler.py
View file @
057daef7
...
...
@@ -292,10 +292,12 @@ class Scheduler:
# Append a new token to the sequence.
output
=
seq_outputs
[
seq
.
seq_id
]
seq
.
append_token_id
(
output
.
output_token
,
output
.
logprobs
)
# Return a shallow copy of the running queue to prevent the queue
# from being modified by the caller.
return
self
.
running
.
copy
()
def
free_seq
(
self
,
seq
:
Sequence
)
->
None
:
seq
.
status
=
SequenceStatus
.
FINISHED
def
free_seq
(
self
,
seq
:
Sequence
,
finish_status
:
SequenceStatus
)
->
None
:
seq
.
status
=
finish_status
self
.
block_manager
.
free
(
seq
)
def
free_finished_seq_groups
(
self
)
->
None
:
...
...
cacheflow/entrypoints/openai/openai_frontend.py
0 → 100644
View file @
057daef7
# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
import
argparse
from
http
import
HTTPStatus
import
json
import
time
from
typing
import
AsyncGenerator
,
Dict
,
List
,
Optional
import
fastapi
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
StreamingResponse
,
JSONResponse
import
uvicorn
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.server.arg_utils
import
ServerArgs
from
cacheflow.server.async_llm_server
import
AsyncLLMServer
from
cacheflow.server.tokenizer_utils
import
get_tokenizer
from
cacheflow.logger
import
init_logger
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.utils
import
random_uuid
from
cacheflow.entrypoints.openai.protocol
import
(
CompletionRequest
,
CompletionResponse
,
CompletionResponseChoice
,
CompletionResponseStreamChoice
,
CompletionStreamResponse
,
ErrorResponse
,
LogProbs
,
ModelCard
,
ModelList
,
ModelPermission
,
UsageInfo
,
)
logger
=
init_logger
(
__name__
)
served_model
=
None
app
=
fastapi
.
FastAPI
()
def
create_error_response
(
status_code
:
HTTPStatus
,
message
:
str
)
->
JSONResponse
:
return
JSONResponse
(
ErrorResponse
(
message
=
message
,
type
=
"invalid_request_error"
).
dict
(),
status_code
=
status_code
.
value
)
@
app
.
exception_handler
(
RequestValidationError
)
async
def
validation_exception_handler
(
request
,
exc
):
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
exc
))
async
def
check_model
(
request
)
->
Optional
[
JSONResponse
]:
if
request
.
model
==
served_model
:
return
ret
=
create_error_response
(
HTTPStatus
.
NOT_FOUND
,
f
"The model `
{
request
.
model
}
` does not exist."
,
)
return
ret
@
app
.
get
(
"/v1/models"
)
async
def
show_available_models
():
"""Show available models. Right now we only have one model."""
model_cards
=
[
ModelCard
(
id
=
served_model
,
root
=
served_model
,
permission
=
[
ModelPermission
()])]
return
ModelList
(
data
=
model_cards
)
def
create_logprobs
(
token_ids
:
List
[
int
],
id_logprobs
:
List
[
Dict
[
int
,
float
]],
initial_text_offset
:
int
=
0
)
->
LogProbs
:
"""Create OpenAI-style logprobs."""
logprobs
=
LogProbs
()
last_token_len
=
0
for
token_id
,
id_logprob
in
zip
(
token_ids
,
id_logprobs
):
token
=
tokenizer
.
convert_ids_to_tokens
(
token_id
)
logprobs
.
tokens
.
append
(
token
)
logprobs
.
token_logprobs
.
append
(
id_logprob
[
token_id
])
if
len
(
logprobs
.
text_offset
)
==
0
:
logprobs
.
text_offset
.
append
(
initial_text_offset
)
else
:
logprobs
.
text_offset
.
append
(
logprobs
.
text_offset
[
-
1
]
+
last_token_len
)
last_token_len
=
len
(
token
)
logprobs
.
top_logprobs
.
append
(
{
tokenizer
.
convert_ids_to_tokens
(
i
):
p
for
i
,
p
in
id_logprob
.
items
()})
return
logprobs
@
app
.
post
(
"/v1/completions"
)
async
def
create_completion
(
request
:
CompletionRequest
):
logger
.
info
(
f
"Received completion request:
{
request
}
"
)
error_check_ret
=
await
check_model
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
if
request
.
echo
:
# We do not support echo since the cacheflow server does not
# currently support getting the logprobs of prompt tokens.
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
"echo is not currently supported"
)
if
request
.
suffix
is
not
None
:
# The language models we currently support do not support suffix.
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
"suffix is not currently supported"
)
if
request
.
logit_bias
is
not
None
:
# TODO: support logit_bias in cacheflow server.
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
"logit_bias is not currently supported"
)
model_name
=
request
.
model
request_id
=
f
"cmpl-
{
random_uuid
()
}
"
prompt
=
request
.
prompt
created_time
=
int
(
time
.
time
())
try
:
sampling_params
=
SamplingParams
(
n
=
request
.
n
,
best_of
=
request
.
best_of
,
presence_penalty
=
request
.
presence_penalty
,
frequency_penalty
=
request
.
frequency_penalty
,
temperature
=
request
.
temperature
,
top_p
=
request
.
top_p
,
top_k
=
request
.
top_k
,
stop
=
request
.
stop
,
ignore_eos
=
request
.
ignore_eos
,
max_tokens
=
request
.
max_tokens
,
logprobs
=
request
.
logprobs
,
use_beam_search
=
request
.
use_beam_search
,
)
except
ValueError
as
e
:
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
e
))
result_generator
=
server
.
generate
(
prompt
,
sampling_params
,
request_id
=
request_id
)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
stream
=
(
request
.
stream
and
(
request
.
best_of
is
None
or
request
.
n
==
request
.
best_of
)
and
not
request
.
use_beam_search
)
def
create_stream_response_json
(
index
:
int
,
text
:
str
,
logprobs
:
Optional
[
LogProbs
]
=
None
,
finish_reason
:
Optional
[
str
]
=
None
)
->
str
:
choice_data
=
CompletionResponseStreamChoice
(
index
=
index
,
text
=
text
,
logprobs
=
logprobs
,
finish_reason
=
finish_reason
,
)
response
=
CompletionStreamResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
choices
=
[
choice_data
],
)
response_json
=
response
.
json
(
ensure_ascii
=
False
)
return
response_json
async
def
completion_stream_generator
()
->
AsyncGenerator
[
str
,
None
]:
previous_texts
=
[
""
]
*
request
.
n
previous_num_tokens
=
[
0
]
*
request
.
n
async
for
res
in
result_generator
:
res
:
RequestOutput
for
output
in
res
.
outputs
:
i
=
output
.
index
delta_text
=
output
.
text
[
len
(
previous_texts
[
i
]):]
if
request
.
logprobs
is
not
None
:
logprobs
=
create_logprobs
(
output
.
token_ids
[
previous_num_tokens
[
i
]:],
output
.
logprobs
[
previous_num_tokens
[
i
]:],
len
(
previous_texts
[
i
]))
else
:
logprobs
=
None
previous_texts
[
i
]
=
output
.
text
previous_num_tokens
[
i
]
=
len
(
output
.
token_ids
)
response_json
=
create_stream_response_json
(
index
=
i
,
text
=
delta_text
,
logprobs
=
logprobs
,
)
yield
f
"data:
{
response_json
}
\n\n
"
if
output
.
finish_reason
is
not
None
:
logprobs
=
LogProbs
()
if
request
.
logprobs
is
not
None
else
None
response_json
=
create_stream_response_json
(
index
=
i
,
text
=
""
,
logprobs
=
logprobs
,
finish_reason
=
output
.
finish_reason
,
)
yield
f
"data:
{
response_json
}
\n\n
"
yield
"data: [DONE]
\n\n
"
# Streaming response
if
stream
:
return
StreamingResponse
(
completion_stream_generator
(),
media_type
=
"text/event-stream"
)
# Non-streaming response
final_res
:
RequestOutput
=
None
async
for
res
in
result_generator
:
final_res
=
res
assert
final_res
is
not
None
choices
=
[]
for
output
in
final_res
.
outputs
:
if
request
.
logprobs
is
not
None
:
logprobs
=
create_logprobs
(
output
.
token_ids
,
output
.
logprobs
)
else
:
logprobs
=
None
choice_data
=
CompletionResponseChoice
(
index
=
output
.
index
,
text
=
output
.
text
,
logprobs
=
logprobs
,
finish_reason
=
output
.
finish_reason
,
)
choices
.
append
(
choice_data
)
num_prompt_tokens
=
len
(
final_res
.
prompt_token_ids
)
num_generated_tokens
=
sum
(
len
(
output
.
token_ids
)
for
output
in
final_res
.
outputs
)
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
completion_tokens
=
num_generated_tokens
,
total_tokens
=
num_prompt_tokens
+
num_generated_tokens
,
)
response
=
CompletionResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
choices
=
choices
,
usage
=
usage
,
)
if
request
.
stream
:
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
response_json
=
response
.
json
(
ensure_ascii
=
False
)
async
def
fake_stream_generator
()
->
AsyncGenerator
[
str
,
None
]:
yield
f
"data:
{
response_json
}
\n\n
"
yield
"data: [DONE]
\n\n
"
return
StreamingResponse
(
fake_stream_generator
(),
media_type
=
"text/event-stream"
)
return
response
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"CacheFlow OpenAI-Compatible RESTful API server."
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
,
help
=
"host name"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
,
help
=
"port number"
)
parser
.
add_argument
(
"--allow-credentials"
,
action
=
"store_true"
,
help
=
"allow credentials"
)
parser
.
add_argument
(
"--allowed-origins"
,
type
=
json
.
loads
,
default
=
[
"*"
],
help
=
"allowed origins"
)
parser
.
add_argument
(
"--allowed-methods"
,
type
=
json
.
loads
,
default
=
[
"*"
],
help
=
"allowed methods"
)
parser
.
add_argument
(
"--allowed-headers"
,
type
=
json
.
loads
,
default
=
[
"*"
],
help
=
"allowed headers"
)
parser
.
add_argument
(
"--served-model-name"
,
type
=
str
,
default
=
None
,
help
=
"The model name used in the API. If not specified, "
"the model name will be the same as the "
"huggingface name."
)
parser
=
ServerArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
app
.
add_middleware
(
CORSMiddleware
,
allow_origins
=
args
.
allowed_origins
,
allow_credentials
=
args
.
allow_credentials
,
allow_methods
=
args
.
allowed_methods
,
allow_headers
=
args
.
allowed_headers
,
)
logger
.
info
(
f
"args:
{
args
}
"
)
served_model
=
args
.
served_model_name
or
args
.
model
server_args
=
ServerArgs
.
from_cli_args
(
args
)
server
=
AsyncLLMServer
.
from_server_args
(
server_args
)
# A separate tokenizer to map token IDs to strings.
tokenizer
=
get_tokenizer
(
args
.
model
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"info"
)
cacheflow/entrypoints/openai/protocol.py
0 → 100644
View file @
057daef7
# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import
time
from
typing
import
Dict
,
List
,
Literal
,
Optional
,
Union
from
pydantic
import
BaseModel
,
Field
from
cacheflow.utils
import
random_uuid
class
ErrorResponse
(
BaseModel
):
object
:
str
=
"error"
message
:
str
type
:
str
param
:
Optional
[
str
]
=
None
code
:
Optional
[
str
]
=
None
class
ModelPermission
(
BaseModel
):
id
:
str
=
Field
(
default_factory
=
lambda
:
f
"modelperm-
{
random_uuid
()
}
"
)
object
:
str
=
"model_permission"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
allow_create_engine
:
bool
=
False
allow_sampling
:
bool
=
True
allow_logprobs
:
bool
=
True
allow_search_indices
:
bool
=
False
allow_view
:
bool
=
True
allow_fine_tuning
:
bool
=
False
organization
:
str
=
"*"
group
:
Optional
[
str
]
=
None
is_blocking
:
str
=
False
class
ModelCard
(
BaseModel
):
id
:
str
object
:
str
=
"model"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
owned_by
:
str
=
"cacheflow"
root
:
Optional
[
str
]
=
None
parent
:
Optional
[
str
]
=
None
permission
:
List
[
ModelPermission
]
=
Field
(
default_factory
=
list
)
class
ModelList
(
BaseModel
):
object
:
str
=
"list"
data
:
List
[
ModelCard
]
=
Field
(
default_factory
=
list
)
class
UsageInfo
(
BaseModel
):
prompt_tokens
:
int
=
0
total_tokens
:
int
=
0
completion_tokens
:
Optional
[
int
]
=
0
class
ChatCompletionRequest
(
BaseModel
):
model
:
str
messages
:
List
[
Dict
[
str
,
str
]]
temperature
:
Optional
[
float
]
=
0.7
top_p
:
Optional
[
float
]
=
1.0
n
:
Optional
[
int
]
=
1
max_tokens
:
Optional
[
int
]
=
None
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
stream
:
Optional
[
bool
]
=
False
presence_penalty
:
Optional
[
float
]
=
0.0
frequency_penalty
:
Optional
[
float
]
=
0.0
user
:
Optional
[
str
]
=
None
class
CompletionRequest
(
BaseModel
):
model
:
str
prompt
:
str
suffix
:
Optional
[
str
]
=
None
max_tokens
:
Optional
[
int
]
=
16
temperature
:
Optional
[
float
]
=
1.0
top_p
:
Optional
[
float
]
=
1.0
n
:
Optional
[
int
]
=
1
stream
:
Optional
[
bool
]
=
False
logprobs
:
Optional
[
int
]
=
None
echo
:
Optional
[
bool
]
=
False
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default_factory
=
list
)
presence_penalty
:
Optional
[
float
]
=
0.0
frequency_penalty
:
Optional
[
float
]
=
0.0
best_of
:
Optional
[
int
]
=
None
logit_bias
:
Optional
[
Dict
[
str
,
float
]]
=
None
user
:
Optional
[
str
]
=
None
# Additional parameters supported by cacheflow
top_k
:
Optional
[
int
]
=
-
1
ignore_eos
:
Optional
[
bool
]
=
False
use_beam_search
:
Optional
[
bool
]
=
False
class
LogProbs
(
BaseModel
):
text_offset
:
List
[
int
]
=
Field
(
default_factory
=
list
)
token_logprobs
:
List
[
Optional
[
float
]]
=
Field
(
default_factory
=
list
)
tokens
:
List
[
str
]
=
Field
(
default_factory
=
list
)
top_logprobs
:
List
[
Optional
[
Dict
[
str
,
float
]]]
=
Field
(
default_factory
=
list
)
class
CompletionResponseChoice
(
BaseModel
):
index
:
int
text
:
str
logprobs
:
Optional
[
LogProbs
]
=
None
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
]]
=
None
class
CompletionResponse
(
BaseModel
):
id
:
str
=
Field
(
default_factory
=
lambda
:
f
"cmpl-
{
random_uuid
()
}
"
)
object
:
str
=
"text_completion"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
model
:
str
choices
:
List
[
CompletionResponseChoice
]
usage
:
UsageInfo
class
CompletionResponseStreamChoice
(
BaseModel
):
index
:
int
text
:
str
logprobs
:
Optional
[
LogProbs
]
=
None
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
]]
=
None
class
CompletionStreamResponse
(
BaseModel
):
id
:
str
=
Field
(
default_factory
=
lambda
:
f
"cmpl-
{
random_uuid
()
}
"
)
object
:
str
=
"text_completion"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
model
:
str
choices
:
List
[
CompletionResponseStreamChoice
]
cacheflow/entrypoints/simple_fastapi_frontend.py
0 → 100644
View file @
057daef7
import
argparse
import
json
from
typing
import
AsyncGenerator
from
fastapi
import
FastAPI
,
Request
from
fastapi.responses
import
StreamingResponse
import
uvicorn
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.server.arg_utils
import
ServerArgs
from
cacheflow.server.async_llm_server
import
AsyncLLMServer
from
cacheflow.server.ray_utils
import
initialize_cluster
TIMEOUT_TO_PREVENT_DEADLOCK
=
1
# seconds
app
=
FastAPI
()
@
app
.
post
(
"/generate"
)
async
def
generate_stream
(
request
:
Request
)
->
StreamingResponse
:
request_dict
=
await
request
.
json
()
prompt
=
request_dict
.
pop
(
"prompt"
)
sampling_params
=
SamplingParams
(
**
request_dict
)
results_generator
=
server
.
generate
(
prompt
,
sampling_params
)
async
def
stream_results
()
->
AsyncGenerator
[
bytes
,
None
]:
async
for
request_output
in
results_generator
:
prompt
=
request_output
.
prompt
text_outputs
=
[
prompt
+
output
.
text
for
output
in
request_output
.
outputs
]
ret
=
{
"text"
:
text_outputs
,
"error"
:
0
,
}
yield
(
json
.
dumps
(
ret
)
+
"
\0
"
).
encode
(
"utf-8"
)
return
StreamingResponse
(
stream_results
())
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8001
)
parser
=
ServerArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
server_args
=
ServerArgs
.
from_cli_args
(
args
)
server
=
AsyncLLMServer
.
from_server_args
(
server_args
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"info"
)
cacheflow/model_executor/layers/sampler.py
View file @
057daef7
"""A layer that samples the next tokens from the model's outputs."""
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Tuple
,
Optional
import
numpy
as
np
import
torch
...
...
@@ -258,9 +258,9 @@ def _apply_top_p_top_k(
def
_get_topk_logprobs
(
logprobs
:
torch
.
Tensor
,
num_logprobs
:
int
,
num_logprobs
:
Optional
[
int
]
,
)
->
Dict
[
int
,
float
]:
if
num_logprobs
==
0
:
if
num_logprobs
is
None
or
num_logprobs
==
0
:
return
{}
topk_logprobs
,
topk_ids
=
torch
.
topk
(
logprobs
,
num_logprobs
)
...
...
cacheflow/outputs.py
View file @
057daef7
from
typing
import
Dict
,
List
from
typing
import
Dict
,
List
,
Optional
from
cacheflow.sequence
import
SequenceGroup
from
cacheflow.sequence
import
SequenceGroup
,
SequenceStatus
class
CompletionOutput
:
...
...
@@ -12,19 +12,25 @@ class CompletionOutput:
token_ids
:
List
[
int
],
cumulative_logprob
:
float
,
logprobs
:
List
[
Dict
[
int
,
float
]],
finish_reason
:
Optional
[
str
]
=
None
,
)
->
None
:
self
.
index
=
index
self
.
text
=
text
self
.
token_ids
=
token_ids
self
.
cumulative_logprob
=
cumulative_logprob
self
.
logprobs
=
logprobs
self
.
finish_reason
=
finish_reason
def
finished
(
self
)
->
bool
:
return
self
.
finish_reason
is
not
None
def
__repr__
(
self
)
->
str
:
return
(
f
"CompletionOutput(index=
{
self
.
index
}
, "
f
"text=
{
self
.
text
!
r
}
, "
f
"token_ids=
{
self
.
token_ids
}
, "
f
"cumulative_logprob=
{
self
.
cumulative_logprob
}
, "
f
"logprobs=
{
self
.
logprobs
}
)"
)
f
"logprobs=
{
self
.
logprobs
}
,"
f
"finish_reason=
{
self
.
finish_reason
}
)"
)
class
RequestOutput
:
...
...
@@ -35,13 +41,11 @@ class RequestOutput:
prompt
:
str
,
prompt_token_ids
:
List
[
int
],
outputs
:
List
[
CompletionOutput
],
done
:
bool
,
)
->
None
:
self
.
request_id
=
request_id
self
.
prompt
=
prompt
self
.
prompt_token_ids
=
prompt_token_ids
self
.
outputs
=
outputs
self
.
done
=
done
@
classmethod
def
from_seq_group
(
cls
,
seq_group
:
SequenceGroup
)
->
"RequestOutput"
:
...
...
@@ -57,25 +61,28 @@ class RequestOutput:
outputs
:
List
[
CompletionOutput
]
=
[]
for
seq
in
top_n_seqs
:
logprobs
=
seq
.
output_logprobs
if
seq_group
.
sampling_params
.
logprobs
==
0
:
if
seq_group
.
sampling_params
.
logprobs
is
None
:
# NOTE: We need to take care of this case because the sequence
# always has the logprobs of the sampled tokens even if the
# logprobs are not requested.
logprobs
=
{}
finshed_reason
=
SequenceStatus
.
get_finished_reason
(
seq
.
status
)
output
=
CompletionOutput
(
seqs
.
index
(
seq
),
seq
.
output_text
,
seq
.
get_output_token_ids
(),
seq
.
get_cumulative_logprob
(),
logprobs
)
seq
.
get_cumulative_logprob
(),
logprobs
,
finshed_reason
)
outputs
.
append
(
output
)
# Every sequence in the sequence group should have the same prompt.
prompt
=
top_n_seqs
[
0
].
prompt
prompt_token_ids
=
top_n_seqs
[
0
].
data
.
prompt_token_ids
return
cls
(
seq_group
.
request_id
,
prompt
,
prompt_token_ids
,
outputs
,
seq_group
.
is_finished
())
return
cls
(
seq_group
.
request_id
,
prompt
,
prompt_token_ids
,
outputs
)
def
__repr__
(
self
)
->
str
:
return
(
f
"RequestOutput(request_id=
{
self
.
request_id
}
, "
f
"prompt=
{
self
.
prompt
!
r
}
, "
f
"prompt_token_ids=
{
self
.
prompt_token_ids
}
, "
f
"outputs=
{
self
.
outputs
}
, "
f
"done=
{
self
.
done
}
)"
)
f
"outputs=
{
self
.
outputs
}
)"
)
def
finished
(
self
)
->
bool
:
return
all
(
output
.
finished
()
for
output
in
self
.
outputs
)
cacheflow/sampling_params.py
View file @
057daef7
...
...
@@ -53,7 +53,7 @@ class SamplingParams:
stop
:
Union
[
str
,
List
[
str
]]
=
[],
ignore_eos
:
bool
=
False
,
max_tokens
:
int
=
16
,
logprobs
:
int
=
0
,
logprobs
:
Optional
[
int
]
=
None
,
)
->
None
:
self
.
n
=
n
self
.
best_of
=
best_of
if
best_of
is
not
None
else
n
...
...
@@ -98,7 +98,7 @@ class SamplingParams:
if
self
.
max_tokens
<
1
:
raise
ValueError
(
f
"max_tokens must be at least 1, got
{
self
.
max_tokens
}
."
)
if
self
.
logprobs
<
0
:
if
self
.
logprobs
is
not
None
and
self
.
logprobs
<
0
:
raise
ValueError
(
f
"logprobs must be non-negative, got
{
self
.
logprobs
}
."
)
...
...
cacheflow/sequence.py
View file @
057daef7
import
copy
import
enum
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
,
Union
from
cacheflow.block
import
LogicalTokenBlock
from
cacheflow.sampling_params
import
SamplingParams
...
...
@@ -10,8 +10,25 @@ class SequenceStatus(enum.Enum):
WAITING
=
enum
.
auto
()
RUNNING
=
enum
.
auto
()
SWAPPED
=
enum
.
auto
()
FINISHED
=
enum
.
auto
()
FINISHED_STOPPED
=
enum
.
auto
()
FINISHED_LENGTH_CAPPED
=
enum
.
auto
()
@
staticmethod
def
is_finished
(
status
:
"SequenceStatus"
)
->
bool
:
return
status
in
[
SequenceStatus
.
FINISHED_STOPPED
,
SequenceStatus
.
FINISHED_LENGTH_CAPPED
,
]
@
staticmethod
def
get_finished_reason
(
status
:
"SequenceStatus"
)
->
Union
[
str
,
None
]:
if
status
==
SequenceStatus
.
FINISHED_STOPPED
:
finish_reason
=
"stop"
elif
status
==
SequenceStatus
.
FINISHED_LENGTH_CAPPED
:
finish_reason
=
"length"
else
:
finish_reason
=
None
return
finish_reason
class
SequenceData
:
...
...
@@ -20,7 +37,6 @@ class SequenceData:
prompt_token_ids
:
List
[
int
],
)
->
None
:
self
.
prompt_token_ids
=
prompt_token_ids
self
.
output_token_ids
:
List
[
int
]
=
[]
self
.
cumulative_logprob
=
0.0
...
...
@@ -166,7 +182,7 @@ class SequenceGroup:
raise
ValueError
(
f
'Sequence
{
seq_id
}
not found.'
)
def
is_finished
(
self
)
->
bool
:
return
all
(
seq
.
status
==
SequenceStatus
.
FINISHED
for
seq
in
self
.
seqs
)
return
all
(
SequenceStatus
.
is_finished
(
seq
.
status
)
for
seq
in
self
.
seqs
)
def
__repr__
(
self
)
->
str
:
return
(
f
"SequenceGroup(request_id=
{
self
.
request_id
}
, "
...
...
cacheflow/
entrypoints/fastapi
_server.py
→
cacheflow/
server/async_llm
_server.py
View file @
057daef7
import
argparse
import
asyncio
import
json
import
time
from
typing
import
Any
,
Dict
import
uuid
from
typing
import
Dict
,
Optional
from
fastapi
import
FastAPI
,
Request
from
fastapi.responses
import
StreamingResponse
import
ray
import
uvicorn
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.server.arg_utils
import
ServerArgs
from
cacheflow.server.llm_server
import
LLMServer
from
cacheflow.server.ray_utils
import
initialize_cluster
from
cacheflow.utils
import
random_uuid
TIMEOUT_TO_PREVENT_DEADLOCK
=
1
# seconds
app
=
FastAPI
()
class
FastAPI
Server
:
class
AsyncLLM
Server
:
def
__init__
(
self
,
server_use_ray
:
bool
,
*
args
,
**
kwargs
)
->
None
:
if
server_use_ray
:
...
...
@@ -45,15 +39,15 @@ class FastAPIServer:
self
.
request_outputs
[
request_id
]
=
request_output
self
.
request_events
[
request_id
].
set
()
async
def
generate
(
self
,
request_dict
:
Dict
[
str
,
Any
]):
async
def
generate
(
self
,
prompt
:
str
,
sampling_params
:
SamplingParams
,
request_id
:
Optional
[
str
]
=
None
)
->
RequestOutput
:
# Preprocess the request.
arrival_time
=
time
.
time
()
prompt
=
request_dict
.
pop
(
"prompt"
)
sampling_params
=
SamplingParams
(
**
request_dict
)
# Create an event to notify us that there is new output from the
# cacheflow server.
request_id
=
str
(
uuid
.
uuid4
().
hex
[:
8
])
if
request_id
is
None
:
request_id
=
random_uuid
()
request_event
=
asyncio
.
Event
()
self
.
request_events
[
request_id
]
=
request_event
...
...
@@ -82,19 +76,10 @@ class FastAPIServer:
# Decode and return new outputs.
request_output
=
self
.
request_outputs
[
request_id
]
prompt
=
request_output
.
prompt
text_outputs
=
[
prompt
+
output
.
text
for
output
in
request_output
.
outputs
]
ret
=
{
"text"
:
text_outputs
,
"error"
:
0
,
}
yield
(
json
.
dumps
(
ret
)
+
"
\0
"
).
encode
(
"utf-8"
)
yield
request_output
# Once finished, release the resources of the sequence group.
if
request_output
.
done
:
if
request_output
.
finished
()
:
del
self
.
request_outputs
[
request_id
]
del
self
.
request_events
[
request_id
]
# Kick the server if the server is not running. This is to
...
...
@@ -104,25 +89,15 @@ class FastAPIServer:
await
self
.
server_step
()
break
@
app
.
post
(
"/generate"
)
async
def
generate_stream
(
request
:
Request
):
request_dict
=
await
request
.
json
()
return
StreamingResponse
(
server
.
generate
(
request_dict
))
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
10002
)
parser
=
ServerArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
server_configs
=
ServerArgs
.
from_cli_args
(
args
).
create_server_configs
()
parallel_config
=
server_configs
[
2
]
distributed_init_method
,
stage_devices
=
initialize_cluster
(
parallel_config
)
server
=
FastAPIServer
(
args
.
use_ray
,
*
server_configs
,
distributed_init_method
,
stage_devices
,
log_stats
=
not
args
.
disable_log_stats
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"info"
)
@
classmethod
def
from_server_args
(
cls
,
server_args
:
ServerArgs
)
->
"AsyncLLMServer"
:
# Create the server configs.
server_configs
=
server_args
.
create_server_configs
()
parallel_config
=
server_configs
[
2
]
# Initialize the cluster.
distributed_init_method
,
devices
=
initialize_cluster
(
parallel_config
)
# Create the LLM server.
server
=
cls
(
server_args
.
use_ray
,
*
server_configs
,
distributed_init_method
,
devices
,
log_stats
=
not
server_args
.
disable_log_stats
)
return
server
cacheflow/server/llm_server.py
View file @
057daef7
...
...
@@ -210,7 +210,8 @@ class LLMServer:
# Truncate the output text so that the stop string is
# not included in the output.
seq
.
output_text
=
seq
.
output_text
[:
-
len
(
stop_str
)]
self
.
scheduler
.
free_seq
(
seq
)
self
.
scheduler
.
free_seq
(
seq
,
SequenceStatus
.
FINISHED_STOPPED
)
stopped
=
True
break
if
stopped
:
...
...
@@ -218,12 +219,14 @@ class LLMServer:
# Check if the sequence has reached max_tokens.
if
seq
.
get_output_len
()
==
sampling_params
.
max_tokens
:
self
.
scheduler
.
free_seq
(
seq
)
self
.
scheduler
.
free_seq
(
seq
,
SequenceStatus
.
FINISHED_LENGTH_CAPPED
)
continue
# Check if the sequence has generated the EOS token.
if
not
sampling_params
.
ignore_eos
:
if
seq
.
get_last_token_id
()
==
self
.
tokenizer
.
eos_token_id
:
self
.
scheduler
.
free_seq
(
seq
)
self
.
scheduler
.
free_seq
(
seq
,
SequenceStatus
.
FINISHED_STOPPED
)
continue
def
_run_workers
(
...
...
@@ -238,10 +241,10 @@ class LLMServer:
executor
=
getattr
(
worker
,
method
)
if
self
.
parallel_config
.
use_ray
:
executor
=
executor
.
remote
output
=
executor
(
*
args
,
**
kwargs
)
all_outputs
.
append
(
output
)
if
self
.
parallel_config
.
use_ray
:
all_outputs
=
ray
.
get
(
all_outputs
)
...
...
cacheflow/utils.py
View file @
057daef7
import
enum
import
uuid
import
psutil
import
torch
...
...
@@ -31,3 +32,7 @@ def get_gpu_memory(gpu: int = 0) -> int:
def
get_cpu_memory
()
->
int
:
"""Returns the total CPU memory of the node in bytes."""
return
psutil
.
virtual_memory
().
total
def
random_uuid
()
->
str
:
return
str
(
uuid
.
uuid4
().
hex
)
gradio_webserver.py
→
examples/
gradio_webserver.py
View file @
057daef7
import
argparse
import
json
import
time
import
gradio
as
gr
import
requests
...
...
@@ -24,9 +23,9 @@ def http_bot(prompt):
def
build_demo
():
with
gr
.
Blocks
()
as
demo
:
gr
.
Markdown
(
"# Cacheflow demo
\n
"
"# Cacheflow
text completion
demo
\n
"
)
inputbox
=
gr
.
Textbox
(
label
=
"Input"
,
placeholder
=
"Enter text and press ENTER"
)
# .style(container=False)
inputbox
=
gr
.
Textbox
(
label
=
"Input"
,
placeholder
=
"Enter text and press ENTER"
)
outputbox
=
gr
.
Textbox
(
label
=
"Output"
,
placeholder
=
"Generated result from the model"
)
inputbox
.
submit
(
http_bot
,
[
inputbox
],
[
outputbox
])
return
demo
...
...
@@ -35,9 +34,11 @@ def build_demo():
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
1
00
03
)
parser
.
add_argument
(
"--model-url"
,
type
=
str
,
default
=
"http://localhost:
1
00
02
/generate"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8
00
2
)
parser
.
add_argument
(
"--model-url"
,
type
=
str
,
default
=
"http://localhost:
8
00
1
/generate"
)
args
=
parser
.
parse_args
()
demo
=
build_demo
()
demo
.
queue
(
concurrency_count
=
100
).
launch
(
server_name
=
args
.
host
,
server_port
=
args
.
port
)
\ No newline at end of file
demo
.
queue
(
concurrency_count
=
100
).
launch
(
server_name
=
args
.
host
,
server_port
=
args
.
port
,
share
=
True
)
examples/openai_client.py
0 → 100644
View file @
057daef7
import
openai
openai
.
api_key
=
"EMPTY"
openai
.
api_base
=
"http://localhost:8000/v1"
model
=
"facebook/opt-125m"
# list models
models
=
openai
.
Model
.
list
()
print
(
models
)
# create a completion
stream
=
True
completion
=
openai
.
Completion
.
create
(
model
=
model
,
prompt
=
"A robot may not injure a human being"
,
echo
=
False
,
n
=
2
,
best_of
=
3
,
stream
=
stream
,
logprobs
=
3
)
# print the completion
if
stream
:
for
c
in
completion
:
print
(
c
)
else
:
print
(
"completion:"
,
completion
)
examples/simple_fastapi_client.py
0 → 100644
View file @
057daef7
import
argparse
import
requests
import
json
def
clear_line
(
n
=
1
):
LINE_UP
=
'
\033
[1A'
LINE_CLEAR
=
'
\x1b
[2K'
for
i
in
range
(
n
):
print
(
LINE_UP
,
end
=
LINE_CLEAR
,
flush
=
True
)
def
http_request
(
prompt
:
str
,
api_url
:
str
,
n
:
int
=
1
):
headers
=
{
"User-Agent"
:
"Test Client"
}
pload
=
{
"prompt"
:
prompt
,
"n"
:
n
,
"use_beam_search"
:
True
,
"temperature"
:
0.0
,
"max_tokens"
:
16
,
}
response
=
requests
.
post
(
api_url
,
headers
=
headers
,
json
=
pload
,
stream
=
True
)
for
chunk
in
response
.
iter_lines
(
chunk_size
=
8192
,
decode_unicode
=
False
,
delimiter
=
b
"
\0
"
):
if
chunk
:
data
=
json
.
loads
(
chunk
.
decode
(
"utf-8"
))
output
=
data
[
"text"
]
yield
output
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8001
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
default
=
"San Francisco is a"
)
args
=
parser
.
parse_args
()
prompt
=
args
.
prompt
api_url
=
f
"http://
{
args
.
host
}
:
{
args
.
port
}
/generate"
n
=
args
.
n
print
(
f
"Prompt:
{
prompt
}
\n
"
,
flush
=
True
)
num_printed_lines
=
0
for
h
in
http_request
(
prompt
,
api_url
,
n
):
clear_line
(
num_printed_lines
)
num_printed_lines
=
0
for
i
,
line
in
enumerate
(
h
):
num_printed_lines
+=
1
print
(
f
"Beam candidate
{
i
}
:
{
line
}
"
,
flush
=
True
)
examples/simple_server.py
View file @
057daef7
import
argparse
import
uuid
from
cacheflow
import
ServerArgs
,
LLMServer
,
SamplingParams
...
...
@@ -20,17 +19,19 @@ def main(args: argparse.Namespace):
SamplingParams
(
n
=
3
,
best_of
=
3
,
use_beam_search
=
True
,
temperature
=
0.0
)),
]
request_id
=
0
# Run the server.
while
True
:
# To test iteration-level scheduling, we add one request at each step.
if
test_prompts
:
prompt
,
sampling_params
=
test_prompts
.
pop
(
0
)
request_id
=
str
(
uuid
.
uuid4
().
hex
[:
8
]
)
server
.
add_request
(
request_id
,
prompt
,
sampling_params
)
server
.
add_request
(
str
(
request_id
),
prompt
,
sampling_params
)
request_id
+=
1
request_outputs
=
server
.
step
()
for
request_output
in
request_outputs
:
if
request_output
.
done
:
if
request_output
.
finished
()
:
print
(
request_output
)
if
not
(
server
.
has_unfinished_requests
()
or
test_prompts
):
...
...
playground/http_client.py
deleted
100644 → 0
View file @
e8671783
import
requests
import
json
def
http_bot
():
prompt
=
"How are you? I'm fine."
headers
=
{
"User-Agent"
:
"Test Client"
}
pload
=
{
"prompt"
:
prompt
,
}
response
=
requests
.
post
(
"http://localhost:10002"
,
headers
=
headers
,
json
=
pload
,
stream
=
True
)
for
chunk
in
response
.
iter_lines
(
chunk_size
=
8192
,
decode_unicode
=
False
,
delimiter
=
b
"
\0
"
):
if
chunk
:
data
=
json
.
loads
(
chunk
.
decode
(
"utf-8"
))
output
=
data
[
"text"
]
yield
output
for
h
in
http_bot
():
print
(
h
,
end
=
""
,
flush
=
True
)
\ No newline at end of file
playground/streaming_fastapi_worker.py
deleted
100644 → 0
View file @
e8671783
import
argparse
import
asyncio
import
time
from
typing
import
Union
import
json
from
fastapi
import
FastAPI
,
Request
from
fastapi.responses
import
StreamingResponse
import
uvicorn
app
=
FastAPI
()
async
def
text_streamer
(
args
):
context
=
args
[
"prompt"
]
words
=
context
.
split
(
" "
)
for
word
in
words
:
await
asyncio
.
sleep
(
1
)
print
(
"word:"
,
word
)
ret
=
{
"text"
:
word
+
" "
,
"error"
:
0
,
}
yield
(
json
.
dumps
(
ret
)
+
"
\0
"
).
encode
(
"utf-8"
)
@
app
.
post
(
"/"
)
async
def
read_root
(
request
:
Request
):
args
=
await
request
.
json
()
return
StreamingResponse
(
text_streamer
(
args
))
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
10002
)
args
=
parser
.
parse_args
()
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"info"
)
requirements.txt
View file @
057daef7
...
...
@@ -8,3 +8,4 @@ transformers >= 4.28.0 # Required for LLaMA.
xformers
>= 0.0.19
fastapi
uvicorn
pydantic
# Required for OpenAI server.
test_cli_client.py
deleted
100644 → 0
View file @
e8671783
import
requests
import
json
def
http_request
():
prompt
=
"Ion Stoica is a"
headers
=
{
"User-Agent"
:
"Test Client"
}
pload
=
{
"prompt"
:
prompt
,
"n"
:
4
,
"use_beam_search"
:
True
,
"temperature"
:
0.0
,
}
response
=
requests
.
post
(
"http://localhost:10002/generate"
,
headers
=
headers
,
json
=
pload
,
stream
=
True
)
for
chunk
in
response
.
iter_lines
(
chunk_size
=
8192
,
decode_unicode
=
False
,
delimiter
=
b
"
\0
"
):
if
chunk
:
data
=
json
.
loads
(
chunk
.
decode
(
"utf-8"
))
output
=
data
[
"text"
]
yield
output
for
h
in
http_request
():
print
(
h
,
flush
=
True
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment