Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
SIYIXNI
vllm
Commits
e548c148
"examples/community/wildcard_stable_diffusion.py" did not exist on "9ebaea545ffddf2e9079994f2ea657a7fa5f358c"
Unverified
Commit
e548c148
authored
May 04, 2023
by
Woosuk Kwon
Committed by
GitHub
May 04, 2023
Browse files
Add support for GPT-2 (#60)
parent
130d5fd8
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
350 additions
and
8 deletions
+350
-8
cacheflow/models/gpt2.py
cacheflow/models/gpt2.py
+265
-0
cacheflow/models/gpt_neox.py
cacheflow/models/gpt_neox.py
+5
-5
cacheflow/models/llama.py
cacheflow/models/llama.py
+1
-1
cacheflow/models/memory_analyzer.py
cacheflow/models/memory_analyzer.py
+70
-0
cacheflow/models/model_utils.py
cacheflow/models/model_utils.py
+4
-0
cacheflow/models/opt.py
cacheflow/models/opt.py
+1
-1
cacheflow/models/sample.py
cacheflow/models/sample.py
+4
-1
No files found.
cacheflow/models/gpt2.py
0 → 100644
View file @
e548c148
"""1D GPT-2 model compatible with HuggingFace weights."""
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
GPT2Config
from
cacheflow.models
import
InputMetadata
from
cacheflow.models.attention
import
GPTCacheFlowAttention
from
cacheflow.models.sample
import
Sampler
from
cacheflow.models.utils
import
(
hf_model_weights_iterator
,
load_tensor_parallel_weights
)
from
cacheflow.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
cacheflow.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
cacheflow.sequence
import
SequenceOutputs
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
GPT2Attention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GPT2Config
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
total_num_heads
=
config
.
num_attention_heads
tensor_model_parallel_world_size
=
get_tensor_model_parallel_world_size
()
assert
total_num_heads
%
tensor_model_parallel_world_size
==
0
self
.
num_heads
=
total_num_heads
//
tensor_model_parallel_world_size
self
.
head_dim
=
self
.
hidden_size
//
total_num_heads
self
.
scale
=
self
.
head_dim
**
-
0.5
self
.
c_attn
=
ColumnParallelLinear
(
self
.
hidden_size
,
3
*
self
.
hidden_size
,
bias
=
True
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
c_proj
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
self
.
attn
=
GPTCacheFlowAttention
(
scale
=
self
.
scale
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
key_cache
,
value_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
key_cache
,
value_cache
,
input_metadata
,
cache_event
)
attn_output
,
_
=
self
.
c_proj
(
attn_output
)
return
attn_output
class
GPT2MLP
(
nn
.
Module
):
def
__init__
(
self
,
intermediate_size
:
int
,
config
:
GPT2Config
,
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
self
.
c_fc
=
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
bias
=
True
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
True
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
act_fn
=
config
.
activation_function
if
act_fn
!=
"gelu_new"
:
raise
ValueError
(
f
"Unsupported activation:
{
act_fn
}
. "
"GPT-2 only supports gelu_new for now."
)
self
.
act
=
torch
.
nn
.
GELU
(
approximate
=
"tanh"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
c_fc
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
,
_
=
self
.
c_proj
(
hidden_states
)
return
hidden_states
class
GPT2Block
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GPT2Config
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
inner_dim
=
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
hidden_size
self
.
ln_1
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
attn
=
GPT2Attention
(
config
)
self
.
ln_2
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
GPT2MLP
(
inner_dim
,
config
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
ln_1
(
hidden_states
)
attn_output
=
self
.
attn
(
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input_metadata
=
input_metadata
,
cache_event
=
cache_event
,
)
# residual connection
hidden_states
=
attn_output
+
residual
residual
=
hidden_states
hidden_states
=
self
.
ln_2
(
hidden_states
)
feed_forward_hidden_states
=
self
.
mlp
(
hidden_states
)
# residual connection
hidden_states
=
residual
+
feed_forward_hidden_states
return
hidden_states
class
GPT2Model
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GPT2Config
):
super
().
__init__
()
self
.
config
=
config
assert
config
.
add_cross_attention
==
False
assert
config
.
scale_attn_by_inverse_layer_idx
==
False
assert
config
.
reorder_and_upcast_attn
==
False
self
.
embed_dim
=
config
.
hidden_size
# Optimization: While the vocab size of GPT-2 is 50257, we extend it
# to 50304 in order to make it divisible by 64.
# This improves performance since GPUs are faster if the dimension
# is divisible by 64. In addition, it allows us to shard the embedding
# layer across 2, 4, 8, or more GPUs.
vocab_size
=
((
config
.
vocab_size
+
63
)
//
64
)
*
64
self
.
wte
=
VocabParallelEmbedding
(
vocab_size
,
self
.
embed_dim
)
self
.
wpe
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
self
.
embed_dim
)
self
.
h
=
nn
.
ModuleList
(
[
GPT2Block
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
ln_f
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
position_ids
:
torch
.
LongTensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
wte
(
input_ids
)
position_embeds
=
self
.
wpe
(
position_ids
)
hidden_states
=
inputs_embeds
+
position_embeds
for
i
in
range
(
len
(
self
.
h
)):
if
cache_events
is
None
:
cache_event
=
None
else
:
cache_event
=
cache_events
[
i
]
layer
=
self
.
h
[
i
]
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
input_metadata
,
cache_event
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
return
hidden_states
class
GPT2LMHeadModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GPT2Config
):
super
().
__init__
()
self
.
config
=
config
self
.
transformer
=
GPT2Model
(
config
)
# TODO(zhuohan): create a new weight after implementing pipeline
# parallelism
self
.
lm_head_weight
=
self
.
transformer
.
wte
.
weight
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
LongTensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
SequenceOutputs
]:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
input_metadata
)
return
next_tokens
_column_parallel_weights
=
[
"wte.weight"
,
"c_fc.weight"
,
"c_fc.bias"
]
_row_parallel_weights
=
[
"c_proj.weight"
]
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
tensor_model_parallel_world_size
=
get_tensor_model_parallel_world_size
()
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
use_np_cache
):
if
"lm_head.weight"
in
name
:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
continue
if
".attn.bias"
in
name
:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
name
=
"transformer."
+
name
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
for
conv1d_weight_name
in
[
"c_attn"
,
"c_proj"
,
"c_fc"
]:
if
conv1d_weight_name
not
in
name
:
continue
if
not
name
.
endswith
(
".weight"
):
continue
loaded_weight
=
loaded_weight
.
t
()
param
=
state_dict
[
name
]
if
name
==
"transformer.wte.weight"
:
# Consider padding in the vocab size.
padded_vocab_size
=
param
.
shape
[
0
]
*
tensor_model_parallel_world_size
num_extra_rows
=
padded_vocab_size
-
self
.
config
.
vocab_size
extra_rows
=
torch
.
empty
(
num_extra_rows
,
loaded_weight
.
shape
[
1
])
extra_rows
=
extra_rows
.
to
(
loaded_weight
)
loaded_weight
=
torch
.
cat
([
loaded_weight
,
extra_rows
],
dim
=
0
)
# For the fused QKV linear layer, manually shard the weights.
if
"c_attn"
in
name
:
# GPT-2's fused QKV has the shape of [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along the head dimension.
total_num_heads
=
self
.
config
.
num_attention_heads
hidden_size
=
self
.
config
.
hidden_size
head_size
=
hidden_size
//
total_num_heads
num_heads
=
total_num_heads
//
tensor_model_parallel_world_size
head_start
=
tensor_model_parallel_rank
*
num_heads
head_end
=
(
tensor_model_parallel_rank
+
1
)
*
num_heads
if
name
.
endswith
(
".weight"
):
loaded_weight
=
loaded_weight
.
view
(
3
,
total_num_heads
,
head_size
,
hidden_size
)
loaded_weight
=
loaded_weight
[:,
head_start
:
head_end
,
:,
:]
loaded_weight
=
loaded_weight
.
reshape
(
-
1
,
hidden_size
)
elif
name
.
endswith
(
".bias"
):
loaded_weight
=
loaded_weight
.
view
(
3
,
total_num_heads
,
head_size
)
loaded_weight
=
loaded_weight
[:,
head_start
:
head_end
,
:]
loaded_weight
=
loaded_weight
.
reshape
(
-
1
)
else
:
raise
ValueError
(
f
"Unexpected parameter name
{
name
}
"
)
load_tensor_parallel_weights
(
param
,
loaded_weight
,
name
,
self
.
_column_parallel_weights
,
self
.
_row_parallel_weights
)
def
initialize_dummy_weights
(
self
)
->
None
:
for
param
in
self
.
state_dict
().
values
():
param
.
data
.
uniform_
(
-
1e-3
,
1e-3
)
cacheflow/models/gpt_neox.py
View file @
e548c148
...
@@ -173,7 +173,7 @@ class GPTNeoXForCausalLM(nn.Module):
...
@@ -173,7 +173,7 @@ class GPTNeoXForCausalLM(nn.Module):
self
.
embed_out
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
self
.
embed_out
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
,
gather_output
=
False
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
)
perform_initialization
=
False
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -205,8 +205,8 @@ class GPTNeoXForCausalLM(nn.Module):
...
@@ -205,8 +205,8 @@ class GPTNeoXForCausalLM(nn.Module):
param
=
state_dict
[
name
]
param
=
state_dict
[
name
]
if
"query_key_value"
in
name
:
if
"query_key_value"
in
name
:
# NOTE(woosuk): GPT-NeoX's fused QKV has the shape of
# NOTE(woosuk): GPT-NeoX's fused QKV has the shape of
# [num_heads * 3 * head_size,
num_heads * head
_size], while the
# [num_heads * 3 * head_size,
hidden
_size], while the
# required shape is [3 * num_heads * head_size,
num_heads * head
_size].
# required shape is [3 * num_heads * head_size,
hidden
_size].
# Thus, we need weight conversion.
# Thus, we need weight conversion.
shard_size
=
param
.
shape
[
0
]
shard_size
=
param
.
shape
[
0
]
loaded_weight
=
loaded_weight
[
shard_size
*
tensor_model_parallel_rank
loaded_weight
=
loaded_weight
[
shard_size
*
tensor_model_parallel_rank
...
@@ -218,11 +218,11 @@ class GPTNeoXForCausalLM(nn.Module):
...
@@ -218,11 +218,11 @@ class GPTNeoXForCausalLM(nn.Module):
if
'query_key_value.weight'
in
name
:
if
'query_key_value.weight'
in
name
:
loaded_weight
=
loaded_weight
.
view
(
-
1
,
3
,
head_size
,
hidden_size
)
loaded_weight
=
loaded_weight
.
view
(
-
1
,
3
,
head_size
,
hidden_size
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
,
hidden_size
)
.
contiguous
()
loaded_weight
=
loaded_weight
.
reshape
(
-
1
,
hidden_size
)
elif
'query_key_value.bias'
in
name
:
elif
'query_key_value.bias'
in
name
:
loaded_weight
=
loaded_weight
.
view
(
-
1
,
3
,
head_size
)
loaded_weight
=
loaded_weight
.
view
(
-
1
,
3
,
head_size
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
)
.
contiguous
()
loaded_weight
=
loaded_weight
.
reshape
(
-
1
)
else
:
else
:
raise
ValueError
(
f
"Unexpected weight name:
{
name
}
"
)
raise
ValueError
(
f
"Unexpected weight name:
{
name
}
"
)
load_tensor_parallel_weights
(
param
,
loaded_weight
,
name
,
load_tensor_parallel_weights
(
param
,
loaded_weight
,
name
,
...
...
cacheflow/models/llama.py
View file @
e548c148
...
@@ -192,7 +192,7 @@ class LlamaForCausalLM(nn.Module):
...
@@ -192,7 +192,7 @@ class LlamaForCausalLM(nn.Module):
bias
=
False
,
bias
=
False
,
gather_output
=
False
,
gather_output
=
False
,
perform_initialization
=
False
)
perform_initialization
=
False
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
def
forward
(
def
forward
(
self
,
self
,
...
...
cacheflow/models/memory_analyzer.py
View file @
e548c148
...
@@ -72,6 +72,76 @@ class CacheFlowMemoryAnalyzer:
...
@@ -72,6 +72,76 @@ class CacheFlowMemoryAnalyzer:
return
max_num_blocks
return
max_num_blocks
class
GPT2MemoryAnalyzer
(
CacheFlowMemoryAnalyzer
):
def
__init__
(
self
,
model_name
:
str
,
block_size
:
int
,
dtype
:
torch
.
dtype
,
gpu_memory
:
int
,
cpu_memory
:
int
,
tensor_parallel_size
:
int
,
)
->
None
:
self
.
model_name
=
model_name
self
.
block_size
=
block_size
self
.
dtype
=
dtype
self
.
gpu_memory
=
gpu_memory
self
.
cpu_memory
=
cpu_memory
self
.
tensor_parallel_size
=
tensor_parallel_size
config
=
AutoConfig
.
from_pretrained
(
model_name
)
self
.
num_layers
=
config
.
num_hidden_layers
self
.
hidden_size
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
head_size
=
config
.
hidden_size
//
self
.
num_heads
self
.
ffn_size
=
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
self
.
hidden_size
self
.
vocab_size
=
config
.
vocab_size
self
.
max_position
=
config
.
max_position_embeddings
def
get_param_size
(
self
)
->
int
:
word_embedding
=
self
.
vocab_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
position_embedding
=
self
.
max_position
*
self
.
hidden_size
ln1
=
2
*
self
.
hidden_size
q
=
self
.
hidden_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
+
self
.
hidden_size
k
=
self
.
hidden_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
+
self
.
hidden_size
v
=
self
.
hidden_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
+
self
.
hidden_size
out
=
self
.
hidden_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
+
self
.
hidden_size
mha
=
ln1
+
q
+
k
+
v
+
out
ln2
=
2
*
self
.
hidden_size
ffn1
=
self
.
hidden_size
*
self
.
ffn_size
//
self
.
tensor_parallel_size
+
self
.
ffn_size
ffn2
=
self
.
ffn_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
+
self
.
hidden_size
ffn
=
ln2
+
ffn1
+
ffn2
total
=
(
word_embedding
+
position_embedding
+
self
.
num_layers
*
(
mha
+
ffn
))
dtype_size
=
get_dtype_size
(
self
.
dtype
)
return
dtype_size
*
total
def
get_max_act_size
(
self
,
max_num_batched_tokens
:
int
,
)
->
int
:
# NOTE: We approxmiately calculate the maximum activation size by
# estimating
# 1) the maximum activation tensor size during inference
# 2) the residual tensor size during inference
# Here, we assume that FlashAttention is used and
# thus the attention maps are never materialized in GPU DRAM.
residual
=
max_num_batched_tokens
*
self
.
hidden_size
qkv
=
3
*
(
max_num_batched_tokens
*
self
.
hidden_size
)
//
self
.
tensor_parallel_size
ffn
=
max_num_batched_tokens
*
self
.
ffn_size
//
self
.
tensor_parallel_size
# Double the activation size for input and output.
max_act
=
2
*
(
max
(
qkv
,
ffn
)
+
residual
)
# Size of output logits.
output_logits
=
2
*
(
max_num_batched_tokens
*
self
.
vocab_size
)
max_act
=
max
(
max_act
,
output_logits
)
dtype_size
=
get_dtype_size
(
self
.
dtype
)
return
dtype_size
*
max_act
class
OPTMemoryAnalyzer
(
CacheFlowMemoryAnalyzer
):
class
OPTMemoryAnalyzer
(
CacheFlowMemoryAnalyzer
):
def
__init__
(
def
__init__
(
...
...
cacheflow/models/model_utils.py
View file @
e548c148
...
@@ -5,9 +5,11 @@ import torch.nn as nn
...
@@ -5,9 +5,11 @@ import torch.nn as nn
from
transformers
import
AutoConfig
from
transformers
import
AutoConfig
from
cacheflow.models.memory_analyzer
import
CacheFlowMemoryAnalyzer
from
cacheflow.models.memory_analyzer
import
CacheFlowMemoryAnalyzer
from
cacheflow.models.memory_analyzer
import
GPT2MemoryAnalyzer
from
cacheflow.models.memory_analyzer
import
GPTNeoXMemoryAnalyzer
from
cacheflow.models.memory_analyzer
import
GPTNeoXMemoryAnalyzer
from
cacheflow.models.memory_analyzer
import
LlamaMemoryAnalyzer
from
cacheflow.models.memory_analyzer
import
LlamaMemoryAnalyzer
from
cacheflow.models.memory_analyzer
import
OPTMemoryAnalyzer
from
cacheflow.models.memory_analyzer
import
OPTMemoryAnalyzer
from
cacheflow.models.gpt2
import
GPT2LMHeadModel
from
cacheflow.models.gpt_neox
import
GPTNeoXForCausalLM
from
cacheflow.models.gpt_neox
import
GPTNeoXForCausalLM
from
cacheflow.models.llama
import
LlamaForCausalLM
from
cacheflow.models.llama
import
LlamaForCausalLM
from
cacheflow.models.opt
import
OPTForCausalLM
from
cacheflow.models.opt
import
OPTForCausalLM
...
@@ -15,6 +17,7 @@ from cacheflow.models.utils import get_torch_dtype
...
@@ -15,6 +17,7 @@ from cacheflow.models.utils import get_torch_dtype
_MODELS
=
{
_MODELS
=
{
'gpt2'
:
GPT2LMHeadModel
,
'llama'
:
LlamaForCausalLM
,
'llama'
:
LlamaForCausalLM
,
'opt'
:
OPTForCausalLM
,
'opt'
:
OPTForCausalLM
,
'stablelm'
:
GPTNeoXForCausalLM
,
'stablelm'
:
GPTNeoXForCausalLM
,
...
@@ -22,6 +25,7 @@ _MODELS = {
...
@@ -22,6 +25,7 @@ _MODELS = {
}
}
_MEMORY_ANALYZERS
=
{
_MEMORY_ANALYZERS
=
{
'gpt2'
:
GPT2MemoryAnalyzer
,
'llama'
:
LlamaMemoryAnalyzer
,
'llama'
:
LlamaMemoryAnalyzer
,
'opt'
:
OPTMemoryAnalyzer
,
'opt'
:
OPTMemoryAnalyzer
,
'stablelm'
:
GPTNeoXMemoryAnalyzer
,
'stablelm'
:
GPTNeoXMemoryAnalyzer
,
...
...
cacheflow/models/opt.py
View file @
e548c148
...
@@ -234,7 +234,7 @@ class OPTForCausalLM(nn.Module):
...
@@ -234,7 +234,7 @@ class OPTForCausalLM(nn.Module):
# TODO(zhuohan): create a new weight after implementing pipeline
# TODO(zhuohan): create a new weight after implementing pipeline
# parallelism
# parallelism
self
.
lm_head_weight
=
self
.
model
.
decoder
.
embed_tokens
.
weight
self
.
lm_head_weight
=
self
.
model
.
decoder
.
embed_tokens
.
weight
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
def
forward
(
def
forward
(
self
,
self
,
...
...
cacheflow/models/sample.py
View file @
e548c148
...
@@ -11,8 +11,9 @@ from cacheflow.parallel_utils.tensor_parallel import gather_from_tensor_model_pa
...
@@ -11,8 +11,9 @@ from cacheflow.parallel_utils.tensor_parallel import gather_from_tensor_model_pa
class
Sampler
(
nn
.
Module
):
class
Sampler
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
,
vocab_size
:
int
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
vocab_size
=
vocab_size
def
forward
(
def
forward
(
self
,
self
,
...
@@ -26,6 +27,8 @@ class Sampler(nn.Module):
...
@@ -26,6 +27,8 @@ class Sampler(nn.Module):
# Get the logits for the next tokens.
# Get the logits for the next tokens.
logits
=
torch
.
matmul
(
hidden_states
,
embedding
.
t
())
logits
=
torch
.
matmul
(
hidden_states
,
embedding
.
t
())
logits
=
gather_from_tensor_model_parallel_region
(
logits
)
logits
=
gather_from_tensor_model_parallel_region
(
logits
)
# Remove paddings in vocab.
logits
=
logits
[:,
:
self
.
vocab_size
]
# Apply temperature scaling.
# Apply temperature scaling.
temperatures
=
_get_temperatures
(
input_metadata
)
temperatures
=
_get_temperatures
(
input_metadata
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment