Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
655a5e48
Unverified
Commit
655a5e48
authored
May 21, 2023
by
Woosuk Kwon
Committed by
GitHub
May 21, 2023
Browse files
Introduce LLM class for offline inference (#115)
parent
f746ced0
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
221 additions
and
80 deletions
+221
-80
cacheflow/__init__.py
cacheflow/__init__.py
+5
-9
cacheflow/config.py
cacheflow/config.py
+5
-1
cacheflow/entrypoints/fastapi_server.py
cacheflow/entrypoints/fastapi_server.py
+3
-4
cacheflow/entrypoints/llm.py
cacheflow/entrypoints/llm.py
+62
-0
cacheflow/outputs.py
cacheflow/outputs.py
+5
-5
cacheflow/server/arg_utils.py
cacheflow/server/arg_utils.py
+97
-54
cacheflow/server/llm_server.py
cacheflow/server/llm_server.py
+16
-2
examples/offline_inference.py
examples/offline_inference.py
+23
-0
examples/simple_server.py
examples/simple_server.py
+5
-5
No files found.
cacheflow/__init__.py
View file @
655a5e48
from
cacheflow.entrypoints.llm
import
LLM
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.server.arg_utils
import
(
from
cacheflow.server.arg_utils
import
ServerArgs
add_server_arguments
,
create_server_configs_from_args
,
initialize_server_from_args
,
)
from
cacheflow.server.llm_server
import
LLMServer
from
cacheflow.server.llm_server
import
LLMServer
from
cacheflow.server.ray_utils
import
initialize_cluster
from
cacheflow.server.ray_utils
import
initialize_cluster
__all__
=
[
__all__
=
[
"
RequestOutput
"
,
"
LLM
"
,
"SamplingParams"
,
"SamplingParams"
,
"RequestOutput"
,
"LLMServer"
,
"LLMServer"
,
"add_server_arguments"
,
"ServerArgs"
,
"create_server_configs_from_args"
,
"initialize_server_from_args"
,
"initialize_cluster"
,
"initialize_cluster"
,
]
]
cacheflow/config.py
View file @
655a5e48
...
@@ -3,6 +3,8 @@ from typing import Optional
...
@@ -3,6 +3,8 @@ from typing import Optional
import
torch
import
torch
from
transformers
import
AutoConfig
,
PretrainedConfig
from
transformers
import
AutoConfig
,
PretrainedConfig
_GiB
=
1
<<
30
class
ModelConfig
:
class
ModelConfig
:
...
@@ -70,7 +72,7 @@ class CacheConfig:
...
@@ -70,7 +72,7 @@ class CacheConfig:
)
->
None
:
)
->
None
:
self
.
block_size
=
block_size
self
.
block_size
=
block_size
self
.
gpu_memory_utilization
=
gpu_memory_utilization
self
.
gpu_memory_utilization
=
gpu_memory_utilization
self
.
swap_space
=
swap_space
self
.
swap_space
_bytes
=
swap_space
*
_GiB
# Will be set after profiling.
# Will be set after profiling.
self
.
num_gpu_blocks
=
None
self
.
num_gpu_blocks
=
None
...
@@ -138,6 +140,8 @@ def _get_and_verify_dtype(
...
@@ -138,6 +140,8 @@ def _get_and_verify_dtype(
else
:
else
:
torch_dtype
=
config_dtype
torch_dtype
=
config_dtype
else
:
else
:
if
dtype
not
in
_STR_DTYPE_TO_TORCH_DTYPE
:
raise
ValueError
(
f
"Unknown dtype:
{
dtype
}
"
)
torch_dtype
=
_STR_DTYPE_TO_TORCH_DTYPE
[
dtype
]
torch_dtype
=
_STR_DTYPE_TO_TORCH_DTYPE
[
dtype
]
# Verify the dtype.
# Verify the dtype.
...
...
cacheflow/entrypoints/fastapi_server.py
View file @
655a5e48
...
@@ -12,8 +12,7 @@ import uvicorn
...
@@ -12,8 +12,7 @@ import uvicorn
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.server.arg_utils
import
(
from
cacheflow.server.arg_utils
import
ServerArgs
add_server_arguments
,
create_server_configs_from_args
)
from
cacheflow.server.llm_server
import
LLMServer
from
cacheflow.server.llm_server
import
LLMServer
from
cacheflow.server.ray_utils
import
initialize_cluster
from
cacheflow.server.ray_utils
import
initialize_cluster
...
@@ -116,10 +115,10 @@ if __name__ == "__main__":
...
@@ -116,10 +115,10 @@ if __name__ == "__main__":
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
10002
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
10002
)
parser
=
add_server_argument
s
(
parser
)
parser
=
ServerArgs
.
add_cli_arg
s
(
parser
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
server_configs
=
create_server_configs
_from_args
(
args
)
server_configs
=
ServerArgs
.
from_cli_args
(
args
).
create_server_configs
(
)
parallel_config
=
server_configs
[
2
]
parallel_config
=
server_configs
[
2
]
distributed_init_method
,
stage_devices
=
initialize_cluster
(
parallel_config
)
distributed_init_method
,
stage_devices
=
initialize_cluster
(
parallel_config
)
...
...
cacheflow/entrypoints/llm.py
0 → 100644
View file @
655a5e48
from
typing
import
List
,
Optional
from
tqdm
import
tqdm
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.server.arg_utils
import
ServerArgs
from
cacheflow.server.llm_server
import
LLMServer
from
cacheflow.utils
import
Counter
class
LLM
:
def
__init__
(
self
,
model
:
str
,
tensor_parallel_size
:
int
=
1
,
dtype
:
str
=
"default"
,
seed
:
int
=
0
,
**
kwargs
,
)
->
None
:
if
"disable_log_stats"
not
in
kwargs
:
kwargs
[
"disable_log_stats"
]
=
True
server_args
=
ServerArgs
(
model
=
model
,
tensor_parallel_size
=
tensor_parallel_size
,
dtype
=
dtype
,
seed
=
seed
,
**
kwargs
,
)
self
.
llm_server
=
LLMServer
.
from_server_args
(
server_args
)
self
.
request_counter
=
Counter
()
def
generate
(
self
,
prompts
:
List
[
str
],
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
use_tqdm
:
bool
=
True
,
)
->
List
[
RequestOutput
]:
if
sampling_params
is
None
:
sampling_params
=
SamplingParams
()
# Initialize tqdm.
if
use_tqdm
:
pbar
=
tqdm
(
total
=
len
(
prompts
),
desc
=
"Processed prompts"
)
# Add requests to the server.
for
prompt
in
prompts
:
request_id
=
str
(
next
(
self
.
request_counter
))
self
.
llm_server
.
add_request
(
request_id
,
prompt
,
sampling_params
)
# Run the server.
outputs
:
List
[
RequestOutput
]
=
[]
while
self
.
llm_server
.
has_unfinished_requests
():
step_outputs
=
self
.
llm_server
.
step
()
for
output
in
step_outputs
:
if
output
.
done
:
outputs
.
append
(
output
)
if
use_tqdm
:
pbar
.
update
(
1
)
if
use_tqdm
:
pbar
.
close
()
return
outputs
cacheflow/outputs.py
View file @
655a5e48
...
@@ -35,7 +35,7 @@ class RequestOutput:
...
@@ -35,7 +35,7 @@ class RequestOutput:
prompt
:
str
,
prompt
:
str
,
prompt_token_ids
:
List
[
int
],
prompt_token_ids
:
List
[
int
],
outputs
:
List
[
CompletionOutput
],
outputs
:
List
[
CompletionOutput
],
done
:
bool
=
False
,
done
:
bool
,
)
->
None
:
)
->
None
:
self
.
request_id
=
request_id
self
.
request_id
=
request_id
self
.
prompt
=
prompt
self
.
prompt
=
prompt
...
@@ -43,8 +43,8 @@ class RequestOutput:
...
@@ -43,8 +43,8 @@ class RequestOutput:
self
.
outputs
=
outputs
self
.
outputs
=
outputs
self
.
done
=
done
self
.
done
=
done
@
static
method
@
class
method
def
from_seq_group
(
seq_group
:
SequenceGroup
)
->
"RequestOutput"
:
def
from_seq_group
(
cls
,
seq_group
:
SequenceGroup
)
->
"RequestOutput"
:
# Get the top-n sequences.
# Get the top-n sequences.
n
=
seq_group
.
sampling_params
.
n
n
=
seq_group
.
sampling_params
.
n
seqs
=
seq_group
.
get_seqs
()
seqs
=
seq_group
.
get_seqs
()
...
@@ -70,8 +70,8 @@ class RequestOutput:
...
@@ -70,8 +70,8 @@ class RequestOutput:
# Every sequence in the sequence group should have the same prompt.
# Every sequence in the sequence group should have the same prompt.
prompt
=
top_n_seqs
[
0
].
prompt
prompt
=
top_n_seqs
[
0
].
prompt
prompt_token_ids
=
top_n_seqs
[
0
].
data
.
prompt_token_ids
prompt_token_ids
=
top_n_seqs
[
0
].
data
.
prompt_token_ids
return
RequestOutput
(
seq_group
.
request_id
,
prompt
,
prompt_token_ids
,
return
cls
(
seq_group
.
request_id
,
prompt
,
prompt_token_ids
,
outputs
,
outputs
,
seq_group
.
is_finished
())
seq_group
.
is_finished
())
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"RequestOutput(request_id=
{
self
.
request_id
}
, "
return
(
f
"RequestOutput(request_id=
{
self
.
request_id
}
, "
...
...
cacheflow/server/arg_utils.py
View file @
655a5e48
import
argparse
import
argparse
from
typing
import
Tuple
import
dataclasses
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
from
cacheflow.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
from
cacheflow.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
SchedulerConfig
)
from
cacheflow.server.llm_server
import
LLMServer
from
cacheflow.server.ray_utils
import
initialize_cluster
_GiB
=
1
<<
30
@
dataclass
class
ServerArgs
:
model
:
str
download_dir
:
Optional
[
str
]
=
None
use_np_weights
:
bool
=
False
use_dummy_weights
:
bool
=
False
dtype
:
str
=
"default"
seed
:
int
=
0
use_ray
:
bool
=
False
pipeline_parallel_size
:
int
=
1
tensor_parallel_size
:
int
=
1
block_size
:
int
=
16
swap_space
:
int
=
4
# GiB
gpu_memory_utilization
:
float
=
0.95
max_num_batched_tokens
:
int
=
2560
max_num_seqs
:
int
=
256
disable_log_stats
:
bool
=
False
def
add_server_arguments
(
parser
:
argparse
.
ArgumentParser
):
def
__post_init__
(
self
):
"""Shared arguments for CacheFlow servers."""
self
.
max_num_seqs
=
min
(
self
.
max_num_seqs
,
self
.
max_num_batched_tokens
)
@
staticmethod
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
,
)
->
argparse
.
ArgumentParser
:
return
_add_server_arguments
(
parser
)
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
)
->
"ServerArgs"
:
# Get the list of attributes of this dataclass.
attrs
=
[
attr
.
name
for
attr
in
dataclasses
.
fields
(
cls
)]
# Set the attributes from the parsed arguments.
server_args
=
cls
(
**
{
attr
:
getattr
(
args
,
attr
)
for
attr
in
attrs
})
return
server_args
def
create_server_configs
(
self
,
)
->
Tuple
[
ModelConfig
,
CacheConfig
,
ParallelConfig
,
SchedulerConfig
]:
# Initialize the configs.
model_config
=
ModelConfig
(
self
.
model
,
self
.
download_dir
,
self
.
use_np_weights
,
self
.
use_dummy_weights
,
self
.
dtype
,
self
.
seed
)
cache_config
=
CacheConfig
(
self
.
block_size
,
self
.
gpu_memory_utilization
,
self
.
swap_space
)
parallel_config
=
ParallelConfig
(
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
,
self
.
use_ray
)
scheduler_config
=
SchedulerConfig
(
self
.
max_num_batched_tokens
,
self
.
max_num_seqs
)
return
model_config
,
cache_config
,
parallel_config
,
scheduler_config
def
_add_server_arguments
(
parser
:
argparse
.
ArgumentParser
,
)
->
argparse
.
ArgumentParser
:
"""Shared CLI arguments for CacheFlow servers."""
# Model arguments
# Model arguments
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'facebook/opt-125m'
,
help
=
'model name'
)
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'facebook/opt-125m'
,
parser
.
add_argument
(
'--download-dir'
,
type
=
str
,
default
=
None
,
help
=
'name or path of the huggingface model to use'
)
parser
.
add_argument
(
'--download-dir'
,
type
=
str
,
default
=
ServerArgs
.
download_dir
,
help
=
'directory to download and load the weights, '
help
=
'directory to download and load the weights, '
'default to the default cache dir of huggingface'
)
'default to the default cache dir of huggingface'
)
parser
.
add_argument
(
'--use-np-weights'
,
action
=
'store_true'
,
parser
.
add_argument
(
'--use-np-weights'
,
action
=
'store_true'
,
help
=
'save a numpy copy of model weights for faster loading'
)
help
=
'save a numpy copy of model weights for faster '
parser
.
add_argument
(
'--use-dummy-weights'
,
action
=
'store_true'
,
help
=
'use dummy values for model weights'
)
'loading. This can increase the disk usage by up '
'to 2x.'
)
parser
.
add_argument
(
'--use-dummy-weights'
,
action
=
'store_true'
,
help
=
'use dummy values for model weights'
)
# TODO(woosuk): Support FP32.
# TODO(woosuk): Support FP32.
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'default'
,
choices
=
[
'default'
,
'half'
,
'bfloat16'
],
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
ServerArgs
.
dtype
,
choices
=
[
'default'
,
'half'
,
'bfloat16'
],
help
=
(
'data type for model weights and activations. '
help
=
(
'data type for model weights and activations. '
'The "default" option will use FP16 precision '
'The "default" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.'
))
'for BF16 models.'
))
# Parallel arguments
# Parallel arguments
parser
.
add_argument
(
'--use-ray'
,
action
=
'store_true'
,
help
=
'use Ray for distributed serving, will be automatically set when using more than 1 GPU'
)
parser
.
add_argument
(
'--use-ray'
,
action
=
'store_true'
,
parser
.
add_argument
(
'--pipeline-parallel-size'
,
'-pp'
,
type
=
int
,
default
=
1
,
help
=
'number of pipeline stages'
)
help
=
'use Ray for distributed serving, will be '
parser
.
add_argument
(
'--tensor-parallel-size'
,
'-tp'
,
type
=
int
,
default
=
1
,
help
=
'number of tensor parallel replicas'
)
'automatically set when using more than 1 GPU'
)
parser
.
add_argument
(
'--pipeline-parallel-size'
,
'-pp'
,
type
=
int
,
default
=
ServerArgs
.
pipeline_parallel_size
,
help
=
'number of pipeline stages'
)
parser
.
add_argument
(
'--tensor-parallel-size'
,
'-tp'
,
type
=
int
,
default
=
ServerArgs
.
tensor_parallel_size
,
help
=
'number of tensor parallel replicas'
)
# KV cache arguments
# KV cache arguments
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
default
=
16
,
choices
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
],
help
=
'token block size'
)
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
default
=
ServerArgs
.
block_size
,
choices
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
],
help
=
'token block size'
)
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
0
,
help
=
'random seed'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
ServerArgs
.
seed
,
parser
.
add_argument
(
'--swap-space'
,
type
=
int
,
default
=
4
,
help
=
'CPU swap space size (GiB) per GPU'
)
help
=
'random seed'
)
parser
.
add_argument
(
'--gpu-memory-utilization'
,
type
=
float
,
default
=
0.95
,
help
=
'the percentage of GPU memory to be used for the model executor'
)
parser
.
add_argument
(
'--swap-space'
,
type
=
int
,
default
=
ServerArgs
.
swap_space
,
parser
.
add_argument
(
'--max-num-batched-tokens'
,
type
=
int
,
default
=
2560
,
help
=
'maximum number of batched tokens per iteration'
)
help
=
'CPU swap space size (GiB) per GPU'
)
parser
.
add_argument
(
'--max-num-seqs'
,
type
=
int
,
default
=
256
,
help
=
'maximum number of sequences per iteration'
)
parser
.
add_argument
(
'--gpu-memory-utilization'
,
type
=
float
,
parser
.
add_argument
(
'--disable-log-stats'
,
action
=
'store_true'
,
help
=
'disable logging statistics'
)
default
=
ServerArgs
.
gpu_memory_utilization
,
help
=
'the percentage of GPU memory to be used for the '
'model executor'
)
parser
.
add_argument
(
'--max-num-batched-tokens'
,
type
=
int
,
default
=
ServerArgs
.
max_num_batched_tokens
,
help
=
'maximum number of batched tokens per iteration'
)
parser
.
add_argument
(
'--max-num-seqs'
,
type
=
int
,
default
=
ServerArgs
.
max_num_seqs
,
help
=
'maximum number of sequences per iteration'
)
parser
.
add_argument
(
'--disable-log-stats'
,
action
=
'store_true'
,
help
=
'disable logging statistics'
)
return
parser
return
parser
def
create_server_configs_from_args
(
args
:
argparse
.
Namespace
,
)
->
Tuple
[
ModelConfig
,
CacheConfig
,
ParallelConfig
,
SchedulerConfig
]:
# Post-process the parsed arguments.
args
.
swap_space
=
args
.
swap_space
*
_GiB
args
.
max_num_seqs
=
min
(
args
.
max_num_seqs
,
args
.
max_num_batched_tokens
)
# Initialize the configs.
model_config
=
ModelConfig
(
args
.
model
,
args
.
download_dir
,
args
.
use_np_weights
,
args
.
use_dummy_weights
,
args
.
dtype
,
args
.
seed
)
cache_config
=
CacheConfig
(
args
.
block_size
,
args
.
gpu_memory_utilization
,
args
.
swap_space
)
parallel_config
=
ParallelConfig
(
args
.
pipeline_parallel_size
,
args
.
tensor_parallel_size
,
args
.
use_ray
)
scheduler_config
=
SchedulerConfig
(
args
.
max_num_batched_tokens
,
args
.
max_num_seqs
)
return
model_config
,
cache_config
,
parallel_config
,
scheduler_config
def
initialize_server_from_args
(
args
:
argparse
.
Namespace
)
->
LLMServer
:
server_configs
=
create_server_configs_from_args
(
args
)
parallel_config
=
server_configs
[
2
]
# Initialize the cluster.
distributed_init_method
,
devices
=
initialize_cluster
(
parallel_config
)
# Create the LLM server.
server
=
LLMServer
(
*
server_configs
,
distributed_init_method
,
devices
,
log_stats
=
not
args
.
disable_log_stats
)
return
server
cacheflow/server/llm_server.py
View file @
655a5e48
...
@@ -12,6 +12,8 @@ from cacheflow.core.scheduler import Scheduler
...
@@ -12,6 +12,8 @@ from cacheflow.core.scheduler import Scheduler
from
cacheflow.logger
import
init_logger
from
cacheflow.logger
import
init_logger
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.server.arg_utils
import
ServerArgs
from
cacheflow.server.ray_utils
import
initialize_cluster
from
cacheflow.server.tokenizer_utils
import
get_tokenizer
from
cacheflow.server.tokenizer_utils
import
get_tokenizer
from
cacheflow.sequence
import
Sequence
,
SequenceGroup
,
SequenceStatus
from
cacheflow.sequence
import
Sequence
,
SequenceGroup
,
SequenceStatus
from
cacheflow.utils
import
Counter
from
cacheflow.utils
import
Counter
...
@@ -30,7 +32,7 @@ class LLMServer:
...
@@ -30,7 +32,7 @@ class LLMServer:
scheduler_config
:
SchedulerConfig
,
scheduler_config
:
SchedulerConfig
,
distributed_init_method
:
str
,
distributed_init_method
:
str
,
stage_devices
:
List
[
List
[
Any
]],
stage_devices
:
List
[
List
[
Any
]],
log_stats
:
bool
=
True
,
log_stats
:
bool
,
)
->
None
:
)
->
None
:
logger
.
info
(
logger
.
info
(
"Initializing an LLM server with config: "
"Initializing an LLM server with config: "
...
@@ -90,7 +92,7 @@ class LLMServer:
...
@@ -90,7 +92,7 @@ class LLMServer:
get_all_outputs
=
True
,
get_all_outputs
=
True
,
block_size
=
self
.
cache_config
.
block_size
,
block_size
=
self
.
cache_config
.
block_size
,
gpu_memory_utilization
=
self
.
cache_config
.
gpu_memory_utilization
,
gpu_memory_utilization
=
self
.
cache_config
.
gpu_memory_utilization
,
cpu_swap_space
=
self
.
cache_config
.
swap_space
,
cpu_swap_space
=
self
.
cache_config
.
swap_space
_bytes
,
)
)
# Since we use a shared centralized controller, we take the minimum
# Since we use a shared centralized controller, we take the minimum
...
@@ -107,6 +109,18 @@ class LLMServer:
...
@@ -107,6 +109,18 @@ class LLMServer:
# Initialize the cache.
# Initialize the cache.
self
.
_run_workers
(
"init_cache_engine"
,
cache_config
=
self
.
cache_config
)
self
.
_run_workers
(
"init_cache_engine"
,
cache_config
=
self
.
cache_config
)
@
classmethod
def
from_server_args
(
cls
,
server_args
:
ServerArgs
)
->
"LLMServer"
:
# Create the server configs.
server_configs
=
server_args
.
create_server_configs
()
parallel_config
=
server_configs
[
2
]
# Initialize the cluster.
distributed_init_method
,
devices
=
initialize_cluster
(
parallel_config
)
# Create the LLM server.
server
=
cls
(
*
server_configs
,
distributed_init_method
,
devices
,
log_stats
=
not
server_args
.
disable_log_stats
)
return
server
def
add_request
(
def
add_request
(
self
,
self
,
request_id
:
str
,
request_id
:
str
,
...
...
examples/offline_inference.py
0 → 100644
View file @
655a5e48
from
cacheflow
import
LLM
,
SamplingParams
# Sample prompts.
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
# Create a sampling params object.
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
)
# Create an LLM.
llm
=
LLM
(
model
=
"facebook/opt-125m"
)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
# Print the outputs.
for
output
in
outputs
:
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
examples/simple_server.py
View file @
655a5e48
import
argparse
import
argparse
import
uuid
import
uuid
from
cacheflow
import
(
add_server_arguments
,
initialize_server_from_args
,
from
cacheflow
import
ServerArgs
,
LLMServer
,
SamplingParams
SamplingParams
)
def
main
(
args
:
argparse
.
Namespace
):
def
main
(
args
:
argparse
.
Namespace
):
# Initialize the server.
# Parse the CLI argument and initialize the server.
server
=
initialize_server_from_args
(
args
)
server_args
=
ServerArgs
.
from_cli_args
(
args
)
server
=
LLMServer
.
from_server_args
(
server_args
)
# Test the following prompts.
# Test the following prompts.
test_prompts
=
[
test_prompts
=
[
...
@@ -39,6 +39,6 @@ def main(args: argparse.Namespace):
...
@@ -39,6 +39,6 @@ def main(args: argparse.Namespace):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'Simple CacheFlow server.'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'Simple CacheFlow server.'
)
parser
=
add_server_argument
s
(
parser
)
parser
=
ServerArgs
.
add_cli_arg
s
(
parser
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
)
main
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment