Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cdc72e3c
Unverified
Commit
cdc72e3c
authored
Oct 08, 2024
by
Hui Liu
Committed by
GitHub
Oct 09, 2024
Browse files
[Model] Remap FP8 kv_scale in CommandR and DBRX (#9174)
parent
7627172b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
2 deletions
+14
-2
vllm/model_executor/models/commandr.py
vllm/model_executor/models/commandr.py
+7
-1
vllm/model_executor/models/dbrx.py
vllm/model_executor/models/dbrx.py
+7
-1
No files found.
vllm/model_executor/models/commandr.py
View file @
cdc72e3c
...
@@ -41,7 +41,8 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
...
@@ -41,7 +41,8 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
row_parallel_weight_loader
)
default_weight_loader
,
maybe_remap_kv_scale_name
,
row_parallel_weight_loader
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
...
@@ -426,6 +427,11 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -426,6 +427,11 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
# Remapping the name of FP8 kv-scale.
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
if
is_pp_missing_parameter
(
name
,
self
):
continue
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
...
...
vllm/model_executor/models/dbrx.py
View file @
cdc72e3c
...
@@ -18,7 +18,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
...
@@ -18,7 +18,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs.dbrx
import
DbrxConfig
from
vllm.transformers_utils.configs.dbrx
import
DbrxConfig
...
@@ -425,6 +426,11 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
...
@@ -425,6 +426,11 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
weight_loader
(
param
,
loaded_weight
,
weight_name
)
weight_loader
(
param
,
loaded_weight
,
weight_name
)
break
break
else
:
else
:
# Remapping the name of FP8 kv-scale.
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
if
is_pp_missing_parameter
(
name
,
self
):
continue
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment