Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
c3442c1f
Unverified
Commit
c3442c1f
authored
May 20, 2023
by
Woosuk Kwon
Committed by
GitHub
May 20, 2023
Browse files
Refactor system architecture (#109)
parent
7297fa6f
Changes
23
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
902 additions
and
704 deletions
+902
-704
README.md
README.md
+8
-5
cacheflow/__init__.py
cacheflow/__init__.py
+19
-0
cacheflow/config.py
cacheflow/config.py
+165
-0
cacheflow/core/scheduler.py
cacheflow/core/scheduler.py
+80
-84
cacheflow/core/server.py
cacheflow/core/server.py
+0
-302
cacheflow/entrypoints/fastapi_server.py
cacheflow/entrypoints/fastapi_server.py
+128
-0
cacheflow/frontend/simple_frontend.py
cacheflow/frontend/simple_frontend.py
+0
-72
cacheflow/model_executor/__init__.py
cacheflow/model_executor/__init__.py
+1
-3
cacheflow/model_executor/layers/attention.py
cacheflow/model_executor/layers/attention.py
+1
-1
cacheflow/model_executor/model_loader.py
cacheflow/model_executor/model_loader.py
+12
-43
cacheflow/model_executor/utils.py
cacheflow/model_executor/utils.py
+0
-35
cacheflow/outputs.py
cacheflow/outputs.py
+79
-0
cacheflow/sampling_params.py
cacheflow/sampling_params.py
+1
-1
cacheflow/sequence.py
cacheflow/sequence.py
+9
-6
cacheflow/server/arg_utils.py
cacheflow/server/arg_utils.py
+74
-0
cacheflow/server/llm_server.py
cacheflow/server/llm_server.py
+198
-0
cacheflow/server/ray_utils.py
cacheflow/server/ray_utils.py
+90
-0
cacheflow/server/tokenizer_utils.py
cacheflow/server/tokenizer_utils.py
+0
-1
cacheflow/worker/cache_engine.py
cacheflow/worker/cache_engine.py
+37
-21
cacheflow/worker/controller.py
cacheflow/worker/controller.py
+0
-130
No files found.
README.md
View file @
c3442c1f
...
...
@@ -10,13 +10,17 @@ pip install -e . # This may take several minutes.
## Test simple server
```
bash
# Single-GPU inference.
python examples/simple_server.py
# --model <your_model>
# Multi-GPU inference (e.g., 2 GPUs).
ray start
--head
python simple_server.py
python
examples/
simple_server.py
-tp
2
# --model <your_model>
```
The detailed arguments for
`simple_server.py`
can be found by:
```
bash
python simple_server.py
--help
python
examples/
simple_server.py
--help
```
## FastAPI server
...
...
@@ -24,12 +28,12 @@ python simple_server.py --help
To start the server:
```
bash
ray start
--head
python
-m
cacheflow.
http_frontend.fastapi_frontend
python
-m
cacheflow.
entrypoints.fastapi_server
# --model <your_model>
```
To test the server:
```
bash
python
-m
cacheflow.http_frontend.
test_cli_client
python test_cli_client
.py
```
## Gradio web server
...
...
@@ -55,7 +59,6 @@ Since LLaMA weight is not fully public, we cannot directly download the LLaMA we
python src/transformers/models/llama/convert_llama_weights_to_hf.py
\
--input_dir
/path/to/downloaded/llama/weights
--model_size
7B
--output_dir
/output/path/llama-7b
```
Please make sure that
`llama`
is included in the output directory name.
2.
For all the commands above, specify the model with
`--model /output/path/llama-7b`
to load the model. For example:
```
bash
python simple_server.py
--model
/output/path/llama-7b
...
...
cacheflow/__init__.py
0 → 100644
View file @
c3442c1f
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.server.arg_utils
import
(
add_server_arguments
,
create_server_configs_from_args
,
initialize_server_from_args
,
)
from
cacheflow.server.llm_server
import
LLMServer
from
cacheflow.server.ray_utils
import
initialize_cluster
__all__
=
[
"RequestOutput"
,
"SamplingParams"
,
"LLMServer"
,
"add_server_arguments"
,
"create_server_configs_from_args"
,
"initialize_server_from_args"
,
"initialize_cluster"
,
]
cacheflow/config.py
0 → 100644
View file @
c3442c1f
from
typing
import
Optional
import
torch
from
transformers
import
AutoConfig
,
PretrainedConfig
class
ModelConfig
:
def
__init__
(
self
,
model
:
str
,
download_dir
:
Optional
[
str
],
use_np_weights
:
bool
,
use_dummy_weights
:
bool
,
dtype
:
str
,
seed
:
int
,
)
->
None
:
self
.
model
=
model
self
.
download_dir
=
download_dir
self
.
use_np_weights
=
use_np_weights
self
.
use_dummy_weights
=
use_dummy_weights
self
.
seed
=
seed
self
.
hf_config
:
PretrainedConfig
=
AutoConfig
.
from_pretrained
(
model
)
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_config
,
dtype
)
def
verify_with_parallel_config
(
self
,
parallel_config
:
"ParallelConfig"
,
)
->
None
:
total_num_attention_heads
=
self
.
hf_config
.
num_attention_heads
tensor_parallel_size
=
parallel_config
.
tensor_parallel_size
if
total_num_attention_heads
%
tensor_parallel_size
!=
0
:
raise
ValueError
(
f
"Total number of attention heads (
{
total_num_attention_heads
}
)"
" must be divisible by tensor parallel size "
f
"(
{
tensor_parallel_size
}
)."
)
total_num_hidden_layers
=
self
.
hf_config
.
num_hidden_layers
pipeline_parallel_size
=
parallel_config
.
pipeline_parallel_size
if
total_num_hidden_layers
%
pipeline_parallel_size
!=
0
:
raise
ValueError
(
f
"Total number of hidden layers (
{
total_num_hidden_layers
}
) "
"must be divisible by pipeline parallel size "
f
"(
{
pipeline_parallel_size
}
)."
)
def
get_hidden_size
(
self
)
->
int
:
return
self
.
hf_config
.
hidden_size
def
get_head_size
(
self
)
->
int
:
# FIXME(woosuk): This may not be true for all models.
return
self
.
hf_config
.
hidden_size
//
self
.
hf_config
.
num_attention_heads
def
get_num_heads
(
self
,
parallel_config
:
"ParallelConfig"
)
->
int
:
total_num_attention_heads
=
self
.
hf_config
.
num_attention_heads
return
total_num_attention_heads
//
parallel_config
.
tensor_parallel_size
def
get_num_layers
(
self
,
parallel_config
:
"ParallelConfig"
)
->
int
:
total_num_hidden_layers
=
self
.
hf_config
.
num_hidden_layers
return
total_num_hidden_layers
//
parallel_config
.
pipeline_parallel_size
class
CacheConfig
:
def
__init__
(
self
,
block_size
:
int
,
gpu_memory_utilization
:
float
,
swap_space
:
int
,
)
->
None
:
self
.
block_size
=
block_size
self
.
gpu_memory_utilization
=
gpu_memory_utilization
self
.
swap_space
=
swap_space
# Will be set after profiling.
self
.
num_gpu_blocks
=
None
self
.
num_cpu_blocks
=
None
class
ParallelConfig
:
def
__init__
(
self
,
pipeline_parallel_size
:
int
,
tensor_parallel_size
:
int
,
use_ray
:
bool
,
)
->
None
:
self
.
pipeline_parallel_size
=
pipeline_parallel_size
self
.
tensor_parallel_size
=
tensor_parallel_size
self
.
use_ray
=
use_ray
self
.
world_size
=
pipeline_parallel_size
*
tensor_parallel_size
if
self
.
world_size
>
1
:
self
.
use_ray
=
True
self
.
_verify_args
()
def
_verify_args
(
self
)
->
None
:
if
self
.
pipeline_parallel_size
>
1
:
raise
NotImplementedError
(
"Pipeline parallelism is not supported yet."
)
class
SchedulerConfig
:
def
__init__
(
self
,
max_num_batched_tokens
:
int
,
max_num_seqs
:
int
,
)
->
None
:
self
.
max_num_batched_tokens
=
max_num_batched_tokens
self
.
max_num_seqs
=
max_num_seqs
_STR_DTYPE_TO_TORCH_DTYPE
=
{
"half"
:
torch
.
float16
,
"float16"
:
torch
.
float16
,
"float"
:
torch
.
float32
,
"float32"
:
torch
.
float32
,
"bfloat16"
:
torch
.
bfloat16
,
}
def
_get_and_verify_dtype
(
config
:
PretrainedConfig
,
dtype
:
str
,
)
->
torch
.
dtype
:
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype
=
getattr
(
config
,
"torch_dtype"
,
None
)
if
config_dtype
is
None
:
config_dtype
=
torch
.
float32
dtype
=
dtype
.
lower
()
if
dtype
==
"default"
:
if
config_dtype
==
torch
.
float32
:
# Following the common practice, we use float16 for float32 models.
torch_dtype
=
torch
.
float16
else
:
torch_dtype
=
config_dtype
else
:
torch_dtype
=
_STR_DTYPE_TO_TORCH_DTYPE
[
dtype
]
# Verify the dtype.
if
torch_dtype
!=
config_dtype
:
if
torch_dtype
==
torch
.
float32
:
# Upcasting to float32 is allowed.
pass
elif
config_dtype
==
torch
.
float32
:
# Downcasting from float32 to float16 or bfloat16 is allowed.
pass
else
:
# Casting between float16 and bfloat16 is not allowed.
raise
ValueError
(
f
"Cannot use
{
torch_dtype
}
for
{
config_dtype
}
model."
)
# Check if the GPU supports the dtype.
if
torch_dtype
==
torch
.
bfloat16
:
compute_capability
=
torch
.
cuda
.
get_device_capability
()
if
compute_capability
[
0
]
<
8
:
gpu_name
=
torch
.
cuda
.
get_device_name
()
raise
ValueError
(
"Bfloat16 is only supported on GPUs with compute capability "
f
"of at least 8.0. Your
{
gpu_name
}
GPU has compute capability "
f
"
{
compute_capability
[
0
]
}
.
{
compute_capability
[
1
]
}
."
)
return
torch_dtype
cacheflow/core/scheduler.py
View file @
c3442c1f
...
...
@@ -2,10 +2,10 @@ import enum
import
time
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
cacheflow.config
import
CacheConfig
,
SchedulerConfig
from
cacheflow.core.block_manager
import
BlockSpaceManager
from
cacheflow.core.policy
import
PolicyFactory
from
cacheflow.logger
import
init_logger
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sequence
import
(
Sequence
,
SequenceData
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceOutputs
,
SequenceStatus
)
...
...
@@ -28,43 +28,53 @@ class PreemptionMode(enum.Enum):
RECOMPUTE
=
enum
.
auto
()
class
SchedulerOutputs
:
def
__init__
(
self
,
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
)
->
None
:
self
.
blocks_to_swap_in
=
blocks_to_swap_in
self
.
blocks_to_swap_out
=
blocks_to_swap_out
self
.
blocks_to_copy
=
blocks_to_copy
# Swap in and swap out should never happen at the same time.
assert
not
(
blocks_to_swap_in
and
blocks_to_swap_out
)
def
is_empty
(
self
)
->
bool
:
return
(
not
self
.
blocks_to_swap_in
and
not
self
.
blocks_to_swap_out
and
not
self
.
blocks_to_copy
)
class
Scheduler
:
def
__init__
(
self
,
controllers
:
List
,
block_size
:
int
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
,
max_num_batched_tokens
:
int
,
max_num_sequences
:
int
,
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
,
log_stats
:
bool
,
)
->
None
:
self
.
controllers
=
controllers
self
.
block_size
=
block_size
self
.
num_gpu_blocks
=
num_gpu_blocks
self
.
num_cpu_blocks
=
num_cpu_blocks
self
.
max_num_batched_tokens
=
max_num_batched_tokens
self
.
max_num_sequences
=
max_num_sequences
self
.
scheduler_config
=
scheduler_config
self
.
cache_config
=
cache_config
self
.
log_stats
=
log_stats
# Instantiate the scheduling policy.
self
.
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
'fcfs'
)
# Create the block space manager.
self
.
block_manager
=
BlockSpaceManager
(
block_size
=
block_size
,
num_gpu_blocks
=
num_gpu_blocks
,
num_cpu_blocks
=
num_cpu_blocks
,
block_size
=
self
.
cache_config
.
block_size
,
num_gpu_blocks
=
self
.
cache_config
.
num_gpu_blocks
,
num_cpu_blocks
=
self
.
cache_config
.
num_cpu_blocks
,
)
# Sequence groups in the WAITING state.
self
.
waiting
:
List
[
SequenceGroup
]
=
[]
# Sequence groups in the RUNNING state.
self
.
running
:
List
[
SequenceGroup
]
=
[]
# Mapping: group_id -> num_steps.
self
.
num_steps
:
Dict
[
int
,
int
]
=
{}
# Mapping: group_id -> sampling params.
self
.
sampling_params
:
Dict
[
int
,
SamplingParams
]
=
{}
# Mapping: request_id -> num_steps.
self
.
num_steps
:
Dict
[
str
,
int
]
=
{}
# Sequence groups in the SWAPPED state.
self
.
swapped
:
List
[
SequenceGroup
]
=
[]
...
...
@@ -72,18 +82,15 @@ class Scheduler:
# List[timestamp, num_tokens]
self
.
num_input_tokens
:
List
[
Tuple
[
float
,
int
]]
=
[]
def
add_sequence_groups
(
self
,
seq_groups
:
List
[
Tuple
[
SequenceGroup
,
SamplingParams
]],
)
->
None
:
def
add_seq_group
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
# Add sequence groups to the waiting queue.
for
seq_group
,
sampling_params
in
seq_groups
:
self
.
waiting
.
append
(
seq_group
)
self
.
sampling_params
[
seq_group
.
group_id
]
=
sampling_params
assert
seq_group
.
request_id
not
in
self
.
num_steps
self
.
waiting
.
append
(
seq_group
)
def
_schedule
(
self
,
)
->
Tuple
[
Dict
[
int
,
int
],
Dict
[
int
,
int
],
Dict
[
int
,
List
[
int
]],
List
[
int
]]:
def
has_unfinished_seqs
(
self
)
->
bool
:
return
self
.
waiting
or
self
.
running
or
self
.
swapped
def
_schedule
(
self
)
->
Tuple
[
SchedulerOutputs
,
List
[
int
]]:
# Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in
:
Dict
[
int
,
int
]
=
{}
blocks_to_swap_out
:
Dict
[
int
,
int
]
=
{}
...
...
@@ -136,8 +143,9 @@ class Scheduler:
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_seqs
=
seq_group
.
num_seqs
(
status
=
SequenceStatus
.
SWAPPED
)
if
len
(
self
.
running
)
+
num_seqs
>
self
.
max_num_sequences
:
num_new_seqs
=
seq_group
.
num_seqs
(
status
=
SequenceStatus
.
SWAPPED
)
num_curr_seqs
=
len
(
self
.
running
)
if
num_curr_seqs
+
num_new_seqs
>
self
.
scheduler_config
.
max_num_seqs
:
break
seq_group
=
self
.
swapped
.
pop
(
0
)
...
...
@@ -151,7 +159,7 @@ class Scheduler:
)
# Join waiting sequences if possible.
prompt_group_ids
:
List
[
int
]
=
[]
prompt_group_ids
:
List
[
str
]
=
[]
# NOTE(woosuk): The sequence groups in the SWAPPED state are strictly
# prioritized over the sequence groups in the WAITING state.
# This is because we want to bound the amount of CPU memory taken by
...
...
@@ -172,25 +180,31 @@ class Scheduler:
# If the number of batched tokens exceeds the limit, stop.
num_prompt_tokens
=
seq_group
.
seqs
[
0
].
get_len
()
if
(
num_batched_tokens
+
num_prompt_tokens
>
self
.
max_num_batched_tokens
):
>
self
.
scheduler_config
.
max_num_batched_tokens
):
break
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_seqs
=
seq_group
.
num_seqs
(
status
=
SequenceStatus
.
WAITING
)
if
len
(
self
.
running
)
+
num_seqs
>
self
.
max_num_sequences
:
num_new_seqs
=
seq_group
.
num_seqs
(
status
=
SequenceStatus
.
WAITING
)
num_curr_seqs
=
len
(
self
.
running
)
if
num_curr_seqs
+
num_new_seqs
>
self
.
scheduler_config
.
max_num_seqs
:
break
seq_group
=
self
.
waiting
.
pop
(
0
)
self
.
_allocate
(
seq_group
)
self
.
running
.
append
(
seq_group
)
num_batched_tokens
+=
num_prompt_tokens
prompt_group_ids
.
append
(
seq_group
.
group
_id
)
prompt_group_ids
.
append
(
seq_group
.
request
_id
)
scheduler_outputs
=
SchedulerOutputs
(
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
)
if
not
self
.
log_stats
:
return
(
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
,
prompt_group_ids
)
return
scheduler_outputs
,
prompt_group_ids
# TODO(woosuk): Move the below code to server.
now
=
time
.
time
()
if
num_batched_tokens
>
0
:
self
.
num_input_tokens
.
append
((
now
,
num_batched_tokens
))
...
...
@@ -208,13 +222,16 @@ class Scheduler:
else
:
avg_throughput
=
0.0
total_num_gpu_blocks
=
self
.
cache_config
.
num_gpu_blocks
num_free_gpu_blocks
=
self
.
block_manager
.
get_num_free_gpu_blocks
()
num_used_gpu_blocks
=
self
.
num_gpu_blocks
-
num_free_gpu_blocks
gpu_cache_usage
=
num_used_gpu_blocks
/
self
.
num_gpu_blocks
if
self
.
num_cpu_blocks
>
0
:
num_used_gpu_blocks
=
total_num_gpu_blocks
-
num_free_gpu_blocks
gpu_cache_usage
=
num_used_gpu_blocks
/
total_num_gpu_blocks
total_num_cpu_blocks
=
self
.
cache_config
.
num_cpu_blocks
if
total_num_cpu_blocks
>
0
:
num_free_cpu_blocks
=
self
.
block_manager
.
get_num_free_cpu_blocks
()
num_used_cpu_blocks
=
self
.
num_cpu_blocks
-
num_free_cpu_blocks
cpu_cache_usage
=
num_used_cpu_blocks
/
self
.
num_cpu_blocks
num_used_cpu_blocks
=
total_
num_cpu_blocks
-
num_free_cpu_blocks
cpu_cache_usage
=
num_used_cpu_blocks
/
total_
num_cpu_blocks
else
:
cpu_cache_usage
=
0.0
...
...
@@ -225,27 +242,18 @@ class Scheduler:
f
"Pending:
{
len
(
self
.
waiting
)
}
reqs, "
f
"GPU KV cache usage:
{
gpu_cache_usage
*
100
:.
1
f
}
%, "
f
"CPU KV cache usage:
{
cpu_cache_usage
*
100
:.
1
f
}
%"
)
return
scheduler_outputs
,
prompt_group_ids
return
(
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
,
prompt_group_ids
)
def
step
(
self
)
->
List
[
SequenceGroup
]:
def
schedule
(
self
)
->
Tuple
[
List
[
SequenceGroupMetadata
],
SchedulerOutputs
]:
# Schedule sequence groups.
# This function call changes the internal states of the scheduler
# such as self.running, self.swapped, and self.waiting.
scheduler_output
=
self
.
_schedule
()
blocks_to_swap_in
=
scheduler_output
[
0
]
blocks_to_swap_out
=
scheduler_output
[
1
]
blocks_to_copy
=
scheduler_output
[
2
]
prompt_group_ids
=
scheduler_output
[
3
]
scheduler_outputs
,
prompt_group_ids
=
self
.
_schedule
()
# Create input data structures.
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
updated_seq_groups
:
List
[
SequenceGroup
]
=
self
.
running
.
copy
()
for
seq_group
in
self
.
running
:
group_id
=
seq_group
.
group_id
is_prompt
=
group_id
in
prompt_group_ids
is_prompt
=
seq_group
.
request_id
in
prompt_group_ids
seq_data
:
Dict
[
int
,
List
[
SequenceData
]]
=
{}
block_tables
:
Dict
[
int
,
List
[
int
]]
=
{}
...
...
@@ -255,36 +263,24 @@ class Scheduler:
block_tables
[
seq_id
]
=
self
.
block_manager
.
get_block_table
(
seq
)
seq_group_metadata
=
SequenceGroupMetadata
(
group_id
=
group
_id
,
request_id
=
seq_group
.
request
_id
,
is_prompt
=
is_prompt
,
seq_data
=
seq_data
,
sampling_params
=
se
lf
.
sampling_params
[
group_id
]
,
sampling_params
=
se
q_group
.
sampling_params
,
block_tables
=
block_tables
,
)
seq_group_metadata_list
.
append
(
seq_group_metadata
)
return
seq_group_metadata_list
,
scheduler_outputs
# Execute the first stage of the pipeline.
if
seq_group_metadata_list
or
blocks_to_swap_in
or
blocks_to_swap_out
:
# Swap in and swap out should never happen at the same time.
assert
not
(
blocks_to_swap_in
and
blocks_to_swap_out
)
self
.
controllers
[
0
].
execute_stage
(
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
)
return
updated_seq_groups
def
post_step
(
def
update
(
self
,
seq_outputs
:
Dict
[
int
,
SequenceOutputs
],
)
->
None
:
)
->
List
[
SequenceGroup
]
:
# Update the running sequences and free blocks.
for
seq_group
in
self
.
running
:
group
_id
=
seq_group
.
group
_id
self
.
num_steps
[
group
_id
]
+=
1
stop_token_ids
=
se
lf
.
sampling_params
[
group_id
]
.
stop_token_ids
request
_id
=
seq_group
.
request
_id
self
.
num_steps
[
request
_id
]
+=
1
stop_token_ids
=
se
q_group
.
sampling_params
.
stop_token_ids
# Process beam search results before processing the next tokens.
for
seq
in
seq_group
.
seqs
:
...
...
@@ -316,12 +312,13 @@ class Scheduler:
continue
# Check if the sequence has reached the maximum number of steps.
max_num_steps
=
se
lf
.
sampling_params
[
group_id
]
.
max_tokens
if
self
.
num_steps
[
group
_id
]
==
max_num_steps
:
max_num_steps
=
se
q_group
.
sampling_params
.
max_tokens
if
self
.
num_steps
[
request
_id
]
==
max_num_steps
:
self
.
_free_seq
(
seq
)
continue
# Update the running sequences.
updated
=
self
.
running
.
copy
()
running
:
List
[
SequenceGroup
]
=
[]
for
seq_group
in
self
.
running
:
if
seq_group
.
is_finished
():
...
...
@@ -329,13 +326,14 @@ class Scheduler:
else
:
running
.
append
(
seq_group
)
self
.
running
=
running
return
updated
def
_allocate
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
self
.
block_manager
.
allocate
(
seq_group
)
for
seq
in
seq_group
.
seqs
:
seq
.
status
=
SequenceStatus
.
RUNNING
if
seq_group
.
group
_id
not
in
self
.
num_steps
:
self
.
num_steps
[
seq_group
.
group
_id
]
=
0
if
seq_group
.
request
_id
not
in
self
.
num_steps
:
self
.
num_steps
[
seq_group
.
request
_id
]
=
0
def
_append_slot
(
self
,
...
...
@@ -410,9 +408,7 @@ class Scheduler:
self
.
block_manager
.
free
(
seq
)
def
_free_seq_group
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
group_id
=
seq_group
.
group_id
del
self
.
num_steps
[
group_id
]
del
self
.
sampling_params
[
group_id
]
del
self
.
num_steps
[
seq_group
.
request_id
]
def
_swap_in
(
self
,
...
...
cacheflow/core/server.py
deleted
100644 → 0
View file @
7297fa6f
import
argparse
import
random
from
typing
import
List
,
Optional
,
Tuple
try
:
import
ray
except
ImportError
:
ray
=
None
import
numpy
as
np
import
torch
from
cacheflow.core.scheduler
import
Scheduler
from
cacheflow.frontend.simple_frontend
import
SimpleFrontend
from
cacheflow.logger
import
init_logger
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sequence
import
SequenceGroup
from
cacheflow.worker.controller
import
Controller
,
DeviceID
logger
=
init_logger
(
__name__
)
class
Server
:
def
__init__
(
self
,
model
:
str
,
cache_dir
:
Optional
[
str
],
use_dummy_weights
:
bool
,
use_np_cache
:
bool
,
pipeline_parallel_size
:
int
,
tensor_parallel_size
:
int
,
block_size
:
int
,
dtype
:
str
,
seed
:
int
,
swap_space
:
int
,
gpu_memory_utilization
:
float
,
max_num_batched_tokens
:
int
,
max_num_sequences
:
int
,
num_nodes
:
int
,
num_devices_per_node
:
int
,
distributed_init_method
:
str
,
all_stage_devices
:
List
[
List
[
DeviceID
]],
use_ray
:
bool
,
log_stats
:
bool
,
):
logger
.
info
(
"Initializing a server with config: "
f
"model=
{
model
!
r
}
, "
f
"dtype=
{
dtype
}
, "
f
"use_dummy_weights=
{
use_dummy_weights
}
, "
f
"cache_dir=
{
cache_dir
!
r
}
, "
f
"use_np_cache=
{
use_np_cache
}
, "
f
"tensor_parallel_size=
{
tensor_parallel_size
}
, "
f
"seed=
{
seed
}
)"
)
self
.
num_nodes
=
num_nodes
self
.
num_devices_per_node
=
num_devices_per_node
self
.
world_size
=
pipeline_parallel_size
*
tensor_parallel_size
if
not
use_ray
:
assert
self
.
world_size
==
1
,
(
"Only support single GPU without Ray."
)
# Create a controller for each pipeline stage.
self
.
controllers
:
List
[
Controller
]
=
[]
for
i
in
range
(
pipeline_parallel_size
):
controller
=
Controller
(
stage_id
=
i
,
stage_devices
=
all_stage_devices
[
i
],
world_size
=
self
.
world_size
,
pipeline_parallel_size
=
pipeline_parallel_size
,
tensor_parallel_size
=
tensor_parallel_size
,
distributed_init_method
=
distributed_init_method
,
model_name
=
model
,
dtype
=
dtype
,
seed
=
seed
,
cache_dir
=
cache_dir
,
use_dummy_weights
=
use_dummy_weights
,
use_np_cache
=
use_np_cache
,
max_num_batched_tokens
=
max_num_batched_tokens
,
max_num_sequences
=
max_num_sequences
,
use_ray
=
use_ray
,
)
self
.
controllers
.
append
(
controller
)
# Initialize cache engine.
all_worker_num_available_blocks
=
[]
for
controller
in
self
.
controllers
:
all_worker_num_available_blocks
.
extend
(
controller
.
get_num_available_blocks
(
block_size
,
swap_space
,
gpu_memory_utilization
)
)
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# operators can be applied to all workers.
self
.
num_gpu_blocks
=
np
.
min
([
b
[
0
]
for
b
in
all_worker_num_available_blocks
])
self
.
num_cpu_blocks
=
np
.
min
([
b
[
1
]
for
b
in
all_worker_num_available_blocks
])
logger
.
info
(
f
'# GPU blocks:
{
self
.
num_gpu_blocks
}
, '
f
'# CPU blocks:
{
self
.
num_cpu_blocks
}
'
)
for
controller
in
self
.
controllers
:
controller
.
init_cache_engine
(
block_size
,
self
.
num_gpu_blocks
,
self
.
num_cpu_blocks
)
# Create a scheduler.
self
.
scheduler
=
Scheduler
(
controllers
=
self
.
controllers
,
block_size
=
block_size
,
num_gpu_blocks
=
self
.
num_gpu_blocks
,
num_cpu_blocks
=
self
.
num_cpu_blocks
,
max_num_batched_tokens
=
max_num_batched_tokens
,
max_num_sequences
=
max_num_sequences
,
log_stats
=
log_stats
,
)
# Connect the controllers.
for
i
in
range
(
len
(
self
.
controllers
)
-
1
):
self
.
controllers
[
i
].
set_next
(
self
.
controllers
[
i
+
1
])
self
.
controllers
[
-
1
].
set_next
(
self
.
scheduler
)
def
add_sequence_groups
(
self
,
sequence_groups
:
List
[
Tuple
[
SequenceGroup
,
SamplingParams
]]
):
self
.
scheduler
.
add_sequence_groups
(
sequence_groups
)
def
step
(
self
):
return
self
.
scheduler
.
step
()
def
has_unfinished_requests
(
self
):
return
(
self
.
scheduler
.
waiting
or
self
.
scheduler
.
running
or
self
.
scheduler
.
swapped
)
def
initialize_cluster
(
use_ray
:
bool
=
False
,
address
:
Optional
[
str
]
=
None
,
pipeline_parallel_size
:
int
=
1
,
tensor_parallel_size
:
int
=
1
,
)
->
Tuple
[
int
,
int
,
str
,
List
[
List
[
DeviceID
]]]:
# Initialize cluster locally.
if
not
use_ray
:
assert
pipeline_parallel_size
*
tensor_parallel_size
==
1
,
(
"Only support single GPU without Ray."
)
num_nodes
=
1
num_devices_per_node
=
torch
.
cuda
.
device_count
()
port
=
random
.
randint
(
10000
,
20000
)
# We need to setup the distributed init method to make sure
# the distributed megatron code (e.g., get world size) works correctly.
distributed_init_method
=
f
"tcp://localhost:
{
port
}
"
all_stage_devices
=
[[(
0
,
None
,
0
)]]
return
(
num_nodes
,
num_devices_per_node
,
distributed_init_method
,
all_stage_devices
)
assert
ray
is
not
None
,
(
"Ray is not installed. Please install Ray to use distributed "
"serving."
)
# Connect to a ray cluster.
ray
.
init
(
address
=
address
)
# Assume we have a uniform cluster that each node has the same number of
# GPUs for now.
valid_node_resources
=
[]
num_devices_per_node
=
None
for
node
in
ray
.
nodes
():
if
(
not
node
[
'Alive'
])
or
node
[
'Resources'
][
'GPU'
]
<=
0
:
continue
if
num_devices_per_node
is
None
:
num_devices_per_node
=
node
[
'Resources'
][
'GPU'
]
else
:
assert
num_devices_per_node
==
node
[
'Resources'
][
'GPU'
],
(
"The number of GPUs per node is not uniform."
)
for
key
in
node
[
'Resources'
]:
if
key
.
startswith
(
'node:'
):
valid_node_resources
.
append
(
key
)
num_nodes
=
len
(
valid_node_resources
)
assert
(
pipeline_parallel_size
*
tensor_parallel_size
<=
num_nodes
*
num_devices_per_node
),
(
"The number of required GPUs exceeds the total number of "
"available GPUs."
)
if
tensor_parallel_size
>=
num_devices_per_node
:
assert
tensor_parallel_size
%
num_devices_per_node
==
0
,
(
"The number of tensor parallelism is not divisible by the "
"number of GPUs per node."
)
else
:
assert
num_devices_per_node
%
tensor_parallel_size
==
0
,
(
"The number of GPUs per node is not divisible by the number "
"of tensor parallelism."
)
# Assign GPUs to pipeline stages.
rank
=
0
current_node_id
=
0
current_device_id
=
0
distributed_init_method
=
None
all_stage_devices
=
[]
for
i
in
range
(
pipeline_parallel_size
):
stage_devices
=
[]
for
j
in
range
(
tensor_parallel_size
):
node_resource
=
valid_node_resources
[
current_node_id
]
stage_devices
.
append
((
rank
,
node_resource
,
current_device_id
))
if
distributed_init_method
is
None
:
ip
=
node_resource
.
split
(
"node:"
)[
-
1
]
port
=
random
.
randint
(
10000
,
20000
)
distributed_init_method
=
f
"tcp://
{
ip
}
:
{
port
}
"
rank
+=
1
current_device_id
+=
1
if
current_device_id
>=
num_devices_per_node
:
current_node_id
+=
1
current_device_id
=
0
all_stage_devices
.
append
(
stage_devices
)
return
(
num_nodes
,
num_devices_per_node
,
distributed_init_method
,
all_stage_devices
)
_GiB
=
1
<<
30
def
add_server_arguments
(
parser
:
argparse
.
ArgumentParser
):
"""Shared arguments for CacheFlow servers."""
# Model arguments
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'facebook/opt-125m'
,
help
=
'model name'
)
parser
.
add_argument
(
'--cache-dir'
,
type
=
str
,
default
=
None
,
help
=
'cache dir to download and load the weights, '
'default to the default cache dir of huggingface'
)
parser
.
add_argument
(
'--use-np-cache'
,
action
=
'store_true'
,
help
=
'save a numpy copy of model weights for faster loading'
)
parser
.
add_argument
(
'--use-dummy-weights'
,
action
=
'store_true'
,
help
=
'use dummy values for model weights'
)
# TODO(woosuk): Support FP32 for debugging.
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'default'
,
choices
=
[
'default'
,
'half'
,
'bfloat16'
],
help
=
(
'data type for model weights and activations. '
'The "default" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.'
))
# Parallel arguments
parser
.
add_argument
(
'--use-ray'
,
action
=
'store_true'
,
help
=
'use Ray for distributed serving, will be automatically set when using more than 1 GPU'
)
parser
.
add_argument
(
'--pipeline-parallel-size'
,
'-pp'
,
type
=
int
,
default
=
1
,
help
=
'number of pipeline stages'
)
parser
.
add_argument
(
'--tensor-parallel-size'
,
'-tp'
,
type
=
int
,
default
=
1
,
help
=
'number of tensor parallel replicas'
)
# KV cache arguments
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
default
=
16
,
choices
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
],
help
=
'token block size'
)
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
0
,
help
=
'random seed'
)
parser
.
add_argument
(
'--swap-space'
,
type
=
int
,
default
=
20
,
help
=
'CPU swap space size (GiB) per GPU'
)
parser
.
add_argument
(
'--gpu-memory-utilization'
,
type
=
float
,
default
=
0.95
,
help
=
'the percentage of GPU memory to be used for the model executor'
)
parser
.
add_argument
(
'--max-num-batched-tokens'
,
type
=
int
,
default
=
2560
,
help
=
'maximum number of batched tokens per iteration'
)
parser
.
add_argument
(
'--max-num-sequences'
,
type
=
int
,
default
=
256
,
help
=
'maximum number of sequences per iteration'
)
parser
.
add_argument
(
'--log-stats'
,
action
=
'store_true'
,
help
=
'log system statistics'
)
return
parser
def
process_server_arguments
(
args
:
argparse
.
Namespace
):
"""Post process the parsed arguments."""
if
args
.
pipeline_parallel_size
*
args
.
tensor_parallel_size
>
1
:
args
.
use_ray
=
True
args
.
swap_space
=
args
.
swap_space
*
_GiB
args
.
max_num_sequences
=
min
(
args
.
max_num_sequences
,
args
.
max_num_batched_tokens
)
return
args
def
init_local_server_and_frontend_with_arguments
(
args
:
argparse
.
Namespace
):
# TODO(zhuohan): Support pipeline parallelism.
assert
args
.
pipeline_parallel_size
==
1
,
(
'Pipeline parallelism is not supported yet.'
)
(
num_nodes
,
num_devices_per_node
,
distributed_init_method
,
all_stage_devices
)
=
(
initialize_cluster
(
use_ray
=
args
.
use_ray
,
pipeline_parallel_size
=
args
.
pipeline_parallel_size
,
tensor_parallel_size
=
args
.
tensor_parallel_size
))
# Create a server.
server
=
Server
(
model
=
args
.
model
,
cache_dir
=
args
.
cache_dir
,
use_dummy_weights
=
args
.
use_dummy_weights
,
use_np_cache
=
args
.
use_np_cache
,
pipeline_parallel_size
=
args
.
pipeline_parallel_size
,
tensor_parallel_size
=
args
.
tensor_parallel_size
,
block_size
=
args
.
block_size
,
dtype
=
args
.
dtype
,
seed
=
args
.
seed
,
swap_space
=
args
.
swap_space
,
gpu_memory_utilization
=
args
.
gpu_memory_utilization
,
max_num_batched_tokens
=
args
.
max_num_batched_tokens
,
max_num_sequences
=
args
.
max_num_sequences
,
num_nodes
=
num_nodes
,
num_devices_per_node
=
num_devices_per_node
,
distributed_init_method
=
distributed_init_method
,
all_stage_devices
=
all_stage_devices
,
use_ray
=
args
.
use_ray
,
log_stats
=
args
.
log_stats
,
)
# Create a frontend.
frontend
=
SimpleFrontend
(
model_name
=
args
.
model
,
block_size
=
args
.
block_size
,
)
return
server
,
frontend
cacheflow/
frontend
/fastapi_
frontend
.py
→
cacheflow/
entrypoints
/fastapi_
server
.py
View file @
c3442c1f
...
...
@@ -2,115 +2,66 @@ import argparse
import
asyncio
import
json
import
time
from
typing
import
List
,
Dict
,
Optional
from
typing
import
Any
,
Dict
import
uuid
from
fastapi
import
FastAPI
,
Request
from
fastapi.responses
import
StreamingResponse
import
ray
import
uvicorn
from
cacheflow.core.server
import
(
Server
,
add_server_arguments
,
process_server_arguments
,
initialize_cluster
)
from
cacheflow.frontend.utils
import
get_tokenizer
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sequence
import
Sequence
,
SequenceGroup
from
cacheflow.utils
import
Counter
from
cacheflow.worker.controller
import
DeviceID
from
cacheflow.server.arg_utils
import
(
add_server_arguments
,
create_server_configs_from_args
)
from
cacheflow.server.llm_server
import
LLMServer
from
cacheflow.server.ray_utils
import
initialize_cluster
TIMEOUT_TO_PREVENT_DEADLOCK
=
1
# seconds
app
=
FastAPI
()
class
FastAPIServer
:
def
__init__
(
self
,
model
:
str
,
cache_dir
:
Optional
[
str
],
use_np_cache
:
bool
,
pipeline_parallel_size
:
int
,
tensor_parallel_size
:
int
,
block_size
:
int
,
dtype
:
str
,
seed
:
int
,
swap_space
:
int
,
gpu_memory_utilization
:
float
,
max_num_batched_tokens
:
int
,
max_num_sequences
:
int
,
num_nodes
:
int
,
num_devices_per_node
:
int
,
distributed_init_method
:
str
,
all_stage_devices
:
List
[
List
[
DeviceID
]],
server_use_ray
:
bool
,
log_stats
:
bool
,
):
self
.
block_size
=
block_size
self
.
tokenizer
=
get_tokenizer
(
model
)
self
.
seq_group_counter
=
Counter
()
self
.
seq_counter
=
Counter
()
def
__init__
(
self
,
server_use_ray
:
bool
,
*
args
,
**
kwargs
)
->
None
:
if
server_use_ray
:
remote_server_class
=
ray
.
remote
(
num_cpus
=
0
)(
Server
)
remote_server_class
=
ray
.
remote
(
num_cpus
=
0
)(
LLM
Server
)
else
:
remote_server_class
=
ray
.
remote
(
num_gpus
=
1
)(
Server
)
self
.
server
=
remote_server_class
.
remote
(
model
=
model
,
cache_dir
=
cache_dir
,
use_dummy_weights
=
False
,
use_np_cache
=
use_np_cache
,
pipeline_parallel_size
=
pipeline_parallel_size
,
tensor_parallel_size
=
tensor_parallel_size
,
block_size
=
block_size
,
dtype
=
dtype
,
seed
=
seed
,
swap_space
=
swap_space
,
gpu_memory_utilization
=
gpu_memory_utilization
,
max_num_batched_tokens
=
max_num_batched_tokens
,
max_num_sequences
=
max_num_sequences
,
num_nodes
=
num_nodes
,
num_devices_per_node
=
num_devices_per_node
,
distributed_init_method
=
distributed_init_method
,
all_stage_devices
=
all_stage_devices
,
use_ray
=
server_use_ray
,
log_stats
=
log_stats
,
)
self
.
running_seq_groups
:
Dict
[
int
,
SequenceGroup
]
=
{}
self
.
sequence_group_events
:
Dict
[
int
,
asyncio
.
Event
]
=
{}
remote_server_class
=
ray
.
remote
(
num_gpus
=
1
)(
LLMServer
)
self
.
server
=
remote_server_class
.
remote
(
*
args
,
**
kwargs
)
# Request id -> request output.
self
.
request_outputs
:
Dict
[
str
,
RequestOutput
]
=
{}
# Request id -> event to notify that there is new output.
self
.
request_events
:
Dict
[
str
,
asyncio
.
Event
]
=
{}
self
.
is_server_running
=
False
async
def
server_step
(
self
):
self
.
is_server_running
=
True
updated_seq_group
s
=
await
self
.
server
.
step
.
remote
()
request_output
s
=
await
self
.
server
.
step
.
remote
()
self
.
is_server_running
=
False
# Notify the waiting coroutines that there are new outputs ready.
for
s
eq
_group
in
updated_seq_group
s
:
group
_id
=
s
eq
_group
.
group
_id
self
.
r
unning_seq_groups
[
group
_id
]
=
s
eq
_group
self
.
s
eque
nce_group
_events
[
group
_id
].
set
()
for
r
eq
uest_output
in
request_output
s
:
request
_id
=
r
eq
uest_output
.
request
_id
self
.
r
equest_outputs
[
request
_id
]
=
r
eq
uest_output
self
.
r
eque
st
_events
[
request
_id
].
set
()
async
def
generate
(
self
,
request_dict
:
Dict
):
async
def
generate
(
self
,
request_dict
:
Dict
[
str
,
Any
]
):
# Preprocess the request.
arrival_time
=
time
.
time
()
prompt
=
request_dict
.
pop
(
"prompt"
)
sampling_params
=
SamplingParams
(
**
request_dict
)
sampling_params
.
stop_token_ids
.
add
(
self
.
tokenizer
.
eos_token_id
)
token_ids
=
self
.
tokenizer
.
encode
(
prompt
)
seqs
:
List
[
Sequence
]
=
[]
for
_
in
range
(
sampling_params
.
n
):
seq_id
=
next
(
self
.
seq_counter
)
seq
=
Sequence
(
seq_id
,
prompt
,
token_ids
,
block_size
=
self
.
block_size
)
seqs
.
append
(
seq
)
arrival_time
=
time
.
time
()
group_id
=
next
(
self
.
seq_group_counter
)
seq_group
=
SequenceGroup
(
group_id
,
seqs
,
arrival_time
)
# Create an event to notify us that there is new output from the
# cacheflow server.
group_event
=
asyncio
.
Event
()
self
.
running_seq_groups
[
group_id
]
=
seq_group
self
.
sequence_group_events
[
group_id
]
=
group_event
request_id
=
str
(
uuid
.
uuid4
().
hex
[:
8
])
request_event
=
asyncio
.
Event
()
self
.
request_events
[
request_id
]
=
request_event
# Add the request into the cacheflow server's waiting queue.
await
self
.
server
.
add_sequence_groups
.
remote
([(
seq_group
,
sampling_params
)])
await
self
.
server
.
add_request
.
remote
(
request_id
,
prompt
,
sampling_params
,
arrival_time
=
arrival_time
)
# The cacheflow server does not have a background loop that keeps
# processing incoming requests. Therefore, we need to keep kicking
# the server to process the requests.
...
...
@@ -118,32 +69,35 @@ class FastAPIServer:
# Kick the server if the server is not running.
if
not
self
.
is_server_running
:
await
self
.
server_step
()
# Wait for new output. The group_event will be set in server_step
# when there is new output available for the sequence group.
# Added a timeout to prevent deadlock.
try
:
await
asyncio
.
wait_for
(
group_event
.
wait
(),
timeout
=
TIMEOUT_TO_PREVENT_DEADLOCK
)
await
asyncio
.
wait_for
(
request_event
.
wait
(),
timeout
=
TIMEOUT_TO_PREVENT_DEADLOCK
)
except
asyncio
.
TimeoutError
:
continue
# Reset the event to wait for the next output.
group_event
.
clear
()
# Decode and return new outputs
seq_group
=
self
.
running_seq_groups
[
group_id
]
all_outputs
=
[]
for
seq
in
seq_group
.
seqs
:
token_ids
=
seq
.
get_token_ids
()
output
=
self
.
tokenizer
.
decode
(
token_ids
,
skip_special_tokens
=
True
)
all_outputs
.
append
(
output
)
request_event
.
clear
()
# Decode and return new outputs.
request_output
=
self
.
request_outputs
[
request_id
]
prompt
=
request_output
.
prompt
text_outputs
=
[
prompt
+
output
.
text
for
output
in
request_output
.
outputs
]
ret
=
{
"text"
:
all
_outputs
,
"text"
:
text
_outputs
,
"error"
:
0
,
}
yield
(
json
.
dumps
(
ret
)
+
"
\0
"
).
encode
(
"utf-8"
)
# Once finished, release the resources of the sequence group.
if
s
eq
_group
.
is_finished
()
:
del
self
.
r
unning_seq_groups
[
group
_id
]
del
self
.
s
eque
nce_group
_events
[
group
_id
]
if
r
eq
uest_output
.
done
:
del
self
.
r
equest_outputs
[
request
_id
]
del
self
.
r
eque
st
_events
[
request
_id
]
# Kick the server if the server is not running. This is to
# prevent that there are still requests in server's waiting
# queue to be executed.
...
...
@@ -164,38 +118,11 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
10002
)
parser
=
add_server_arguments
(
parser
)
args
=
parser
.
parse_args
()
args
=
process_server_arguments
(
args
)
# TODO(zhuohan): Support pipeline parallelism.
assert
args
.
pipeline_parallel_size
==
1
,
(
'Pipeline parallelism is not supported yet.'
)
(
num_nodes
,
num_devices_per_node
,
distributed_init_method
,
all_stage_devices
)
=
(
initialize_cluster
(
use_ray
=
True
,
pipeline_parallel_size
=
args
.
pipeline_parallel_size
,
tensor_parallel_size
=
args
.
tensor_parallel_size
))
server_configs
=
create_server_configs_from_args
(
args
)
parallel_config
=
server_configs
[
2
]
distributed_init_method
,
stage_devices
=
initialize_cluster
(
parallel_config
)
server
=
FastAPIServer
(
model
=
args
.
model
,
cache_dir
=
args
.
cache_dir
,
use_np_cache
=
args
.
use_np_cache
,
pipeline_parallel_size
=
args
.
pipeline_parallel_size
,
tensor_parallel_size
=
args
.
tensor_parallel_size
,
block_size
=
args
.
block_size
,
dtype
=
args
.
dtype
,
seed
=
args
.
seed
,
swap_space
=
args
.
swap_space
,
gpu_memory_utilization
=
args
.
gpu_memory_utilization
,
max_num_batched_tokens
=
args
.
max_num_batched_tokens
,
max_num_sequences
=
args
.
max_num_sequences
,
num_nodes
=
num_nodes
,
num_devices_per_node
=
num_devices_per_node
,
distributed_init_method
=
distributed_init_method
,
all_stage_devices
=
all_stage_devices
,
server_use_ray
=
args
.
use_ray
,
log_stats
=
args
.
log_stats
,
)
args
.
use_ray
,
*
server_configs
,
distributed_init_method
,
stage_devices
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"info"
)
cacheflow/frontend/simple_frontend.py
deleted
100644 → 0
View file @
7297fa6f
import
time
from
typing
import
List
,
Optional
,
Tuple
from
cacheflow.frontend.utils
import
get_tokenizer
from
cacheflow.logger
import
init_logger
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sequence
import
Sequence
,
SequenceGroup
from
cacheflow.utils
import
Counter
logger
=
init_logger
(
__name__
)
class
SimpleFrontend
:
def
__init__
(
self
,
model_name
:
str
,
block_size
:
int
,
)
->
None
:
self
.
block_size
=
block_size
self
.
tokenizer
=
get_tokenizer
(
model_name
)
self
.
seq_group_counter
=
Counter
()
self
.
seq_counter
=
Counter
()
self
.
inputs
:
List
[
Tuple
[
SequenceGroup
,
SamplingParams
]]
=
[]
def
add_eos_token
(
self
,
sampling_params
:
SamplingParams
)
->
SamplingParams
:
# Stop generation when we see an EOS token.
sampling_params
.
stop_token_ids
.
add
(
self
.
tokenizer
.
eos_token_id
)
return
sampling_params
def
query
(
self
,
prompt
:
str
,
sampling_params
:
SamplingParams
,
)
->
None
:
token_ids
=
self
.
tokenizer
.
encode
(
prompt
)
self
.
_add_query
(
prompt
,
token_ids
,
sampling_params
)
def
_add_query
(
self
,
prompt
:
str
,
token_ids
:
List
[
int
],
sampling_params
:
SamplingParams
,
arrival_time
:
Optional
[
float
]
=
None
,
)
->
None
:
if
arrival_time
is
None
:
arrival_time
=
time
.
time
()
seqs
:
List
[
Sequence
]
=
[]
for
_
in
range
(
sampling_params
.
n
):
seq_id
=
next
(
self
.
seq_counter
)
seq
=
Sequence
(
seq_id
,
prompt
,
token_ids
,
block_size
=
self
.
block_size
)
seqs
.
append
(
seq
)
group_id
=
next
(
self
.
seq_group_counter
)
seq_group
=
SequenceGroup
(
group_id
,
seqs
,
arrival_time
)
self
.
inputs
.
append
((
seq_group
,
sampling_params
))
def
get_inputs
(
self
)
->
List
[
Tuple
[
SequenceGroup
,
SamplingParams
]]:
inputs
=
self
.
inputs
self
.
inputs
=
[]
return
inputs
def
print_response
(
self
,
seq_group
:
SequenceGroup
,
)
->
None
:
for
seq
in
seq_group
.
seqs
:
token_ids
=
seq
.
get_token_ids
()
output
=
self
.
tokenizer
.
decode
(
token_ids
,
skip_special_tokens
=
True
)
output
=
output
.
strip
()
logger
.
info
(
f
"Seq
{
seq
.
seq_id
}
:
{
output
!
r
}
"
)
cacheflow/model_executor/__init__.py
View file @
c3442c1f
from
cacheflow.model_executor.input_metadata
import
InputMetadata
from
cacheflow.model_executor.model_loader
import
get_model
from
cacheflow.model_executor.utils
import
(
set_random_seed
,
get_cache_block_size
)
from
cacheflow.model_executor.utils
import
set_random_seed
__all__
=
[
"InputMetadata"
,
"get_cache_block_size"
,
"get_model"
,
"set_random_seed"
,
]
cacheflow/model_executor/layers/attention.py
View file @
c3442c1f
...
...
@@ -10,9 +10,9 @@ from cacheflow import cache_ops
from
cacheflow
import
pos_encoding_ops
from
cacheflow.model_executor.input_metadata
import
InputMetadata
_SUPPORTED_HEAD_SIZES
=
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]
class
GPTCacheFlowAttention
(
nn
.
Module
):
"""GPT-style multi-head attention.
...
...
cacheflow/model_executor/model_loader.py
View file @
c3442c1f
"""Utilities for selecting and loading models."""
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
transformers
import
AutoConfig
,
PretrainedConfig
from
transformers
import
PretrainedConfig
from
cacheflow.config
import
ModelConfig
from
cacheflow.model_executor.models
import
(
GPT2LMHeadModel
,
GPTNeoXForCausalLM
,
LlamaForCausalLM
,
OPTForCausalLM
)
from
cacheflow.model_executor.utils
import
get_torch_dtype
from
cacheflow.model_executor.weight_utils
import
initialize_dummy_weights
# TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY
=
{
"GPT2LMHeadModel"
:
GPT2LMHeadModel
,
...
...
@@ -19,6 +16,7 @@ _MODEL_REGISTRY = {
"OPTForCausalLM"
:
OPTForCausalLM
,
}
def
_get_model_architecture
(
config
:
PretrainedConfig
)
->
nn
.
Module
:
architectures
=
getattr
(
config
,
"architectures"
,
[])
for
arch
in
architectures
:
...
...
@@ -30,51 +28,22 @@ def _get_model_architecture(config: PretrainedConfig) -> nn.Module:
)
def
_get_dtype
(
config
:
PretrainedConfig
,
dtype
:
str
)
->
torch
.
dtype
:
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype
=
getattr
(
config
,
"torch_dtype"
,
None
)
if
config_dtype
is
None
:
config_dtype
=
torch
.
float32
if
dtype
==
"default"
:
if
config_dtype
==
torch
.
float32
:
# Following the common practice, we use float16 for float32 models.
torch_dtype
=
torch
.
float16
else
:
torch_dtype
=
config_dtype
else
:
torch_dtype
=
get_torch_dtype
(
dtype
)
if
torch_dtype
!=
config_dtype
and
config_dtype
!=
torch
.
float32
:
# TODO(woosuk): Allow using float16 for bfloat16 models and
# vice versa. Print a warning message and continue.
raise
ValueError
(
f
"Cannot use
{
torch_dtype
}
for
{
config_dtype
}
model."
)
return
torch_dtype
def
get_model
(
model_name
:
str
,
dtype
:
str
,
cache_dir
:
Optional
[
str
],
use_dummy_weights
:
bool
,
use_np_cache
:
bool
,
)
->
nn
.
Module
:
config
=
AutoConfig
.
from_pretrained
(
model_name
)
torch_dtype
=
_get_dtype
(
config
,
dtype
)
torch
.
set_default_dtype
(
torch_dtype
)
model_class
=
_get_model_architecture
(
config
)
def
get_model
(
model_config
:
ModelConfig
)
->
nn
.
Module
:
model_class
=
_get_model_architecture
(
model_config
.
hf_config
)
torch
.
set_default_dtype
(
model_config
.
dtype
)
# Create a model instance.
# The weights will be initialized as empty tensors.
model
=
model_class
(
config
)
if
use_dummy_weights
:
model
=
model_class
(
model_config
.
hf_
config
)
if
model_config
.
use_dummy_weights
:
model
=
model
.
cuda
()
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights
(
model
)
else
:
# Load the weights from the cached or downloaded files.
model
.
load_weights
(
model_name
,
cache_dir
,
use_np_cache
)
model
.
load_weights
(
model_config
.
model
,
model_config
.
download_dir
,
model_config
.
use_np_weights
)
model
=
model
.
cuda
()
return
model
.
eval
(),
torch_dtype
return
model
.
eval
()
cacheflow/model_executor/utils.py
View file @
c3442c1f
"""Utils for model executor."""
import
random
from
typing
import
Union
import
numpy
as
np
import
torch
...
...
@@ -9,28 +8,6 @@ from cacheflow.model_executor.parallel_utils.parallel_state import model_paralle
from
cacheflow.model_executor.parallel_utils.tensor_parallel
import
model_parallel_cuda_manual_seed
_STR_DTYPE_TO_TORCH_DTYPE
=
{
"half"
:
torch
.
half
,
"float"
:
torch
.
float
,
"float16"
:
torch
.
float16
,
"float32"
:
torch
.
float32
,
"bfloat16"
:
torch
.
bfloat16
,
}
def
get_torch_dtype
(
dtype
:
Union
[
torch
.
dtype
,
str
])
->
torch
.
dtype
:
if
isinstance
(
dtype
,
str
):
torch_dtype
=
_STR_DTYPE_TO_TORCH_DTYPE
[
dtype
.
lower
()]
else
:
torch_dtype
=
dtype
return
torch_dtype
def
get_dtype_size
(
dtype
:
Union
[
torch
.
dtype
,
str
])
->
int
:
torch_dtype
=
get_torch_dtype
(
dtype
)
return
torch
.
tensor
([],
dtype
=
torch_dtype
).
element_size
()
def
set_random_seed
(
seed
:
int
)
->
None
:
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
...
...
@@ -40,15 +17,3 @@ def set_random_seed(seed: int) -> None:
if
model_parallel_is_initialized
():
model_parallel_cuda_manual_seed
(
seed
)
def
get_cache_block_size
(
block_size
:
int
,
num_heads
:
int
,
head_size
:
int
,
num_layers
:
int
,
dtype
:
str
)
->
int
:
key_cache_block
=
block_size
*
num_heads
*
head_size
value_cache_block
=
key_cache_block
total
=
num_layers
*
(
key_cache_block
+
value_cache_block
)
dtype_size
=
get_dtype_size
(
dtype
)
return
dtype_size
*
total
cacheflow/outputs.py
0 → 100644
View file @
c3442c1f
from
typing
import
Dict
,
List
,
Union
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
from
cacheflow.sequence
import
SequenceGroup
class
CompletionOutput
:
def
__init__
(
self
,
text
:
str
,
token_ids
:
List
[
int
],
cumulative_logprobs
:
float
,
logprobs
:
List
[
Dict
[
int
,
float
]],
)
->
None
:
self
.
text
=
text
self
.
token_ids
=
token_ids
self
.
cumulative_logprobs
=
cumulative_logprobs
self
.
logprobs
=
logprobs
def
__repr__
(
self
)
->
str
:
return
(
f
"CompletionOutput(output=
{
self
.
text
!
r
}
, "
f
"token_ids=
{
self
.
token_ids
}
, "
f
"cumulative_logprobs=
{
self
.
cumulative_logprobs
}
, "
f
"logprobs=
{
self
.
logprobs
}
)"
)
class
RequestOutput
:
def
__init__
(
self
,
request_id
:
int
,
prompt
:
str
,
prompt_token_ids
:
List
[
int
],
outputs
:
List
[
CompletionOutput
],
done
:
bool
=
False
,
)
->
None
:
self
.
request_id
=
request_id
self
.
prompt
=
prompt
self
.
prompt_token_ids
=
prompt_token_ids
self
.
outputs
=
outputs
self
.
done
=
done
@
staticmethod
def
from_seq_group
(
seq_group
:
SequenceGroup
,
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
],
)
->
"RequestOutput"
:
outputs
:
List
[
CompletionOutput
]
=
[]
seqs
=
seq_group
.
get_seqs
()
for
seq
in
seqs
:
output_token_ids
=
seq
.
data
.
output_token_ids
output_str
=
tokenizer
.
decode
(
output_token_ids
,
skip_special_tokens
=
True
)
seq_logprobs
=
seq
.
data
.
cumulative_logprobs
logprobs
=
seq
.
output_logprobs
if
seq_group
.
sampling_params
.
logprobs
==
0
:
# NOTE: We need to take care of this case because the sequence
# always has the logprobs of the sampled tokens even if the
# logprobs are not requested.
logprobs
=
{}
output
=
CompletionOutput
(
output_str
,
output_token_ids
,
seq_logprobs
,
logprobs
)
outputs
.
append
(
output
)
# Every sequence in the sequence group should have the same prompt.
prompt
=
seqs
[
0
].
prompt
prompt_token_ids
=
seqs
[
0
].
data
.
prompt_token_ids
return
RequestOutput
(
seq_group
.
request_id
,
prompt
,
prompt_token_ids
,
outputs
,
seq_group
.
is_finished
())
def
__repr__
(
self
)
->
str
:
return
(
f
"RequestOutput(request_id=
{
self
.
request_id
}
, "
f
"prompt=
{
self
.
prompt
!
r
}
, "
f
"prompt_token_ids=
{
self
.
prompt_token_ids
}
, "
f
"outputs=
{
self
.
outputs
}
, "
f
"done=
{
self
.
done
}
)"
)
cacheflow/sampling_params.py
View file @
c3442c1f
...
...
@@ -116,4 +116,4 @@ class SamplingParams:
f
"use_beam_search=
{
self
.
use_beam_search
}
, "
f
"stop_token_ids=
{
self
.
stop_token_ids
}
, "
f
"max_tokens=
{
self
.
max_tokens
}
, "
f
"logprobs=
{
self
.
logprobs
}
"
)
f
"logprobs=
{
self
.
logprobs
}
)
"
)
cacheflow/sequence.py
View file @
c3442c1f
...
...
@@ -115,12 +115,14 @@ class SequenceGroup:
def
__init__
(
self
,
group
_id
:
int
,
request
_id
:
str
,
seqs
:
List
[
Sequence
],
sampling_params
:
SamplingParams
,
arrival_time
:
float
,
)
->
None
:
self
.
group_id
=
group
_id
self
.
request_id
=
request
_id
self
.
seqs
=
seqs
self
.
sampling_params
=
sampling_params
self
.
arrival_time
=
arrival_time
def
get_seqs
(
...
...
@@ -145,21 +147,22 @@ class SequenceGroup:
return
all
(
seq
.
status
==
SequenceStatus
.
FINISHED
for
seq
in
self
.
seqs
)
def
__repr__
(
self
)
->
str
:
return
(
f
'SequenceGroup(group_id=
{
self
.
group_id
}
, '
f
'num_seqs=
{
len
(
self
.
seqs
)
}
)'
)
return
(
f
"SequenceGroup(request_id=
{
self
.
request_id
}
, "
f
"sampling_params=
{
self
.
sampling_params
}
, "
f
"num_seqs=
{
len
(
self
.
seqs
)
}
)"
)
class
SequenceGroupMetadata
:
def
__init__
(
self
,
group
_id
:
int
,
request
_id
:
str
,
is_prompt
:
bool
,
seq_data
:
Dict
[
int
,
SequenceData
],
# Seq id -> sequence data.
sampling_params
:
SamplingParams
,
block_tables
:
Dict
[
int
,
List
[
int
]],
# Seq id -> list of physical block numbers.
)
->
None
:
self
.
group_id
=
group
_id
self
.
request_id
=
request
_id
self
.
is_prompt
=
is_prompt
self
.
seq_data
=
seq_data
self
.
sampling_params
=
sampling_params
...
...
cacheflow/server/arg_utils.py
0 → 100644
View file @
c3442c1f
import
argparse
from
typing
import
Tuple
from
cacheflow.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
from
cacheflow.server.llm_server
import
LLMServer
from
cacheflow.server.ray_utils
import
initialize_cluster
_GiB
=
1
<<
30
def
add_server_arguments
(
parser
:
argparse
.
ArgumentParser
):
"""Shared arguments for CacheFlow servers."""
# Model arguments
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'facebook/opt-125m'
,
help
=
'model name'
)
parser
.
add_argument
(
'--download-dir'
,
type
=
str
,
default
=
None
,
help
=
'directory to download and load the weights, '
'default to the default cache dir of huggingface'
)
parser
.
add_argument
(
'--use-np-weights'
,
action
=
'store_true'
,
help
=
'save a numpy copy of model weights for faster loading'
)
parser
.
add_argument
(
'--use-dummy-weights'
,
action
=
'store_true'
,
help
=
'use dummy values for model weights'
)
# TODO(woosuk): Support FP32.
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'default'
,
choices
=
[
'default'
,
'half'
,
'bfloat16'
],
help
=
(
'data type for model weights and activations. '
'The "default" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.'
))
# Parallel arguments
parser
.
add_argument
(
'--use-ray'
,
action
=
'store_true'
,
help
=
'use Ray for distributed serving, will be automatically set when using more than 1 GPU'
)
parser
.
add_argument
(
'--pipeline-parallel-size'
,
'-pp'
,
type
=
int
,
default
=
1
,
help
=
'number of pipeline stages'
)
parser
.
add_argument
(
'--tensor-parallel-size'
,
'-tp'
,
type
=
int
,
default
=
1
,
help
=
'number of tensor parallel replicas'
)
# KV cache arguments
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
default
=
16
,
choices
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
],
help
=
'token block size'
)
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
0
,
help
=
'random seed'
)
parser
.
add_argument
(
'--swap-space'
,
type
=
int
,
default
=
4
,
help
=
'CPU swap space size (GiB) per GPU'
)
parser
.
add_argument
(
'--gpu-memory-utilization'
,
type
=
float
,
default
=
0.95
,
help
=
'the percentage of GPU memory to be used for the model executor'
)
parser
.
add_argument
(
'--max-num-batched-tokens'
,
type
=
int
,
default
=
2560
,
help
=
'maximum number of batched tokens per iteration'
)
parser
.
add_argument
(
'--max-num-seqs'
,
type
=
int
,
default
=
256
,
help
=
'maximum number of sequences per iteration'
)
parser
.
add_argument
(
'--disable-log-stats'
,
action
=
'store_true'
,
help
=
'disable logging statistics'
)
return
parser
def
create_server_configs_from_args
(
args
:
argparse
.
Namespace
,
)
->
Tuple
[
ModelConfig
,
CacheConfig
,
ParallelConfig
,
SchedulerConfig
]:
# Post-process the parsed arguments.
args
.
swap_space
=
args
.
swap_space
*
_GiB
args
.
max_num_seqs
=
min
(
args
.
max_num_seqs
,
args
.
max_num_batched_tokens
)
# Initialize the configs.
model_config
=
ModelConfig
(
args
.
model
,
args
.
download_dir
,
args
.
use_np_weights
,
args
.
use_dummy_weights
,
args
.
dtype
,
args
.
seed
)
cache_config
=
CacheConfig
(
args
.
block_size
,
args
.
gpu_memory_utilization
,
args
.
swap_space
)
parallel_config
=
ParallelConfig
(
args
.
pipeline_parallel_size
,
args
.
tensor_parallel_size
,
args
.
use_ray
)
scheduler_config
=
SchedulerConfig
(
args
.
max_num_batched_tokens
,
args
.
max_num_seqs
)
return
model_config
,
cache_config
,
parallel_config
,
scheduler_config
def
initialize_server_from_args
(
args
:
argparse
.
Namespace
)
->
LLMServer
:
server_configs
=
create_server_configs_from_args
(
args
)
parallel_config
=
server_configs
[
2
]
# Initialize the cluster.
distributed_init_method
,
devices
=
initialize_cluster
(
parallel_config
)
# Create the LLM server.
server
=
LLMServer
(
*
server_configs
,
distributed_init_method
,
devices
,
log_stats
=
not
args
.
disable_log_stats
)
return
server
cacheflow/server/llm_server.py
0 → 100644
View file @
c3442c1f
import
time
from
typing
import
Any
,
List
,
Optional
try
:
import
ray
except
ImportError
:
ray
=
None
from
cacheflow.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
from
cacheflow.core.scheduler
import
Scheduler
from
cacheflow.logger
import
init_logger
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.server.tokenizer_utils
import
get_tokenizer
from
cacheflow.sequence
import
Sequence
,
SequenceGroup
from
cacheflow.utils
import
Counter
from
cacheflow.worker.worker
import
Worker
logger
=
init_logger
(
__name__
)
class
LLMServer
:
def
__init__
(
self
,
model_config
:
ModelConfig
,
cache_config
:
CacheConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
distributed_init_method
:
str
,
stage_devices
:
List
[
List
[
Any
]],
log_stats
:
bool
=
True
,
)
->
None
:
logger
.
info
(
"Initializing an LLM server with config: "
f
"model=
{
model_config
.
model
!
r
}
, "
f
"dtype=
{
model_config
.
dtype
}
, "
f
"use_dummy_weights=
{
model_config
.
use_dummy_weights
}
, "
f
"download_dir=
{
model_config
.
download_dir
!
r
}
, "
f
"use_np_weights=
{
model_config
.
use_np_weights
}
, "
f
"tensor_parallel_size=
{
parallel_config
.
tensor_parallel_size
}
, "
f
"seed=
{
model_config
.
seed
}
)"
)
# TODO(woosuk): Print more configs in debug mode.
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
log_stats
=
log_stats
self
.
_verify_args
()
self
.
tokenizer
=
get_tokenizer
(
model_config
.
model
)
self
.
seq_counter
=
Counter
()
# Create the parallel GPU workers.
self
.
workers
:
List
[
Worker
]
=
[]
assert
len
(
stage_devices
)
==
1
,
"Only support one stage for now."
for
rank
,
node_resource
,
_
in
stage_devices
[
0
]:
worker_cls
=
Worker
if
self
.
parallel_config
.
use_ray
:
worker_cls
=
ray
.
remote
(
num_cpus
=
0
,
num_gpus
=
1
,
resources
=
{
node_resource
:
1e-5
},
)(
worker_cls
).
remote
worker
=
worker_cls
(
model_config
,
parallel_config
,
scheduler_config
,
rank
,
distributed_init_method
,
)
self
.
workers
.
append
(
worker
)
# Profile the memory usage and initialize the cache.
self
.
_init_cache
()
# Create the scheduler.
self
.
scheduler
=
Scheduler
(
scheduler_config
,
cache_config
,
log_stats
)
def
_verify_args
(
self
)
->
None
:
self
.
model_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
def
_init_cache
(
self
)
->
None
:
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks
=
self
.
_run_workers
(
"profile_num_available_blocks"
,
get_all_outputs
=
True
,
block_size
=
self
.
cache_config
.
block_size
,
gpu_memory_utilization
=
self
.
cache_config
.
gpu_memory_utilization
,
cpu_swap_space
=
self
.
cache_config
.
swap_space
,
)
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# operators can be applied to all workers.
num_gpu_blocks
=
min
(
b
[
0
]
for
b
in
num_blocks
)
num_cpu_blocks
=
min
(
b
[
1
]
for
b
in
num_blocks
)
# FIXME(woosuk): Change to debug log.
logger
.
info
(
f
'# GPU blocks:
{
num_gpu_blocks
}
, '
f
'# CPU blocks:
{
num_cpu_blocks
}
'
)
self
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
self
.
cache_config
.
num_cpu_blocks
=
num_cpu_blocks
# Initialize the cache.
self
.
_run_workers
(
"init_cache_engine"
,
cache_config
=
self
.
cache_config
)
def
add_request
(
self
,
request_id
:
str
,
prompt
:
str
,
sampling_params
:
SamplingParams
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
)
->
None
:
if
arrival_time
is
None
:
arrival_time
=
time
.
time
()
if
prompt_token_ids
is
None
:
prompt_token_ids
=
self
.
tokenizer
.
encode
(
prompt
)
# Create the sequences.
block_size
=
self
.
cache_config
.
block_size
seqs
:
List
[
Sequence
]
=
[]
for
_
in
range
(
sampling_params
.
n
):
seq_id
=
next
(
self
.
seq_counter
)
seq
=
Sequence
(
seq_id
,
prompt
,
prompt_token_ids
,
block_size
)
seqs
.
append
(
seq
)
# FIXME(woosuk)
# Add the EOS token to the stop token list.
sampling_params
.
stop_token_ids
.
add
(
self
.
tokenizer
.
eos_token_id
)
# Create the sequence group.
seq_group
=
SequenceGroup
(
request_id
,
seqs
,
sampling_params
,
arrival_time
)
# Add the sequence group to the scheduler.
self
.
scheduler
.
add_seq_group
(
seq_group
)
def
has_unfinished_requests
(
self
)
->
bool
:
return
self
.
scheduler
.
has_unfinished_seqs
()
def
step
(
self
)
->
List
[
RequestOutput
]:
seq_group_metadata_list
,
scheduler_outputs
=
self
.
scheduler
.
schedule
()
if
(
not
seq_group_metadata_list
)
and
scheduler_outputs
.
is_empty
():
# Nothing to do.
return
[]
# Execute the model.
output
=
self
.
_run_workers
(
"execute_model"
,
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
blocks_to_swap_out
=
scheduler_outputs
.
blocks_to_swap_out
,
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
)
# Update the scheduler.
updated_seq_groups
=
self
.
scheduler
.
update
(
output
)
# Create the outputs.
request_outputs
:
List
[
RequestOutput
]
=
[]
for
seq_group
in
updated_seq_groups
:
# TODO(woosuk): Batch-decode the outputs for speedup.
request_output
=
RequestOutput
.
from_seq_group
(
seq_group
,
self
.
tokenizer
)
request_outputs
.
append
(
request_output
)
return
request_outputs
def
_run_workers
(
self
,
method
:
str
,
get_all_outputs
:
bool
=
False
,
*
args
,
**
kwargs
,
)
->
Any
:
all_outputs
=
[]
for
worker
in
self
.
workers
:
executor
=
getattr
(
worker
,
method
)
if
self
.
parallel_config
.
use_ray
:
executor
=
executor
.
remote
output
=
executor
(
*
args
,
**
kwargs
)
all_outputs
.
append
(
output
)
if
self
.
parallel_config
.
use_ray
:
all_outputs
=
ray
.
get
(
all_outputs
)
if
get_all_outputs
:
return
all_outputs
# Make sure all workers have the same results.
output
=
all_outputs
[
0
]
for
other_output
in
all_outputs
[
1
:]:
assert
output
==
other_output
return
output
cacheflow/server/ray_utils.py
0 → 100644
View file @
c3442c1f
import
random
from
typing
import
List
,
Optional
,
Tuple
try
:
import
ray
except
ImportError
:
ray
=
None
from
cacheflow.config
import
ParallelConfig
DeviceID
=
Tuple
[
int
,
str
,
int
]
# rank, node resource (node IP), device id
def
initialize_cluster
(
parallel_config
:
ParallelConfig
,
address
:
Optional
[
str
]
=
None
,
)
->
Tuple
[
str
,
List
[
List
[
DeviceID
]]]:
if
not
parallel_config
.
use_ray
:
# Initialize cluster locally.
port
=
random
.
randint
(
10000
,
20000
)
# We need to setup the distributed init method to make sure
# the distributed megatron code (e.g., get world size) works correctly.
distributed_init_method
=
f
"tcp://localhost:
{
port
}
"
all_stage_devices
=
[[(
0
,
None
,
0
)]]
return
distributed_init_method
,
all_stage_devices
if
ray
is
None
:
raise
ImportError
(
"Ray is not installed. Please install Ray to use distributed "
"serving."
)
# Connect to a ray cluster.
ray
.
init
(
address
=
address
)
# Assume we have a uniform cluster that each node has the same number of
# GPUs for now.
valid_node_resources
=
[]
num_devices_per_node
=
None
for
node
in
ray
.
nodes
():
if
(
not
node
[
'Alive'
])
or
node
[
'Resources'
][
'GPU'
]
<=
0
:
continue
if
num_devices_per_node
is
None
:
num_devices_per_node
=
node
[
'Resources'
][
'GPU'
]
else
:
assert
num_devices_per_node
==
node
[
'Resources'
][
'GPU'
],
(
"The number of GPUs per node is not uniform."
)
for
key
in
node
[
'Resources'
]:
if
key
.
startswith
(
'node:'
):
valid_node_resources
.
append
(
key
)
# Verify the parallel config.
num_nodes
=
len
(
valid_node_resources
)
if
parallel_config
.
world_size
>
num_nodes
*
num_devices_per_node
:
raise
ValueError
(
"The number of required GPUs exceeds the total number of "
"available GPUs."
)
if
parallel_config
.
tensor_parallel_size
>=
num_devices_per_node
:
if
parallel_config
.
tensor_parallel_size
%
num_devices_per_node
!=
0
:
raise
ValueError
(
"The number of tensor parallelism is not divisible by the "
"number of GPUs per node."
)
else
:
if
num_devices_per_node
%
parallel_config
.
tensor_parallel_size
!=
0
:
raise
ValueError
(
"The number of GPUs per node is not divisible by the number "
"of tensor parallelism."
)
# Assign GPUs to pipeline stages.
rank
=
0
current_node_id
=
0
current_device_id
=
0
distributed_init_method
=
None
all_stage_devices
=
[]
for
_
in
range
(
parallel_config
.
pipeline_parallel_size
):
stage_devices
=
[]
for
_
in
range
(
parallel_config
.
tensor_parallel_size
):
node_resource
=
valid_node_resources
[
current_node_id
]
stage_devices
.
append
((
rank
,
node_resource
,
current_device_id
))
if
distributed_init_method
is
None
:
ip
=
node_resource
.
split
(
"node:"
)[
-
1
]
port
=
random
.
randint
(
10000
,
20000
)
distributed_init_method
=
f
"tcp://
{
ip
}
:
{
port
}
"
rank
+=
1
current_device_id
+=
1
if
current_device_id
>=
num_devices_per_node
:
current_node_id
+=
1
current_device_id
=
0
all_stage_devices
.
append
(
stage_devices
)
return
distributed_init_method
,
all_stage_devices
cacheflow/
frontend/
utils.py
→
cacheflow/
server/tokenizer_
utils.py
View file @
c3442c1f
...
...
@@ -3,7 +3,6 @@ from typing import Union
from
transformers
import
(
AutoConfig
,
AutoTokenizer
,
PreTrainedTokenizer
,
PreTrainedTokenizerFast
)
_MODEL_TYPES_WITH_SLOW_TOKENIZER
=
[
# LLaMA fast tokenizer has a bug related to protobuf.
# See https://github.com/WoosukKwon/cacheflow/issues/80#issue-1698550554
...
...
cacheflow/worker/cache_engine.py
View file @
c3442c1f
...
...
@@ -4,6 +4,7 @@ from typing import Dict, List, Tuple
import
torch
from
cacheflow
import
cache_ops
from
cacheflow.config
import
CacheConfig
,
ModelConfig
,
ParallelConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
...
@@ -18,27 +19,22 @@ class CacheEngine:
def
__init__
(
self
,
worker_id
:
int
,
num_layers
:
int
,
num_heads
:
int
,
head_size
:
int
,
block_size
:
int
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
,
dtype
:
torch
.
dtype
,
cache_config
:
CacheConfig
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
)
->
None
:
if
head_size
%
16
!=
0
:
raise
ValueError
(
f
'head_size (
{
head_size
}
) must be a multiple of 16.'
)
self
.
worker_id
=
worker_id
self
.
num_layers
=
num_layers
self
.
num_heads
=
num_heads
self
.
head_size
=
head_siz
e
self
.
block_size
=
block_size
self
.
num_gpu_blocks
=
num_gpu_blocks
self
.
num_
c
pu_blocks
=
num_
c
pu_blocks
self
.
dtype
=
dtype
self
.
cache_config
=
cache_config
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
head_size
=
model_config
.
get_head_size
()
self
.
num_layers
=
model_config
.
get_num_layers
(
parallel_config
)
self
.
num_heads
=
model_config
.
get_num_heads
(
parallel_config
)
self
.
dtype
=
model_config
.
dtyp
e
self
.
block_size
=
cache_config
.
block_size
self
.
num_
g
pu_blocks
=
cache_config
.
num_
g
pu_blocks
self
.
num_cpu_blocks
=
cache_config
.
num_cpu_blocks
# Initialize the cache.
self
.
gpu_cache
=
self
.
allocate_gpu_cache
()
...
...
@@ -48,7 +44,7 @@ class CacheEngine:
self
.
cache_stream
=
torch
.
cuda
.
Stream
()
assert
self
.
cache_stream
!=
torch
.
cuda
.
current_stream
()
# Initialize the events for stream synchronization.
self
.
events
=
[
torch
.
cuda
.
Event
()
for
_
in
range
(
num_layers
)]
self
.
events
=
[
torch
.
cuda
.
Event
()
for
_
in
range
(
self
.
num_layers
)]
def
get_key_block_shape
(
self
)
->
Tuple
[
int
,
int
,
int
,
int
]:
element_size
=
torch
.
tensor
([],
dtype
=
self
.
dtype
).
element_size
()
...
...
@@ -133,3 +129,23 @@ class CacheEngine:
value_caches
=
[
value_cache
for
_
,
value_cache
in
self
.
gpu_cache
]
# NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU.
cache_ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dsts
)
@
staticmethod
def
get_cache_block_size
(
block_size
:
int
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
)
->
int
:
head_size
=
model_config
.
get_head_size
()
num_heads
=
model_config
.
get_num_heads
(
parallel_config
)
num_layers
=
model_config
.
get_num_layers
(
parallel_config
)
key_cache_block
=
block_size
*
num_heads
*
head_size
value_cache_block
=
key_cache_block
total
=
num_layers
*
(
key_cache_block
+
value_cache_block
)
dtype_size
=
_get_dtype_size
(
model_config
.
dtype
)
return
dtype_size
*
total
def
_get_dtype_size
(
dtype
:
torch
.
dtype
)
->
int
:
return
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
cacheflow/worker/controller.py
deleted
100644 → 0
View file @
7297fa6f
from
typing
import
List
,
Optional
,
Tuple
,
Union
try
:
import
ray
except
ImportError
:
ray
=
None
from
cacheflow.core.scheduler
import
Scheduler
from
cacheflow.worker.worker
import
Worker
DeviceID
=
Tuple
[
int
,
str
,
int
]
# rank, node resource (node IP), device id
class
Controller
:
def
__init__
(
self
,
stage_id
:
int
,
stage_devices
:
List
[
DeviceID
],
world_size
:
int
,
tensor_parallel_size
:
int
,
pipeline_parallel_size
:
int
,
distributed_init_method
:
str
,
model_name
:
str
,
dtype
:
str
,
seed
:
int
,
cache_dir
:
Optional
[
str
],
use_dummy_weights
:
bool
,
use_np_cache
:
bool
,
max_num_batched_tokens
:
int
,
max_num_sequences
:
int
,
use_ray
:
bool
,
)
->
None
:
self
.
stage_id
=
stage_id
self
.
stage_devices
=
stage_devices
self
.
model_name
=
model_name
self
.
use_ray
=
use_ray
# Which pipeline stage is this node assigned to?
self
.
is_first_stage
=
stage_id
==
0
self
.
is_last_stage
=
False
self
.
workers
:
List
[
Worker
]
=
[]
for
rank
,
node_resource
,
device_id
in
stage_devices
:
if
self
.
use_ray
:
worker_cls
=
ray
.
remote
(
num_cpus
=
0
,
num_gpus
=
1
,
resources
=
{
node_resource
:
1e-5
})(
Worker
).
remote
else
:
worker_cls
=
Worker
worker
=
worker_cls
(
model_name
=
model_name
,
dtype
=
dtype
,
seed
=
seed
,
distributed_init_method
=
distributed_init_method
,
rank
=
rank
,
world_size
=
world_size
,
tensor_parallel_size
=
tensor_parallel_size
,
pipeline_parallel_size
=
pipeline_parallel_size
,
cache_dir
=
cache_dir
,
use_dummy_weights
=
use_dummy_weights
,
use_np_cache
=
use_np_cache
,
max_num_batched_tokens
=
max_num_batched_tokens
,
max_num_sequences
=
max_num_sequences
,
)
self
.
workers
.
append
(
worker
)
def
get_num_available_blocks
(
self
,
block_size
:
int
,
cpu_swap_space
:
int
,
gpu_memory_utilization
:
float
)
->
List
[
Tuple
[
int
,
int
]]:
all_worker_results
=
[]
for
worker
in
self
.
workers
:
executor
=
worker
.
get_num_available_blocks
if
self
.
use_ray
:
executor
=
executor
.
remote
result
=
executor
(
block_size
,
cpu_swap_space
,
gpu_memory_utilization
,
)
all_worker_results
.
append
(
result
)
if
self
.
use_ray
:
all_worker_results
=
ray
.
get
(
all_worker_results
)
return
all_worker_results
def
init_cache_engine
(
self
,
block_size
:
int
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
):
all_worker_futures
=
[]
for
worker
in
self
.
workers
:
executor
=
worker
.
init_cache_engine
if
self
.
use_ray
:
executor
=
executor
.
remote
future
=
executor
(
block_size
,
num_gpu_blocks
,
num_cpu_blocks
,
)
all_worker_futures
.
append
(
future
)
if
self
.
use_ray
:
ray
.
get
(
all_worker_futures
)
def
set_next
(
self
,
next_node
:
Union
[
'Controller'
,
'Scheduler'
],
)
->
None
:
self
.
next_node
=
next_node
self
.
is_last_stage
=
isinstance
(
next_node
,
Scheduler
)
def
execute_stage
(
self
,
*
args
,
**
kwargs
)
->
None
:
all_outputs
=
[]
for
worker
in
self
.
workers
:
executor
=
(
worker
.
execute_stage
.
remote
if
self
.
use_ray
else
worker
.
execute_stage
)
output
=
executor
(
*
args
,
**
kwargs
)
all_outputs
.
append
(
output
)
if
self
.
use_ray
:
all_outputs
=
ray
.
get
(
all_outputs
)
# Make sure all workers have the same results.
output
=
all_outputs
[
0
]
for
other_output
in
all_outputs
[
1
:]:
assert
output
==
other_output
if
self
.
is_last_stage
:
self
.
next_node
.
post_step
(
output
)
else
:
# TODO: Support pipeline parallelism.
assert
False
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment