Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
82c540be
Unverified
Commit
82c540be
authored
Mar 27, 2024
by
Woosuk Kwon
Committed by
GitHub
Mar 27, 2024
Browse files
[Bugfix] More faithful implementation of Gemma (#3653)
parent
8f44facd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
43 additions
and
3 deletions
+43
-3
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+43
-3
No files found.
vllm/model_executor/models/gemma.py
View file @
82c540be
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only Gemma model compatible with HuggingFace weights."""
"""Inference-only Gemma model compatible with HuggingFace weights."""
from
functools
import
lru_cache
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -22,6 +23,7 @@ from transformers import GemmaConfig
...
@@ -22,6 +23,7 @@ from transformers import GemmaConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
LoRAConfig
from
vllm.config
import
LoRAConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
...
@@ -40,6 +42,34 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
...
@@ -40,6 +42,34 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
logger
=
init_logger
(
__name__
)
@
lru_cache
(
maxsize
=
None
)
def
_get_gemma_act_fn
(
hidden_act
:
Optional
[
str
],
hidden_activation
:
Optional
[
str
],
)
->
nn
.
Module
:
if
hidden_activation
is
None
:
if
hidden_act
is
not
None
:
logger
.
warning
(
"Gemma's activation function was incorrectly set to exact GeLU "
"in the config JSON file when it was initially released. "
"Changing the activation function to approximate GeLU "
"(`gelu_pytorch_tanh`). If you want to use the legacy "
f
"`
{
hidden_act
}
`, edit the config JSON to set "
f
"`hidden_activation=
{
hidden_act
}
` instead of `hidden_act`. "
"See https://github.com/huggingface/transformers/pull/29402 "
"for more details."
)
return
GeluAndMul
(
approximate
=
"tanh"
)
elif
hidden_activation
==
"gelu_pytorch_tanh"
:
return
GeluAndMul
(
approximate
=
"tanh"
)
elif
hidden_activation
==
"gelu"
:
return
GeluAndMul
(
approximate
=
"none"
)
else
:
raise
ValueError
(
f
"Activation function
{
hidden_act
}
is not "
"supported for Gemma models."
)
class
GemmaMLP
(
nn
.
Module
):
class
GemmaMLP
(
nn
.
Module
):
...
@@ -47,6 +77,8 @@ class GemmaMLP(nn.Module):
...
@@ -47,6 +77,8 @@ class GemmaMLP(nn.Module):
self
,
self
,
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
Optional
[
str
]
=
None
,
hidden_activation
:
Optional
[
str
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -58,7 +90,7 @@ class GemmaMLP(nn.Module):
...
@@ -58,7 +90,7 @@ class GemmaMLP(nn.Module):
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
)
linear_method
=
linear_method
)
self
.
act_fn
=
GeluAndMul
(
)
self
.
act_fn
=
_get_gemma_act_fn
(
hidden_act
,
hidden_activation
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
...
@@ -162,6 +194,8 @@ class GemmaDecoderLayer(nn.Module):
...
@@ -162,6 +194,8 @@ class GemmaDecoderLayer(nn.Module):
self
.
mlp
=
GemmaMLP
(
self
.
mlp
=
GemmaMLP
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
hidden_activation
=
getattr
(
config
,
"hidden_activation"
,
None
),
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
...
@@ -218,6 +252,13 @@ class GemmaModel(nn.Module):
...
@@ -218,6 +252,13 @@ class GemmaModel(nn.Module):
])
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
# Normalize the embedding by sqrt(hidden_size)
# The normalizer's data type should be downcasted to the model's
# data type such as bfloat16, not float32.
# See https://github.com/huggingface/transformers/pull/29402
normalizer
=
self
.
config
.
hidden_size
**
0.5
self
.
register_buffer
(
"normalizer"
,
torch
.
tensor
(
normalizer
))
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -226,8 +267,7 @@ class GemmaModel(nn.Module):
...
@@ -226,8 +267,7 @@ class GemmaModel(nn.Module):
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
embed_tokens
(
input_ids
)
# Normalize the embedding by sqrt(hidden_size)
hidden_states
*=
self
.
normalizer
hidden_states
*=
self
.
config
.
hidden_size
**
0.5
residual
=
None
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
for
i
in
range
(
len
(
self
.
layers
)):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment