Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
60fb4f3b
Unverified
Commit
60fb4f3b
authored
Dec 23, 2024
by
Michael Goin
Committed by
GitHub
Dec 23, 2024
Browse files
[Bugfix] Add kv cache scales to gemma2.py (#11269)
parent
63afbe92
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
1 deletion
+17
-1
vllm/model_executor/models/gemma2.py
vllm/model_executor/models/gemma2.py
+17
-1
No files found.
vllm/model_executor/models/gemma2.py
View file @
60fb4f3b
...
@@ -31,11 +31,14 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...
@@ -31,11 +31,14 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
get_compressed_tensors_cache_scale
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
...
@@ -326,6 +329,15 @@ class Gemma2Model(nn.Module):
...
@@ -326,6 +329,15 @@ class Gemma2Model(nn.Module):
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
scale_name
:
=
get_compressed_tensors_cache_scale
(
name
):
# Loading kv cache scales for compressed-tensors quantization
param
=
params_dict
[
scale_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
loaded_weight
=
loaded_weight
[
0
]
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
scale_name
)
continue
for
(
param_name
,
shard_name
,
shard_id
)
in
stacked_params_mapping
:
for
(
param_name
,
shard_name
,
shard_id
)
in
stacked_params_mapping
:
if
shard_name
not
in
name
:
if
shard_name
not
in
name
:
continue
continue
...
@@ -343,6 +355,10 @@ class Gemma2Model(nn.Module):
...
@@ -343,6 +355,10 @@ class Gemma2Model(nn.Module):
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
# Remapping the name of FP8 kv-scale.
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
if
is_pp_missing_parameter
(
name
,
self
):
continue
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment