Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
42983742
Unverified
Commit
42983742
authored
Jun 07, 2023
by
Zhuohan Li
Committed by
GitHub
Jun 07, 2023
Browse files
Add docstrings for LLMServer and related classes and examples (#142)
parent
e38074b1
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
216 additions
and
22 deletions
+216
-22
cacheflow/config.py
cacheflow/config.py
+39
-3
cacheflow/entrypoints/openai/openai_frontend.py
cacheflow/entrypoints/openai/openai_frontend.py
+12
-0
cacheflow/entrypoints/simple_fastapi_frontend.py
cacheflow/entrypoints/simple_fastapi_frontend.py
+6
-0
cacheflow/server/arg_utils.py
cacheflow/server/arg_utils.py
+2
-0
cacheflow/server/async_llm_server.py
cacheflow/server/async_llm_server.py
+67
-7
cacheflow/server/llm_server.py
cacheflow/server/llm_server.py
+65
-4
cacheflow/server/ray_utils.py
cacheflow/server/ray_utils.py
+17
-2
cacheflow/server/tokenizer_utils.py
cacheflow/server/tokenizer_utils.py
+1
-0
examples/openai_client.py
examples/openai_client.py
+6
-5
examples/simple_server.py
examples/simple_server.py
+1
-1
No files found.
cacheflow/config.py
View file @
42983742
...
@@ -12,6 +12,20 @@ _GiB = 1 << 30
...
@@ -12,6 +12,20 @@ _GiB = 1 << 30
class
ModelConfig
:
class
ModelConfig
:
"""Configuration for the model.
Args:
model: Name or path of the huggingface model to use.
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
use_np_weights: Save a numpy copy of model weights for faster loading.
This can increase the disk usage by up to 2x.
use_dummy_weights: Use dummy values for model weights (for profiling).
dtype: Data type for model weights and activations. The "auto" option
will use FP16 precision for FP32 and FP16 models, and BF16 precision
for BF16 models.
seed: Random seed for reproducibility.
"""
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -68,7 +82,14 @@ class ModelConfig:
...
@@ -68,7 +82,14 @@ class ModelConfig:
class
CacheConfig
:
class
CacheConfig
:
"""Configuration for the KV cache.
Args:
block_size: Size of a cache block in number of tokens.
gpu_memory_utilization: Fraction of GPU memory to use for the
CacheFlow execution.
swap_space: Size of the CPU swap space per GPU (in GiB).
"""
def
__init__
(
def
__init__
(
self
,
self
,
block_size
:
int
,
block_size
:
int
,
...
@@ -111,7 +132,15 @@ class CacheConfig:
...
@@ -111,7 +132,15 @@ class CacheConfig:
class
ParallelConfig
:
class
ParallelConfig
:
"""Configuration for the distributed execution.
Args:
pipeline_parallel_size: Number of pipeline parallel groups.
tensor_parallel_size: Number of tensor parallel groups.
worker_use_ray: Whether to use Ray for model workers. Will be set to
True if either pipeline_parallel_size or tensor_parallel_size is
greater than 1.
"""
def
__init__
(
def
__init__
(
self
,
self
,
pipeline_parallel_size
:
int
,
pipeline_parallel_size
:
int
,
...
@@ -134,7 +163,14 @@ class ParallelConfig:
...
@@ -134,7 +163,14 @@ class ParallelConfig:
class
SchedulerConfig
:
class
SchedulerConfig
:
"""Scheduler configuration.
Args:
max_num_batched_tokens: Maximum number of tokens to be processed in
a single iteration.
max_num_seqs: Maximum number of sequences to be processed in a single
iteration.
"""
def
__init__
(
def
__init__
(
self
,
self
,
max_num_batched_tokens
:
int
,
max_num_batched_tokens
:
int
,
...
...
cacheflow/entrypoints/openai/openai_frontend.py
View file @
42983742
...
@@ -96,6 +96,18 @@ def create_logprobs(token_ids: List[int],
...
@@ -96,6 +96,18 @@ def create_logprobs(token_ids: List[int],
@
app
.
post
(
"/v1/completions"
)
@
app
.
post
(
"/v1/completions"
)
async
def
create_completion
(
raw_request
:
Request
):
async
def
create_completion
(
raw_request
:
Request
):
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create
for the API specification. This API mimics the OpenAI Completion API.
NOTE: Currently we do not support the following features:
- echo (since the cacheflow server does not currently support
getting the logprobs of prompt tokens)
- suffix (the language models we currently support do not support
suffix)
- logit_bias (to be supported in cacheflow server)
"""
request
=
CompletionRequest
(
**
await
raw_request
.
json
())
request
=
CompletionRequest
(
**
await
raw_request
.
json
())
logger
.
info
(
f
"Received completion request:
{
request
}
"
)
logger
.
info
(
f
"Received completion request:
{
request
}
"
)
...
...
cacheflow/entrypoints/simple_fastapi_frontend.py
View file @
42983742
...
@@ -18,6 +18,12 @@ app = FastAPI()
...
@@ -18,6 +18,12 @@ app = FastAPI()
@
app
.
post
(
"/generate"
)
@
app
.
post
(
"/generate"
)
async
def
generate_stream
(
request
:
Request
)
->
StreamingResponse
:
async
def
generate_stream
(
request
:
Request
)
->
StreamingResponse
:
""" Stream the results of the generation request.
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- other fields: the sampling parameters (See `SamplingParams` for details).
"""
request_dict
=
await
request
.
json
()
request_dict
=
await
request
.
json
()
prompt
=
request_dict
.
pop
(
"prompt"
)
prompt
=
request_dict
.
pop
(
"prompt"
)
sampling_params
=
SamplingParams
(
**
request_dict
)
sampling_params
=
SamplingParams
(
**
request_dict
)
...
...
cacheflow/server/arg_utils.py
View file @
42983742
...
@@ -9,6 +9,7 @@ from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
...
@@ -9,6 +9,7 @@ from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
@
dataclass
@
dataclass
class
ServerArgs
:
class
ServerArgs
:
"""Arguments for CacheFlow servers."""
model
:
str
model
:
str
download_dir
:
Optional
[
str
]
=
None
download_dir
:
Optional
[
str
]
=
None
use_np_weights
:
bool
=
False
use_np_weights
:
bool
=
False
...
@@ -117,6 +118,7 @@ class ServerArgs:
...
@@ -117,6 +118,7 @@ class ServerArgs:
@
dataclass
@
dataclass
class
AsyncServerArgs
(
ServerArgs
):
class
AsyncServerArgs
(
ServerArgs
):
"""Arguments for asynchronous CacheFlow servers."""
server_use_ray
:
bool
=
False
server_use_ray
:
bool
=
False
@
staticmethod
@
staticmethod
...
...
cacheflow/server/async_llm_server.py
View file @
42983742
import
asyncio
import
asyncio
import
time
import
time
from
typing
import
Dict
,
Optional
from
typing
import
Dict
,
List
,
Optional
from
cacheflow.logger
import
init_logger
from
cacheflow.logger
import
init_logger
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.outputs
import
RequestOutput
...
@@ -15,7 +15,25 @@ TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
...
@@ -15,7 +15,25 @@ TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
class
AsyncLLMServer
:
class
AsyncLLMServer
:
"""An asynchronous wrapper for LLMServer.
This class is used to wrap the LLMServer class to make it asynchronous. It
uses asyncio to create a background loop that keeps processing incoming
requests. The LLMServer is kicked by the generate method when there
are requests in the waiting queue. The generate method yields the outputs
from the LLMServer to the caller.
NOTE: For the comprehensive list of arguments, see `LLMServer`.
Args:
worker_use_ray: Whether to use Ray for model workers. Required for
distributed execution. Should be the same as
`parallel_config.worker_use_ray`.
server_use_ray: Whether to make LLMServer a Ray actor. If so, the
async frontend will be executed in a separate process as the
model workers.
*args, *kwargs: Arguments for LLMServer.
"""
def
__init__
(
self
,
worker_use_ray
:
bool
,
server_use_ray
:
bool
,
def
__init__
(
self
,
worker_use_ray
:
bool
,
server_use_ray
:
bool
,
*
args
,
**
kwargs
)
->
None
:
*
args
,
**
kwargs
)
->
None
:
self
.
worker_use_ray
=
worker_use_ray
self
.
worker_use_ray
=
worker_use_ray
...
@@ -35,6 +53,7 @@ class AsyncLLMServer:
...
@@ -35,6 +53,7 @@ class AsyncLLMServer:
self
.
kicking_request_id
:
Optional
[
str
]
=
None
self
.
kicking_request_id
:
Optional
[
str
]
=
None
async
def
server_step
(
self
,
kicking_request_id
:
Optional
[
str
]
=
None
):
async
def
server_step
(
self
,
kicking_request_id
:
Optional
[
str
]
=
None
):
"""Kick the server to process the waiting requests."""
self
.
is_server_running
=
True
self
.
is_server_running
=
True
self
.
kicking_request_id
=
kicking_request_id
self
.
kicking_request_id
=
kicking_request_id
if
self
.
server_use_ray
:
if
self
.
server_use_ray
:
...
@@ -54,8 +73,31 @@ class AsyncLLMServer:
...
@@ -54,8 +73,31 @@ class AsyncLLMServer:
self
.
request_outputs
[
request_id
]
=
request_output
self
.
request_outputs
[
request_id
]
=
request_output
self
.
request_events
[
request_id
].
set
()
self
.
request_events
[
request_id
].
set
()
async
def
generate
(
self
,
prompt
:
str
,
sampling_params
:
SamplingParams
,
async
def
generate
(
request_id
:
str
)
->
RequestOutput
:
self
,
prompt
:
Optional
[
str
],
sampling_params
:
SamplingParams
,
request_id
:
str
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
)
->
RequestOutput
:
"""Generate outputs for a request.
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMServer and streams the outputs
from the LLMServer to the caller.
Args:
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
Yields:
The output `RequestOutput` objects from the LLMServer for the
request.
"""
# Preprocess the request.
# Preprocess the request.
arrival_time
=
time
.
time
()
arrival_time
=
time
.
time
()
...
@@ -66,20 +108,29 @@ class AsyncLLMServer:
...
@@ -66,20 +108,29 @@ class AsyncLLMServer:
logger
.
info
(
f
"Received request
{
request_id
}
: "
logger
.
info
(
f
"Received request
{
request_id
}
: "
f
"prompt:
{
prompt
!
r
}
, "
f
"prompt:
{
prompt
!
r
}
, "
f
"sampling params:
{
sampling_params
}
."
)
f
"sampling params:
{
sampling_params
}
, "
f
"prompt token ids:
{
prompt_token_ids
}
."
)
# Add the request into the cacheflow server's waiting queue.
# Add the request into the cacheflow server's waiting queue.
if
self
.
server_use_ray
:
if
self
.
server_use_ray
:
await
self
.
server
.
add_request
.
remote
(
await
self
.
server
.
add_request
.
remote
(
request_id
,
prompt
,
sampling_params
,
arrival_time
=
arrival_time
)
request_id
,
prompt
,
sampling_params
,
prompt_token_ids
=
prompt_token_ids
,
arrival_time
=
arrival_time
)
else
:
else
:
self
.
server
.
add_request
(
self
.
server
.
add_request
(
request_id
,
prompt
,
sampling_params
,
arrival_time
=
arrival_time
)
request_id
,
prompt
,
sampling_params
,
prompt_token_ids
=
prompt_token_ids
,
arrival_time
=
arrival_time
)
# The cacheflow server does not have a background loop that keeps
# The cacheflow server does not have a background loop that keeps
# processing incoming requests. Therefore, we need to keep kicking
# processing incoming requests. Therefore, we need to keep kicking
# the server to process the requests.
# the server to process the requests.
while
True
:
while
True
:
if
request_id
not
in
self
.
request_events
:
# The request has been aborted.
return
# Kick the server if the server is not running.
# Kick the server if the server is not running.
if
not
self
.
is_server_running
:
if
not
self
.
is_server_running
:
await
self
.
server_step
(
request_id
)
await
self
.
server_step
(
request_id
)
...
@@ -113,6 +164,14 @@ class AsyncLLMServer:
...
@@ -113,6 +164,14 @@ class AsyncLLMServer:
break
break
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
"""Abort a request.
Abort a submitted request. If the request is finished or not found,
this method will be a no-op.
Args:
request_id: The unique id of the request.
"""
if
request_id
not
in
self
.
request_events
:
if
request_id
not
in
self
.
request_events
:
# The request has already finished or been aborted.
# The request has already finished or been aborted.
return
return
...
@@ -137,6 +196,7 @@ class AsyncLLMServer:
...
@@ -137,6 +196,7 @@ class AsyncLLMServer:
@
classmethod
@
classmethod
def
from_server_args
(
cls
,
server_args
:
AsyncServerArgs
)
->
"AsyncLLMServer"
:
def
from_server_args
(
cls
,
server_args
:
AsyncServerArgs
)
->
"AsyncLLMServer"
:
"""Creates an async LLM server from the server arguments."""
# Create the server configs.
# Create the server configs.
server_configs
=
server_args
.
create_server_configs
()
server_configs
=
server_args
.
create_server_configs
()
parallel_config
=
server_configs
[
2
]
parallel_config
=
server_configs
[
2
]
...
...
cacheflow/server/llm_server.py
View file @
42983742
...
@@ -8,7 +8,7 @@ from cacheflow.logger import init_logger
...
@@ -8,7 +8,7 @@ from cacheflow.logger import init_logger
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.server.arg_utils
import
ServerArgs
from
cacheflow.server.arg_utils
import
ServerArgs
from
cacheflow.server.ray_utils
import
ray
,
initialize_cluster
from
cacheflow.server.ray_utils
import
DeviceID
,
initialize_cluster
,
ray
from
cacheflow.server.tokenizer_utils
import
(
get_tokenizer
,
from
cacheflow.server.tokenizer_utils
import
(
get_tokenizer
,
detokenize_incrementally
)
detokenize_incrementally
)
from
cacheflow.sequence
import
Sequence
,
SequenceGroup
,
SequenceStatus
from
cacheflow.sequence
import
Sequence
,
SequenceGroup
,
SequenceStatus
...
@@ -19,6 +19,33 @@ logger = init_logger(__name__)
...
@@ -19,6 +19,33 @@ logger = init_logger(__name__)
class
LLMServer
:
class
LLMServer
:
"""An LLM server that receives requests and generates texts.
This is the main class for the CacheFlow LLM server. It receives requests
from clients and generates texts from the LLM. It includes a tokenizer, a
language model (possibly distributed across multiple GPUs), and GPU memory
space allocated for intermediate states (aka KV cache). This class utilizes
iteration-level scheduling and efficient memory management to maximize the
serving throughput.
The `LLM` class wraps this class for offline batched inference and the
`AsyncLLMServer` class wraps this class for online serving.
NOTE: The config arguments are derived from the `ServerArgs` class. For the
comprehensive list of arguments, see `ServerArgs`.
Args:
model_config: The configuration related to the LLM model.
cache_config: The configuration related to the KV cache memory
management.
parallel_config: The configuration related to distributed execution.
scheduler_config: The configuration related to the request scheduler.
distributed_init_method: The initialization method for distributed
execution. See `torch.distributed.init_process_group` for details.
stage_devices: The list of devices for each stage. Each stage is a list
of (rank, node_resource, device) tuples.
log_stats: Whether to log statistics.
"""
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -27,7 +54,7 @@ class LLMServer:
...
@@ -27,7 +54,7 @@ class LLMServer:
parallel_config
:
ParallelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
scheduler_config
:
SchedulerConfig
,
distributed_init_method
:
str
,
distributed_init_method
:
str
,
stage_devices
:
List
[
List
[
Any
]],
stage_devices
:
List
[
List
[
DeviceID
]],
log_stats
:
bool
,
log_stats
:
bool
,
)
->
None
:
)
->
None
:
logger
.
info
(
logger
.
info
(
...
@@ -83,6 +110,7 @@ class LLMServer:
...
@@ -83,6 +110,7 @@ class LLMServer:
self
.
cache_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
self
.
cache_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
def
_init_cache
(
self
)
->
None
:
def
_init_cache
(
self
)
->
None
:
"""Profiles the memory usage and initializes the KV cache."""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks
=
self
.
_run_workers
(
num_blocks
=
self
.
_run_workers
(
"profile_num_available_blocks"
,
"profile_num_available_blocks"
,
...
@@ -108,6 +136,7 @@ class LLMServer:
...
@@ -108,6 +136,7 @@ class LLMServer:
@
classmethod
@
classmethod
def
from_server_args
(
cls
,
server_args
:
ServerArgs
)
->
"LLMServer"
:
def
from_server_args
(
cls
,
server_args
:
ServerArgs
)
->
"LLMServer"
:
"""Creates an LLM server from the server arguments."""
# Create the server configs.
# Create the server configs.
server_configs
=
server_args
.
create_server_configs
()
server_configs
=
server_args
.
create_server_configs
()
parallel_config
=
server_configs
[
2
]
parallel_config
=
server_configs
[
2
]
...
@@ -126,6 +155,22 @@ class LLMServer:
...
@@ -126,6 +155,22 @@ class LLMServer:
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
)
->
None
:
)
->
None
:
"""Add a request to the server's request pool.
The request is added to the request pool and will be processed by the
scheduler as `server.step()` is called. The exact scheduling policy is
determined by the scheduler.
Args:
request_id: The unique ID of the request.
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
sampling_params: The sampling parameters for text generation.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
arrival_time: The arrival time of the request. If None, we use
the current time.
"""
if
arrival_time
is
None
:
if
arrival_time
is
None
:
arrival_time
=
time
.
time
()
arrival_time
=
time
.
time
()
if
prompt_token_ids
is
None
:
if
prompt_token_ids
is
None
:
...
@@ -148,15 +193,30 @@ class LLMServer:
...
@@ -148,15 +193,30 @@ class LLMServer:
self
.
scheduler
.
add_seq_group
(
seq_group
)
self
.
scheduler
.
add_seq_group
(
seq_group
)
def
abort_request
(
self
,
request_id
:
str
)
->
None
:
def
abort_request
(
self
,
request_id
:
str
)
->
None
:
"""Aborts a request with the given ID.
Args:
request_id: The ID of the request to abort.
"""
self
.
scheduler
.
abort_seq_group
(
request_id
)
self
.
scheduler
.
abort_seq_group
(
request_id
)
def
get_num_unfinished_requests
(
self
)
->
int
:
def
get_num_unfinished_requests
(
self
)
->
int
:
"""Gets the number of unfinished requests."""
return
self
.
scheduler
.
get_num_unfinished_seq_groups
()
return
self
.
scheduler
.
get_num_unfinished_seq_groups
()
def
has_unfinished_requests
(
self
)
->
bool
:
def
has_unfinished_requests
(
self
)
->
bool
:
"""Returns True if there are unfinished requests."""
return
self
.
scheduler
.
has_unfinished_seqs
()
return
self
.
scheduler
.
has_unfinished_seqs
()
def
step
(
self
)
->
List
[
RequestOutput
]:
def
step
(
self
)
->
List
[
RequestOutput
]:
"""Performs one decoding iteration and returns newly generated results.
This function performs one decoding iteration for the server. It first
schedules the sequences to be executed in the next iteration and the
token blocks to be swapped in/out/copy. Then, it executes the model
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
seq_group_metadata_list
,
scheduler_outputs
=
self
.
scheduler
.
schedule
()
seq_group_metadata_list
,
scheduler_outputs
=
self
.
scheduler
.
schedule
()
if
(
not
seq_group_metadata_list
)
and
scheduler_outputs
.
is_empty
():
if
(
not
seq_group_metadata_list
)
and
scheduler_outputs
.
is_empty
():
# Nothing to do.
# Nothing to do.
...
@@ -188,7 +248,7 @@ class LLMServer:
...
@@ -188,7 +248,7 @@ class LLMServer:
return
request_outputs
return
request_outputs
def
_decode_sequences
(
self
,
seq_groups
:
List
[
SequenceGroup
])
->
None
:
def
_decode_sequences
(
self
,
seq_groups
:
List
[
SequenceGroup
])
->
None
:
#
Decode the sequence outputs.
"""
Decode
s
the sequence outputs.
"""
for
seq_group
in
seq_groups
:
for
seq_group
in
seq_groups
:
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
new_token
,
new_output_text
=
detokenize_incrementally
(
new_token
,
new_output_text
=
detokenize_incrementally
(
...
@@ -201,7 +261,7 @@ class LLMServer:
...
@@ -201,7 +261,7 @@ class LLMServer:
seq
.
output_text
=
new_output_text
seq
.
output_text
=
new_output_text
def
_stop_sequences
(
self
,
seq_groups
:
List
[
SequenceGroup
])
->
None
:
def
_stop_sequences
(
self
,
seq_groups
:
List
[
SequenceGroup
])
->
None
:
#
Stop the sequences.
"""
Stop the
finished
sequences.
"""
for
seq_group
in
seq_groups
:
for
seq_group
in
seq_groups
:
sampling_params
=
seq_group
.
sampling_params
sampling_params
=
seq_group
.
sampling_params
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
...
@@ -238,6 +298,7 @@ class LLMServer:
...
@@ -238,6 +298,7 @@ class LLMServer:
*
args
,
*
args
,
**
kwargs
,
**
kwargs
,
)
->
Any
:
)
->
Any
:
"""Runs the given method on all workers."""
all_outputs
=
[]
all_outputs
=
[]
for
worker
in
self
.
workers
:
for
worker
in
self
.
workers
:
executor
=
getattr
(
worker
,
method
)
executor
=
getattr
(
worker
,
method
)
...
...
cacheflow/server/ray_utils.py
View file @
42983742
...
@@ -14,15 +14,30 @@ DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), devi
...
@@ -14,15 +14,30 @@ DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), devi
def
initialize_cluster
(
def
initialize_cluster
(
parallel_config
:
ParallelConfig
,
parallel_config
:
ParallelConfig
,
server_use_ray
:
bool
=
False
,
server_use_ray
:
bool
=
False
,
address
:
Optional
[
str
]
=
None
,
ray_server_
address
:
Optional
[
str
]
=
None
,
)
->
Tuple
[
str
,
List
[
List
[
DeviceID
]]]:
)
->
Tuple
[
str
,
List
[
List
[
DeviceID
]]]:
"""Initialize the distributed cluster probably with Ray.
Args:
parallel_config: The configurations for parallel execution.
server_use_ray: Whether to use Ray for async server.
ray_server_address: The address of the Ray cluster. If None, uses
the default Ray cluster address.
Returns:
A tuple of (`distributed_init_method`, `all_stage_devices`). The
`distributed_init_method` is the address for initializing the
distributed backend. `all_stage_devices` includes device IDs for
each worker in each pipeline stage. Each device ID is a tuple of
(rank, node resource, device id).
"""
if
parallel_config
.
worker_use_ray
or
server_use_ray
:
if
parallel_config
.
worker_use_ray
or
server_use_ray
:
if
ray
is
None
:
if
ray
is
None
:
raise
ImportError
(
raise
ImportError
(
"Ray is not installed. Please install Ray to use distributed "
"Ray is not installed. Please install Ray to use distributed "
"serving."
)
"serving."
)
# Connect to a ray cluster.
# Connect to a ray cluster.
ray
.
init
(
address
=
address
)
ray
.
init
(
address
=
ray_server_
address
)
if
not
parallel_config
.
worker_use_ray
:
if
not
parallel_config
.
worker_use_ray
:
# Initialize cluster locally.
# Initialize cluster locally.
...
...
cacheflow/server/tokenizer_utils.py
View file @
42983742
...
@@ -15,6 +15,7 @@ def get_tokenizer(
...
@@ -15,6 +15,7 @@ def get_tokenizer(
*
args
,
*
args
,
**
kwargs
,
**
kwargs
,
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]:
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]:
"""Gets a tokenizer for the given model name via Huggingface."""
config
=
AutoConfig
.
from_pretrained
(
model_name
)
config
=
AutoConfig
.
from_pretrained
(
model_name
)
if
config
.
model_type
==
"llama"
and
getattr
(
kwargs
,
"use_fast"
,
True
):
if
config
.
model_type
==
"llama"
and
getattr
(
kwargs
,
"use_fast"
,
True
):
# LLaMA fast tokenizer causes protobuf errors in some environments.
# LLaMA fast tokenizer causes protobuf errors in some environments.
...
...
examples/openai_client.py
View file @
42983742
import
openai
import
openai
# Modify OpenAI's API key and API base to use CacheFlow's API server.
openai
.
api_key
=
"EMPTY"
openai
.
api_key
=
"EMPTY"
openai
.
api_base
=
"http://localhost:8000/v1"
openai
.
api_base
=
"http://localhost:8000/v1"
model
=
"facebook/opt-125m"
model
=
"facebook/opt-125m"
# list models
#
Test
list models
API
models
=
openai
.
Model
.
list
()
models
=
openai
.
Model
.
list
()
print
(
models
)
print
(
"Models:"
,
models
)
# create a completion
# Test completion API
stream
=
True
stream
=
True
completion
=
openai
.
Completion
.
create
(
completion
=
openai
.
Completion
.
create
(
model
=
model
,
prompt
=
"A robot may not injure a human being"
,
echo
=
False
,
n
=
2
,
model
=
model
,
prompt
=
"A robot may not injure a human being"
,
echo
=
False
,
n
=
2
,
...
@@ -19,4 +20,4 @@ if stream:
...
@@ -19,4 +20,4 @@ if stream:
for
c
in
completion
:
for
c
in
completion
:
print
(
c
)
print
(
c
)
else
:
else
:
print
(
"
c
ompletion:"
,
completion
)
print
(
"
C
ompletion
result
:"
,
completion
)
examples/simple_server.py
View file @
42983742
...
@@ -19,7 +19,7 @@ def main(args: argparse.Namespace):
...
@@ -19,7 +19,7 @@ def main(args: argparse.Namespace):
SamplingParams
(
n
=
3
,
best_of
=
3
,
use_beam_search
=
True
,
temperature
=
0.0
)),
SamplingParams
(
n
=
3
,
best_of
=
3
,
use_beam_search
=
True
,
temperature
=
0.0
)),
]
]
# Run the server.
# Run the server
by calling `server.step()` manually
.
request_id
=
0
request_id
=
0
while
True
:
while
True
:
# To test iteration-level scheduling, we add one request at each step.
# To test iteration-level scheduling, we add one request at each step.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment