Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
a96d63c2
Unverified
Commit
a96d63c2
authored
Apr 28, 2023
by
Woosuk Kwon
Committed by
GitHub
Apr 28, 2023
Browse files
Add support for GPT-NeoX (Pythia) (#50)
parent
aa50b17c
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
436 additions
and
71 deletions
+436
-71
cacheflow/models/attention.py
cacheflow/models/attention.py
+11
-5
cacheflow/models/gpt_neox.py
cacheflow/models/gpt_neox.py
+278
-0
cacheflow/models/llama.py
cacheflow/models/llama.py
+1
-1
cacheflow/models/memory_analyzer.py
cacheflow/models/memory_analyzer.py
+107
-50
cacheflow/models/model_utils.py
cacheflow/models/model_utils.py
+6
-1
cacheflow/models/opt.py
cacheflow/models/opt.py
+1
-1
csrc/pos_encoding.cpp
csrc/pos_encoding.cpp
+1
-0
csrc/pos_encoding_kernels.cu
csrc/pos_encoding_kernels.cu
+9
-6
tests/kernels/pos_encoding.py
tests/kernels/pos_encoding.py
+22
-7
No files found.
cacheflow/models/attention.py
View file @
a96d63c2
...
...
@@ -150,20 +150,20 @@ class OPTCacheFlowAttention(GPTCacheFlowAttention):
super
().
__init__
(
scale
)
class
Llama
CacheFlowAttention
(
GPTCacheFlowAttention
):
"""
Llama uses
GPT-NeoX style rotary embedding."""
class
GPTNeoX
CacheFlowAttention
(
GPTCacheFlowAttention
):
"""
Attention with
GPT-NeoX style rotary embedding."""
def
__init__
(
self
,
scale
:
float
,
head_size
:
int
,
rotary_dim
:
int
,
max_position
:
int
=
8192
,
base
:
int
=
10000
,
)
->
None
:
super
().
__init__
(
scale
)
# Create the cos and sin cache.
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
head_size
,
2
)
/
head_size
))
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
rotary_dim
,
2
)
/
rotary_dim
))
t
=
torch
.
arange
(
max_position
).
float
()
freqs
=
torch
.
einsum
(
'i,j -> ij'
,
t
,
inv_freq
.
float
())
cos
=
freqs
.
cos
()
...
...
@@ -174,7 +174,7 @@ class LlamaCacheFlowAttention(GPTCacheFlowAttention):
# initializing the model. Make it more robust.
torch_dtype
=
torch
.
get_default_dtype
()
cache
=
cache
.
to
(
torch_dtype
)
# Embedding size: [max_position,
head_size
]
# Embedding size: [max_position,
rotary_dim
]
self
.
register_buffer
(
'cos_sin_cache'
,
cache
,
persistent
=
False
)
def
forward
(
...
...
@@ -190,10 +190,12 @@ class LlamaCacheFlowAttention(GPTCacheFlowAttention):
)
->
torch
.
Tensor
:
# [num_tokens, num_heads * head_size]
# Apply rotary embedding to the query and key before passing them
# to the attention op.
head_size
=
value_cache
.
shape
[
2
]
pos_encoding_ops
.
rotary_embedding_neox
(
positions
,
query
,
key
,
head_size
,
self
.
cos_sin_cache
,
)
return
super
().
forward
(
...
...
@@ -205,3 +207,7 @@ class LlamaCacheFlowAttention(GPTCacheFlowAttention):
input_metadata
,
cache_event
,
)
class
LlamaCacheFlowAttention
(
GPTNeoXCacheFlowAttention
):
"""LLaMA uses the GPT-NeoX style rotary embedding."""
cacheflow/models/gpt_neox.py
0 → 100644
View file @
a96d63c2
"""1D GPT-NeoX model compatible with HuggingFace weights."""
import
os
import
glob
import
filelock
from
tqdm
import
tqdm
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
from
torch
import
nn
from
huggingface_hub
import
snapshot_download
from
cacheflow.models
import
InputMetadata
from
cacheflow.models.attention
import
GPTNeoXCacheFlowAttention
from
cacheflow.models.sample
import
Sampler
from
cacheflow.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
cacheflow.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
cacheflow.sequence
import
SequenceOutputs
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
GPTNeoXAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
total_num_heads
=
config
.
num_attention_heads
self
.
hidden_size
=
config
.
hidden_size
self
.
head_size
=
self
.
hidden_size
//
self
.
total_num_heads
tensor_model_parallel_world_size
=
get_tensor_model_parallel_world_size
()
assert
self
.
total_num_heads
%
tensor_model_parallel_world_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tensor_model_parallel_world_size
self
.
query_key_value
=
ColumnParallelLinear
(
config
.
hidden_size
,
3
*
config
.
hidden_size
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
dense
=
RowParallelLinear
(
config
.
hidden_size
,
config
.
hidden_size
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
scaling
=
self
.
head_size
**
-
0.5
rotary_dim
=
int
(
self
.
head_size
*
config
.
rotary_pct
)
assert
rotary_dim
%
2
==
0
self
.
attn
=
GPTNeoXCacheFlowAttention
(
scaling
,
rotary_dim
)
def
forward
(
self
,
position_ids
:
torch
.
LongTensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
position_ids
,
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
output
,
_
=
self
.
dense
(
attn_output
)
return
output
class
GPTNeoXMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
dense_h_to_4h
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
intermediate_size
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
dense_4h_to_h
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
if
config
.
hidden_act
!=
'gelu'
:
raise
ValueError
(
f
'Unsupported activation:
{
config
.
hidden_act
}
. '
'Only gelu is supported for now.'
)
self
.
act
=
torch
.
nn
.
GELU
()
def
forward
(
self
,
hidden_states
):
hidden_states
,
_
=
self
.
dense_h_to_4h
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
,
_
=
self
.
dense_4h_to_h
(
hidden_states
)
return
hidden_states
class
GPTNeoXLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
use_parallel_residual
=
config
.
use_parallel_residual
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
attention
=
GPTNeoXAttention
(
config
)
self
.
mlp
=
GPTNeoXMLP
(
config
)
def
forward
(
self
,
position_ids
:
torch
.
LongTensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
attn_input
=
self
.
input_layernorm
(
hidden_states
)
attn_output
=
self
.
attention
(
position_ids
=
position_ids
,
hidden_states
=
attn_input
,
kv_cache
=
kv_cache
,
input_metadata
=
input_metadata
,
cache_event
=
cache_event
,
)
if
self
.
use_parallel_residual
:
# pseudocode:
# x = x + attn(ln1(x)) + mlp(ln2(x))
mlp_input
=
self
.
post_attention_layernorm
(
hidden_states
)
mlp_output
=
self
.
mlp
(
mlp_input
)
hidden_states
=
mlp_output
+
attn_output
+
hidden_states
else
:
# pseudocode:
# x = x + attn(ln1(x))
# x = x + mlp(ln2(x))
attn_output
=
attn_output
+
hidden_states
mlp_input
=
self
.
post_attention_layernorm
(
attn_output
)
mlp_output
=
self
.
mlp
(
mlp_input
)
hidden_states
=
mlp_output
+
attn_output
return
hidden_states
class
GPTNeoXModel
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_in
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
perform_initialization
=
False
)
self
.
layers
=
nn
.
ModuleList
([
GPTNeoXLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
final_layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
position_ids
:
torch
.
LongTensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_in
(
input_ids
)
for
i
in
range
(
len
(
self
.
layers
)):
if
cache_events
is
None
:
cache_event
=
None
else
:
cache_event
=
cache_events
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
position_ids
,
hidden_states
,
kv_caches
[
i
],
input_metadata
,
cache_event
,
)
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
return
hidden_states
class
GPTNeoXForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
gpt_neox
=
GPTNeoXModel
(
config
)
self
.
embed_out
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
LongTensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
SequenceOutputs
]:
hidden_states
=
self
.
gpt_neox
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
self
.
embed_out
.
weight
,
hidden_states
,
input_metadata
)
return
next_tokens
_column_parallel_weights
=
[
"embed_in.weight"
,
"embed_out.weight"
,
"dense_h_to_4h.weight"
,
"dense_h_to_4h.bias"
]
_row_parallel_weights
=
[
"dense.weight"
,
"dense_4h_to_h.weight"
]
def
load_weights
(
self
,
weights_path
:
str
):
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
for
name
,
param
in
state_dict
.
items
():
if
"query_key_value"
in
name
:
# NOTE(woosuk): GPT-NeoX's fused QKV has the shape of
# [num_heads * 3 * head_size, num_heads * head_size], while the
# required shape is [3 * num_heads * head_size, num_heads * head_size].
# Thus, we need weight conversion.
loaded_weight
=
torch
.
from_numpy
(
np
.
load
(
os
.
path
.
join
(
weights_path
,
name
)))
shard_size
=
param
.
shape
[
0
]
loaded_weight
=
loaded_weight
[
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
num_heads
=
self
.
config
.
num_attention_heads
hidden_size
=
self
.
config
.
hidden_size
head_size
=
hidden_size
//
num_heads
if
'query_key_value.weight'
in
name
:
loaded_weight
=
loaded_weight
.
view
(
-
1
,
3
,
head_size
,
hidden_size
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
,
hidden_size
).
contiguous
()
elif
'query_key_value.bias'
in
name
:
loaded_weight
=
loaded_weight
.
view
(
-
1
,
3
,
head_size
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
).
contiguous
()
else
:
assert
False
else
:
loaded_weight
=
torch
.
from_numpy
(
np
.
load
(
os
.
path
.
join
(
weights_path
,
name
)))
for
p
in
self
.
_column_parallel_weights
:
if
p
in
name
:
shard_size
=
param
.
shape
[
0
]
loaded_weight
=
loaded_weight
[
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
break
for
p
in
self
.
_row_parallel_weights
:
if
p
in
name
:
shard_size
=
param
.
shape
[
1
]
loaded_weight
=
loaded_weight
[
:,
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
break
assert
param
.
shape
==
loaded_weight
.
shape
param
.
data
.
copy_
(
loaded_weight
)
@
staticmethod
def
get_weights
(
model_name
:
str
,
path
:
str
):
path
=
os
.
path
.
join
(
path
,
f
"
{
model_name
}
-np"
)
path
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
path
))
os
.
makedirs
(
path
,
exist_ok
=
True
)
lock_path
=
os
.
path
.
join
(
path
,
"file_lock"
)
lock
=
filelock
.
FileLock
(
lock_path
)
with
lock
:
test_weight_path
=
os
.
path
.
join
(
path
,
"gpt_neox.embed_in.weight"
)
if
os
.
path
.
exists
(
test_weight_path
):
return
path
folder
=
snapshot_download
(
model_name
,
allow_patterns
=
"*.bin"
,
cache_dir
=
os
.
path
.
join
(
path
,
"cache"
))
bin_files
=
glob
.
glob
(
os
.
path
.
join
(
folder
,
"*.bin"
))
for
bin_file
in
tqdm
(
bin_files
,
desc
=
"Convert format"
):
state
=
torch
.
load
(
bin_file
,
map_location
=
"cpu"
)
for
name
,
param
in
tqdm
(
state
.
items
(),
leave
=
False
):
param_path
=
os
.
path
.
join
(
path
,
name
)
with
open
(
param_path
,
"wb"
)
as
f
:
np
.
save
(
f
,
param
.
cpu
().
detach
().
numpy
())
return
path
def
initialize_dummy_weights
(
self
)
->
None
:
for
param
in
self
.
state_dict
().
values
():
param
.
data
.
uniform_
(
-
1e-3
,
1e-3
)
cacheflow/models/llama.py
View file @
a96d63c2
...
...
@@ -289,4 +289,4 @@ class LlamaForCausalLM(nn.Module):
def
initialize_dummy_weights
(
self
)
->
None
:
for
param
in
self
.
state_dict
().
values
():
param
.
data
.
uniform_
(
-
0.1
,
0.1
)
param
.
data
.
uniform_
(
-
1e-3
,
1e-3
)
cacheflow/models/memory_analyzer.py
View file @
a96d63c2
...
...
@@ -40,6 +40,37 @@ class CacheFlowMemoryAnalyzer:
max_num_blocks
=
swap_space
//
self
.
get_cache_block_size
()
return
max_num_blocks
def
get_param_size
(
self
)
->
int
:
raise
NotImplementedError
()
def
get_max_act_size
(
self
,
max_num_batched_tokens
:
int
)
->
int
:
raise
NotImplementedError
()
def
get_cache_block_size
(
self
)
->
int
:
key_cache_block
=
self
.
block_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
value_cache_block
=
key_cache_block
total
=
self
.
num_layers
*
(
key_cache_block
+
value_cache_block
)
dtype_size
=
get_dtype_size
(
self
.
dtype
)
return
dtype_size
*
total
def
get_max_num_gpu_blocks
(
self
,
max_num_batched_tokens
:
int
,
memory_utilization
:
float
=
0.95
,
)
->
int
:
# NOTE(woosuk): This assumes that the machine has homogeneous GPUs.
usable_memory
=
int
(
memory_utilization
*
self
.
gpu_memory
)
param_size
=
self
.
get_param_size
()
act_size
=
self
.
get_max_act_size
(
max_num_batched_tokens
)
workspace_size
=
self
.
get_workspace_size
()
max_cache_size
=
usable_memory
-
(
param_size
+
act_size
+
workspace_size
)
if
max_cache_size
<=
0
:
raise
RuntimeError
(
'Not enough GPU memory.'
)
max_num_blocks
=
max_cache_size
//
self
.
get_cache_block_size
()
return
max_num_blocks
class
OPTMemoryAnalyzer
(
CacheFlowMemoryAnalyzer
):
...
...
@@ -69,7 +100,7 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
self
.
vocab_size
=
config
.
vocab_size
self
.
max_position
=
config
.
max_position_embeddings
def
_
get_param_size
(
self
)
->
int
:
def
get_param_size
(
self
)
->
int
:
word_embedding
=
self
.
vocab_size
*
self
.
embedding_size
//
self
.
tensor_parallel_size
if
self
.
embedding_size
!=
self
.
hidden_size
:
# Project in/out.
...
...
@@ -93,7 +124,7 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
dtype_size
=
get_dtype_size
(
self
.
dtype
)
return
dtype_size
*
total
def
_
get_max_act_size
(
def
get_max_act_size
(
self
,
max_num_batched_tokens
:
int
,
)
->
int
:
...
...
@@ -114,31 +145,6 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
dtype_size
=
get_dtype_size
(
self
.
dtype
)
return
dtype_size
*
max_act
def
get_cache_block_size
(
self
)
->
int
:
key_cache_block
=
self
.
block_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
value_cache_block
=
key_cache_block
total
=
self
.
num_layers
*
(
key_cache_block
+
value_cache_block
)
dtype_size
=
get_dtype_size
(
self
.
dtype
)
return
dtype_size
*
total
def
get_max_num_gpu_blocks
(
self
,
max_num_batched_tokens
:
int
,
memory_utilization
:
float
=
0.95
,
)
->
int
:
# NOTE(woosuk): This assumes that the machine has homogeneous GPUs.
usable_memory
=
int
(
memory_utilization
*
self
.
gpu_memory
)
param_size
=
self
.
_get_param_size
()
act_size
=
self
.
_get_max_act_size
(
max_num_batched_tokens
)
workspace_size
=
self
.
get_workspace_size
()
max_cache_size
=
usable_memory
-
(
param_size
+
act_size
+
workspace_size
)
if
max_cache_size
<=
0
:
raise
RuntimeError
(
'Not enough GPU memory.'
)
max_num_blocks
=
max_cache_size
//
self
.
get_cache_block_size
()
return
max_num_blocks
class
LlamaMemoryAnalyzer
(
CacheFlowMemoryAnalyzer
):
...
...
@@ -167,9 +173,10 @@ class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
self
.
vocab_size
=
config
.
vocab_size
self
.
max_position
=
8192
def
_get_param_size
(
self
)
->
int
:
def
get_param_size
(
self
)
->
int
:
# NOTE: LLaMA does not tie the two embeddings.
word_embedding
=
self
.
vocab_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
position_embedding
=
self
.
max_position
*
self
.
hidden
_size
lm_head
=
self
.
vocab_size
*
self
.
hidden_size
//
self
.
tensor_parallel
_size
# NOTE: LLaMA does not have bias terms.
ln1
=
self
.
hidden_size
...
...
@@ -188,11 +195,11 @@ class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
up
=
self
.
hidden_size
*
self
.
ffn_size
//
self
.
tensor_parallel_size
ffn
=
ln2
+
gate
+
down
+
up
total
=
(
word_embedding
+
position_embedding
+
self
.
num_layers
*
(
mha
+
ffn
)
)
total
=
word_embedding
+
self
.
num_layers
*
(
mha
+
ffn
)
+
lm_head
dtype_size
=
get_dtype_size
(
self
.
dtype
)
return
dtype_size
*
total
def
_
get_max_act_size
(
def
get_max_act_size
(
self
,
max_num_batched_tokens
:
int
,
)
->
int
:
...
...
@@ -213,28 +220,78 @@ class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
dtype_size
=
get_dtype_size
(
self
.
dtype
)
return
dtype_size
*
max_act
def
get_cache_block_size
(
self
)
->
int
:
key_cache_block
=
self
.
block_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
value_cache_block
=
key_cache_block
total
=
self
.
num_layers
*
(
key_cache_block
+
value_cache_block
)
class
GPTNeoXMemoryAnalyzer
(
CacheFlowMemoryAnalyzer
):
def
__init__
(
self
,
model_name
:
str
,
block_size
:
int
,
dtype
:
torch
.
dtype
,
gpu_memory
:
int
,
cpu_memory
:
int
,
tensor_parallel_size
:
int
,
)
->
None
:
self
.
model_name
=
model_name
self
.
block_size
=
block_size
self
.
dtype
=
dtype
self
.
gpu_memory
=
gpu_memory
self
.
cpu_memory
=
cpu_memory
self
.
tensor_parallel_size
=
tensor_parallel_size
config
=
AutoConfig
.
from_pretrained
(
model_name
)
self
.
num_layers
=
config
.
num_hidden_layers
self
.
hidden_size
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
head_size
=
config
.
hidden_size
//
self
.
num_heads
self
.
ffn_size
=
config
.
intermediate_size
self
.
vocab_size
=
config
.
vocab_size
self
.
max_position
=
8192
self
.
tie_word_embeddings
=
config
.
tie_word_embeddings
def
get_param_size
(
self
)
->
int
:
word_embedding
=
self
.
vocab_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
if
self
.
tie_word_embeddings
:
lm_head
=
0
else
:
lm_head
=
self
.
vocab_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
ln1
=
2
*
self
.
hidden_size
q
=
self
.
hidden_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
+
self
.
hidden_size
k
=
self
.
hidden_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
+
self
.
hidden_size
v
=
self
.
hidden_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
+
self
.
hidden_size
out
=
self
.
hidden_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
+
self
.
hidden_size
# Rotary embedding.
# TODO(woosuk): Share the rotary embedding between layers.
rot
=
self
.
max_position
*
self
.
head_size
mha
=
ln1
+
q
+
k
+
v
+
out
+
rot
ln2
=
2
*
self
.
hidden_size
ffn1
=
self
.
hidden_size
*
self
.
ffn_size
//
self
.
tensor_parallel_size
+
self
.
ffn_size
ffn2
=
self
.
ffn_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
+
self
.
hidden_size
ffn
=
ln2
+
ffn1
+
ffn2
total
=
word_embedding
+
self
.
num_layers
*
(
mha
+
ffn
)
+
lm_head
dtype_size
=
get_dtype_size
(
self
.
dtype
)
return
dtype_size
*
total
def
get_max_
num_gpu_blocks
(
def
get_max_
act_size
(
self
,
max_num_batched_tokens
:
int
,
memory_utilization
:
float
=
0.95
,
)
->
int
:
# NOTE(woosuk): This assumes that the machine has homogeneous GPUs.
gpu_memory
=
self
.
gpu_memory
usable_memory
=
int
(
memory_utilization
*
gpu_memory
)
param_size
=
self
.
_get_param_size
()
act_size
=
self
.
_get_max_act_size
(
max_num_batched_tokens
)
workspace_size
=
self
.
get_workspace_size
()
max_cache_size
=
usable_memory
-
(
param_size
+
act_size
+
workspace_size
)
if
max_cache_size
<=
0
:
raise
RuntimeError
(
'Not enough GPU memory.'
)
max_num_blocks
=
max_cache_size
//
self
.
get_cache_block_size
()
return
max_num_blocks
# NOTE: We approxmiately calculate the maximum activation size by
# estimating
# 1) the maximum activation tensor size during inference
# 2) the residual tensor size during inference
# Here, we assume that FlashAttention is used and
# thus the attention maps are never materialized in GPU DRAM.
residual
=
max_num_batched_tokens
*
self
.
hidden_size
qkv
=
3
*
(
max_num_batched_tokens
*
self
.
hidden_size
)
//
self
.
tensor_parallel_size
ffn
=
2
*
(
max_num_batched_tokens
*
self
.
ffn_size
)
//
self
.
tensor_parallel_size
# Double the activation size for input and output.
max_act
=
2
*
(
max
(
qkv
,
ffn
)
+
residual
)
# Size of output logits.
output_logits
=
2
*
(
max_num_batched_tokens
*
self
.
vocab_size
)
max_act
=
max
(
max_act
,
output_logits
)
dtype_size
=
get_dtype_size
(
self
.
dtype
)
return
dtype_size
*
max_act
cacheflow/models/model_utils.py
View file @
a96d63c2
from
typing
import
Union
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
transformers
import
AutoConfig
from
cacheflow.models.memory_analyzer
import
CacheFlowMemoryAnalyzer
from
cacheflow.models.memory_analyzer
import
GPTNeoXMemoryAnalyzer
from
cacheflow.models.memory_analyzer
import
LlamaMemoryAnalyzer
from
cacheflow.models.memory_analyzer
import
OPTMemoryAnalyzer
from
cacheflow.models.gpt_neox
import
GPTNeoXForCausalLM
from
cacheflow.models.llama
import
LlamaForCausalLM
from
cacheflow.models.opt
import
OPTForCausalLM
from
cacheflow.models.utils
import
get_torch_dtype
...
...
@@ -16,11 +17,15 @@ from cacheflow.models.utils import get_torch_dtype
_MODELS
=
{
'llama'
:
LlamaForCausalLM
,
'opt'
:
OPTForCausalLM
,
'stablelm'
:
GPTNeoXForCausalLM
,
'pythia'
:
GPTNeoXForCausalLM
,
}
_MEMORY_ANALYZERS
=
{
'llama'
:
LlamaMemoryAnalyzer
,
'opt'
:
OPTMemoryAnalyzer
,
'stablelm'
:
GPTNeoXMemoryAnalyzer
,
'pythia'
:
GPTNeoXMemoryAnalyzer
,
}
...
...
cacheflow/models/opt.py
View file @
a96d63c2
...
...
@@ -327,4 +327,4 @@ class OPTForCausalLM(nn.Module):
def
initialize_dummy_weights
(
self
)
->
None
:
for
param
in
self
.
state_dict
().
values
():
param
.
data
.
uniform_
(
-
0.1
,
0.1
)
param
.
data
.
uniform_
(
-
1e-3
,
1e-3
)
csrc/pos_encoding.cpp
View file @
a96d63c2
...
...
@@ -4,6 +4,7 @@ void rotary_embedding_neox(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
...
...
csrc/pos_encoding_kernels.cu
View file @
a96d63c2
...
...
@@ -8,16 +8,17 @@ __global__ void rotary_embedding_neox_kernel(
const
int64_t
*
__restrict__
positions
,
// [num_tokens]
scalar_t
*
__restrict__
query
,
// [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
key
,
// [num_tokens, num_heads, head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, head_size // 2]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim // 2]
const
int
rot_dim
,
const
int
stride
,
const
int
num_heads
,
const
int
head_size
)
{
// Each thread block is responsible for one token.
const
int
token_idx
=
blockIdx
.
x
;
int64_t
pos
=
positions
[
token_idx
];
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
head_size
;
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
rot_dim
;
const
int
embed_dim
=
head_size
/
2
;
const
int
embed_dim
=
rot_dim
/
2
;
const
int
n
=
num_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
n
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
embed_dim
;
...
...
@@ -51,16 +52,17 @@ void rotary_embedding_neox(
torch
::
Tensor
&
positions
,
// [num_tokens]
torch
::
Tensor
&
query
,
// [num_tokens, num_heads * head_size]
torch
::
Tensor
&
key
,
// [num_tokens, num_heads * head_size]
torch
::
Tensor
&
cos_sin_cache
)
// [max_position, head_size]
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
)
// [max_position, rot_dim]
{
int
num_tokens
=
query
.
size
(
0
);
int
head_size
=
cos_sin_cache
.
size
(
1
);
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
num_heads
=
query
.
size
(
1
)
/
head_size
;
int
stride
=
query
.
stride
(
0
);
TORCH_CHECK
(
stride
==
key
.
stride
(
0
));
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
head_size
/
2
,
512
));
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
query
.
scalar_type
(),
...
...
@@ -71,6 +73,7 @@ void rotary_embedding_neox(
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
stride
,
num_heads
,
head_size
);
...
...
tests/kernels/pos_encoding.py
View file @
a96d63c2
...
...
@@ -34,6 +34,7 @@ class RefRotaryEmbeddingNeox(nn.Module):
base
:
int
=
10000
,
)
->
None
:
super
().
__init__
()
self
.
rotary_dim
=
dim
self
.
max_position_embeddings
=
max_position_embeddings
# Create cos and sin embeddings.
...
...
@@ -52,13 +53,24 @@ class RefRotaryEmbeddingNeox(nn.Module):
query
:
torch
.
Tensor
,
# [num_tokens, num_heads, head_size]
key
:
torch
.
Tensor
,
# [num_tokens, num_heads, head_size]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
query_rot
=
query_rot
.
transpose
(
0
,
1
)
key_rot
=
key_rot
.
transpose
(
0
,
1
)
cos
=
F
.
embedding
(
positions
,
self
.
cos_cached
)
sin
=
F
.
embedding
(
positions
,
self
.
sin_cached
)
query
=
query
.
transpose
(
0
,
1
)
key
=
key
.
transpose
(
0
,
1
)
query
,
key
=
apply_rotary_pos_emb
(
query
,
key
,
cos
,
sin
)
query
=
query
.
transpose
(
0
,
1
).
contiguous
()
key
=
key
.
transpose
(
0
,
1
).
contiguous
()
query_rot
,
key_rot
=
apply_rotary_pos_emb
(
query_rot
,
key_rot
,
cos
,
sin
)
query_rot
=
query_rot
.
transpose
(
0
,
1
).
contiguous
()
key_rot
=
key_rot
.
transpose
(
0
,
1
).
contiguous
()
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
)
# Output query/key shape: [num_tokens, num_tokens, head_size]
return
query
,
key
...
...
@@ -69,6 +81,7 @@ def test_rotary_embedding_neox(
num_heads
:
int
,
head_size
:
int
,
max_position
:
int
,
rotary_dim
:
int
,
dtype
:
torch
.
dtype
,
base
:
int
=
10000
,
)
->
None
:
...
...
@@ -77,7 +90,7 @@ def test_rotary_embedding_neox(
key
=
torch
.
randn
(
num_tokens
,
num_heads
*
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
# Create the rotary embedding.
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
head_size
,
2
)
/
head_size
))
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
rotary_dim
,
2
)
/
rotary_dim
))
t
=
torch
.
arange
(
max_position
).
float
()
freqs
=
torch
.
einsum
(
'i,j -> ij'
,
t
,
inv_freq
.
float
())
cos
=
freqs
.
cos
()
...
...
@@ -92,12 +105,13 @@ def test_rotary_embedding_neox(
positions
,
out_query
,
out_key
,
head_size
,
cos_sin_cache
,
)
# Run the reference implementation.
ref_rotary_embedding
=
RefRotaryEmbeddingNeox
(
dim
=
head_size
,
dim
=
rotary_dim
,
max_position_embeddings
=
max_position
,
base
=
base
,
).
to
(
dtype
=
dtype
,
device
=
'cuda'
)
...
...
@@ -123,5 +137,6 @@ if __name__ == '__main__':
num_heads
=
5
,
head_size
=
head_size
,
max_position
=
8192
,
rotary_dim
=
int
(
head_size
*
0.25
),
dtype
=
dtype
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment