Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
95529e32
"src/diffusers/models/autoencoders/consistency_decoder_vae.py" did not exist on "77ba494b297c58b61bb50fd9b8a623c147ead333"
Unverified
Commit
95529e32
authored
Feb 21, 2024
by
Woosuk Kwon
Committed by
GitHub
Feb 21, 2024
Browse files
Use Llama RMSNorm custom op for Gemma (#2974)
parent
344020c9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
33 deletions
+27
-33
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+27
-33
No files found.
vllm/model_executor/models/gemma.py
View file @
95529e32
...
@@ -22,6 +22,7 @@ from transformers import GemmaConfig
...
@@ -22,6 +22,7 @@ from transformers import GemmaConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
LinearMethodBase
,
QKVParallelLinear
,
QKVParallelLinear
,
...
@@ -40,21 +41,6 @@ from vllm.sequence import SamplerOutput
...
@@ -40,21 +41,6 @@ from vllm.sequence import SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
GemmaRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
eps
:
float
=
1e-6
):
super
().
__init__
()
self
.
eps
=
eps
self
.
weight
=
nn
.
Parameter
(
torch
.
zeros
(
dim
))
def
_norm
(
self
,
x
):
return
x
*
torch
.
rsqrt
(
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
self
.
eps
)
def
forward
(
self
,
x
):
output
=
self
.
_norm
(
x
.
float
()).
type_as
(
x
)
return
output
*
(
1
+
self
.
weight
)
class
GemmaMLP
(
nn
.
Module
):
class
GemmaMLP
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -185,10 +171,10 @@ class GemmaDecoderLayer(nn.Module):
...
@@ -185,10 +171,10 @@ class GemmaDecoderLayer(nn.Module):
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
input_layernorm
=
Gemma
RMSNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
Gemma
RMSNorm
(
config
.
hidden_size
,
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -196,25 +182,27 @@ class GemmaDecoderLayer(nn.Module):
...
@@ -196,25 +182,27 @@ class GemmaDecoderLayer(nn.Module):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
# Self Attention
residual
=
hidden_states
if
residual
is
None
:
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
kv_cache
=
kv_cache
,
input_metadata
=
input_metadata
,
input_metadata
=
input_metadata
,
)
)
hidden_states
=
residual
+
hidden_states
# Fully Connected
# Fully Connected
residual
=
hidden_states
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
,
residual
return
hidden_states
class
GemmaModel
(
nn
.
Module
):
class
GemmaModel
(
nn
.
Module
):
...
@@ -235,7 +223,7 @@ class GemmaModel(nn.Module):
...
@@ -235,7 +223,7 @@ class GemmaModel(nn.Module):
GemmaDecoderLayer
(
config
,
linear_method
)
GemmaDecoderLayer
(
config
,
linear_method
)
for
_
in
range
(
config
.
num_hidden_layers
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
])
self
.
norm
=
Gemma
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -246,17 +234,19 @@ class GemmaModel(nn.Module):
...
@@ -246,17 +234,19 @@ class GemmaModel(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
embed_tokens
(
input_ids
)
# Normalize the embedding by sqrt(hidden_size)
# Normalize the embedding by sqrt(hidden_size)
hidden_states
=
hidden_states
*
(
self
.
config
.
hidden_size
**
0.5
)
hidden_states
*
=
self
.
config
.
hidden_size
**
0.5
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
positions
,
hidden_states
,
hidden_states
,
kv_caches
[
i
],
kv_caches
[
i
],
input_metadata
,
input_metadata
,
residual
,
)
)
hidden_states
=
self
.
norm
(
hidden_states
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
return
hidden_states
...
@@ -321,6 +311,10 @@ class GemmaForCausalLM(nn.Module):
...
@@ -321,6 +311,10 @@ class GemmaForCausalLM(nn.Module):
# Skip loading extra layer for lora models.
# Skip loading extra layer for lora models.
if
"lm_head"
in
name
:
if
"lm_head"
in
name
:
continue
continue
# GemmaRMSNorm is different from Llama's in that it multiplies
# (1 + weight) to the output, instead of just weight.
if
"norm.weight"
in
name
:
loaded_weight
+=
1.0
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
...
@@ -329,5 +323,5 @@ class GemmaForCausalLM(nn.Module):
...
@@ -329,5 +323,5 @@ class GemmaForCausalLM(nn.Module):
unloaded_params
=
params_dict
.
keys
()
-
loaded_params
unloaded_params
=
params_dict
.
keys
()
-
loaded_params
if
unloaded_params
:
if
unloaded_params
:
raise
RuntimeError
(
raise
RuntimeError
(
f
"Some weights are not initialized from checkpoints:
{
unloaded_params
}
"
"Some weights are not initialized from checkpoints: "
)
f
"
{
unloaded_params
}
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment