Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
e9d3f2ff
Unverified
Commit
e9d3f2ff
authored
Mar 11, 2023
by
Woosuk Kwon
Committed by
GitHub
Mar 11, 2023
Browse files
Add memory analyzer & utomatically configure KV cache size (#6)
parent
1a7eb7da
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
216 additions
and
34 deletions
+216
-34
README.md
README.md
+1
-1
cacheflow/master/scheduler.py
cacheflow/master/scheduler.py
+3
-3
cacheflow/models/__init__.py
cacheflow/models/__init__.py
+4
-2
cacheflow/models/memory_analyzer.py
cacheflow/models/memory_analyzer.py
+125
-0
cacheflow/models/model_utils.py
cacheflow/models/model_utils.py
+23
-21
cacheflow/models/utils.py
cacheflow/models/utils.py
+43
-0
server.py
server.py
+17
-7
No files found.
README.md
View file @
e9d3f2ff
...
...
@@ -3,7 +3,7 @@
## Installation
```
bash
pip
install
cmake
torch transformers
pip
install
psutil numpy
torch transformers
pip
install
flash-attn
# This may take up to 10 mins.
pip
install
-e
.
```
...
...
cacheflow/master/scheduler.py
View file @
e9d3f2ff
...
...
@@ -9,8 +9,6 @@ from cacheflow.sequence import SequenceGroupInputs
from
cacheflow.sequence
import
SequenceOutputs
from
cacheflow.sequence
import
SequenceStatus
_MAX_NUM_BATCHED_TOKENS
=
2048
class
Scheduler
:
...
...
@@ -21,12 +19,14 @@ class Scheduler:
block_size
:
int
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
,
max_num_batched_tokens
:
int
,
)
->
None
:
self
.
frontend
=
frontend
self
.
controllers
=
controllers
self
.
block_size
=
block_size
self
.
num_gpu_blocks
=
num_gpu_blocks
self
.
num_cpu_blocks
=
num_cpu_blocks
self
.
max_num_batched_tokens
=
max_num_batched_tokens
# Create the block space manager.
self
.
block_manager
=
BlockSpaceManager
(
...
...
@@ -164,7 +164,7 @@ class Scheduler:
num_prompt_tokens
=
seq_group
.
seqs
[
0
].
get_len
()
if
self
.
block_manager
.
can_allocate
(
seq_group
):
if
(
num_batched_tokens
+
num_prompt_tokens
<=
_MAX_NUM_BATCHED_TOKENS
):
<=
self
.
max_num_batched_tokens
):
self
.
_allocate
(
seq_group
)
num_batched_tokens
+=
num_prompt_tokens
continue
...
...
cacheflow/models/__init__.py
View file @
e9d3f2ff
from
cacheflow.models.input_metadata
import
InputMetadata
from
cacheflow.models.model_utils
import
get_memory_analyzer
from
cacheflow.models.model_utils
import
get_model
from
cacheflow.models.
model_
utils
import
set_seed
from
cacheflow.models.utils
import
set_seed
__all__
=
[
'InputMetadata'
,
'get_memory_analyzer'
,
'get_model'
,
'set_seed'
'set_seed'
,
]
cacheflow/models/memory_analyzer.py
0 → 100644
View file @
e9d3f2ff
import
torch
from
transformers
import
AutoConfig
from
cacheflow.models.utils
import
get_cpu_memory
from
cacheflow.models.utils
import
get_dtype_size
from
cacheflow.models.utils
import
get_gpu_memory
_GiB
=
1
<<
30
class
CacheFlowMemoryAnalyzer
:
def
get_max_num_gpu_blocks
(
self
,
max_num_batched_tokens
:
int
,
memory_utilization
:
float
,
)
->
int
:
raise
NotImplementedError
()
def
get_max_num_cpu_blocks
(
self
,
memory_utilization
:
float
,
)
->
int
:
raise
NotImplementedError
()
class
OPTMemoryAnalyzer
(
CacheFlowMemoryAnalyzer
):
def
__init__
(
self
,
model_name
:
str
,
block_size
:
int
,
dtype
:
torch
.
dtype
,
)
->
None
:
self
.
model_name
=
model_name
self
.
block_size
=
block_size
self
.
dtype
=
dtype
# TODO(woosuk): Support tensor parallelism.
config
=
AutoConfig
.
from_pretrained
(
model_name
)
self
.
num_layers
=
config
.
num_hidden_layers
self
.
hidden_size
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
head_size
=
config
.
hidden_size
//
self
.
num_heads
self
.
ffn_size
=
config
.
ffn_dim
self
.
embedding_size
=
config
.
word_embed_proj_dim
self
.
vocab_size
=
config
.
vocab_size
self
.
max_position
=
config
.
max_position_embeddings
def
_get_param_size
(
self
)
->
int
:
# TODO(woosuk): Support tensor parallelism.
word_embedding
=
self
.
vocab_size
*
self
.
embedding_size
if
self
.
embedding_size
!=
self
.
vocab_size
:
# Project in/out.
word_embedding
+=
2
*
self
.
embedding_size
*
self
.
vocab_size
position_embedding
=
self
.
max_position
*
self
.
hidden_size
ln1
=
2
*
self
.
hidden_size
q
=
self
.
hidden_size
*
self
.
hidden_size
+
self
.
hidden_size
k
=
self
.
hidden_size
*
self
.
hidden_size
+
self
.
hidden_size
v
=
self
.
hidden_size
*
self
.
hidden_size
+
self
.
hidden_size
out
=
self
.
hidden_size
*
self
.
hidden_size
+
self
.
hidden_size
mha
=
ln1
+
q
+
k
+
v
+
out
ln2
=
2
*
self
.
hidden_size
ffn1
=
self
.
hidden_size
*
self
.
ffn_size
+
self
.
ffn_size
ffn2
=
self
.
ffn_size
*
self
.
hidden_size
+
self
.
hidden_size
ffn
=
ln2
+
ffn1
+
ffn2
total
=
(
word_embedding
+
position_embedding
+
self
.
num_layers
*
(
mha
+
ffn
))
dtype_size
=
get_dtype_size
(
self
.
dtype
)
return
dtype_size
*
total
def
_get_max_act_size
(
self
,
max_num_batched_tokens
:
int
,
)
->
int
:
# TODO(woosuk): Support tensor parallelism.
# NOTE: We approxmiately calculate the maximum activation size by
# 1) estimating the maximum activation tensor size during inference, and
# 2) multiplying it by 4.
# Here, we assume that FlashAttention is used and
# thus the attention maps are never materialized in GPU DRAM.
qkv
=
3
*
(
max_num_batched_tokens
*
self
.
hidden_size
)
ffn
=
max_num_batched_tokens
*
self
.
ffn_size
max_act
=
4
*
max
(
qkv
,
ffn
)
dtype_size
=
get_dtype_size
(
self
.
dtype
)
return
dtype_size
*
max_act
def
_get_workspace_size
(
self
)
->
int
:
return
1
*
_GiB
def
_get_cache_block_size
(
self
)
->
int
:
key_cache_block
=
self
.
block_size
*
self
.
num_heads
*
self
.
head_size
value_cache_block
=
self
.
block_size
*
self
.
num_heads
*
self
.
head_size
total
=
self
.
num_layers
*
(
key_cache_block
+
value_cache_block
)
dtype_size
=
get_dtype_size
(
self
.
dtype
)
return
dtype_size
*
total
def
get_max_num_gpu_blocks
(
self
,
max_num_batched_tokens
:
int
,
memory_utilization
:
float
=
0.95
,
)
->
int
:
# NOTE(woosuk): This assumes that the machine has homogeneous GPUs.
gpu_memory
=
get_gpu_memory
()
usable_memory
=
int
(
memory_utilization
*
gpu_memory
)
param_size
=
self
.
_get_param_size
()
act_size
=
self
.
_get_max_act_size
(
max_num_batched_tokens
)
workspace_size
=
self
.
_get_workspace_size
()
max_cache_size
=
usable_memory
-
(
param_size
+
act_size
+
workspace_size
)
max_num_blocks
=
max_cache_size
//
self
.
_get_cache_block_size
()
return
max_num_blocks
def
get_max_num_cpu_blocks
(
self
,
memory_utilization
:
float
=
0.25
,
)
->
int
:
cpu_memory
=
get_cpu_memory
()
usable_memory
=
int
(
memory_utilization
*
cpu_memory
)
max_num_blocks
=
usable_memory
//
self
.
_get_cache_block_size
()
return
max_num_blocks
cacheflow/models/model_utils.py
View file @
e9d3f2ff
import
random
from
typing
import
Union
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
cacheflow.models.memory_analyzer
import
CacheFlowMemoryAnalyzer
from
cacheflow.models.memory_analyzer
import
OPTMemoryAnalyzer
from
cacheflow.models.opt
import
OPTForCausalLM
from
cacheflow.models.utils
import
get_torch_dtype
MODEL_CLASSES
=
{
_MODELS
=
{
'opt'
:
OPTForCausalLM
,
}
STR_DTYPE_TO_TORCH_DTYPE
=
{
'half'
:
torch
.
half
,
'float'
:
torch
.
float
,
'float16'
:
torch
.
float16
,
'float32'
:
torch
.
float32
,
_MEMORY_ANALYZERS
=
{
'opt'
:
OPTMemoryAnalyzer
,
}
...
...
@@ -23,20 +22,23 @@ def get_model(
model_name
:
str
,
dtype
:
Union
[
torch
.
dtype
,
str
],
)
->
nn
.
Module
:
if
isinstance
(
dtype
,
str
):
torch_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
dtype
.
lower
()]
else
:
torch_dtype
=
dtype
for
model_class
,
hf_model
in
MODEL_CLASSES
.
items
():
torch_dtype
=
get_torch_dtype
(
dtype
)
for
model_class
,
hf_model
in
_MODELS
.
items
():
if
model_class
in
model_name
:
model
=
hf_model
.
from_pretrained
(
model_name
,
torch_dtype
=
torch_dtype
)
model
=
hf_model
.
from_pretrained
(
model_name
,
torch_dtype
=
torch_dtype
)
return
model
.
eval
()
raise
ValueError
(
f
'
Invali
d model name:
{
model_name
}
'
)
raise
ValueError
(
f
'
Unsupporte
d model name:
{
model_name
}
'
)
def
set_seed
(
seed
:
int
)
->
None
:
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
seed
)
def
get_memory_analyzer
(
model_name
:
str
,
block_size
:
int
,
dtype
:
Union
[
torch
.
dtype
,
str
],
)
->
CacheFlowMemoryAnalyzer
:
torch_dtype
=
get_torch_dtype
(
dtype
)
for
model_class
,
memory_analyzer
in
_MEMORY_ANALYZERS
.
items
():
if
model_class
in
model_name
:
return
memory_analyzer
(
model_name
,
block_size
,
torch_dtype
)
raise
ValueError
(
f
'Unsupported model name:
{
model_name
}
'
)
cacheflow/models/utils.py
0 → 100644
View file @
e9d3f2ff
from
typing
import
Union
import
random
import
numpy
as
np
import
psutil
import
torch
_STR_DTYPE_TO_TORCH_DTYPE
=
{
'half'
:
torch
.
half
,
'float'
:
torch
.
float
,
'float16'
:
torch
.
float16
,
'float32'
:
torch
.
float32
,
}
def
get_torch_dtype
(
dtype
:
Union
[
torch
.
dtype
,
str
])
->
torch
.
dtype
:
if
isinstance
(
dtype
,
str
):
torch_dtype
=
_STR_DTYPE_TO_TORCH_DTYPE
[
dtype
.
lower
()]
else
:
torch_dtype
=
dtype
return
torch_dtype
def
get_dtype_size
(
dtype
:
Union
[
torch
.
dtype
,
str
])
->
int
:
torch_dtype
=
get_torch_dtype
(
dtype
)
return
torch
.
tensor
([],
dtype
=
torch_dtype
).
element_size
()
def
set_seed
(
seed
:
int
)
->
None
:
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
seed
)
def
get_gpu_memory
(
gpu
:
int
=
0
)
->
int
:
return
torch
.
cuda
.
get_device_properties
(
gpu
).
total_memory
def
get_cpu_memory
()
->
int
:
return
psutil
.
virtual_memory
().
total
server.py
View file @
e9d3f2ff
...
...
@@ -3,6 +3,7 @@ from typing import List
from
cacheflow.master.frontend
import
Frontend
from
cacheflow.master.scheduler
import
Scheduler
from
cacheflow.models
import
get_memory_analyzer
from
cacheflow.worker.controller
import
Controller
parser
=
argparse
.
ArgumentParser
(
description
=
'CacheFlow server'
)
...
...
@@ -10,17 +11,25 @@ parser.add_argument('--model', type=str, default='facebook/opt-125m', help='mode
parser
.
add_argument
(
'--num-nodes'
,
type
=
int
,
default
=
1
,
help
=
'number of nodes'
)
parser
.
add_argument
(
'--num-workers'
,
type
=
int
,
default
=
1
,
help
=
'number of workers per node'
)
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
default
=
8
,
choices
=
[
8
,
16
],
help
=
'token block size'
)
# TODO(woosuk): Add an analytical model to determine the maximum number of GPU/CPU blocks.
parser
.
add_argument
(
'--num-gpu-blocks'
,
type
=
int
,
default
=
1024
,
help
=
'number of GPU blocks (per GPU)'
)
parser
.
add_argument
(
'--num-cpu-blocks'
,
type
=
int
,
default
=
32
,
help
=
'number of CPU blocks (per GPU)'
)
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'half'
,
choices
=
[
'half'
,
'float'
],
help
=
'data type'
)
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
0
,
help
=
'random seed'
)
parser
.
add_argument
(
'--max-batch-size'
,
type
=
int
,
default
=
2048
,
help
=
'maximum number of batched tokens'
)
args
=
parser
.
parse_args
()
def
main
():
memory_analyzer
=
get_memory_analyzer
(
model_name
=
args
.
model
,
block_size
=
args
.
block_size
,
dtype
=
args
.
dtype
,
)
num_gpu_blocks
=
memory_analyzer
.
get_max_num_gpu_blocks
(
max_num_batched_tokens
=
args
.
max_batch_size
)
num_cpu_blocks
=
memory_analyzer
.
get_max_num_cpu_blocks
()
print
(
f
'# GPU blocks:
{
num_gpu_blocks
}
, # CPU blocks:
{
num_cpu_blocks
}
'
)
# Create a controller for each node.
controllers
:
List
[
Controller
]
=
[]
for
i
in
range
(
args
.
num_nodes
):
...
...
@@ -29,8 +38,8 @@ def main():
num_workers
=
args
.
num_workers
,
model_name
=
args
.
model
,
block_size
=
args
.
block_size
,
num_gpu_blocks
=
args
.
num_gpu_blocks
,
num_cpu_blocks
=
args
.
num_cpu_blocks
,
num_gpu_blocks
=
num_gpu_blocks
,
num_cpu_blocks
=
num_cpu_blocks
,
dtype
=
args
.
dtype
,
seed
=
args
.
seed
,
)
...
...
@@ -47,8 +56,9 @@ def main():
frontend
=
frontend
,
controllers
=
controllers
,
block_size
=
args
.
block_size
,
num_gpu_blocks
=
args
.
num_gpu_blocks
,
num_cpu_blocks
=
args
.
num_cpu_blocks
,
num_gpu_blocks
=
num_gpu_blocks
,
num_cpu_blocks
=
num_cpu_blocks
,
max_num_batched_tokens
=
args
.
max_batch_size
,
)
# Connect the controllers.
for
i
in
range
(
len
(
controllers
)
-
1
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment