Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
e5464ee4
Unverified
Commit
e5464ee4
authored
Jun 17, 2023
by
Zhuohan Li
Committed by
GitHub
Jun 17, 2023
Browse files
Rename servers to engines (#152)
parent
bab8f3dd
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
165 additions
and
174 deletions
+165
-174
benchmarks/benchmark_latency.py
benchmarks/benchmark_latency.py
+1
-1
benchmarks/benchmark_serving.py
benchmarks/benchmark_serving.py
+1
-1
benchmarks/benchmark_throughput.py
benchmarks/benchmark_throughput.py
+2
-2
cacheflow/__init__.py
cacheflow/__init__.py
+5
-5
cacheflow/core/scheduler.py
cacheflow/core/scheduler.py
+1
-1
cacheflow/engine/__init__.py
cacheflow/engine/__init__.py
+0
-0
cacheflow/engine/arg_utils.py
cacheflow/engine/arg_utils.py
+24
-24
cacheflow/engine/async_llm_engine.py
cacheflow/engine/async_llm_engine.py
+53
-53
cacheflow/engine/llm_engine.py
cacheflow/engine/llm_engine.py
+21
-21
cacheflow/engine/ray_utils.py
cacheflow/engine/ray_utils.py
+6
-6
cacheflow/engine/tokenizer_utils.py
cacheflow/engine/tokenizer_utils.py
+0
-0
cacheflow/entrypoints/api_server.py
cacheflow/entrypoints/api_server.py
+8
-8
cacheflow/entrypoints/llm.py
cacheflow/entrypoints/llm.py
+15
-15
cacheflow/entrypoints/openai/api_server.py
cacheflow/entrypoints/openai/api_server.py
+18
-27
examples/llm_engine_example.py
examples/llm_engine_example.py
+10
-10
No files found.
benchmarks/benchmark_latency.py
View file @
e5464ee4
...
@@ -14,7 +14,7 @@ def main(args: argparse.Namespace):
...
@@ -14,7 +14,7 @@ def main(args: argparse.Namespace):
# Process all the requests in a single batch if possible.
# Process all the requests in a single batch if possible.
# NOTE(woosuk): If the request cannot be processed in a single batch,
# NOTE(woosuk): If the request cannot be processed in a single batch,
# the
server
will automatically process the request in multiple batches.
# the
engine
will automatically process the request in multiple batches.
llm
=
LLM
(
llm
=
LLM
(
model
=
args
.
model
,
model
=
args
.
model
,
tensor_parallel_size
=
args
.
tensor_parallel_size
,
tensor_parallel_size
=
args
.
tensor_parallel_size
,
...
...
benchmarks/benchmark_serving.py
View file @
e5464ee4
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
On the server side, run one of the following commands:
On the server side, run one of the following commands:
(CacheFlow backend)
(CacheFlow backend)
python -m cacheflow.entrypoints.
simple_fastapi_frontend
\
python -m cacheflow.entrypoints.
api_server
\
--disable-log-requests --model <your_model>
--disable-log-requests --model <your_model>
(TGI backend)
(TGI backend)
...
...
benchmarks/benchmark_throughput.py
View file @
e5464ee4
...
@@ -84,7 +84,7 @@ def run_cacheflow(
...
@@ -84,7 +84,7 @@ def run_cacheflow(
seed
=
seed
,
seed
=
seed
,
)
)
# Add the requests to the
server
.
# Add the requests to the
engine
.
for
prompt
,
_
,
output_len
in
requests
:
for
prompt
,
_
,
output_len
in
requests
:
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
n
=
n
,
n
=
n
,
...
@@ -103,7 +103,7 @@ def run_cacheflow(
...
@@ -103,7 +103,7 @@ def run_cacheflow(
start
=
time
.
time
()
start
=
time
.
time
()
# FIXME(woosuk): Do use internal method.
# FIXME(woosuk): Do use internal method.
llm
.
_run_
server
(
use_tqdm
=
True
)
llm
.
_run_
engine
(
use_tqdm
=
True
)
end
=
time
.
time
()
end
=
time
.
time
()
return
end
-
start
return
end
-
start
...
...
cacheflow/__init__.py
View file @
e5464ee4
from
cacheflow.engine.arg_utils
import
EngineArgs
from
cacheflow.engine.llm_engine
import
LLMEngine
from
cacheflow.engine.ray_utils
import
initialize_cluster
from
cacheflow.entrypoints.llm
import
LLM
from
cacheflow.entrypoints.llm
import
LLM
from
cacheflow.outputs
import
RequestOutput
,
Completion
Output
from
cacheflow.outputs
import
CompletionOutput
,
Request
Output
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.server.arg_utils
import
ServerArgs
from
cacheflow.server.llm_server
import
LLMEngine
from
cacheflow.server.ray_utils
import
initialize_cluster
__version__
=
"0.1.0"
__version__
=
"0.1.0"
...
@@ -13,6 +13,6 @@ __all__ = [
...
@@ -13,6 +13,6 @@ __all__ = [
"RequestOutput"
,
"RequestOutput"
,
"CompletionOutput"
,
"CompletionOutput"
,
"LLMEngine"
,
"LLMEngine"
,
"
Server
Args"
,
"
Engine
Args"
,
"initialize_cluster"
,
"initialize_cluster"
,
]
]
cacheflow/core/scheduler.py
View file @
e5464ee4
...
@@ -216,7 +216,7 @@ class Scheduler:
...
@@ -216,7 +216,7 @@ class Scheduler:
if
not
self
.
log_stats
:
if
not
self
.
log_stats
:
return
scheduler_outputs
,
prompt_group_ids
return
scheduler_outputs
,
prompt_group_ids
# TODO(woosuk): Move the below code to
server
.
# TODO(woosuk): Move the below code to
the engine
.
now
=
time
.
time
()
now
=
time
.
time
()
if
num_batched_tokens
>
0
:
if
num_batched_tokens
>
0
:
self
.
num_input_tokens
.
append
((
now
,
num_batched_tokens
))
self
.
num_input_tokens
.
append
((
now
,
num_batched_tokens
))
...
...
cacheflow/
server
/__init__.py
→
cacheflow/
engine
/__init__.py
View file @
e5464ee4
File moved
cacheflow/
server
/arg_utils.py
→
cacheflow/
engine
/arg_utils.py
View file @
e5464ee4
...
@@ -8,8 +8,8 @@ from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
...
@@ -8,8 +8,8 @@ from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
@
dataclass
@
dataclass
class
Server
Args
:
class
Engine
Args
:
"""Arguments for CacheFlow
servers
."""
"""Arguments for CacheFlow
engine
."""
model
:
str
model
:
str
download_dir
:
Optional
[
str
]
=
None
download_dir
:
Optional
[
str
]
=
None
use_np_weights
:
bool
=
False
use_np_weights
:
bool
=
False
...
@@ -33,12 +33,12 @@ class ServerArgs:
...
@@ -33,12 +33,12 @@ class ServerArgs:
def
add_cli_args
(
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
,
parser
:
argparse
.
ArgumentParser
,
)
->
argparse
.
ArgumentParser
:
)
->
argparse
.
ArgumentParser
:
"""Shared CLI arguments for CacheFlow
servers
."""
"""Shared CLI arguments for CacheFlow
engine
."""
# Model arguments
# Model arguments
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'facebook/opt-125m'
,
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'facebook/opt-125m'
,
help
=
'name or path of the huggingface model to use'
)
help
=
'name or path of the huggingface model to use'
)
parser
.
add_argument
(
'--download-dir'
,
type
=
str
,
parser
.
add_argument
(
'--download-dir'
,
type
=
str
,
default
=
Server
Args
.
download_dir
,
default
=
Engine
Args
.
download_dir
,
help
=
'directory to download and load the weights, '
help
=
'directory to download and load the weights, '
'default to the default cache dir of '
'default to the default cache dir of '
'huggingface'
)
'huggingface'
)
...
@@ -49,7 +49,7 @@ class ServerArgs:
...
@@ -49,7 +49,7 @@ class ServerArgs:
parser
.
add_argument
(
'--use-dummy-weights'
,
action
=
'store_true'
,
parser
.
add_argument
(
'--use-dummy-weights'
,
action
=
'store_true'
,
help
=
'use dummy values for model weights'
)
help
=
'use dummy values for model weights'
)
# TODO(woosuk): Support FP32.
# TODO(woosuk): Support FP32.
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
Server
Args
.
dtype
,
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
Engine
Args
.
dtype
,
choices
=
[
'auto'
,
'half'
,
'bfloat16'
,
'float'
],
choices
=
[
'auto'
,
'half'
,
'bfloat16'
,
'float'
],
help
=
'data type for model weights and activations. '
help
=
'data type for model weights and activations. '
'The "auto" option will use FP16 precision '
'The "auto" option will use FP16 precision '
...
@@ -60,46 +60,46 @@ class ServerArgs:
...
@@ -60,46 +60,46 @@ class ServerArgs:
help
=
'use Ray for distributed serving, will be '
help
=
'use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU'
)
'automatically set when using more than 1 GPU'
)
parser
.
add_argument
(
'--pipeline-parallel-size'
,
'-pp'
,
type
=
int
,
parser
.
add_argument
(
'--pipeline-parallel-size'
,
'-pp'
,
type
=
int
,
default
=
Server
Args
.
pipeline_parallel_size
,
default
=
Engine
Args
.
pipeline_parallel_size
,
help
=
'number of pipeline stages'
)
help
=
'number of pipeline stages'
)
parser
.
add_argument
(
'--tensor-parallel-size'
,
'-tp'
,
type
=
int
,
parser
.
add_argument
(
'--tensor-parallel-size'
,
'-tp'
,
type
=
int
,
default
=
Server
Args
.
tensor_parallel_size
,
default
=
Engine
Args
.
tensor_parallel_size
,
help
=
'number of tensor parallel replicas'
)
help
=
'number of tensor parallel replicas'
)
# KV cache arguments
# KV cache arguments
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
default
=
Server
Args
.
block_size
,
default
=
Engine
Args
.
block_size
,
choices
=
[
8
,
16
,
32
],
choices
=
[
8
,
16
,
32
],
help
=
'token block size'
)
help
=
'token block size'
)
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
Server
Args
.
seed
,
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
Engine
Args
.
seed
,
help
=
'random seed'
)
help
=
'random seed'
)
parser
.
add_argument
(
'--swap-space'
,
type
=
int
,
parser
.
add_argument
(
'--swap-space'
,
type
=
int
,
default
=
Server
Args
.
swap_space
,
default
=
Engine
Args
.
swap_space
,
help
=
'CPU swap space size (GiB) per GPU'
)
help
=
'CPU swap space size (GiB) per GPU'
)
parser
.
add_argument
(
'--gpu-memory-utilization'
,
type
=
float
,
parser
.
add_argument
(
'--gpu-memory-utilization'
,
type
=
float
,
default
=
Server
Args
.
gpu_memory_utilization
,
default
=
Engine
Args
.
gpu_memory_utilization
,
help
=
'the percentage of GPU memory to be used for'
help
=
'the percentage of GPU memory to be used for'
'the model executor'
)
'the model executor'
)
parser
.
add_argument
(
'--max-num-batched-tokens'
,
type
=
int
,
parser
.
add_argument
(
'--max-num-batched-tokens'
,
type
=
int
,
default
=
Server
Args
.
max_num_batched_tokens
,
default
=
Engine
Args
.
max_num_batched_tokens
,
help
=
'maximum number of batched tokens per '
help
=
'maximum number of batched tokens per '
'iteration'
)
'iteration'
)
parser
.
add_argument
(
'--max-num-seqs'
,
type
=
int
,
parser
.
add_argument
(
'--max-num-seqs'
,
type
=
int
,
default
=
Server
Args
.
max_num_seqs
,
default
=
Engine
Args
.
max_num_seqs
,
help
=
'maximum number of sequences per iteration'
)
help
=
'maximum number of sequences per iteration'
)
parser
.
add_argument
(
'--disable-log-stats'
,
action
=
'store_true'
,
parser
.
add_argument
(
'--disable-log-stats'
,
action
=
'store_true'
,
help
=
'disable logging statistics'
)
help
=
'disable logging statistics'
)
return
parser
return
parser
@
classmethod
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
)
->
"
Server
Args"
:
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
)
->
"
Engine
Args"
:
# Get the list of attributes of this dataclass.
# Get the list of attributes of this dataclass.
attrs
=
[
attr
.
name
for
attr
in
dataclasses
.
fields
(
cls
)]
attrs
=
[
attr
.
name
for
attr
in
dataclasses
.
fields
(
cls
)]
# Set the attributes from the parsed arguments.
# Set the attributes from the parsed arguments.
server
_args
=
cls
(
**
{
attr
:
getattr
(
args
,
attr
)
for
attr
in
attrs
})
engine
_args
=
cls
(
**
{
attr
:
getattr
(
args
,
attr
)
for
attr
in
attrs
})
return
server
_args
return
engine
_args
def
create_
server
_configs
(
def
create_
engine
_configs
(
self
,
self
,
)
->
Tuple
[
ModelConfig
,
CacheConfig
,
ParallelConfig
,
SchedulerConfig
]:
)
->
Tuple
[
ModelConfig
,
CacheConfig
,
ParallelConfig
,
SchedulerConfig
]:
# Initialize the configs.
# Initialize the configs.
...
@@ -117,19 +117,19 @@ class ServerArgs:
...
@@ -117,19 +117,19 @@ class ServerArgs:
@
dataclass
@
dataclass
class
Async
ServerArgs
(
Server
Args
):
class
Async
EngineArgs
(
Engine
Args
):
"""Arguments for asynchronous CacheFlow
servers
."""
"""Arguments for asynchronous CacheFlow
engine
."""
server
_use_ray
:
bool
=
False
engine
_use_ray
:
bool
=
False
disable_log_requests
:
bool
=
False
disable_log_requests
:
bool
=
False
@
staticmethod
@
staticmethod
def
add_cli_args
(
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
,
parser
:
argparse
.
ArgumentParser
,
)
->
argparse
.
ArgumentParser
:
)
->
argparse
.
ArgumentParser
:
parser
=
Server
Args
.
add_cli_args
(
parser
)
parser
=
Engine
Args
.
add_cli_args
(
parser
)
parser
.
add_argument
(
'--
server
-use-ray'
,
action
=
'store_true'
,
parser
.
add_argument
(
'--
engine
-use-ray'
,
action
=
'store_true'
,
help
=
'use Ray to start the LLM
server
in a '
help
=
'use Ray to start the LLM
engine
in a '
'separate process as the
web
server process.'
)
'separate process as the server process.'
)
parser
.
add_argument
(
'--disable-log-requests'
,
action
=
'store_true'
,
parser
.
add_argument
(
'--disable-log-requests'
,
action
=
'store_true'
,
help
=
'disable logging requests'
)
help
=
'disable logging requests'
)
return
parser
return
parser
cacheflow/
server
/async_llm_
server
.py
→
cacheflow/
engine
/async_llm_
engine
.py
View file @
e5464ee4
...
@@ -2,12 +2,12 @@ import asyncio
...
@@ -2,12 +2,12 @@ import asyncio
import
time
import
time
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
from
cacheflow.engine.arg_utils
import
AsyncEngineArgs
from
cacheflow.engine.llm_engine
import
LLMEngine
from
cacheflow.engine.ray_utils
import
initialize_cluster
,
ray
from
cacheflow.logger
import
init_logger
from
cacheflow.logger
import
init_logger
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.server.arg_utils
import
AsyncServerArgs
from
cacheflow.server.llm_server
import
LLMEngine
from
cacheflow.server.ray_utils
import
ray
,
initialize_cluster
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -29,44 +29,44 @@ class AsyncLLMEngine:
...
@@ -29,44 +29,44 @@ class AsyncLLMEngine:
worker_use_ray: Whether to use Ray for model workers. Required for
worker_use_ray: Whether to use Ray for model workers. Required for
distributed execution. Should be the same as
distributed execution. Should be the same as
`parallel_config.worker_use_ray`.
`parallel_config.worker_use_ray`.
server
_use_ray: Whether to make LLMEngine a Ray actor. If so, the
engine
_use_ray: Whether to make LLMEngine a Ray actor. If so, the
async frontend will be executed in a separate process as the
async frontend will be executed in a separate process as the
model workers.
model workers.
log_requests: Whether to log the requests.
log_requests: Whether to log the requests.
*args, *kwargs: Arguments for LLMEngine.
*args, *kwargs: Arguments for LLMEngine.
"""
"""
def
__init__
(
self
,
worker_use_ray
:
bool
,
server
_use_ray
:
bool
,
def
__init__
(
self
,
worker_use_ray
:
bool
,
engine
_use_ray
:
bool
,
log_requests
:
bool
=
True
,
*
args
,
**
kwargs
)
->
None
:
log_requests
:
bool
=
True
,
*
args
,
**
kwargs
)
->
None
:
self
.
worker_use_ray
=
worker_use_ray
self
.
worker_use_ray
=
worker_use_ray
self
.
server
_use_ray
=
server
_use_ray
self
.
engine
_use_ray
=
engine
_use_ray
self
.
log_requests
=
log_requests
self
.
log_requests
=
log_requests
if
not
self
.
server
_use_ray
:
if
not
self
.
engine
_use_ray
:
server
_class
=
LLMEngine
engine
_class
=
LLMEngine
elif
self
.
worker_use_ray
:
elif
self
.
worker_use_ray
:
server
_class
=
ray
.
remote
(
num_cpus
=
0
)(
LLMEngine
).
remote
engine
_class
=
ray
.
remote
(
num_cpus
=
0
)(
LLMEngine
).
remote
else
:
else
:
server
_class
=
ray
.
remote
(
num_gpus
=
1
)(
LLMEngine
).
remote
engine
_class
=
ray
.
remote
(
num_gpus
=
1
)(
LLMEngine
).
remote
self
.
server
=
server
_class
(
*
args
,
**
kwargs
)
self
.
engine
=
engine
_class
(
*
args
,
**
kwargs
)
# Request id -> request output.
# Request id -> request output.
self
.
request_outputs
:
Dict
[
str
,
RequestOutput
]
=
{}
self
.
request_outputs
:
Dict
[
str
,
RequestOutput
]
=
{}
# Request id -> event to notify that there is new output.
# Request id -> event to notify that there is new output.
self
.
request_events
:
Dict
[
str
,
asyncio
.
Event
]
=
{}
self
.
request_events
:
Dict
[
str
,
asyncio
.
Event
]
=
{}
self
.
is_
server
_running
=
False
self
.
is_
engine
_running
=
False
self
.
kicking_request_id
:
Optional
[
str
]
=
None
self
.
kicking_request_id
:
Optional
[
str
]
=
None
async
def
server
_step
(
self
,
kicking_request_id
:
Optional
[
str
]
=
None
):
async
def
engine
_step
(
self
,
kicking_request_id
:
Optional
[
str
]
=
None
):
"""Kick the
server
to process the waiting requests."""
"""Kick the
engine
to process the waiting requests."""
self
.
is_
server
_running
=
True
self
.
is_
engine
_running
=
True
self
.
kicking_request_id
=
kicking_request_id
self
.
kicking_request_id
=
kicking_request_id
if
self
.
server
_use_ray
:
if
self
.
engine
_use_ray
:
request_outputs
=
await
self
.
server
.
step
.
remote
()
request_outputs
=
await
self
.
engine
.
step
.
remote
()
else
:
else
:
# Yield to the event loop to allow other coroutines to run
# Yield to the event loop to allow other coroutines to run
# while is_
server
_running is True. This let the
server
to add new
# while is_
engine
_running is True. This let the
engine
to add new
# requests into the queue.
# requests into the queue.
await
asyncio
.
sleep
(
0
)
await
asyncio
.
sleep
(
0
)
request_outputs
=
self
.
server
.
step
()
request_outputs
=
self
.
engine
.
step
()
self
.
is_
server
_running
=
False
self
.
is_
engine
_running
=
False
self
.
kicking_request_id
=
None
self
.
kicking_request_id
=
None
# Notify the waiting coroutines that there are new outputs ready.
# Notify the waiting coroutines that there are new outputs ready.
...
@@ -104,7 +104,7 @@ class AsyncLLMEngine:
...
@@ -104,7 +104,7 @@ class AsyncLLMEngine:
arrival_time
=
time
.
time
()
arrival_time
=
time
.
time
()
# Create an event to notify us that there is new output from the
# Create an event to notify us that there is new output from the
# cacheflow
server
.
# cacheflow
engine
.
request_event
=
asyncio
.
Event
()
request_event
=
asyncio
.
Event
()
self
.
request_events
[
request_id
]
=
request_event
self
.
request_events
[
request_id
]
=
request_event
...
@@ -114,31 +114,31 @@ class AsyncLLMEngine:
...
@@ -114,31 +114,31 @@ class AsyncLLMEngine:
f
"sampling params:
{
sampling_params
}
, "
f
"sampling params:
{
sampling_params
}
, "
f
"prompt token ids:
{
prompt_token_ids
}
."
)
f
"prompt token ids:
{
prompt_token_ids
}
."
)
# Add the request into the cacheflow
server
's waiting queue.
# Add the request into the cacheflow
engine
's waiting queue.
if
self
.
server
_use_ray
:
if
self
.
engine
_use_ray
:
await
self
.
server
.
add_request
.
remote
(
await
self
.
engine
.
add_request
.
remote
(
request_id
,
prompt
,
sampling_params
,
request_id
,
prompt
,
sampling_params
,
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
arrival_time
=
arrival_time
)
arrival_time
=
arrival_time
)
else
:
else
:
self
.
server
.
add_request
(
self
.
engine
.
add_request
(
request_id
,
prompt
,
sampling_params
,
request_id
,
prompt
,
sampling_params
,
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
arrival_time
=
arrival_time
)
arrival_time
=
arrival_time
)
# The cacheflow
server
does not have a background loop that keeps
# The cacheflow
engine
does not have a background loop that keeps
# processing incoming requests. Therefore, we need to keep kicking
# processing incoming requests. Therefore, we need to keep kicking
# the
server
to process the requests.
# the
engine
to process the requests.
while
True
:
while
True
:
if
request_id
not
in
self
.
request_events
:
if
request_id
not
in
self
.
request_events
:
# The request has been aborted.
# The request has been aborted.
return
return
# Kick the
server
if the
server
is not running.
# Kick the
engine
if the
engine
is not running.
if
not
self
.
is_
server
_running
:
if
not
self
.
is_
engine
_running
:
await
self
.
server
_step
(
request_id
)
await
self
.
engine
_step
(
request_id
)
# Wait for new output. The group_event will be set in
server
_step
# Wait for new output. The group_event will be set in
engine
_step
# when there is new output available for the sequence group.
# when there is new output available for the sequence group.
# Added a timeout to prevent deadlock.
# Added a timeout to prevent deadlock.
try
:
try
:
...
@@ -160,11 +160,11 @@ class AsyncLLMEngine:
...
@@ -160,11 +160,11 @@ class AsyncLLMEngine:
del
self
.
request_outputs
[
request_id
]
del
self
.
request_outputs
[
request_id
]
del
self
.
request_events
[
request_id
]
del
self
.
request_events
[
request_id
]
# Kick the
server
if the
server
is not running. This is to
# Kick the
engine
if the
engine
is not running. This is to
# prevent that there are still requests in
server
's waiting
# prevent that there are still requests in
engine
's waiting
# queue to be executed.
# queue to be executed.
if
not
self
.
is_
server
_running
:
if
not
self
.
is_
engine
_running
:
await
self
.
server
_step
()
await
self
.
engine
_step
()
break
break
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
...
@@ -183,36 +183,36 @@ class AsyncLLMEngine:
...
@@ -183,36 +183,36 @@ class AsyncLLMEngine:
if
self
.
log_requests
:
if
self
.
log_requests
:
logger
.
info
(
f
"Aborted request
{
request_id
}
."
)
logger
.
info
(
f
"Aborted request
{
request_id
}
."
)
if
self
.
server
_use_ray
:
if
self
.
engine
_use_ray
:
await
self
.
server
.
abort_request
.
remote
(
request_id
)
await
self
.
engine
.
abort_request
.
remote
(
request_id
)
else
:
else
:
self
.
server
.
abort_request
(
request_id
)
self
.
engine
.
abort_request
(
request_id
)
if
request_id
in
self
.
request_events
:
if
request_id
in
self
.
request_events
:
del
self
.
request_events
[
request_id
]
del
self
.
request_events
[
request_id
]
if
request_id
in
self
.
request_outputs
:
if
request_id
in
self
.
request_outputs
:
del
self
.
request_outputs
[
request_id
]
del
self
.
request_outputs
[
request_id
]
# To prevent deadlock when a request is aborted while the
server
is
# To prevent deadlock when a request is aborted while the
engine
is
# running.
# running.
if
self
.
kicking_request_id
==
request_id
:
if
self
.
kicking_request_id
==
request_id
:
self
.
is_
server
_running
=
False
self
.
is_
engine
_running
=
False
self
.
kicking_request_id
=
None
self
.
kicking_request_id
=
None
@
classmethod
@
classmethod
def
from_
server
_args
(
cls
,
server
_args
:
Async
Server
Args
)
->
"AsyncLLMEngine"
:
def
from_
engine
_args
(
cls
,
engine
_args
:
Async
Engine
Args
)
->
"AsyncLLMEngine"
:
"""Creates an async LLM
server
from the
server
arguments."""
"""Creates an async LLM
engine
from the
engine
arguments."""
# Create the
server
configs.
# Create the
engine
configs.
server
_configs
=
server
_args
.
create_
server
_configs
()
engine
_configs
=
engine
_args
.
create_
engine
_configs
()
parallel_config
=
server
_configs
[
2
]
parallel_config
=
engine
_configs
[
2
]
# Initialize the cluster.
# Initialize the cluster.
distributed_init_method
,
devices
=
initialize_cluster
(
distributed_init_method
,
devices
=
initialize_cluster
(
parallel_config
,
server_args
.
server
_use_ray
)
parallel_config
,
engine_args
.
engine
_use_ray
)
# Create the
LLM server
.
# Create the
async LLM engine
.
server
=
cls
(
server
_args
.
worker_use_ray
,
engine
=
cls
(
engine
_args
.
worker_use_ray
,
server_args
.
server
_use_ray
,
engine_args
.
engine
_use_ray
,
not
server
_args
.
disable_log_requests
,
not
engine
_args
.
disable_log_requests
,
*
server
_configs
,
*
engine
_configs
,
distributed_init_method
,
devices
,
distributed_init_method
,
devices
,
log_stats
=
not
server
_args
.
disable_log_stats
)
log_stats
=
not
engine
_args
.
disable_log_stats
)
return
server
return
engine
cacheflow/
server/llm_server
.py
→
cacheflow/
engine/llm_engine
.py
View file @
e5464ee4
...
@@ -4,13 +4,13 @@ from typing import Any, List, Optional
...
@@ -4,13 +4,13 @@ from typing import Any, List, Optional
from
cacheflow.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
from
cacheflow.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
SchedulerConfig
)
from
cacheflow.core.scheduler
import
Scheduler
from
cacheflow.core.scheduler
import
Scheduler
from
cacheflow.engine.arg_utils
import
EngineArgs
from
cacheflow.engine.ray_utils
import
DeviceID
,
initialize_cluster
,
ray
from
cacheflow.engine.tokenizer_utils
import
(
detokenize_incrementally
,
get_tokenizer
)
from
cacheflow.logger
import
init_logger
from
cacheflow.logger
import
init_logger
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.server.arg_utils
import
ServerArgs
from
cacheflow.server.ray_utils
import
DeviceID
,
initialize_cluster
,
ray
from
cacheflow.server.tokenizer_utils
import
(
get_tokenizer
,
detokenize_incrementally
)
from
cacheflow.sequence
import
Sequence
,
SequenceGroup
,
SequenceStatus
from
cacheflow.sequence
import
Sequence
,
SequenceGroup
,
SequenceStatus
from
cacheflow.utils
import
Counter
from
cacheflow.utils
import
Counter
from
cacheflow.worker.worker
import
Worker
from
cacheflow.worker.worker
import
Worker
...
@@ -19,9 +19,9 @@ logger = init_logger(__name__)
...
@@ -19,9 +19,9 @@ logger = init_logger(__name__)
class
LLMEngine
:
class
LLMEngine
:
"""An LLM
server
that receives requests and generates texts.
"""An LLM
engine
that receives requests and generates texts.
This is the main class for the CacheFlow LLM
server
. It receives requests
This is the main class for the CacheFlow LLM
engine
. It receives requests
from clients and generates texts from the LLM. It includes a tokenizer, a
from clients and generates texts from the LLM. It includes a tokenizer, a
language model (possibly distributed across multiple GPUs), and GPU memory
language model (possibly distributed across multiple GPUs), and GPU memory
space allocated for intermediate states (aka KV cache). This class utilizes
space allocated for intermediate states (aka KV cache). This class utilizes
...
@@ -31,8 +31,8 @@ class LLMEngine:
...
@@ -31,8 +31,8 @@ class LLMEngine:
The `LLM` class wraps this class for offline batched inference and the
The `LLM` class wraps this class for offline batched inference and the
`AsyncLLMEngine` class wraps this class for online serving.
`AsyncLLMEngine` class wraps this class for online serving.
NOTE: The config arguments are derived from the `
Server
Args` class. For the
NOTE: The config arguments are derived from the `
Engine
Args` class. For the
comprehensive list of arguments, see `
Server
Args`.
comprehensive list of arguments, see `
Engine
Args`.
Args:
Args:
model_config: The configuration related to the LLM model.
model_config: The configuration related to the LLM model.
...
@@ -58,7 +58,7 @@ class LLMEngine:
...
@@ -58,7 +58,7 @@ class LLMEngine:
log_stats
:
bool
,
log_stats
:
bool
,
)
->
None
:
)
->
None
:
logger
.
info
(
logger
.
info
(
"Initializing an LLM
server
with config: "
"Initializing an LLM
engine
with config: "
f
"model=
{
model_config
.
model
!
r
}
, "
f
"model=
{
model_config
.
model
!
r
}
, "
f
"dtype=
{
model_config
.
dtype
}
, "
f
"dtype=
{
model_config
.
dtype
}
, "
f
"use_dummy_weights=
{
model_config
.
use_dummy_weights
}
, "
f
"use_dummy_weights=
{
model_config
.
use_dummy_weights
}
, "
...
@@ -135,17 +135,17 @@ class LLMEngine:
...
@@ -135,17 +135,17 @@ class LLMEngine:
self
.
_run_workers
(
"init_cache_engine"
,
cache_config
=
self
.
cache_config
)
self
.
_run_workers
(
"init_cache_engine"
,
cache_config
=
self
.
cache_config
)
@
classmethod
@
classmethod
def
from_
server
_args
(
cls
,
server
_args
:
Server
Args
)
->
"LLMEngine"
:
def
from_
engine
_args
(
cls
,
engine
_args
:
Engine
Args
)
->
"LLMEngine"
:
"""Creates an LLM
server
from the
server
arguments."""
"""Creates an LLM
engine
from the
engine
arguments."""
# Create the
server
configs.
# Create the
engine
configs.
server
_configs
=
server
_args
.
create_
server
_configs
()
engine
_configs
=
engine
_args
.
create_
engine
_configs
()
parallel_config
=
server
_configs
[
2
]
parallel_config
=
engine
_configs
[
2
]
# Initialize the cluster.
# Initialize the cluster.
distributed_init_method
,
devices
=
initialize_cluster
(
parallel_config
)
distributed_init_method
,
devices
=
initialize_cluster
(
parallel_config
)
# Create the LLM
server
.
# Create the LLM
engine
.
server
=
cls
(
*
server
_configs
,
distributed_init_method
,
devices
,
engine
=
cls
(
*
engine
_configs
,
distributed_init_method
,
devices
,
log_stats
=
not
server
_args
.
disable_log_stats
)
log_stats
=
not
engine
_args
.
disable_log_stats
)
return
server
return
engine
def
add_request
(
def
add_request
(
self
,
self
,
...
@@ -155,10 +155,10 @@ class LLMEngine:
...
@@ -155,10 +155,10 @@ class LLMEngine:
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
)
->
None
:
)
->
None
:
"""Add a request to the
server
's request pool.
"""Add a request to the
engine
's request pool.
The request is added to the request pool and will be processed by the
The request is added to the request pool and will be processed by the
scheduler as `
server
.step()` is called. The exact scheduling policy is
scheduler as `
engine
.step()` is called. The exact scheduling policy is
determined by the scheduler.
determined by the scheduler.
Args:
Args:
...
@@ -211,7 +211,7 @@ class LLMEngine:
...
@@ -211,7 +211,7 @@ class LLMEngine:
def
step
(
self
)
->
List
[
RequestOutput
]:
def
step
(
self
)
->
List
[
RequestOutput
]:
"""Performs one decoding iteration and returns newly generated results.
"""Performs one decoding iteration and returns newly generated results.
This function performs one decoding iteration f
or
the
server
. It first
This function performs one decoding iteration
o
f the
engine
. It first
schedules the sequences to be executed in the next iteration and the
schedules the sequences to be executed in the next iteration and the
token blocks to be swapped in/out/copy. Then, it executes the model
token blocks to be swapped in/out/copy. Then, it executes the model
and updates the scheduler with the model outputs. Finally, it decodes
and updates the scheduler with the model outputs. Finally, it decodes
...
...
cacheflow/
server
/ray_utils.py
→
cacheflow/
engine
/ray_utils.py
View file @
e5464ee4
...
@@ -13,15 +13,15 @@ DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), devi
...
@@ -13,15 +13,15 @@ DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), devi
def
initialize_cluster
(
def
initialize_cluster
(
parallel_config
:
ParallelConfig
,
parallel_config
:
ParallelConfig
,
server
_use_ray
:
bool
=
False
,
engine
_use_ray
:
bool
=
False
,
ray_
server_
address
:
Optional
[
str
]
=
None
,
ray_address
:
Optional
[
str
]
=
None
,
)
->
Tuple
[
str
,
List
[
List
[
DeviceID
]]]:
)
->
Tuple
[
str
,
List
[
List
[
DeviceID
]]]:
"""Initialize the distributed cluster probably with Ray.
"""Initialize the distributed cluster probably with Ray.
Args:
Args:
parallel_config: The configurations for parallel execution.
parallel_config: The configurations for parallel execution.
server
_use_ray: Whether to use Ray for async
server
.
engine
_use_ray: Whether to use Ray for async
engine
.
ray_
server_
address: The address of the Ray cluster. If None, uses
ray_address: The address of the Ray cluster. If None, uses
the default Ray cluster address.
the default Ray cluster address.
Returns:
Returns:
...
@@ -31,13 +31,13 @@ def initialize_cluster(
...
@@ -31,13 +31,13 @@ def initialize_cluster(
each worker in each pipeline stage. Each device ID is a tuple of
each worker in each pipeline stage. Each device ID is a tuple of
(rank, node resource, device id).
(rank, node resource, device id).
"""
"""
if
parallel_config
.
worker_use_ray
or
server
_use_ray
:
if
parallel_config
.
worker_use_ray
or
engine
_use_ray
:
if
ray
is
None
:
if
ray
is
None
:
raise
ImportError
(
raise
ImportError
(
"Ray is not installed. Please install Ray to use distributed "
"Ray is not installed. Please install Ray to use distributed "
"serving."
)
"serving."
)
# Connect to a ray cluster.
# Connect to a ray cluster.
ray
.
init
(
address
=
ray_
server_
address
)
ray
.
init
(
address
=
ray_address
)
if
not
parallel_config
.
worker_use_ray
:
if
not
parallel_config
.
worker_use_ray
:
# Initialize cluster locally.
# Initialize cluster locally.
...
...
cacheflow/
server
/tokenizer_utils.py
→
cacheflow/
engine
/tokenizer_utils.py
View file @
e5464ee4
File moved
cacheflow/entrypoints/api_server.py
View file @
e5464ee4
...
@@ -6,9 +6,9 @@ from fastapi import BackgroundTasks, FastAPI, Request
...
@@ -6,9 +6,9 @@ from fastapi import BackgroundTasks, FastAPI, Request
from
fastapi.responses
import
Response
,
StreamingResponse
from
fastapi.responses
import
Response
,
StreamingResponse
import
uvicorn
import
uvicorn
from
cacheflow.engine.arg_utils
import
AsyncEngineArgs
from
cacheflow.engine.async_llm_engine
import
AsyncLLMEngine
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.server.arg_utils
import
AsyncServerArgs
from
cacheflow.server.async_llm_server
import
AsyncLLMEngine
from
cacheflow.utils
import
random_uuid
from
cacheflow.utils
import
random_uuid
TIMEOUT_KEEP_ALIVE
=
5
# seconds.
TIMEOUT_KEEP_ALIVE
=
5
# seconds.
...
@@ -30,7 +30,7 @@ async def generate(request: Request) -> Response:
...
@@ -30,7 +30,7 @@ async def generate(request: Request) -> Response:
stream
=
request_dict
.
pop
(
"stream"
,
False
)
stream
=
request_dict
.
pop
(
"stream"
,
False
)
sampling_params
=
SamplingParams
(
**
request_dict
)
sampling_params
=
SamplingParams
(
**
request_dict
)
request_id
=
random_uuid
()
request_id
=
random_uuid
()
results_generator
=
server
.
generate
(
prompt
,
sampling_params
,
request_id
)
results_generator
=
engine
.
generate
(
prompt
,
sampling_params
,
request_id
)
# Streaming case
# Streaming case
async
def
stream_results
()
->
AsyncGenerator
[
bytes
,
None
]:
async
def
stream_results
()
->
AsyncGenerator
[
bytes
,
None
]:
...
@@ -44,7 +44,7 @@ async def generate(request: Request) -> Response:
...
@@ -44,7 +44,7 @@ async def generate(request: Request) -> Response:
yield
(
json
.
dumps
(
ret
)
+
"
\0
"
).
encode
(
"utf-8"
)
yield
(
json
.
dumps
(
ret
)
+
"
\0
"
).
encode
(
"utf-8"
)
async
def
abort_request
()
->
None
:
async
def
abort_request
()
->
None
:
await
server
.
abort
(
request_id
)
await
engine
.
abort
(
request_id
)
if
stream
:
if
stream
:
background_tasks
=
BackgroundTasks
()
background_tasks
=
BackgroundTasks
()
...
@@ -57,7 +57,7 @@ async def generate(request: Request) -> Response:
...
@@ -57,7 +57,7 @@ async def generate(request: Request) -> Response:
async
for
request_output
in
results_generator
:
async
for
request_output
in
results_generator
:
if
await
request
.
is_disconnected
():
if
await
request
.
is_disconnected
():
# Abort the request if the client disconnects.
# Abort the request if the client disconnects.
await
server
.
abort
(
request_id
)
await
engine
.
abort
(
request_id
)
return
Response
(
status_code
=
499
)
return
Response
(
status_code
=
499
)
final_output
=
request_output
final_output
=
request_output
...
@@ -75,11 +75,11 @@ if __name__ == "__main__":
...
@@ -75,11 +75,11 @@ if __name__ == "__main__":
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
)
parser
=
Async
Server
Args
.
add_cli_args
(
parser
)
parser
=
Async
Engine
Args
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
server
_args
=
Async
Server
Args
.
from_cli_args
(
args
)
engine
_args
=
Async
Engine
Args
.
from_cli_args
(
args
)
server
=
AsyncLLMEngine
.
from_
server_args
(
server
_args
)
engine
=
AsyncLLMEngine
.
from_
engine_args
(
engine
_args
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"debug"
,
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"debug"
,
timeout_keep_alive
=
TIMEOUT_KEEP_ALIVE
)
timeout_keep_alive
=
TIMEOUT_KEEP_ALIVE
)
cacheflow/entrypoints/llm.py
View file @
e5464ee4
from
typing
import
List
,
Optional
,
Union
from
typing
import
List
,
Optional
,
Union
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
from
cacheflow.engine.arg_utils
import
EngineArgs
from
cacheflow.engine.llm_engine
import
LLMEngine
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.server.arg_utils
import
ServerArgs
from
cacheflow.server.llm_server
import
LLMEngine
from
cacheflow.utils
import
Counter
from
cacheflow.utils
import
Counter
...
@@ -21,7 +21,7 @@ class LLM:
...
@@ -21,7 +21,7 @@ class LLM:
NOTE: This class is intended to be used for offline inference. For online
NOTE: This class is intended to be used for offline inference. For online
serving, use the `AsyncLLMEngine` class instead.
serving, use the `AsyncLLMEngine` class instead.
NOTE: For the comprehensive list of arguments, see `
Server
Args`.
NOTE: For the comprehensive list of arguments, see `
Engine
Args`.
Args:
Args:
model: The name or path of a HuggingFace Transformers model.
model: The name or path of a HuggingFace Transformers model.
...
@@ -45,20 +45,20 @@ class LLM:
...
@@ -45,20 +45,20 @@ class LLM:
)
->
None
:
)
->
None
:
if
"disable_log_stats"
not
in
kwargs
:
if
"disable_log_stats"
not
in
kwargs
:
kwargs
[
"disable_log_stats"
]
=
True
kwargs
[
"disable_log_stats"
]
=
True
server
_args
=
Server
Args
(
engine
_args
=
Engine
Args
(
model
=
model
,
model
=
model
,
tensor_parallel_size
=
tensor_parallel_size
,
tensor_parallel_size
=
tensor_parallel_size
,
dtype
=
dtype
,
dtype
=
dtype
,
seed
=
seed
,
seed
=
seed
,
**
kwargs
,
**
kwargs
,
)
)
self
.
llm_
server
=
LLMEngine
.
from_
server_args
(
server
_args
)
self
.
llm_
engine
=
LLMEngine
.
from_
engine_args
(
engine
_args
)
self
.
request_counter
=
Counter
()
self
.
request_counter
=
Counter
()
def
get_tokenizer
(
def
get_tokenizer
(
self
,
self
,
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]:
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]:
return
self
.
llm_
server
.
tokenizer
return
self
.
llm_
engine
.
tokenizer
def
generate
(
def
generate
(
self
,
self
,
...
@@ -99,7 +99,7 @@ class LLM:
...
@@ -99,7 +99,7 @@ class LLM:
# Use default sampling params.
# Use default sampling params.
sampling_params
=
SamplingParams
()
sampling_params
=
SamplingParams
()
# Add requests to the
server
.
# Add requests to the
engine
.
if
prompts
is
not
None
:
if
prompts
is
not
None
:
num_requests
=
len
(
prompts
)
num_requests
=
len
(
prompts
)
else
:
else
:
...
@@ -111,7 +111,7 @@ class LLM:
...
@@ -111,7 +111,7 @@ class LLM:
else
:
else
:
token_ids
=
prompt_token_ids
[
i
]
token_ids
=
prompt_token_ids
[
i
]
self
.
_add_request
(
prompt
,
sampling_params
,
token_ids
)
self
.
_add_request
(
prompt
,
sampling_params
,
token_ids
)
return
self
.
_run_
server
(
use_tqdm
)
return
self
.
_run_
engine
(
use_tqdm
)
def
_add_request
(
def
_add_request
(
self
,
self
,
...
@@ -120,18 +120,18 @@ class LLM:
...
@@ -120,18 +120,18 @@ class LLM:
prompt_token_ids
:
Optional
[
List
[
int
]],
prompt_token_ids
:
Optional
[
List
[
int
]],
)
->
None
:
)
->
None
:
request_id
=
str
(
next
(
self
.
request_counter
))
request_id
=
str
(
next
(
self
.
request_counter
))
self
.
llm_
server
.
add_request
(
request_id
,
prompt
,
sampling_params
,
self
.
llm_
engine
.
add_request
(
request_id
,
prompt
,
sampling_params
,
prompt_token_ids
)
prompt_token_ids
)
def
_run_
server
(
self
,
use_tqdm
:
bool
)
->
List
[
RequestOutput
]:
def
_run_
engine
(
self
,
use_tqdm
:
bool
)
->
List
[
RequestOutput
]:
# Initialize tqdm.
# Initialize tqdm.
if
use_tqdm
:
if
use_tqdm
:
num_requests
=
self
.
llm_
server
.
get_num_unfinished_requests
()
num_requests
=
self
.
llm_
engine
.
get_num_unfinished_requests
()
pbar
=
tqdm
(
total
=
num_requests
,
desc
=
"Processed prompts"
)
pbar
=
tqdm
(
total
=
num_requests
,
desc
=
"Processed prompts"
)
# Run the
server
.
# Run the
engine
.
outputs
:
List
[
RequestOutput
]
=
[]
outputs
:
List
[
RequestOutput
]
=
[]
while
self
.
llm_
server
.
has_unfinished_requests
():
while
self
.
llm_
engine
.
has_unfinished_requests
():
step_outputs
=
self
.
llm_
server
.
step
()
step_outputs
=
self
.
llm_
engine
.
step
()
for
output
in
step_outputs
:
for
output
in
step_outputs
:
if
output
.
finished
():
if
output
.
finished
():
outputs
.
append
(
output
)
outputs
.
append
(
output
)
...
...
cacheflow/entrypoints/openai/api_server.py
View file @
e5464ee4
...
@@ -10,29 +10,20 @@ import fastapi
...
@@ -10,29 +10,20 @@ import fastapi
from
fastapi
import
BackgroundTasks
,
Request
from
fastapi
import
BackgroundTasks
,
Request
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
Streaming
Response
,
JSON
Response
from
fastapi.responses
import
JSON
Response
,
Streaming
Response
import
uvicorn
import
uvicorn
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.engine.arg_utils
import
AsyncEngineArgs
from
cacheflow.server.arg_utils
import
AsyncServerArgs
from
cacheflow.engine.async_llm_engine
import
AsyncLLMEngine
from
cacheflow.server.async_llm_server
import
AsyncLLMEngine
from
cacheflow.engine.tokenizer_utils
import
get_tokenizer
from
cacheflow.server.tokenizer_utils
import
get_tokenizer
from
cacheflow.entrypoints.openai.protocol
import
(
CompletionRequest
,
CompletionResponse
,
CompletionResponseChoice
,
CompletionResponseStreamChoice
,
CompletionStreamResponse
,
ErrorResponse
,
LogProbs
,
ModelCard
,
ModelList
,
ModelPermission
,
UsageInfo
)
from
cacheflow.logger
import
init_logger
from
cacheflow.logger
import
init_logger
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.utils
import
random_uuid
from
cacheflow.utils
import
random_uuid
from
cacheflow.entrypoints.openai.protocol
import
(
CompletionRequest
,
CompletionResponse
,
CompletionResponseChoice
,
CompletionResponseStreamChoice
,
CompletionStreamResponse
,
ErrorResponse
,
LogProbs
,
ModelCard
,
ModelList
,
ModelPermission
,
UsageInfo
,
)
TIMEOUT_KEEP_ALIVE
=
5
# seconds
TIMEOUT_KEEP_ALIVE
=
5
# seconds
...
@@ -102,11 +93,11 @@ async def create_completion(raw_request: Request):
...
@@ -102,11 +93,11 @@ async def create_completion(raw_request: Request):
for the API specification. This API mimics the OpenAI Completion API.
for the API specification. This API mimics the OpenAI Completion API.
NOTE: Currently we do not support the following features:
NOTE: Currently we do not support the following features:
- echo (since the cacheflow
server
does not currently support
- echo (since the cacheflow
engine
does not currently support
getting the logprobs of prompt tokens)
getting the logprobs of prompt tokens)
- suffix (the language models we currently support do not support
- suffix (the language models we currently support do not support
suffix)
suffix)
- logit_bias (to be supported in cacheflow
server
)
- logit_bias (to be supported in cacheflow
engine
)
"""
"""
request
=
CompletionRequest
(
**
await
raw_request
.
json
())
request
=
CompletionRequest
(
**
await
raw_request
.
json
())
logger
.
info
(
f
"Received completion request:
{
request
}
"
)
logger
.
info
(
f
"Received completion request:
{
request
}
"
)
...
@@ -116,7 +107,7 @@ async def create_completion(raw_request: Request):
...
@@ -116,7 +107,7 @@ async def create_completion(raw_request: Request):
return
error_check_ret
return
error_check_ret
if
request
.
echo
:
if
request
.
echo
:
# We do not support echo since the cacheflow
server
does not
# We do not support echo since the cacheflow
engine
does not
# currently support getting the logprobs of prompt tokens.
# currently support getting the logprobs of prompt tokens.
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
"echo is not currently supported"
)
"echo is not currently supported"
)
...
@@ -127,7 +118,7 @@ async def create_completion(raw_request: Request):
...
@@ -127,7 +118,7 @@ async def create_completion(raw_request: Request):
"suffix is not currently supported"
)
"suffix is not currently supported"
)
if
request
.
logit_bias
is
not
None
:
if
request
.
logit_bias
is
not
None
:
# TODO: support logit_bias in cacheflow
server
.
# TODO: support logit_bias in cacheflow
engine
.
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
"logit_bias is not currently supported"
)
"logit_bias is not currently supported"
)
...
@@ -153,7 +144,7 @@ async def create_completion(raw_request: Request):
...
@@ -153,7 +144,7 @@ async def create_completion(raw_request: Request):
except
ValueError
as
e
:
except
ValueError
as
e
:
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
e
))
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
e
))
result_generator
=
server
.
generate
(
prompt
,
sampling_params
,
result_generator
=
engine
.
generate
(
prompt
,
sampling_params
,
request_id
)
request_id
)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# Similar to the OpenAI API, when n != best_of, we do not stream the
...
@@ -163,7 +154,7 @@ async def create_completion(raw_request: Request):
...
@@ -163,7 +154,7 @@ async def create_completion(raw_request: Request):
not
request
.
use_beam_search
)
not
request
.
use_beam_search
)
async
def
abort_request
()
->
None
:
async
def
abort_request
()
->
None
:
await
server
.
abort
(
request_id
)
await
engine
.
abort
(
request_id
)
def
create_stream_response_json
(
index
:
int
,
def
create_stream_response_json
(
index
:
int
,
text
:
str
,
text
:
str
,
...
@@ -303,7 +294,7 @@ if __name__ == "__main__":
...
@@ -303,7 +294,7 @@ if __name__ == "__main__":
help
=
"The model name used in the API. If not specified, "
help
=
"The model name used in the API. If not specified, "
"the model name will be the same as the "
"the model name will be the same as the "
"huggingface name."
)
"huggingface name."
)
parser
=
Async
Server
Args
.
add_cli_args
(
parser
)
parser
=
Async
Engine
Args
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
app
.
add_middleware
(
app
.
add_middleware
(
...
@@ -318,8 +309,8 @@ if __name__ == "__main__":
...
@@ -318,8 +309,8 @@ if __name__ == "__main__":
served_model
=
args
.
served_model_name
or
args
.
model
served_model
=
args
.
served_model_name
or
args
.
model
server
_args
=
Async
Server
Args
.
from_cli_args
(
args
)
engine
_args
=
Async
Engine
Args
.
from_cli_args
(
args
)
server
=
AsyncLLMEngine
.
from_
server_args
(
server
_args
)
engine
=
AsyncLLMEngine
.
from_
engine_args
(
engine
_args
)
# A separate tokenizer to map token IDs to strings.
# A separate tokenizer to map token IDs to strings.
tokenizer
=
get_tokenizer
(
args
.
model
)
tokenizer
=
get_tokenizer
(
args
.
model
)
...
...
examples/llm
server
_example.py
→
examples/llm
_engine
_example.py
View file @
e5464ee4
import
argparse
import
argparse
from
cacheflow
import
Server
Args
,
LLMEngine
,
SamplingParams
from
cacheflow
import
Engine
Args
,
LLMEngine
,
SamplingParams
def
main
(
args
:
argparse
.
Namespace
):
def
main
(
args
:
argparse
.
Namespace
):
# Parse the CLI argument and initialize the
server
.
# Parse the CLI argument and initialize the
engine
.
server
_args
=
Server
Args
.
from_cli_args
(
args
)
engine
_args
=
Engine
Args
.
from_cli_args
(
args
)
server
=
LLMEngine
.
from_
server_args
(
server
_args
)
engine
=
LLMEngine
.
from_
engine_args
(
engine
_args
)
# Test the following prompts.
# Test the following prompts.
test_prompts
=
[
test_prompts
=
[
...
@@ -19,27 +19,27 @@ def main(args: argparse.Namespace):
...
@@ -19,27 +19,27 @@ def main(args: argparse.Namespace):
SamplingParams
(
n
=
3
,
best_of
=
3
,
use_beam_search
=
True
,
temperature
=
0.0
)),
SamplingParams
(
n
=
3
,
best_of
=
3
,
use_beam_search
=
True
,
temperature
=
0.0
)),
]
]
# Run the
server
by calling `
server
.step()` manually.
# Run the
engine
by calling `
engine
.step()` manually.
request_id
=
0
request_id
=
0
while
True
:
while
True
:
# To test iteration-level scheduling, we add one request at each step.
# To test iteration-level scheduling, we add one request at each step.
if
test_prompts
:
if
test_prompts
:
prompt
,
sampling_params
=
test_prompts
.
pop
(
0
)
prompt
,
sampling_params
=
test_prompts
.
pop
(
0
)
server
.
add_request
(
str
(
request_id
),
prompt
,
sampling_params
)
engine
.
add_request
(
str
(
request_id
),
prompt
,
sampling_params
)
request_id
+=
1
request_id
+=
1
request_outputs
=
server
.
step
()
request_outputs
=
engine
.
step
()
for
request_output
in
request_outputs
:
for
request_output
in
request_outputs
:
if
request_output
.
finished
():
if
request_output
.
finished
():
print
(
request_output
)
print
(
request_output
)
if
not
(
server
.
has_unfinished_requests
()
or
test_prompts
):
if
not
(
engine
.
has_unfinished_requests
()
or
test_prompts
):
break
break
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
description
=
'Demo on using the LLMEngine class
synchronous
ly'
)
description
=
'Demo on using the LLMEngine class
direct
ly'
)
parser
=
Server
Args
.
add_cli_args
(
parser
)
parser
=
Engine
Args
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
)
main
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment