Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
655a5e48
Unverified
Commit
655a5e48
authored
May 21, 2023
by
Woosuk Kwon
Committed by
GitHub
May 21, 2023
Browse files
Introduce LLM class for offline inference (#115)
parent
f746ced0
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
221 additions
and
80 deletions
+221
-80
cacheflow/__init__.py
cacheflow/__init__.py
+5
-9
cacheflow/config.py
cacheflow/config.py
+5
-1
cacheflow/entrypoints/fastapi_server.py
cacheflow/entrypoints/fastapi_server.py
+3
-4
cacheflow/entrypoints/llm.py
cacheflow/entrypoints/llm.py
+62
-0
cacheflow/outputs.py
cacheflow/outputs.py
+5
-5
cacheflow/server/arg_utils.py
cacheflow/server/arg_utils.py
+97
-54
cacheflow/server/llm_server.py
cacheflow/server/llm_server.py
+16
-2
examples/offline_inference.py
examples/offline_inference.py
+23
-0
examples/simple_server.py
examples/simple_server.py
+5
-5
No files found.
cacheflow/__init__.py
View file @
655a5e48
from
cacheflow.entrypoints.llm
import
LLM
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.server.arg_utils
import
(
from
cacheflow.server.arg_utils
import
ServerArgs
add_server_arguments
,
create_server_configs_from_args
,
initialize_server_from_args
,
)
from
cacheflow.server.llm_server
import
LLMServer
from
cacheflow.server.llm_server
import
LLMServer
from
cacheflow.server.ray_utils
import
initialize_cluster
from
cacheflow.server.ray_utils
import
initialize_cluster
__all__
=
[
__all__
=
[
"
RequestOutput
"
,
"
LLM
"
,
"SamplingParams"
,
"SamplingParams"
,
"RequestOutput"
,
"LLMServer"
,
"LLMServer"
,
"add_server_arguments"
,
"ServerArgs"
,
"create_server_configs_from_args"
,
"initialize_server_from_args"
,
"initialize_cluster"
,
"initialize_cluster"
,
]
]
cacheflow/config.py
View file @
655a5e48
...
@@ -3,6 +3,8 @@ from typing import Optional
...
@@ -3,6 +3,8 @@ from typing import Optional
import
torch
import
torch
from
transformers
import
AutoConfig
,
PretrainedConfig
from
transformers
import
AutoConfig
,
PretrainedConfig
_GiB
=
1
<<
30
class
ModelConfig
:
class
ModelConfig
:
...
@@ -70,7 +72,7 @@ class CacheConfig:
...
@@ -70,7 +72,7 @@ class CacheConfig:
)
->
None
:
)
->
None
:
self
.
block_size
=
block_size
self
.
block_size
=
block_size
self
.
gpu_memory_utilization
=
gpu_memory_utilization
self
.
gpu_memory_utilization
=
gpu_memory_utilization
self
.
swap_space
=
swap_space
self
.
swap_space
_bytes
=
swap_space
*
_GiB
# Will be set after profiling.
# Will be set after profiling.
self
.
num_gpu_blocks
=
None
self
.
num_gpu_blocks
=
None
...
@@ -138,6 +140,8 @@ def _get_and_verify_dtype(
...
@@ -138,6 +140,8 @@ def _get_and_verify_dtype(
else
:
else
:
torch_dtype
=
config_dtype
torch_dtype
=
config_dtype
else
:
else
:
if
dtype
not
in
_STR_DTYPE_TO_TORCH_DTYPE
:
raise
ValueError
(
f
"Unknown dtype:
{
dtype
}
"
)
torch_dtype
=
_STR_DTYPE_TO_TORCH_DTYPE
[
dtype
]
torch_dtype
=
_STR_DTYPE_TO_TORCH_DTYPE
[
dtype
]
# Verify the dtype.
# Verify the dtype.
...
...
cacheflow/entrypoints/fastapi_server.py
View file @
655a5e48
...
@@ -12,8 +12,7 @@ import uvicorn
...
@@ -12,8 +12,7 @@ import uvicorn
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.server.arg_utils
import
(
from
cacheflow.server.arg_utils
import
ServerArgs
add_server_arguments
,
create_server_configs_from_args
)
from
cacheflow.server.llm_server
import
LLMServer
from
cacheflow.server.llm_server
import
LLMServer
from
cacheflow.server.ray_utils
import
initialize_cluster
from
cacheflow.server.ray_utils
import
initialize_cluster
...
@@ -116,10 +115,10 @@ if __name__ == "__main__":
...
@@ -116,10 +115,10 @@ if __name__ == "__main__":
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
10002
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
10002
)
parser
=
add_server_argument
s
(
parser
)
parser
=
ServerArgs
.
add_cli_arg
s
(
parser
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
server_configs
=
create_server_configs
_from_args
(
args
)
server_configs
=
ServerArgs
.
from_cli_args
(
args
).
create_server_configs
(
)
parallel_config
=
server_configs
[
2
]
parallel_config
=
server_configs
[
2
]
distributed_init_method
,
stage_devices
=
initialize_cluster
(
parallel_config
)
distributed_init_method
,
stage_devices
=
initialize_cluster
(
parallel_config
)
...
...
cacheflow/entrypoints/llm.py
0 → 100644
View file @
655a5e48
from
typing
import
List
,
Optional
from
tqdm
import
tqdm
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.server.arg_utils
import
ServerArgs
from
cacheflow.server.llm_server
import
LLMServer
from
cacheflow.utils
import
Counter
class
LLM
:
def
__init__
(
self
,
model
:
str
,
tensor_parallel_size
:
int
=
1
,
dtype
:
str
=
"default"
,
seed
:
int
=
0
,
**
kwargs
,
)
->
None
:
if
"disable_log_stats"
not
in
kwargs
:
kwargs
[
"disable_log_stats"
]
=
True
server_args
=
ServerArgs
(
model
=
model
,
tensor_parallel_size
=
tensor_parallel_size
,
dtype
=
dtype
,
seed
=
seed
,
**
kwargs
,
)
self
.
llm_server
=
LLMServer
.
from_server_args
(
server_args
)
self
.
request_counter
=
Counter
()
def
generate
(
self
,
prompts
:
List
[
str
],
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
use_tqdm
:
bool
=
True
,
)
->
List
[
RequestOutput
]:
if
sampling_params
is
None
:
sampling_params
=
SamplingParams
()
# Initialize tqdm.
if
use_tqdm
:
pbar
=
tqdm
(
total
=
len
(
prompts
),
desc
=
"Processed prompts"
)
# Add requests to the server.
for
prompt
in
prompts
:
request_id
=
str
(
next
(
self
.
request_counter
))
self
.
llm_server
.
add_request
(
request_id
,
prompt
,
sampling_params
)
# Run the server.
outputs
:
List
[
RequestOutput
]
=
[]
while
self
.
llm_server
.
has_unfinished_requests
():
step_outputs
=
self
.
llm_server
.
step
()
for
output
in
step_outputs
:
if
output
.
done
:
outputs
.
append
(
output
)
if
use_tqdm
:
pbar
.
update
(
1
)
if
use_tqdm
:
pbar
.
close
()
return
outputs
cacheflow/outputs.py
View file @
655a5e48
...
@@ -35,7 +35,7 @@ class RequestOutput:
...
@@ -35,7 +35,7 @@ class RequestOutput:
prompt
:
str
,
prompt
:
str
,
prompt_token_ids
:
List
[
int
],
prompt_token_ids
:
List
[
int
],
outputs
:
List
[
CompletionOutput
],
outputs
:
List
[
CompletionOutput
],
done
:
bool
=
False
,
done
:
bool
,
)
->
None
:
)
->
None
:
self
.
request_id
=
request_id
self
.
request_id
=
request_id
self
.
prompt
=
prompt
self
.
prompt
=
prompt
...
@@ -43,8 +43,8 @@ class RequestOutput:
...
@@ -43,8 +43,8 @@ class RequestOutput:
self
.
outputs
=
outputs
self
.
outputs
=
outputs
self
.
done
=
done
self
.
done
=
done
@
static
method
@
class
method
def
from_seq_group
(
seq_group
:
SequenceGroup
)
->
"RequestOutput"
:
def
from_seq_group
(
cls
,
seq_group
:
SequenceGroup
)
->
"RequestOutput"
:
# Get the top-n sequences.
# Get the top-n sequences.
n
=
seq_group
.
sampling_params
.
n
n
=
seq_group
.
sampling_params
.
n
seqs
=
seq_group
.
get_seqs
()
seqs
=
seq_group
.
get_seqs
()
...
@@ -70,8 +70,8 @@ class RequestOutput:
...
@@ -70,8 +70,8 @@ class RequestOutput:
# Every sequence in the sequence group should have the same prompt.
# Every sequence in the sequence group should have the same prompt.
prompt
=
top_n_seqs
[
0
].
prompt
prompt
=
top_n_seqs
[
0
].
prompt
prompt_token_ids
=
top_n_seqs
[
0
].
data
.
prompt_token_ids
prompt_token_ids
=
top_n_seqs
[
0
].
data
.
prompt_token_ids
return
RequestOutput
(
seq_group
.
request_id
,
prompt
,
prompt_token_ids
,
return
cls
(
seq_group
.
request_id
,
prompt
,
prompt_token_ids
,
outputs
,
outputs
,
seq_group
.
is_finished
())
seq_group
.
is_finished
())
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"RequestOutput(request_id=
{
self
.
request_id
}
, "
return
(
f
"RequestOutput(request_id=
{
self
.
request_id
}
, "
...
...
cacheflow/server/arg_utils.py
View file @
655a5e48
import
argparse
import
argparse
from
typing
import
Tuple
import
dataclasses
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
from
cacheflow.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
from
cacheflow.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
SchedulerConfig
)
from
cacheflow.server.llm_server
import
LLMServer
from
cacheflow.server.ray_utils
import
initialize_cluster
_GiB
=
1
<<
30
@
dataclass
class
ServerArgs
:
model
:
str
download_dir
:
Optional
[
str
]
=
None
use_np_weights
:
bool
=
False
use_dummy_weights
:
bool
=
False
dtype
:
str
=
"default"
seed
:
int
=
0
use_ray
:
bool
=
False
pipeline_parallel_size
:
int
=
1
tensor_parallel_size
:
int
=
1
block_size
:
int
=
16
swap_space
:
int
=
4
# GiB
gpu_memory_utilization
:
float
=
0.95
max_num_batched_tokens
:
int
=
2560
max_num_seqs
:
int
=
256
disable_log_stats
:
bool
=
False
def
add_server_arguments
(
parser
:
argparse
.
ArgumentParser
):
def
__post_init__
(
self
):
"""Shared arguments for CacheFlow servers."""
self
.
max_num_seqs
=
min
(
self
.
max_num_seqs
,
self
.
max_num_batched_tokens
)
@
staticmethod
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
,
)
->
argparse
.
ArgumentParser
:
return
_add_server_arguments
(
parser
)
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
)
->
"ServerArgs"
:
# Get the list of attributes of this dataclass.
attrs
=
[
attr
.
name
for
attr
in
dataclasses
.
fields
(
cls
)]
# Set the attributes from the parsed arguments.
server_args
=
cls
(
**
{
attr
:
getattr
(
args
,
attr
)
for
attr
in
attrs
})
return
server_args
def
create_server_configs
(
self
,
)
->
Tuple
[
ModelConfig
,
CacheConfig
,
ParallelConfig
,
SchedulerConfig
]:
# Initialize the configs.
model_config
=
ModelConfig
(
self
.
model
,
self
.
download_dir
,
self
.
use_np_weights
,
self
.
use_dummy_weights
,
self
.
dtype
,
self
.
seed
)
cache_config
=
CacheConfig
(
self
.
block_size
,
self
.
gpu_memory_utilization
,
self
.
swap_space
)
parallel_config
=
ParallelConfig
(
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
,
self
.
use_ray
)
scheduler_config
=
SchedulerConfig
(
self
.
max_num_batched_tokens
,
self
.
max_num_seqs
)
return
model_config
,
cache_config
,
parallel_config
,
scheduler_config
def
_add_server_arguments
(
parser
:
argparse
.
ArgumentParser
,
)
->
argparse
.
ArgumentParser
:
"""Shared CLI arguments for CacheFlow servers."""
# Model arguments
# Model arguments
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'facebook/opt-125m'
,
help
=
'model name'
)
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'facebook/opt-125m'
,
parser
.
add_argument
(
'--download-dir'
,
type
=
str
,
default
=
None
,
help
=
'name or path of the huggingface model to use'
)
parser
.
add_argument
(
'--download-dir'
,
type
=
str
,
default
=
ServerArgs
.
download_dir
,
help
=
'directory to download and load the weights, '
help
=
'directory to download and load the weights, '
'default to the default cache dir of huggingface'
)
'default to the default cache dir of huggingface'
)
parser
.
add_argument
(
'--use-np-weights'
,
action
=
'store_true'
,
parser
.
add_argument
(
'--use-np-weights'
,
action
=
'store_true'
,
help
=
'save a numpy copy of model weights for faster loading'
)
help
=
'save a numpy copy of model weights for faster '
parser
.
add_argument
(
'--use-dummy-weights'
,
action
=
'store_true'
,
help
=
'use dummy values for model weights'
)
'loading. This can increase the disk usage by up '
'to 2x.'
)
parser
.
add_argument
(
'--use-dummy-weights'
,
action
=
'store_true'
,
help
=
'use dummy values for model weights'
)
# TODO(woosuk): Support FP32.
# TODO(woosuk): Support FP32.
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'default'
,
choices
=
[
'default'
,
'half'
,
'bfloat16'
],
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
ServerArgs
.
dtype
,
choices
=
[
'default'
,
'half'
,
'bfloat16'
],
help
=
(
'data type for model weights and activations. '
help
=
(
'data type for model weights and activations. '
'The "default" option will use FP16 precision '
'The "default" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.'
))
'for BF16 models.'
))
# Parallel arguments
# Parallel arguments
parser
.
add_argument
(
'--use-ray'
,
action
=
'store_true'
,
help
=
'use Ray for distributed serving, will be automatically set when using more than 1 GPU'
)
parser
.
add_argument
(
'--use-ray'
,
action
=
'store_true'
,
parser
.
add_argument
(
'--pipeline-parallel-size'
,
'-pp'
,
type
=
int
,
default
=
1
,
help
=
'number of pipeline stages'
)
help
=
'use Ray for distributed serving, will be '
parser
.
add_argument
(
'--tensor-parallel-size'
,
'-tp'
,
type
=
int
,
default
=
1
,
help
=
'number of tensor parallel replicas'
)
'automatically set when using more than 1 GPU'
)
parser
.
add_argument
(
'--pipeline-parallel-size'
,
'-pp'
,
type
=
int
,
default
=
ServerArgs
.
pipeline_parallel_size
,
help
=
'number of pipeline stages'
)
parser
.
add_argument
(
'--tensor-parallel-size'
,
'-tp'
,
type
=
int
,
default
=
ServerArgs
.
tensor_parallel_size
,
help
=
'number of tensor parallel replicas'
)
# KV cache arguments
# KV cache arguments
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
default
=
16
,
choices
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
],
help
=
'token block size'
)
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
default
=
ServerArgs
.
block_size
,
choices
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
],
help
=
'token block size'
)
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
0
,
help
=
'random seed'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
ServerArgs
.
seed
,
parser
.
add_argument
(
'--swap-space'
,
type
=
int
,
default
=
4
,
help
=
'CPU swap space size (GiB) per GPU'
)
help
=
'random seed'
)
parser
.
add_argument
(
'--gpu-memory-utilization'
,
type
=
float
,
default
=
0.95
,
help
=
'the percentage of GPU memory to be used for the model executor'
)
parser
.
add_argument
(
'--swap-space'
,
type
=
int
,
default
=
ServerArgs
.
swap_space
,
parser
.
add_argument
(
'--max-num-batched-tokens'
,
type
=
int
,
default
=
2560
,
help
=
'maximum number of batched tokens per iteration'
)
help
=
'CPU swap space size (GiB) per GPU'
)
parser
.
add_argument
(
'--max-num-seqs'
,
type
=
int
,
default
=
256
,
help
=
'maximum number of sequences per iteration'
)
parser
.
add_argument
(
'--gpu-memory-utilization'
,
type
=
float
,
parser
.
add_argument
(
'--disable-log-stats'
,
action
=
'store_true'
,
help
=
'disable logging statistics'
)
default
=
ServerArgs
.
gpu_memory_utilization
,
help
=
'the percentage of GPU memory to be used for the '
'model executor'
)
parser
.
add_argument
(
'--max-num-batched-tokens'
,
type
=
int
,
default
=
ServerArgs
.
max_num_batched_tokens
,
help
=
'maximum number of batched tokens per iteration'
)
parser
.
add_argument
(
'--max-num-seqs'
,
type
=
int
,
default
=
ServerArgs
.
max_num_seqs
,
help
=
'maximum number of sequences per iteration'
)
parser
.
add_argument
(
'--disable-log-stats'
,
action
=
'store_true'
,
help
=
'disable logging statistics'
)
return
parser
return
parser
def
create_server_configs_from_args
(
args
:
argparse
.
Namespace
,
)
->
Tuple
[
ModelConfig
,
CacheConfig
,
ParallelConfig
,
SchedulerConfig
]:
# Post-process the parsed arguments.
args
.
swap_space
=
args
.
swap_space
*
_GiB
args
.
max_num_seqs
=
min
(
args
.
max_num_seqs
,
args
.
max_num_batched_tokens
)
# Initialize the configs.
model_config
=
ModelConfig
(
args
.
model
,
args
.
download_dir
,
args
.
use_np_weights
,
args
.
use_dummy_weights
,
args
.
dtype
,
args
.
seed
)
cache_config
=
CacheConfig
(
args
.
block_size
,
args
.
gpu_memory_utilization
,
args
.
swap_space
)
parallel_config
=
ParallelConfig
(
args
.
pipeline_parallel_size
,
args
.
tensor_parallel_size
,
args
.
use_ray
)
scheduler_config
=
SchedulerConfig
(
args
.
max_num_batched_tokens
,
args
.
max_num_seqs
)
return
model_config
,
cache_config
,
parallel_config
,
scheduler_config
def
initialize_server_from_args
(
args
:
argparse
.
Namespace
)
->
LLMServer
:
server_configs
=
create_server_configs_from_args
(
args
)
parallel_config
=
server_configs
[
2
]
# Initialize the cluster.
distributed_init_method
,
devices
=
initialize_cluster
(
parallel_config
)
# Create the LLM server.
server
=
LLMServer
(
*
server_configs
,
distributed_init_method
,
devices
,
log_stats
=
not
args
.
disable_log_stats
)
return
server
cacheflow/server/llm_server.py
View file @
655a5e48
...
@@ -12,6 +12,8 @@ from cacheflow.core.scheduler import Scheduler
...
@@ -12,6 +12,8 @@ from cacheflow.core.scheduler import Scheduler
from
cacheflow.logger
import
init_logger
from
cacheflow.logger
import
init_logger
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.server.arg_utils
import
ServerArgs
from
cacheflow.server.ray_utils
import
initialize_cluster
from
cacheflow.server.tokenizer_utils
import
get_tokenizer
from
cacheflow.server.tokenizer_utils
import
get_tokenizer
from
cacheflow.sequence
import
Sequence
,
SequenceGroup
,
SequenceStatus
from
cacheflow.sequence
import
Sequence
,
SequenceGroup
,
SequenceStatus
from
cacheflow.utils
import
Counter
from
cacheflow.utils
import
Counter
...
@@ -30,7 +32,7 @@ class LLMServer:
...
@@ -30,7 +32,7 @@ class LLMServer:
scheduler_config
:
SchedulerConfig
,
scheduler_config
:
SchedulerConfig
,
distributed_init_method
:
str
,
distributed_init_method
:
str
,
stage_devices
:
List
[
List
[
Any
]],
stage_devices
:
List
[
List
[
Any
]],
log_stats
:
bool
=
True
,
log_stats
:
bool
,
)
->
None
:
)
->
None
:
logger
.
info
(
logger
.
info
(
"Initializing an LLM server with config: "
"Initializing an LLM server with config: "
...
@@ -90,7 +92,7 @@ class LLMServer:
...
@@ -90,7 +92,7 @@ class LLMServer:
get_all_outputs
=
True
,
get_all_outputs
=
True
,
block_size
=
self
.
cache_config
.
block_size
,
block_size
=
self
.
cache_config
.
block_size
,
gpu_memory_utilization
=
self
.
cache_config
.
gpu_memory_utilization
,
gpu_memory_utilization
=
self
.
cache_config
.
gpu_memory_utilization
,
cpu_swap_space
=
self
.
cache_config
.
swap_space
,
cpu_swap_space
=
self
.
cache_config
.
swap_space
_bytes
,
)
)
# Since we use a shared centralized controller, we take the minimum
# Since we use a shared centralized controller, we take the minimum
...
@@ -107,6 +109,18 @@ class LLMServer:
...
@@ -107,6 +109,18 @@ class LLMServer:
# Initialize the cache.
# Initialize the cache.
self
.
_run_workers
(
"init_cache_engine"
,
cache_config
=
self
.
cache_config
)
self
.
_run_workers
(
"init_cache_engine"
,
cache_config
=
self
.
cache_config
)
@
classmethod
def
from_server_args
(
cls
,
server_args
:
ServerArgs
)
->
"LLMServer"
:
# Create the server configs.
server_configs
=
server_args
.
create_server_configs
()
parallel_config
=
server_configs
[
2
]
# Initialize the cluster.
distributed_init_method
,
devices
=
initialize_cluster
(
parallel_config
)
# Create the LLM server.
server
=
cls
(
*
server_configs
,
distributed_init_method
,
devices
,
log_stats
=
not
server_args
.
disable_log_stats
)
return
server
def
add_request
(
def
add_request
(
self
,
self
,
request_id
:
str
,
request_id
:
str
,
...
...
examples/offline_inference.py
0 → 100644
View file @
655a5e48
from
cacheflow
import
LLM
,
SamplingParams
# Sample prompts.
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
# Create a sampling params object.
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
)
# Create an LLM.
llm
=
LLM
(
model
=
"facebook/opt-125m"
)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
# Print the outputs.
for
output
in
outputs
:
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
examples/simple_server.py
View file @
655a5e48
import
argparse
import
argparse
import
uuid
import
uuid
from
cacheflow
import
(
add_server_arguments
,
initialize_server_from_args
,
from
cacheflow
import
ServerArgs
,
LLMServer
,
SamplingParams
SamplingParams
)
def
main
(
args
:
argparse
.
Namespace
):
def
main
(
args
:
argparse
.
Namespace
):
# Initialize the server.
# Parse the CLI argument and initialize the server.
server
=
initialize_server_from_args
(
args
)
server_args
=
ServerArgs
.
from_cli_args
(
args
)
server
=
LLMServer
.
from_server_args
(
server_args
)
# Test the following prompts.
# Test the following prompts.
test_prompts
=
[
test_prompts
=
[
...
@@ -39,6 +39,6 @@ def main(args: argparse.Namespace):
...
@@ -39,6 +39,6 @@ def main(args: argparse.Namespace):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'Simple CacheFlow server.'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'Simple CacheFlow server.'
)
parser
=
add_server_argument
s
(
parser
)
parser
=
ServerArgs
.
add_cli_arg
s
(
parser
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
)
main
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment