Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c8b26096
Unverified
Commit
c8b26096
authored
Sep 13, 2023
by
Younes Belkada
Committed by
GitHub
Sep 13, 2023
Browse files
[`core`] fix 4bit `num_parameters` (#26132)
* fix 4bit `num_parameters` * stronger check
parent
7db1ad63
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
3 deletions
+35
-3
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+24
-3
tests/quantization/bnb/test_4bit.py
tests/quantization/bnb/test_4bit.py
+11
-0
No files found.
src/transformers/modeling_utils.py
View file @
c8b26096
...
@@ -989,12 +989,33 @@ class ModuleUtilsMixin:
...
@@ -989,12 +989,33 @@ class ModuleUtilsMixin:
embedding_param_names
=
[
embedding_param_names
=
[
f
"
{
name
}
.weight"
for
name
,
module_type
in
self
.
named_modules
()
if
isinstance
(
module_type
,
nn
.
Embedding
)
f
"
{
name
}
.weight"
for
name
,
module_type
in
self
.
named_modules
()
if
isinstance
(
module_type
,
nn
.
Embedding
)
]
]
non_embedding
_parameters
=
[
total
_parameters
=
[
parameter
for
name
,
parameter
in
self
.
named_parameters
()
if
name
not
in
embedding_param_names
parameter
for
name
,
parameter
in
self
.
named_parameters
()
if
name
not
in
embedding_param_names
]
]
return
sum
(
p
.
numel
()
for
p
in
non_embedding_parameters
if
p
.
requires_grad
or
not
only_trainable
)
else
:
else
:
return
sum
(
p
.
numel
()
for
p
in
self
.
parameters
()
if
p
.
requires_grad
or
not
only_trainable
)
total_parameters
=
list
(
self
.
parameters
())
total_numel
=
[]
is_loaded_in_4bit
=
getattr
(
self
,
"is_loaded_in_4bit"
,
False
)
if
is_loaded_in_4bit
:
if
is_bitsandbytes_available
():
import
bitsandbytes
as
bnb
else
:
raise
ValueError
(
"bitsandbytes is not installed but it seems that the model has been loaded in 4bit precision, something went wrong"
" make sure to install bitsandbytes with `pip install bitsandbytes`."
)
for
param
in
total_parameters
:
if
param
.
requires_grad
or
not
only_trainable
:
# For 4bit models, we need to multiply the number of parameters by 2 as half of the parameters are
# used for the 4bit quantization (uint8 tensors are stored)
if
is_loaded_in_4bit
and
isinstance
(
param
,
bnb
.
nn
.
Params4bit
):
total_numel
.
append
(
param
.
numel
()
*
2
)
else
:
total_numel
.
append
(
param
.
numel
())
return
sum
(
total_numel
)
def
estimate_tokens
(
self
,
input_dict
:
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]])
->
int
:
def
estimate_tokens
(
self
,
input_dict
:
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]])
->
int
:
"""
"""
...
...
tests/quantization/bnb/test_4bit.py
View file @
c8b26096
...
@@ -118,6 +118,17 @@ class Bnb4BitTest(Base4bitTest):
...
@@ -118,6 +118,17 @@ class Bnb4BitTest(Base4bitTest):
gc
.
collect
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
def
test_quantization_num_parameters
(
self
):
r
"""
Test if the number of returned parameters is correct
See: https://github.com/huggingface/transformers/issues/25978
"""
num_params_4bit
=
self
.
model_4bit
.
num_parameters
()
num_params_fp16
=
self
.
model_fp16
.
num_parameters
()
self
.
assertEqual
(
num_params_4bit
,
num_params_fp16
)
def
test_quantization_config_json_serialization
(
self
):
def
test_quantization_config_json_serialization
(
self
):
r
"""
r
"""
A simple test to check if the quantization config is correctly serialized and deserialized
A simple test to check if the quantization config is correctly serialized and deserialized
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment