Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
28b18cc7
Unverified
Commit
28b18cc7
authored
Aug 01, 2025
by
Jee Jee Li
Committed by
GitHub
Aug 01, 2025
Browse files
[Quantization] Enable BNB support for InternS1 (#21953)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
49314869
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
43 additions
and
16 deletions
+43
-16
vllm/model_executor/model_loader/bitsandbytes_loader.py
vllm/model_executor/model_loader/bitsandbytes_loader.py
+25
-14
vllm/model_executor/utils.py
vllm/model_executor/utils.py
+18
-2
No files found.
vllm/model_executor/model_loader/bitsandbytes_loader.py
View file @
28b18cc7
...
...
@@ -34,7 +34,8 @@ from vllm.model_executor.model_loader.weight_utils import (
filter_duplicate_safetensors_files
,
filter_files_not_needed_for_inference
,
pt_weights_iterator
,
safetensors_weights_iterator
)
from
vllm.model_executor.models
import
is_pooling_model
from
vllm.model_executor.utils
import
(
get_packed_modules_mapping
,
from
vllm.model_executor.utils
import
(
get_moe_expert_mapping
,
get_packed_modules_mapping
,
set_weight_attrs
)
from
vllm.platforms
import
current_platform
...
...
@@ -43,6 +44,12 @@ from vllm.platforms import current_platform
logger
=
init_logger
(
__name__
)
def
is_moe_model
(
model
:
torch
.
nn
.
Module
)
->
bool
:
"""Checks if the model contains FusedMoE layers."""
return
bool
(
any
(
isinstance
(
module
,
FusedMoE
)
for
module
in
model
.
modules
()))
class
BitsAndBytesModelLoader
(
BaseModelLoader
):
"""Model loader to load model weights with BitAndBytes quantization."""
...
...
@@ -61,6 +68,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# Store all module names (from transformers) that support
# BNB quantization.
self
.
target_modules
:
list
[
str
]
=
[]
# Store the mapping of expert parameters for MoE models.
self
.
expert_params_mapping
:
list
[
tuple
[
str
,
str
,
int
,
str
]]
=
[]
# mapping weight names from transformers to vllm.
self
.
weight_mapper
:
Callable
=
lambda
name
:
name
self
.
pre_quant
:
bool
=
False
...
...
@@ -413,13 +422,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# in case model has a mixture of disk-merged and disk-split
# weights with same last name.
self
.
target_modules
.
append
(
name
)
elif
(
isinstance
(
module
,
FusedMoE
)
and
hasattr
(
module
.
quant_method
,
"quant_config"
)):
if
not
hasattr
(
model
,
"get_expert_mapping"
):
raise
AttributeError
(
f
"MoE Model
{
type
(
model
).
__name__
}
does not support "
"BitsAndBytes quantization yet. Ensure this model has "
"'get_expert_mapping' method."
)
elif
isinstance
(
module
,
FusedMoE
)
and
hasattr
(
module
.
quant_method
,
"quant_config"
):
# TODO: support FusedMoE with prequant and 8bit.
if
self
.
pre_quant
:
raise
ValueError
(
...
...
@@ -430,9 +434,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
"BitsAndBytes 8bit quantization with FusedMoE is not "
"supported yet."
)
# Get the corresponding weight name using module name and
#
get_
expert_mapping.
expert_mapping
=
model
.
get_expert_mapping
()
for
exp
in
expert_mapping
:
# expert_
params_
mapping.
for
exp
in
self
.
expert_
params_
mapping
:
weight_name
=
exp
[
1
]
rep_name
=
name
.
replace
(
"experts"
,
""
)
+
weight_name
.
removesuffix
(
"."
)
...
...
@@ -464,7 +468,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
elif
isinstance
(
module
,
(
RowParallelLinear
,
)):
self
.
column_sharded_weights_modules
.
append
(
name
)
elif
isinstance
(
module
,
FusedMoE
):
expert_mapping
=
model
.
get_expert
_mapping
()
expert_mapping
=
self
.
expert_params
_mapping
for
exp
in
expert_mapping
:
if
exp
[
-
1
]
==
"w2"
:
weight_name
=
exp
[
1
]
...
...
@@ -516,6 +520,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
self
.
is_pool_model
=
is_pooling_model
(
model
)
self
.
modules_mapping
=
ParamMapping
(
get_packed_modules_mapping
(
model
))
if
is_moe_model
(
model
):
self
.
expert_params_mapping
=
get_moe_expert_mapping
(
model
)
if
not
self
.
expert_params_mapping
:
raise
AttributeError
(
f
"MoE Model
{
type
(
model
).
__name__
}
does not support "
"BitsAndBytes quantization yet. Ensure this model has "
"'get_expert_mapping' method."
)
# For some models like Molmo, we need to use hf_to_vllm_mapper
# to ensure correct loading of weights.
if
hf_to_vllm_mapper
:
=
getattr
(
model
,
"hf_to_vllm_mapper"
,
None
):
...
...
@@ -569,10 +580,10 @@ class BitsAndBytesModelLoader(BaseModelLoader):
"""
from
bitsandbytes.functional
import
QuantState
if
not
hasattr
(
model
,
"get_expert
_mapping
"
)
:
if
not
self
.
expert_params
_mapping
:
return
dict
()
expert_mapping
=
model
.
get_expert
_mapping
()
expert_mapping
=
self
.
expert_params
_mapping
expert_qs_dict
=
{}
for
name
,
module
in
model
.
named_modules
():
if
not
isinstance
(
module
,
FusedMoE
):
...
...
vllm/model_executor/utils.py
View file @
28b18cc7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utils for model executor."""
import
copy
from
typing
import
Any
,
Optional
...
...
@@ -9,6 +10,7 @@ import torch
def
set_random_seed
(
seed
:
int
)
->
None
:
from
vllm.platforms
import
current_platform
current_platform
.
seed_everything
(
seed
)
...
...
@@ -29,7 +31,7 @@ def set_weight_attrs(
return
for
key
,
value
in
weight_attrs
.
items
():
assert
not
hasattr
(
weight
,
key
),
(
f
"Overwriting existing tensor attribute:
{
key
}
"
)
weight
,
key
),
f
"Overwriting existing tensor attribute:
{
key
}
"
# NOTE(woosuk): During weight loading, we often do something like:
# narrowed_tensor = param.data.narrow(0, offset, len)
...
...
@@ -41,6 +43,7 @@ def set_weight_attrs(
# we sync the param tensor after its weight loader is called.
# TODO(woosuk): Remove this hack once we have a better solution.
from
vllm.platforms
import
current_platform
if
current_platform
.
is_tpu
()
and
key
==
"weight_loader"
:
value
=
_make_synced_weight_loader
(
value
)
setattr
(
weight
,
key
,
value
)
...
...
@@ -78,3 +81,16 @@ def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
else
:
parent_map
.
update
(
child_map
)
return
parent_map
def
get_moe_expert_mapping
(
model
:
torch
.
nn
.
Module
,
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
if
parent_map
:
=
getattr
(
model
,
"get_expert_mapping"
,
None
):
return
parent_map
()
else
:
# We only check main components instead of whole model submodules
for
child
in
model
.
children
():
child_map
=
getattr
(
child
,
"get_expert_mapping"
,
None
)
if
child_map
is
not
None
:
return
child_map
()
return
[]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment