Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
a96d63c2
Unverified
Commit
a96d63c2
authored
Apr 28, 2023
by
Woosuk Kwon
Committed by
GitHub
Apr 28, 2023
Browse files
Add support for GPT-NeoX (Pythia) (#50)
parent
aa50b17c
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
436 additions
and
71 deletions
+436
-71
cacheflow/models/attention.py
cacheflow/models/attention.py
+11
-5
cacheflow/models/gpt_neox.py
cacheflow/models/gpt_neox.py
+278
-0
cacheflow/models/llama.py
cacheflow/models/llama.py
+1
-1
cacheflow/models/memory_analyzer.py
cacheflow/models/memory_analyzer.py
+107
-50
cacheflow/models/model_utils.py
cacheflow/models/model_utils.py
+6
-1
cacheflow/models/opt.py
cacheflow/models/opt.py
+1
-1
csrc/pos_encoding.cpp
csrc/pos_encoding.cpp
+1
-0
csrc/pos_encoding_kernels.cu
csrc/pos_encoding_kernels.cu
+9
-6
tests/kernels/pos_encoding.py
tests/kernels/pos_encoding.py
+22
-7
No files found.
cacheflow/models/attention.py
View file @
a96d63c2
...
...
@@ -150,20 +150,20 @@ class OPTCacheFlowAttention(GPTCacheFlowAttention):
super
().
__init__
(
scale
)
class
Llama
CacheFlowAttention
(
GPTCacheFlowAttention
):
"""
Llama uses
GPT-NeoX style rotary embedding."""
class
GPTNeoX
CacheFlowAttention
(
GPTCacheFlowAttention
):
"""
Attention with
GPT-NeoX style rotary embedding."""
def
__init__
(
self
,
scale
:
float
,
head_size
:
int
,
rotary_dim
:
int
,
max_position
:
int
=
8192
,
base
:
int
=
10000
,
)
->
None
:
super
().
__init__
(
scale
)
# Create the cos and sin cache.
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
head_size
,
2
)
/
head_size
))
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
rotary_dim
,
2
)
/
rotary_dim
))
t
=
torch
.
arange
(
max_position
).
float
()
freqs
=
torch
.
einsum
(
'i,j -> ij'
,
t
,
inv_freq
.
float
())
cos
=
freqs
.
cos
()
...
...
@@ -174,7 +174,7 @@ class LlamaCacheFlowAttention(GPTCacheFlowAttention):
# initializing the model. Make it more robust.
torch_dtype
=
torch
.
get_default_dtype
()
cache
=
cache
.
to
(
torch_dtype
)
# Embedding size: [max_position,
head_size
]
# Embedding size: [max_position,
rotary_dim
]
self
.
register_buffer
(
'cos_sin_cache'
,
cache
,
persistent
=
False
)
def
forward
(
...
...
@@ -190,10 +190,12 @@ class LlamaCacheFlowAttention(GPTCacheFlowAttention):
)
->
torch
.
Tensor
:
# [num_tokens, num_heads * head_size]
# Apply rotary embedding to the query and key before passing them
# to the attention op.
head_size
=
value_cache
.
shape
[
2
]
pos_encoding_ops
.
rotary_embedding_neox
(
positions
,
query
,
key
,
head_size
,
self
.
cos_sin_cache
,
)
return
super
().
forward
(
...
...
@@ -205,3 +207,7 @@ class LlamaCacheFlowAttention(GPTCacheFlowAttention):
input_metadata
,
cache_event
,
)
class
LlamaCacheFlowAttention
(
GPTNeoXCacheFlowAttention
):
"""LLaMA uses the GPT-NeoX style rotary embedding."""
cacheflow/models/gpt_neox.py
0 → 100644
View file @
a96d63c2
"""1D GPT-NeoX model compatible with HuggingFace weights."""
import
os
import
glob
import
filelock
from
tqdm
import
tqdm
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
from
torch
import
nn
from
huggingface_hub
import
snapshot_download
from
cacheflow.models
import
InputMetadata
from
cacheflow.models.attention
import
GPTNeoXCacheFlowAttention
from
cacheflow.models.sample
import
Sampler
from
cacheflow.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
cacheflow.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
cacheflow.sequence
import
SequenceOutputs
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
GPTNeoXAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
total_num_heads
=
config
.
num_attention_heads
self
.
hidden_size
=
config
.
hidden_size
self
.
head_size
=
self
.
hidden_size
//
self
.
total_num_heads
tensor_model_parallel_world_size
=
get_tensor_model_parallel_world_size
()
assert
self
.
total_num_heads
%
tensor_model_parallel_world_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tensor_model_parallel_world_size
self
.
query_key_value
=
ColumnParallelLinear
(
config
.
hidden_size
,
3
*
config
.
hidden_size
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
dense
=
RowParallelLinear
(
config
.
hidden_size
,
config
.
hidden_size
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
scaling
=
self
.
head_size
**
-
0.5
rotary_dim
=
int
(
self
.
head_size
*
config
.
rotary_pct
)
assert
rotary_dim
%
2
==
0
self
.
attn
=
GPTNeoXCacheFlowAttention
(
scaling
,
rotary_dim
)
def
forward
(
self
,
position_ids
:
torch
.
LongTensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
position_ids
,
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
output
,
_
=
self
.
dense
(
attn_output
)
return
output
class
GPTNeoXMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
dense_h_to_4h
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
intermediate_size
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
dense_4h_to_h
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
if
config
.
hidden_act
!=
'gelu'
:
raise
ValueError
(
f
'Unsupported activation:
{
config
.
hidden_act
}
. '
'Only gelu is supported for now.'
)
self
.
act
=
torch
.
nn
.
GELU
()
def
forward
(
self
,
hidden_states
):
hidden_states
,
_
=
self
.
dense_h_to_4h
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
,
_
=
self
.
dense_4h_to_h
(
hidden_states
)
return
hidden_states
class
GPTNeoXLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
use_parallel_residual
=
config
.
use_parallel_residual
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
attention
=
GPTNeoXAttention
(
config
)
self
.
mlp
=
GPTNeoXMLP
(
config
)
def
forward
(
self
,
position_ids
:
torch
.
LongTensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
attn_input
=
self
.
input_layernorm
(
hidden_states
)
attn_output
=
self
.
attention
(
position_ids
=
position_ids
,
hidden_states
=
attn_input
,
kv_cache
=
kv_cache
,
input_metadata
=
input_metadata
,
cache_event
=
cache_event
,
)
if
self
.
use_parallel_residual
:
# pseudocode:
# x = x + attn(ln1(x)) + mlp(ln2(x))
mlp_input
=
self
.
post_attention_layernorm
(
hidden_states
)
mlp_output
=
self
.
mlp
(
mlp_input
)
hidden_states
=
mlp_output
+
attn_output
+
hidden_states
else
:
# pseudocode:
# x = x + attn(ln1(x))
# x = x + mlp(ln2(x))
attn_output
=
attn_output
+
hidden_states
mlp_input
=
self
.
post_attention_layernorm
(
attn_output
)
mlp_output
=
self
.
mlp
(
mlp_input
)
hidden_states
=
mlp_output
+
attn_output
return
hidden_states
class
GPTNeoXModel
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_in
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
perform_initialization
=
False
)
self
.
layers
=
nn
.
ModuleList
([
GPTNeoXLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
final_layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
position_ids
:
torch
.
LongTensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_in
(
input_ids
)
for
i
in
range
(
len
(
self
.
layers
)):
if
cache_events
is
None
:
cache_event
=
None
else
:
cache_event
=
cache_events
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
position_ids
,
hidden_states
,
kv_caches
[
i
],
input_metadata
,
cache_event
,
)
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
return
hidden_states
class
GPTNeoXForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
gpt_neox
=
GPTNeoXModel
(
config
)
self
.
embed_out
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
LongTensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
SequenceOutputs
]:
hidden_states
=
self
.
gpt_neox
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
self
.
embed_out
.
weight
,
hidden_states
,
input_metadata
)
return
next_tokens
_column_parallel_weights
=
[
"embed_in.weight"
,
"embed_out.weight"
,
"dense_h_to_4h.weight"
,
"dense_h_to_4h.bias"
]
_row_parallel_weights
=
[
"dense.weight"
,
"dense_4h_to_h.weight"
]
def
load_weights
(
self
,
weights_path
:
str
):
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
for
name
,
param
in
state_dict
.
items
():
if
"query_key_value"
in
name
:
# NOTE(woosuk): GPT-NeoX's fused QKV has the shape of
# [num_heads * 3 * head_size, num_heads * head_size], while the
# required shape is [3 * num_heads * head_size, num_heads * head_size].
# Thus, we need weight conversion.
loaded_weight
=
torch
.
from_numpy
(
np
.
load
(
os
.
path
.
join
(
weights_path
,
name
)))
shard_size
=
param
.
shape
[
0
]
loaded_weight
=
loaded_weight
[
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
num_heads
=
self
.
config
.
num_attention_heads
hidden_size
=
self
.
config
.
hidden_size
head_size
=
hidden_size
//
num_heads
if
'query_key_value.weight'
in
name
:
loaded_weight
=
loaded_weight
.
view
(
-
1
,
3
,
head_size
,
hidden_size
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
,
hidden_size
).
contiguous
()
elif
'query_key_value.bias'
in
name
:
loaded_weight
=
loaded_weight
.
view
(
-
1
,
3
,
head_size
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
).
contiguous
()
else
:
assert
False
else
:
loaded_weight
=
torch
.
from_numpy
(
np
.
load
(
os
.
path
.
join
(
weights_path
,
name
)))
for
p
in
self
.
_column_parallel_weights
:
if
p
in
name
:
shard_size
=
param
.
shape
[
0
]
loaded_weight
=
loaded_weight
[
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
break
for
p
in
self
.
_row_parallel_weights
:
if
p
in
name
:
shard_size
=
param
.
shape
[
1
]
loaded_weight
=
loaded_weight
[
:,
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
break
assert
param
.
shape
==
loaded_weight
.
shape
param
.
data
.
copy_
(
loaded_weight
)
@
staticmethod
def
get_weights
(
model_name
:
str
,
path
:
str
):
path
=
os
.
path
.
join
(
path
,
f
"
{
model_name
}
-np"
)
path
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
path
))
os
.
makedirs
(
path
,
exist_ok
=
True
)
lock_path
=
os
.
path
.
join
(
path
,
"file_lock"
)
lock
=
filelock
.
FileLock
(
lock_path
)
with
lock
:
test_weight_path
=
os
.
path
.
join
(
path
,
"gpt_neox.embed_in.weight"
)
if
os
.
path
.
exists
(
test_weight_path
):
return
path
folder
=
snapshot_download
(
model_name
,
allow_patterns
=
"*.bin"
,
cache_dir
=
os
.
path
.
join
(
path
,
"cache"
))
bin_files
=
glob
.
glob
(
os
.
path
.
join
(
folder
,
"*.bin"
))
for
bin_file
in
tqdm
(
bin_files
,
desc
=
"Convert format"
):
state
=
torch
.
load
(
bin_file
,
map_location
=
"cpu"
)
for
name
,
param
in
tqdm
(
state
.
items
(),
leave
=
False
):
param_path
=
os
.
path
.
join
(
path
,
name
)
with
open
(
param_path
,
"wb"
)
as
f
:
np
.
save
(
f
,
param
.
cpu
().
detach
().
numpy
())
return
path
def
initialize_dummy_weights
(
self
)
->
None
:
for
param
in
self
.
state_dict
().
values
():
param
.
data
.
uniform_
(
-
1e-3
,
1e-3
)
cacheflow/models/llama.py
View file @
a96d63c2
...
...
@@ -289,4 +289,4 @@ class LlamaForCausalLM(nn.Module):
def
initialize_dummy_weights
(
self
)
->
None
:
for
param
in
self
.
state_dict
().
values
():
param
.
data
.
uniform_
(
-
0.1
,
0.1
)
param
.
data
.
uniform_
(
-
1e-3
,
1e-3
)
cacheflow/models/memory_analyzer.py
View file @
a96d63c2
...
...
@@ -40,6 +40,37 @@ class CacheFlowMemoryAnalyzer:
max_num_blocks
=
swap_space
//
self
.
get_cache_block_size
()
return
max_num_blocks
def
get_param_size
(
self
)
->
int
:
raise
NotImplementedError
()
def
get_max_act_size
(
self
,
max_num_batched_tokens
:
int
)
->
int
:
raise
NotImplementedError
()
def
get_cache_block_size
(
self
)
->
int
:
key_cache_block
=
self
.
block_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
value_cache_block
=
key_cache_block
total
=
self
.
num_layers
*
(
key_cache_block
+
value_cache_block
)
dtype_size
=
get_dtype_size
(
self
.
dtype
)
return
dtype_size
*
total
def
get_max_num_gpu_blocks
(
self
,
max_num_batched_tokens
:
int
,
memory_utilization
:
float
=
0.95
,
)
->
int
:
# NOTE(woosuk): This assumes that the machine has homogeneous GPUs.
usable_memory
=
int
(
memory_utilization
*
self
.
gpu_memory
)
param_size
=
self
.
get_param_size
()
act_size
=
self
.
get_max_act_size
(
max_num_batched_tokens
)
workspace_size
=
self
.
get_workspace_size
()
max_cache_size
=
usable_memory
-
(
param_size
+
act_size
+
workspace_size
)
if
max_cache_size
<=
0
:
raise
RuntimeError
(
'Not enough GPU memory.'
)
max_num_blocks
=
max_cache_size
//
self
.
get_cache_block_size
()
return
max_num_blocks
class
OPTMemoryAnalyzer
(
CacheFlowMemoryAnalyzer
):
...
...
@@ -69,7 +100,7 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
self
.
vocab_size
=
config
.
vocab_size
self
.
max_position
=
config
.
max_position_embeddings
def
_
get_param_size
(
self
)
->
int
:
def
get_param_size
(
self
)
->
int
:
word_embedding
=
self
.
vocab_size
*
self
.
embedding_size
//
self
.
tensor_parallel_size
if
self
.
embedding_size
!=
self
.
hidden_size
:
# Project in/out.
...
...
@@ -93,7 +124,7 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
dtype_size
=
get_dtype_size
(
self
.
dtype
)
return
dtype_size
*
total
def
_
get_max_act_size
(
def
get_max_act_size
(
self
,
max_num_batched_tokens
:
int
,
)
->
int
:
...
...
@@ -114,31 +145,6 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
dtype_size
=
get_dtype_size
(
self
.
dtype
)
return
dtype_size
*
max_act
def
get_cache_block_size
(
self
)
->
int
:
key_cache_block
=
self
.
block_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
value_cache_block
=
key_cache_block
total
=
self
.
num_layers
*
(
key_cache_block
+
value_cache_block
)
dtype_size
=
get_dtype_size
(
self
.
dtype
)
return
dtype_size
*
total
def
get_max_num_gpu_blocks
(
self
,
max_num_batched_tokens
:
int
,
memory_utilization
:
float
=
0.95
,
)
->
int
:
# NOTE(woosuk): This assumes that the machine has homogeneous GPUs.
usable_memory
=
int
(
memory_utilization
*
self
.
gpu_memory
)
param_size
=
self
.
_get_param_size
()
act_size
=
self
.
_get_max_act_size
(
max_num_batched_tokens
)
workspace_size
=
self
.
get_workspace_size
()
max_cache_size
=
usable_memory
-
(
param_size
+
act_size
+
workspace_size
)
if
max_cache_size
<=
0
:
raise
RuntimeError
(
'Not enough GPU memory.'
)
max_num_blocks
=
max_cache_size
//
self
.
get_cache_block_size
()
return
max_num_blocks
class
LlamaMemoryAnalyzer
(
CacheFlowMemoryAnalyzer
):
...
...
@@ -167,9 +173,10 @@ class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
self
.
vocab_size
=
config
.
vocab_size
self
.
max_position
=
8192
def
_get_param_size
(
self
)
->
int
:
def
get_param_size
(
self
)
->
int
:
# NOTE: LLaMA does not tie the two embeddings.
word_embedding
=
self
.
vocab_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
position_embedding
=
self
.
max_position
*
self
.
hidden
_size
lm_head
=
self
.
vocab_size
*
self
.
hidden_size
//
self
.
tensor_parallel
_size
# NOTE: LLaMA does not have bias terms.
ln1
=
self
.
hidden_size
...
...
@@ -188,11 +195,11 @@ class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
up
=
self
.
hidden_size
*
self
.
ffn_size
//
self
.
tensor_parallel_size
ffn
=
ln2
+
gate
+
down
+
up
total
=
(
word_embedding
+
position_embedding
+
self
.
num_layers
*
(
mha
+
ffn
)
)
total
=
word_embedding
+
self
.
num_layers
*
(
mha
+
ffn
)
+
lm_head
dtype_size
=
get_dtype_size
(
self
.
dtype
)
return
dtype_size
*
total
def
_
get_max_act_size
(
def
get_max_act_size
(
self
,
max_num_batched_tokens
:
int
,
)
->
int
:
...
...
@@ -213,28 +220,78 @@ class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
dtype_size
=
get_dtype_size
(
self
.
dtype
)
return
dtype_size
*
max_act
def
get_cache_block_size
(
self
)
->
int
:
key_cache_block
=
self
.
block_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
value_cache_block
=
key_cache_block
total
=
self
.
num_layers
*
(
key_cache_block
+
value_cache_block
)
class
GPTNeoXMemoryAnalyzer
(
CacheFlowMemoryAnalyzer
):
def
__init__
(
self
,
model_name
:
str
,
block_size
:
int
,
dtype
:
torch
.
dtype
,
gpu_memory
:
int
,
cpu_memory
:
int
,
tensor_parallel_size
:
int
,
)
->
None
:
self
.
model_name
=
model_name
self
.
block_size
=
block_size
self
.
dtype
=
dtype
self
.
gpu_memory
=
gpu_memory
self
.
cpu_memory
=
cpu_memory
self
.
tensor_parallel_size
=
tensor_parallel_size
config
=
AutoConfig
.
from_pretrained
(
model_name
)
self
.
num_layers
=
config
.
num_hidden_layers
self
.
hidden_size
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
head_size
=
config
.
hidden_size
//
self
.
num_heads
self
.
ffn_size
=
config
.
intermediate_size
self
.
vocab_size
=
config
.
vocab_size
self
.
max_position
=
8192
self
.
tie_word_embeddings
=
config
.
tie_word_embeddings
def
get_param_size
(
self
)
->
int
:
word_embedding
=
self
.
vocab_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
if
self
.
tie_word_embeddings
:
lm_head
=
0
else
:
lm_head
=
self
.
vocab_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
ln1
=
2
*
self
.
hidden_size
q
=
self
.
hidden_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
+
self
.
hidden_size
k
=
self
.
hidden_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
+
self
.
hidden_size
v
=
self
.
hidden_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
+
self
.
hidden_size
out
=
self
.
hidden_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
+
self
.
hidden_size
# Rotary embedding.
# TODO(woosuk): Share the rotary embedding between layers.
rot
=
self
.
max_position
*
self
.
head_size
mha
=
ln1
+
q
+
k
+
v
+
out
+
rot
ln2
=
2
*
self
.
hidden_size
ffn1
=
self
.
hidden_size
*
self
.
ffn_size
//
self
.
tensor_parallel_size
+
self
.
ffn_size
ffn2
=
self
.
ffn_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
+
self
.
hidden_size
ffn
=
ln2
+
ffn1
+
ffn2
total
=
word_embedding
+
self
.
num_layers
*
(
mha
+
ffn
)
+
lm_head
dtype_size
=
get_dtype_size
(
self
.
dtype
)
return
dtype_size
*
total
def
get_max_
num_gpu_blocks
(
def
get_max_
act_size
(
self
,
max_num_batched_tokens
:
int
,
memory_utilization
:
float
=
0.95
,
)
->
int
:
# NOTE(woosuk): This assumes that the machine has homogeneous GPUs.
gpu_memory
=
self
.
gpu_memory
usable_memory
=
int
(
memory_utilization
*
gpu_memory
)
param_size
=
self
.
_get_param_size
()
act_size
=
self
.
_get_max_act_size
(
max_num_batched_tokens
)
workspace_size
=
self
.
get_workspace_size
()
max_cache_size
=
usable_memory
-
(
param_size
+
act_size
+
workspace_size
)
if
max_cache_size
<=
0
:
raise
RuntimeError
(
'Not enough GPU memory.'
)
max_num_blocks
=
max_cache_size
//
self
.
get_cache_block_size
()
return
max_num_blocks
# NOTE: We approxmiately calculate the maximum activation size by
# estimating
# 1) the maximum activation tensor size during inference
# 2) the residual tensor size during inference
# Here, we assume that FlashAttention is used and
# thus the attention maps are never materialized in GPU DRAM.
residual
=
max_num_batched_tokens
*
self
.
hidden_size
qkv
=
3
*
(
max_num_batched_tokens
*
self
.
hidden_size
)
//
self
.
tensor_parallel_size
ffn
=
2
*
(
max_num_batched_tokens
*
self
.
ffn_size
)
//
self
.
tensor_parallel_size
# Double the activation size for input and output.
max_act
=
2
*
(
max
(
qkv
,
ffn
)
+
residual
)
# Size of output logits.
output_logits
=
2
*
(
max_num_batched_tokens
*
self
.
vocab_size
)
max_act
=
max
(
max_act
,
output_logits
)
dtype_size
=
get_dtype_size
(
self
.
dtype
)
return
dtype_size
*
max_act
cacheflow/models/model_utils.py
View file @
a96d63c2
from
typing
import
Union
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
transformers
import
AutoConfig
from
cacheflow.models.memory_analyzer
import
CacheFlowMemoryAnalyzer
from
cacheflow.models.memory_analyzer
import
GPTNeoXMemoryAnalyzer
from
cacheflow.models.memory_analyzer
import
LlamaMemoryAnalyzer
from
cacheflow.models.memory_analyzer
import
OPTMemoryAnalyzer
from
cacheflow.models.gpt_neox
import
GPTNeoXForCausalLM
from
cacheflow.models.llama
import
LlamaForCausalLM
from
cacheflow.models.opt
import
OPTForCausalLM
from
cacheflow.models.utils
import
get_torch_dtype
...
...
@@ -16,11 +17,15 @@ from cacheflow.models.utils import get_torch_dtype
_MODELS
=
{
'llama'
:
LlamaForCausalLM
,
'opt'
:
OPTForCausalLM
,
'stablelm'
:
GPTNeoXForCausalLM
,
'pythia'
:
GPTNeoXForCausalLM
,
}
_MEMORY_ANALYZERS
=
{
'llama'
:
LlamaMemoryAnalyzer
,
'opt'
:
OPTMemoryAnalyzer
,
'stablelm'
:
GPTNeoXMemoryAnalyzer
,
'pythia'
:
GPTNeoXMemoryAnalyzer
,
}
...
...
cacheflow/models/opt.py
View file @
a96d63c2
...
...
@@ -327,4 +327,4 @@ class OPTForCausalLM(nn.Module):
def
initialize_dummy_weights
(
self
)
->
None
:
for
param
in
self
.
state_dict
().
values
():
param
.
data
.
uniform_
(
-
0.1
,
0.1
)
param
.
data
.
uniform_
(
-
1e-3
,
1e-3
)
csrc/pos_encoding.cpp
View file @
a96d63c2
...
...
@@ -4,6 +4,7 @@ void rotary_embedding_neox(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
...
...
csrc/pos_encoding_kernels.cu
View file @
a96d63c2
...
...
@@ -8,16 +8,17 @@ __global__ void rotary_embedding_neox_kernel(
const
int64_t
*
__restrict__
positions
,
// [num_tokens]
scalar_t
*
__restrict__
query
,
// [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
key
,
// [num_tokens, num_heads, head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, head_size // 2]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim // 2]
const
int
rot_dim
,
const
int
stride
,
const
int
num_heads
,
const
int
head_size
)
{
// Each thread block is responsible for one token.
const
int
token_idx
=
blockIdx
.
x
;
int64_t
pos
=
positions
[
token_idx
];
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
head_size
;
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
rot_dim
;
const
int
embed_dim
=
head_size
/
2
;
const
int
embed_dim
=
rot_dim
/
2
;
const
int
n
=
num_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
n
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
embed_dim
;
...
...
@@ -51,16 +52,17 @@ void rotary_embedding_neox(
torch
::
Tensor
&
positions
,
// [num_tokens]
torch
::
Tensor
&
query
,
// [num_tokens, num_heads * head_size]
torch
::
Tensor
&
key
,
// [num_tokens, num_heads * head_size]
torch
::
Tensor
&
cos_sin_cache
)
// [max_position, head_size]
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
)
// [max_position, rot_dim]
{
int
num_tokens
=
query
.
size
(
0
);
int
head_size
=
cos_sin_cache
.
size
(
1
);
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
num_heads
=
query
.
size
(
1
)
/
head_size
;
int
stride
=
query
.
stride
(
0
);
TORCH_CHECK
(
stride
==
key
.
stride
(
0
));
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
head_size
/
2
,
512
));
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
query
.
scalar_type
(),
...
...
@@ -71,6 +73,7 @@ void rotary_embedding_neox(
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
stride
,
num_heads
,
head_size
);
...
...
tests/kernels/pos_encoding.py
View file @
a96d63c2
...
...
@@ -34,6 +34,7 @@ class RefRotaryEmbeddingNeox(nn.Module):
base
:
int
=
10000
,
)
->
None
:
super
().
__init__
()
self
.
rotary_dim
=
dim
self
.
max_position_embeddings
=
max_position_embeddings
# Create cos and sin embeddings.
...
...
@@ -52,13 +53,24 @@ class RefRotaryEmbeddingNeox(nn.Module):
query
:
torch
.
Tensor
,
# [num_tokens, num_heads, head_size]
key
:
torch
.
Tensor
,
# [num_tokens, num_heads, head_size]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
query_rot
=
query_rot
.
transpose
(
0
,
1
)
key_rot
=
key_rot
.
transpose
(
0
,
1
)
cos
=
F
.
embedding
(
positions
,
self
.
cos_cached
)
sin
=
F
.
embedding
(
positions
,
self
.
sin_cached
)
query
=
query
.
transpose
(
0
,
1
)
key
=
key
.
transpose
(
0
,
1
)
query
,
key
=
apply_rotary_pos_emb
(
query
,
key
,
cos
,
sin
)
query
=
query
.
transpose
(
0
,
1
).
contiguous
()
key
=
key
.
transpose
(
0
,
1
).
contiguous
()
query_rot
,
key_rot
=
apply_rotary_pos_emb
(
query_rot
,
key_rot
,
cos
,
sin
)
query_rot
=
query_rot
.
transpose
(
0
,
1
).
contiguous
()
key_rot
=
key_rot
.
transpose
(
0
,
1
).
contiguous
()
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
)
# Output query/key shape: [num_tokens, num_tokens, head_size]
return
query
,
key
...
...
@@ -69,6 +81,7 @@ def test_rotary_embedding_neox(
num_heads
:
int
,
head_size
:
int
,
max_position
:
int
,
rotary_dim
:
int
,
dtype
:
torch
.
dtype
,
base
:
int
=
10000
,
)
->
None
:
...
...
@@ -77,7 +90,7 @@ def test_rotary_embedding_neox(
key
=
torch
.
randn
(
num_tokens
,
num_heads
*
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
# Create the rotary embedding.
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
head_size
,
2
)
/
head_size
))
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
rotary_dim
,
2
)
/
rotary_dim
))
t
=
torch
.
arange
(
max_position
).
float
()
freqs
=
torch
.
einsum
(
'i,j -> ij'
,
t
,
inv_freq
.
float
())
cos
=
freqs
.
cos
()
...
...
@@ -92,12 +105,13 @@ def test_rotary_embedding_neox(
positions
,
out_query
,
out_key
,
head_size
,
cos_sin_cache
,
)
# Run the reference implementation.
ref_rotary_embedding
=
RefRotaryEmbeddingNeox
(
dim
=
head_size
,
dim
=
rotary_dim
,
max_position_embeddings
=
max_position
,
base
=
base
,
).
to
(
dtype
=
dtype
,
device
=
'cuda'
)
...
...
@@ -123,5 +137,6 @@ if __name__ == '__main__':
num_heads
=
5
,
head_size
=
head_size
,
max_position
=
8192
,
rotary_dim
=
int
(
head_size
*
0.25
),
dtype
=
dtype
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment