Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
08effbff
Unverified
Commit
08effbff
authored
Dec 26, 2024
by
Sangchun Ha (Patrick)
Committed by
GitHub
Dec 26, 2024
Browse files
Error occurs when loading the gemma model in bitsandbytes format. (#2557)
parent
60bd3272
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
11 deletions
+41
-11
python/sglang/srt/model_loader/loader.py
python/sglang/srt/model_loader/loader.py
+22
-11
python/sglang/srt/models/gemma2.py
python/sglang/srt/models/gemma2.py
+19
-0
No files found.
python/sglang/srt/model_loader/loader.py
View file @
08effbff
...
@@ -770,6 +770,21 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -770,6 +770,21 @@ class BitsAndBytesModelLoader(BaseModelLoader):
quant_state_dict
,
quant_state_dict
,
)
)
def
_is_8bit_weight_name
(
self
,
weight_name
:
str
):
quantized_suffix
=
{
".scb"
,
".weight_format"
}
return
any
(
weight_name
.
lower
().
endswith
(
suffix
)
for
suffix
in
quantized_suffix
)
def
_is_4bit_weight_name
(
self
,
weight_name
:
str
):
quantized_suffix
=
{
"absmax"
,
"quant_map"
,
"nested_absmax"
,
"nested_quant_map"
,
"bitsandbytes"
,
}
suffix
=
weight_name
.
split
(
"."
)[
-
1
]
return
any
(
q_suffix
in
suffix
for
q_suffix
in
quantized_suffix
)
def
_quantized_8bit_generator
(
def
_quantized_8bit_generator
(
self
,
hf_weights_files
,
use_safetensors
,
quant_state_dict
self
,
hf_weights_files
,
use_safetensors
,
quant_state_dict
)
->
Generator
:
)
->
Generator
:
...
@@ -779,21 +794,18 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -779,21 +794,18 @@ class BitsAndBytesModelLoader(BaseModelLoader):
if
not
weight_name
.
lower
().
endswith
(
".scb"
):
if
not
weight_name
.
lower
().
endswith
(
".scb"
):
continue
continue
weight_key
=
weight_name
.
lower
().
replace
(
".scb"
,
".
q
weight"
)
weight_key
=
weight_name
.
lower
().
replace
(
".scb"
,
".weight"
)
quant_state_dict
[
weight_key
]
=
weight_tensor
quant_state_dict
[
weight_key
]
=
weight_tensor
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
hf_weights_files
,
use_safetensors
):
):
if
self
.
_is_8bit_weight_name
(
weight_name
):
if
not
weight_name
.
endswith
((
".weight"
,
".bias"
)):
continue
continue
qweight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
if
weight_name
in
quant_state_dict
:
if
qweight_name
in
quant_state_dict
:
set_weight_attrs
(
weight_tensor
,
{
"load_in_8bit"
:
True
})
set_weight_attrs
(
weight_tensor
,
{
"load_in_8bit"
:
True
})
yield
q
weight_name
,
weight_tensor
yield
weight_name
,
weight_tensor
else
:
else
:
yield
weight_name
,
weight_tensor
yield
weight_name
,
weight_tensor
...
@@ -806,7 +818,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -806,7 +818,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
weight_iterator
=
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
)
weight_iterator
=
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
)
temp_state_dict
=
{}
temp_state_dict
=
{}
for
weight_name
,
weight_tensor
in
weight_iterator
:
for
weight_name
,
weight_tensor
in
weight_iterator
:
if
weight_name
.
endswith
((
".weight"
,
".bias"
)
):
if
not
self
.
_is_4bit_weight_name
(
weight_name
):
continue
continue
# bitsandbytes library requires
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__* in CPU
# weight.quant_state.bitsandbytes__* in CPU
...
@@ -830,16 +842,15 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -830,16 +842,15 @@ class BitsAndBytesModelLoader(BaseModelLoader):
hf_weights_files
,
use_safetensors
hf_weights_files
,
use_safetensors
):
):
if
not
weight_name
.
endswith
((
".weight"
,
".bias"
)
):
if
self
.
_is_4bit_weight_name
(
weight_name
):
continue
continue
if
(
f
"
{
weight_name
}
.quant_state.bitsandbytes__nf4"
in
temp_state_dict
)
or
(
if
(
f
"
{
weight_name
}
.quant_state.bitsandbytes__nf4"
in
temp_state_dict
)
or
(
f
"
{
weight_name
}
.quant_state.bitsandbytes__fp4"
in
temp_state_dict
f
"
{
weight_name
}
.quant_state.bitsandbytes__fp4"
in
temp_state_dict
):
):
quant_state
=
_parse_quant_state
(
weight_name
,
temp_state_dict
)
quant_state
=
_parse_quant_state
(
weight_name
,
temp_state_dict
)
weight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
quant_state_dict
[
weight_name
]
=
quant_state
quant_state_dict
[
weight_name
]
=
quant_state
yield
weight_name
.
replace
(
".weight"
,
".qweight"
)
,
weight_tensor
yield
weight_name
,
weight_tensor
else
:
else
:
yield
weight_name
,
weight_tensor
yield
weight_name
,
weight_tensor
...
...
python/sglang/srt/models/gemma2.py
View file @
08effbff
...
@@ -307,6 +307,25 @@ class Gemma2Model(nn.Module):
...
@@ -307,6 +307,25 @@ class Gemma2Model(nn.Module):
class
Gemma2ForCausalLM
(
nn
.
Module
):
class
Gemma2ForCausalLM
(
nn
.
Module
):
# BitandBytes specific attributes
default_bitsandbytes_target_modules
=
[
".gate_proj."
,
".down_proj."
,
".up_proj."
,
".q_proj."
,
".k_proj."
,
".v_proj."
,
".o_proj."
,
]
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
"q_proj"
:
(
"qkv_proj"
,
0
),
"k_proj"
:
(
"qkv_proj"
,
1
),
"v_proj"
:
(
"qkv_proj"
,
2
),
"gate_proj"
:
(
"gate_up_proj"
,
0
),
"up_proj"
:
(
"gate_up_proj"
,
1
),
}
packed_modules_mapping
=
{
packed_modules_mapping
=
{
"qkv_proj"
:
[
"qkv_proj"
:
[
"q_proj"
,
"q_proj"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment