Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
580353da
Unverified
Commit
580353da
authored
Jun 28, 2024
by
Woosuk Kwon
Committed by
GitHub
Jun 29, 2024
Browse files
[Bugfix] Fix precisions in Gemma 1 (#5913)
parent
ba499444
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
14 deletions
+12
-14
tests/models/test_models.py
tests/models/test_models.py
+1
-0
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+11
-14
No files found.
tests/models/test_models.py
View file @
580353da
...
...
@@ -17,6 +17,7 @@ MODELS = [
"stabilityai/stablelm-3b-4e1t"
,
# "allenai/OLMo-1B", # Broken
"bigcode/starcoder2-3b"
,
"google/gemma-1.1-2b-it"
,
]
...
...
vllm/model_executor/models/gemma.py
View file @
580353da
...
...
@@ -26,14 +26,14 @@ from vllm.config import CacheConfig, LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
Gemma
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
GemmaRotaryEmbedding
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
...
...
@@ -148,12 +148,14 @@ class GemmaAttention(nn.Module):
quant_config
=
quant_config
,
)
self
.
rotary_emb
=
get_rope
(
# TODO(woosuk): Use the `get_rope` interface.
self
.
rotary_emb
=
GemmaRotaryEmbedding
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
max_position
_embeddings
=
max_position_embeddings
,
base
=
self
.
rope_theta
,
is_neox_style
=
True
,
dtype
=
torch
.
get_default_dtype
(),
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
...
...
@@ -204,10 +206,10 @@ class GemmaDecoderLayer(nn.Module):
hidden_activation
=
getattr
(
config
,
"hidden_activation"
,
None
),
quant_config
=
quant_config
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
input_layernorm
=
Gemma
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
Gemma
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
...
...
@@ -257,7 +259,7 @@ class GemmaModel(nn.Module):
GemmaDecoderLayer
(
config
,
cache_config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
Gemma
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
# Normalize the embedding by sqrt(hidden_size)
# The normalizer's data type should be downcasted to the model's
...
...
@@ -331,7 +333,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -388,10 +389,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# GemmaRMSNorm is different from Llama's in that it multiplies
# (1 + weight) to the output, instead of just weight.
if
"norm.weight"
in
name
:
loaded_weight
+=
1.0
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment