Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
80a2f812
Unverified
Commit
80a2f812
authored
Mar 29, 2023
by
Woosuk Kwon
Committed by
GitHub
Mar 30, 2023
Browse files
Implement LLaMA (#9)
Co-authored-by:
Zhuohan Li
<
zhuohan123@gmail.com
>
parent
a1b3de86
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
500 additions
and
35 deletions
+500
-35
README.md
README.md
+4
-2
cacheflow/master/simple_frontend.py
cacheflow/master/simple_frontend.py
+1
-0
cacheflow/models/llama.py
cacheflow/models/llama.py
+357
-0
cacheflow/models/memory_analyzer.py
cacheflow/models/memory_analyzer.py
+130
-26
cacheflow/models/model_utils.py
cacheflow/models/model_utils.py
+5
-1
cacheflow/models/opt.py
cacheflow/models/opt.py
+2
-5
cacheflow/models/sample.py
cacheflow/models/sample.py
+1
-1
No files found.
README.md
View file @
80a2f812
...
@@ -3,8 +3,10 @@
...
@@ -3,8 +3,10 @@
## Installation
## Installation
```
bash
```
bash
pip
install
psutil numpy torch transformers
pip
install
psutil numpy ray torch
pip
install
flash-attn
# This may take up to 10 mins.
pip
install
git+https://github.com/huggingface/transformers
# Required for LLaMA.
pip
install
sentencepiece
# Required for LlamaTokenizer.
pip
install
flash-attn
# This may take up to 20 mins.
pip
install
-e
.
pip
install
-e
.
```
```
...
...
cacheflow/master/simple_frontend.py
View file @
80a2f812
...
@@ -61,4 +61,5 @@ class SimpleFrontend:
...
@@ -61,4 +61,5 @@ class SimpleFrontend:
for
seq
in
seq_group
.
seqs
:
for
seq
in
seq_group
.
seqs
:
token_ids
=
seq
.
get_token_ids
()
token_ids
=
seq
.
get_token_ids
()
output
=
self
.
tokenizer
.
decode
(
token_ids
,
skip_special_tokens
=
True
)
output
=
self
.
tokenizer
.
decode
(
token_ids
,
skip_special_tokens
=
True
)
output
=
output
.
strip
()
print
(
f
'Seq
{
seq
.
seq_id
}
:
{
output
!
r
}
'
)
print
(
f
'Seq
{
seq
.
seq_id
}
:
{
output
!
r
}
'
)
cacheflow/models/llama.py
0 → 100644
View file @
80a2f812
"""1D LLaMA model compatible with HuggingFace weights."""
import
os
import
glob
import
filelock
from
tqdm
import
tqdm
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
from
torch
import
nn
import
torch.nn.functional
as
F
from
transformers
import
LlamaConfig
from
transformers
import
PreTrainedModel
from
cacheflow.models
import
InputMetadata
from
cacheflow.models.attention
import
OPTCacheFlowAttention
from
cacheflow.models.sample
import
Sampler
from
cacheflow.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
cacheflow.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
cacheflow.sequence
import
SequenceOutputs
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
LlamaRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-6
):
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
variance_epsilon
=
eps
def
forward
(
self
,
hidden_states
):
variance
=
hidden_states
.
to
(
torch
.
float32
).
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
# convert into half-precision if necessary
if
self
.
weight
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]:
hidden_states
=
hidden_states
.
to
(
self
.
weight
.
dtype
)
return
self
.
weight
*
hidden_states
class
LlamaRotaryEmbedding
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
max_position_embeddings
=
2048
,
base
=
10000
):
super
().
__init__
()
self
.
max_position_embeddings
=
max_position_embeddings
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
dim
,
2
)
/
dim
))
self
.
register_buffer
(
"inv_freq"
,
inv_freq
)
# Create cos and sin embeddings.
t
=
torch
.
arange
(
max_position_embeddings
).
float
()
freqs
=
torch
.
einsum
(
"i,j->ij"
,
t
,
self
.
inv_freq
.
float
())
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
cos
=
emb
.
cos
().
to
(
dtype
=
self
.
inv_freq
.
dtype
)
sin
=
emb
.
sin
().
to
(
dtype
=
self
.
inv_freq
.
dtype
)
self
.
register_buffer
(
"cos_cached"
,
cos
,
persistent
=
False
)
self
.
register_buffer
(
"sin_cached"
,
sin
,
persistent
=
False
)
def
forward
(
self
,
positions
:
torch
.
LongTensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
cos
=
F
.
embedding
(
positions
,
self
.
cos_cached
)
sin
=
F
.
embedding
(
positions
,
self
.
sin_cached
)
return
cos
,
sin
def
rotate_half
(
x
):
"""Rotates half the hidden dims of the input."""
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
def
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
):
# TODO: Optimize.
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
return
q_embed
,
k_embed
class
LlamaMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
):
super
().
__init__
()
# TODO: Merge the gate and down linear layers.
self
.
gate_proj
=
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
self
.
up_proj
=
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
)
assert
hidden_act
==
'silu'
self
.
act_fn
=
nn
.
SiLU
()
def
forward
(
self
,
x
):
gate
,
_
=
self
.
gate_proj
(
x
)
up
,
_
=
self
.
up_proj
(
x
)
x
=
self
.
act_fn
(
gate
)
*
up
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
LlamaAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tensor_model_parallel_world_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tensor_model_parallel_world_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tensor_model_parallel_world_size
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
scaling
=
self
.
head_dim
**
-
0.5
# TODO: Merge the QKV linear layers.
self
.
q_proj
=
ColumnParallelLinear
(
hidden_size
,
self
.
total_num_heads
*
self
.
head_dim
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
,
)
self
.
k_proj
=
ColumnParallelLinear
(
hidden_size
,
self
.
total_num_heads
*
self
.
head_dim
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
,
)
self
.
v_proj
=
ColumnParallelLinear
(
hidden_size
,
self
.
total_num_heads
*
self
.
head_dim
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
perform_initialization
=
False
,
)
self
.
rotary_emb
=
LlamaRotaryEmbedding
(
self
.
head_dim
)
# FIXME(woosuk): Rename this.
self
.
attn
=
OPTCacheFlowAttention
(
scale
=
self
.
scaling
)
def
forward
(
self
,
positions
:
torch
.
LongTensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
q
,
_
=
self
.
q_proj
(
hidden_states
)
k
,
_
=
self
.
k_proj
(
hidden_states
)
v
,
_
=
self
.
v_proj
(
hidden_states
)
# Apply rotrary embedding.
# TODO: Optimize.
q
=
q
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
0
,
1
)
k
=
k
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
0
,
1
)
cos
,
sin
=
self
.
rotary_emb
(
positions
)
q
,
k
=
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
)
q
=
q
.
transpose
(
0
,
1
).
contiguous
().
view
(
-
1
,
self
.
num_heads
*
self
.
head_dim
)
k
=
k
.
transpose
(
0
,
1
).
contiguous
().
view
(
-
1
,
self
.
num_heads
*
self
.
head_dim
)
key_cache
,
value_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
key_cache
,
value_cache
,
input_metadata
,
cache_event
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
LlamaDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
LlamaConfig
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
self_attn
=
LlamaAttention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
)
self
.
mlp
=
LlamaMLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
)
self
.
input_layernorm
=
LlamaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
LlamaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
LongTensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
# Self Attention
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input_metadata
=
input_metadata
,
cache_event
=
cache_event
,
)
hidden_states
=
residual
+
hidden_states
# Fully Connected
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
class
LlamaModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
LlamaConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
perform_initialization
=
False
)
self
.
layers
=
nn
.
ModuleList
([
LlamaDecoderLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
norm
=
LlamaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
LongTensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
for
i
in
range
(
len
(
self
.
layers
)):
if
cache_events
is
None
:
cache_event
=
None
else
:
cache_event
=
cache_events
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
input_metadata
,
cache_event
,
)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
class
LlamaForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
model
=
LlamaModel
(
config
)
self
.
lm_head
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
LongTensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
SequenceOutputs
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
input_metadata
)
return
next_tokens
_column_parallel_weights
=
[
"embed_tokens.weight"
,
"lm_head.weight"
,
"q_proj.weight"
,
"k_proj.weight"
,
"v_proj.weight"
,
"gate_proj.weight"
,
"up_proj.weight"
]
_row_parallel_weights
=
[
"o_proj.weight"
,
"down_proj.weight"
]
def
load_weights
(
self
,
weights_path
:
str
):
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
for
name
,
param
in
state_dict
.
items
():
loaded_weight
=
torch
.
from_numpy
(
np
.
load
(
os
.
path
.
join
(
weights_path
,
name
)))
for
p
in
self
.
_column_parallel_weights
:
if
p
in
name
:
shard_size
=
param
.
shape
[
0
]
loaded_weight
=
loaded_weight
[
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
break
for
p
in
self
.
_row_parallel_weights
:
if
p
in
name
:
shard_size
=
param
.
shape
[
1
]
loaded_weight
=
loaded_weight
[
:,
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
break
assert
param
.
shape
==
loaded_weight
.
shape
param
.
data
.
copy_
(
loaded_weight
)
@
staticmethod
def
get_weights
(
model_name
:
str
,
path
:
str
):
if
not
os
.
path
.
isfile
(
os
.
path
.
join
(
model_name
,
"config.json"
)):
raise
ValueError
(
"LLaMA model's model_name has to be a path"
"to the huggingface model's directory."
)
path
=
os
.
path
.
join
(
model_name
,
f
"np"
)
path
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
path
))
os
.
makedirs
(
path
,
exist_ok
=
True
)
lock_path
=
os
.
path
.
join
(
path
,
"file_lock"
)
lock
=
filelock
.
FileLock
(
lock_path
)
with
lock
:
test_weight_path
=
os
.
path
.
join
(
path
,
"model.embed_tokens.weight"
)
if
os
.
path
.
exists
(
test_weight_path
):
return
path
bin_files
=
glob
.
glob
(
os
.
path
.
join
(
model_name
,
"*.bin"
))
for
bin_file
in
tqdm
(
bin_files
,
desc
=
"Convert format"
):
state
=
torch
.
load
(
bin_file
,
map_location
=
"cpu"
)
for
name
,
param
in
tqdm
(
state
.
items
(),
leave
=
False
):
param_path
=
os
.
path
.
join
(
path
,
name
)
with
open
(
param_path
,
"wb"
)
as
f
:
np
.
save
(
f
,
param
.
cpu
().
detach
().
numpy
())
return
path
cacheflow/models/memory_analyzer.py
View file @
80a2f812
...
@@ -15,11 +15,30 @@ class CacheFlowMemoryAnalyzer:
...
@@ -15,11 +15,30 @@ class CacheFlowMemoryAnalyzer:
)
->
int
:
)
->
int
:
raise
NotImplementedError
()
raise
NotImplementedError
()
def
get_workspace_size
(
self
)
->
int
:
return
1
*
_GiB
def
get_cache_block_size
(
self
)
->
int
:
raise
NotImplementedError
()
def
get_max_num_cpu_blocks
(
def
get_max_num_cpu_blocks
(
self
,
self
,
memory_utilization
:
floa
t
,
swap_space
:
in
t
,
)
->
int
:
)
->
int
:
raise
NotImplementedError
()
swap_space
=
swap_space
*
_GiB
cpu_memory
=
self
.
cpu_memory
if
swap_space
>
0.8
*
cpu_memory
:
raise
ValueError
(
f
'The swap space (
{
swap_space
/
_GiB
:.
2
f
}
GiB) '
'takes more than 80% of the available memory '
f
'(
{
cpu_memory
/
_GiB
:.
2
f
}
GiB).'
'Please check the swap space size.'
)
if
swap_space
>
0.5
*
cpu_memory
:
print
(
f
'WARNING: The swap space (
{
swap_space
/
_GiB
:.
2
f
}
GiB) '
'takes more than 50% of the available memory '
f
'(
{
cpu_memory
/
_GiB
:.
2
f
}
GiB).'
'This may slow the system performance.'
)
max_num_blocks
=
swap_space
//
self
.
get_cache_block_size
()
return
max_num_blocks
class
OPTMemoryAnalyzer
(
CacheFlowMemoryAnalyzer
):
class
OPTMemoryAnalyzer
(
CacheFlowMemoryAnalyzer
):
...
@@ -52,9 +71,9 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
...
@@ -52,9 +71,9 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
def
_get_param_size
(
self
)
->
int
:
def
_get_param_size
(
self
)
->
int
:
word_embedding
=
self
.
vocab_size
*
self
.
embedding_size
//
self
.
tensor_parallel_size
word_embedding
=
self
.
vocab_size
*
self
.
embedding_size
//
self
.
tensor_parallel_size
if
self
.
embedding_size
!=
self
.
vocab
_size
:
if
self
.
embedding_size
!=
self
.
hidden
_size
:
# Project in/out.
# Project in/out.
word_embedding
+=
2
*
self
.
embedding_size
*
self
.
vocab
_size
word_embedding
+=
2
*
self
.
embedding_size
*
self
.
hidden
_size
position_embedding
=
self
.
max_position
*
self
.
hidden_size
position_embedding
=
self
.
max_position
*
self
.
hidden_size
ln1
=
2
*
self
.
hidden_size
ln1
=
2
*
self
.
hidden_size
...
@@ -89,15 +108,15 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
...
@@ -89,15 +108,15 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
ffn
=
max_num_batched_tokens
*
self
.
ffn_size
//
self
.
tensor_parallel_size
ffn
=
max_num_batched_tokens
*
self
.
ffn_size
//
self
.
tensor_parallel_size
# Double the activation size for input and output.
# Double the activation size for input and output.
max_act
=
2
*
(
max
(
qkv
,
ffn
)
+
residual
)
max_act
=
2
*
(
max
(
qkv
,
ffn
)
+
residual
)
# Size of output logits.
output_logits
=
2
*
(
max_num_batched_tokens
*
self
.
vocab_size
)
max_act
=
max
(
max_act
,
output_logits
)
dtype_size
=
get_dtype_size
(
self
.
dtype
)
dtype_size
=
get_dtype_size
(
self
.
dtype
)
return
dtype_size
*
max_act
return
dtype_size
*
max_act
def
_get_workspace_size
(
self
)
->
int
:
def
get_cache_block_size
(
self
)
->
int
:
return
1
*
_GiB
key_cache_block
=
self
.
block_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
value_cache_block
=
key_cache_block
def
_get_cache_block_size
(
self
)
->
int
:
key_cache_block
=
self
.
block_size
*
self
.
num_heads
*
self
.
head_size
value_cache_block
=
self
.
block_size
*
self
.
num_heads
*
self
.
head_size
total
=
self
.
num_layers
*
(
key_cache_block
+
value_cache_block
)
total
=
self
.
num_layers
*
(
key_cache_block
+
value_cache_block
)
dtype_size
=
get_dtype_size
(
self
.
dtype
)
dtype_size
=
get_dtype_size
(
self
.
dtype
)
return
dtype_size
*
total
return
dtype_size
*
total
...
@@ -112,26 +131,111 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
...
@@ -112,26 +131,111 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
param_size
=
self
.
_get_param_size
()
param_size
=
self
.
_get_param_size
()
act_size
=
self
.
_get_max_act_size
(
max_num_batched_tokens
)
act_size
=
self
.
_get_max_act_size
(
max_num_batched_tokens
)
workspace_size
=
self
.
_
get_workspace_size
()
workspace_size
=
self
.
get_workspace_size
()
max_cache_size
=
usable_memory
-
(
param_size
+
act_size
+
workspace_size
)
max_cache_size
=
usable_memory
-
(
param_size
+
act_size
+
workspace_size
)
max_num_blocks
=
max_cache_size
//
self
.
_get_cache_block_size
()
if
max_cache_size
<=
0
:
raise
RuntimeError
(
'Not enough GPU memory.'
)
max_num_blocks
=
max_cache_size
//
self
.
get_cache_block_size
()
return
max_num_blocks
return
max_num_blocks
def
get_max_num_cpu_blocks
(
class
LlamaMemoryAnalyzer
(
CacheFlowMemoryAnalyzer
):
def
__init__
(
self
,
self
,
swap_space
:
int
,
model_name
:
str
,
block_size
:
int
,
dtype
:
torch
.
dtype
,
gpu_memory
:
int
,
cpu_memory
:
int
,
tensor_parallel_size
:
int
,
)
->
None
:
self
.
model_name
=
model_name
self
.
block_size
=
block_size
self
.
dtype
=
dtype
self
.
gpu_memory
=
gpu_memory
self
.
cpu_memory
=
cpu_memory
self
.
tensor_parallel_size
=
tensor_parallel_size
config
=
AutoConfig
.
from_pretrained
(
model_name
)
self
.
num_layers
=
config
.
num_hidden_layers
self
.
hidden_size
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
head_size
=
config
.
hidden_size
//
self
.
num_heads
self
.
ffn_size
=
config
.
intermediate_size
self
.
vocab_size
=
config
.
vocab_size
# FIXME
self
.
max_position
=
2048
def
_get_param_size
(
self
)
->
int
:
word_embedding
=
self
.
vocab_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
position_embedding
=
self
.
max_position
*
self
.
hidden_size
# NOTE: LLaMA does not have bias terms.
ln1
=
self
.
hidden_size
q
=
self
.
hidden_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
k
=
self
.
hidden_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
v
=
self
.
hidden_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
out
=
self
.
hidden_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
# Rotary embedding.
# TODO(woosuk): Share the rotary embedding between layers.
rot
=
self
.
max_position
*
self
.
head_size
mha
=
ln1
+
q
+
k
+
v
+
out
+
rot
ln2
=
self
.
hidden_size
gate
=
self
.
hidden_size
*
self
.
ffn_size
//
self
.
tensor_parallel_size
down
=
self
.
ffn_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
up
=
self
.
hidden_size
*
self
.
ffn_size
//
self
.
tensor_parallel_size
ffn
=
ln2
+
gate
+
down
+
up
total
=
(
word_embedding
+
position_embedding
+
self
.
num_layers
*
(
mha
+
ffn
))
dtype_size
=
get_dtype_size
(
self
.
dtype
)
return
dtype_size
*
total
def
_get_max_act_size
(
self
,
max_num_batched_tokens
:
int
,
)
->
int
:
)
->
int
:
swap_space
=
swap_space
*
_GiB
# NOTE: We approxmiately calculate the maximum activation size by
if
swap_space
>
0.8
*
self
.
cpu_memory
:
# estimating
raise
ValueError
(
f
'The swap space (
{
swap_space
/
_GiB
:.
2
f
}
GiB) '
# 1) the maximum activation tensor size during inference
'takes more than 80% of the available memory '
# 2) the residual tensor size during inference
f
'(
{
self
.
cpu_memory
/
_GiB
:.
2
f
}
GiB).'
# Here, we assume that FlashAttention is used and
'Please check the swap space size.'
)
# thus the attention maps are never materialized in GPU DRAM.
if
swap_space
>
0.5
*
self
.
cpu_memory
:
residual
=
max_num_batched_tokens
*
self
.
hidden_size
print
(
f
'WARNING: The swap space (
{
swap_space
/
_GiB
:.
2
f
}
GiB) '
qkv
=
3
*
(
max_num_batched_tokens
*
self
.
hidden_size
)
//
self
.
tensor_parallel_size
'takes more than 50% of the available memory '
ffn
=
2
*
(
max_num_batched_tokens
*
self
.
ffn_size
)
//
self
.
tensor_parallel_size
f
'(
{
self
.
cpu_memory
/
_GiB
:.
2
f
}
GiB).'
# Double the activation size for input and output.
'This may slow the system performance.'
)
max_act
=
2
*
(
max
(
qkv
,
ffn
)
+
residual
)
max_num_blocks
=
swap_space
//
self
.
_get_cache_block_size
()
# Size of output logits.
output_logits
=
2
*
(
max_num_batched_tokens
*
self
.
vocab_size
)
max_act
=
max
(
max_act
,
output_logits
)
dtype_size
=
get_dtype_size
(
self
.
dtype
)
return
dtype_size
*
max_act
def
get_cache_block_size
(
self
)
->
int
:
key_cache_block
=
self
.
block_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
value_cache_block
=
key_cache_block
total
=
self
.
num_layers
*
(
key_cache_block
+
value_cache_block
)
dtype_size
=
get_dtype_size
(
self
.
dtype
)
return
dtype_size
*
total
def
get_max_num_gpu_blocks
(
self
,
max_num_batched_tokens
:
int
,
memory_utilization
:
float
=
0.95
,
)
->
int
:
# NOTE(woosuk): This assumes that the machine has homogeneous GPUs.
gpu_memory
=
self
.
gpu_memory
usable_memory
=
int
(
memory_utilization
*
gpu_memory
)
param_size
=
self
.
_get_param_size
()
act_size
=
self
.
_get_max_act_size
(
max_num_batched_tokens
)
workspace_size
=
self
.
get_workspace_size
()
max_cache_size
=
usable_memory
-
(
param_size
+
act_size
+
workspace_size
)
if
max_cache_size
<=
0
:
raise
RuntimeError
(
'Not enough GPU memory.'
)
max_num_blocks
=
max_cache_size
//
self
.
get_cache_block_size
()
return
max_num_blocks
return
max_num_blocks
cacheflow/models/model_utils.py
View file @
80a2f812
...
@@ -6,16 +6,20 @@ import torch.nn as nn
...
@@ -6,16 +6,20 @@ import torch.nn as nn
from
transformers
import
AutoConfig
from
transformers
import
AutoConfig
from
cacheflow.models.memory_analyzer
import
CacheFlowMemoryAnalyzer
from
cacheflow.models.memory_analyzer
import
CacheFlowMemoryAnalyzer
from
cacheflow.models.memory_analyzer
import
LlamaMemoryAnalyzer
from
cacheflow.models.memory_analyzer
import
OPTMemoryAnalyzer
from
cacheflow.models.memory_analyzer
import
OPTMemoryAnalyzer
from
cacheflow.models.llama
import
LlamaForCausalLM
from
cacheflow.models.opt
import
OPTForCausalLM
from
cacheflow.models.opt
import
OPTForCausalLM
from
cacheflow.models.utils
import
get_torch_dtype
from
cacheflow.models.utils
import
get_torch_dtype
_MODELS
=
{
_MODELS
=
{
'llama'
:
LlamaForCausalLM
,
'opt'
:
OPTForCausalLM
,
'opt'
:
OPTForCausalLM
,
}
}
_MEMORY_ANALYZERS
=
{
_MEMORY_ANALYZERS
=
{
'llama'
:
LlamaMemoryAnalyzer
,
'opt'
:
OPTMemoryAnalyzer
,
'opt'
:
OPTMemoryAnalyzer
,
}
}
...
@@ -31,7 +35,7 @@ def get_model(
...
@@ -31,7 +35,7 @@ def get_model(
for
model_class_name
,
model_class
in
_MODELS
.
items
():
for
model_class_name
,
model_class
in
_MODELS
.
items
():
if
model_class_name
in
model_name
:
if
model_class_name
in
model_name
:
# Download model weights if it's not cached.
# Download model weights if it's not cached.
weights_dir
=
model_class
.
download
_weights
(
model_name
,
path
=
path
)
weights_dir
=
model_class
.
get
_weights
(
model_name
,
path
=
path
)
# Create a model instance.
# Create a model instance.
model
=
model_class
(
config
)
model
=
model_class
(
config
)
# Load the weights from the cached or downloaded files.
# Load the weights from the cached or downloaded files.
...
...
cacheflow/models/opt.py
View file @
80a2f812
...
@@ -299,7 +299,7 @@ class OPTForCausalLM(nn.Module):
...
@@ -299,7 +299,7 @@ class OPTForCausalLM(nn.Module):
param
.
data
.
copy_
(
loaded_weight
)
param
.
data
.
copy_
(
loaded_weight
)
@
staticmethod
@
staticmethod
def
download
_weights
(
model_name
:
str
,
path
:
str
):
def
get
_weights
(
model_name
:
str
,
path
:
str
):
path
=
os
.
path
.
join
(
path
,
f
"
{
model_name
}
-np"
)
path
=
os
.
path
.
join
(
path
,
f
"
{
model_name
}
-np"
)
path
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
path
))
path
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
path
))
os
.
makedirs
(
path
,
exist_ok
=
True
)
os
.
makedirs
(
path
,
exist_ok
=
True
)
...
@@ -316,11 +316,8 @@ class OPTForCausalLM(nn.Module):
...
@@ -316,11 +316,8 @@ class OPTForCausalLM(nn.Module):
cache_dir
=
os
.
path
.
join
(
path
,
"cache"
))
cache_dir
=
os
.
path
.
join
(
path
,
"cache"
))
bin_files
=
glob
.
glob
(
os
.
path
.
join
(
folder
,
"*.bin"
))
bin_files
=
glob
.
glob
(
os
.
path
.
join
(
folder
,
"*.bin"
))
if
"/"
in
model_name
:
model_name
=
model_name
.
split
(
"/"
)[
1
].
lower
()
for
bin_file
in
tqdm
(
bin_files
,
desc
=
"Convert format"
):
for
bin_file
in
tqdm
(
bin_files
,
desc
=
"Convert format"
):
state
=
torch
.
load
(
bin_file
)
state
=
torch
.
load
(
bin_file
,
map_location
=
"cpu"
)
for
name
,
param
in
tqdm
(
state
.
items
(),
leave
=
False
):
for
name
,
param
in
tqdm
(
state
.
items
(),
leave
=
False
):
if
name
.
startswith
(
"decoder."
):
if
name
.
startswith
(
"decoder."
):
name
=
"model."
+
name
name
=
"model."
+
name
...
...
cacheflow/models/sample.py
View file @
80a2f812
...
@@ -39,7 +39,7 @@ class Sampler(nn.Module):
...
@@ -39,7 +39,7 @@ class Sampler(nn.Module):
# Compute the probabilities.
# Compute the probabilities.
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
# Compute the log probabilities (before applying top-p).
# Compute the log probabilities (before applying top-p).
logprobs
=
torch
.
log
(
probs
)
logprobs
=
torch
.
log
(
probs
,
out
=
logits
)
# Apply top-p truncation.
# Apply top-p truncation.
top_ps
=
_get_top_ps
(
input_metadata
)
top_ps
=
_get_top_ps
(
input_metadata
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment