Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
95529e32
Unverified
Commit
95529e32
authored
Feb 21, 2024
by
Woosuk Kwon
Committed by
GitHub
Feb 21, 2024
Browse files
Use Llama RMSNorm custom op for Gemma (#2974)
parent
344020c9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
33 deletions
+27
-33
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+27
-33
No files found.
vllm/model_executor/models/gemma.py
View file @
95529e32
...
...
@@ -22,6 +22,7 @@ from transformers import GemmaConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
...
...
@@ -40,21 +41,6 @@ from vllm.sequence import SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
GemmaRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
eps
:
float
=
1e-6
):
super
().
__init__
()
self
.
eps
=
eps
self
.
weight
=
nn
.
Parameter
(
torch
.
zeros
(
dim
))
def
_norm
(
self
,
x
):
return
x
*
torch
.
rsqrt
(
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
self
.
eps
)
def
forward
(
self
,
x
):
output
=
self
.
_norm
(
x
.
float
()).
type_as
(
x
)
return
output
*
(
1
+
self
.
weight
)
class
GemmaMLP
(
nn
.
Module
):
def
__init__
(
...
...
@@ -185,10 +171,10 @@ class GemmaDecoderLayer(nn.Module):
intermediate_size
=
config
.
intermediate_size
,
linear_method
=
linear_method
,
)
self
.
input_layernorm
=
Gemma
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
Gemma
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
...
...
@@ -196,25 +182,27 @@ class GemmaDecoderLayer(nn.Module):
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input_metadata
=
input_metadata
,
)
hidden_states
=
residual
+
hidden_states
# Fully Connected
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
return
hidden_states
,
residual
class
GemmaModel
(
nn
.
Module
):
...
...
@@ -235,7 +223,7 @@ class GemmaModel(nn.Module):
GemmaDecoderLayer
(
config
,
linear_method
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
Gemma
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
...
...
@@ -246,17 +234,19 @@ class GemmaModel(nn.Module):
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
# Normalize the embedding by sqrt(hidden_size)
hidden_states
=
hidden_states
*
(
self
.
config
.
hidden_size
**
0.5
)
hidden_states
*
=
self
.
config
.
hidden_size
**
0.5
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
input_metadata
,
residual
,
)
hidden_states
=
self
.
norm
(
hidden_states
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
...
...
@@ -321,6 +311,10 @@ class GemmaForCausalLM(nn.Module):
# Skip loading extra layer for lora models.
if
"lm_head"
in
name
:
continue
# GemmaRMSNorm is different from Llama's in that it multiplies
# (1 + weight) to the output, instead of just weight.
if
"norm.weight"
in
name
:
loaded_weight
+=
1.0
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
...
...
@@ -329,5 +323,5 @@ class GemmaForCausalLM(nn.Module):
unloaded_params
=
params_dict
.
keys
()
-
loaded_params
if
unloaded_params
:
raise
RuntimeError
(
f
"Some weights are not initialized from checkpoints:
{
unloaded_params
}
"
)
"Some weights are not initialized from checkpoints: "
f
"
{
unloaded_params
}
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment