Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
47952192
Unverified
Commit
47952192
authored
Jun 07, 2023
by
Younes Belkada
Committed by
GitHub
Jun 07, 2023
Browse files
[`bnb`] Fix bnb skip modules (#24043)
* fix skip modules test * oops * address comments
parent
a1160185
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
3 deletions
+27
-3
src/transformers/utils/bitsandbytes.py
src/transformers/utils/bitsandbytes.py
+7
-3
tests/bitsandbytes/test_mixed_int8.py
tests/bitsandbytes/test_mixed_int8.py
+20
-0
No files found.
src/transformers/utils/bitsandbytes.py
View file @
47952192
...
...
@@ -109,16 +109,18 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non
module
.
_parameters
[
tensor_name
]
=
new_value
def
_replace_with_bnb_linear
(
model
,
modules_to_not_convert
=
None
,
current_key_name
=
None
,
quantization_config
=
None
):
def
_replace_with_bnb_linear
(
model
,
modules_to_not_convert
=
None
,
current_key_name
=
None
,
quantization_config
=
None
,
has_been_replaced
=
False
):
"""
Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
has_been_replaced
=
False
for
name
,
module
in
model
.
named_children
():
if
current_key_name
is
None
:
current_key_name
=
[]
current_key_name
.
append
(
name
)
if
isinstance
(
module
,
nn
.
Linear
)
and
name
not
in
modules_to_not_convert
:
# Check if the current key is not in the `modules_to_not_convert`
...
...
@@ -151,14 +153,16 @@ def _replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_nam
has_been_replaced
=
True
# Force requires grad to False to avoid unexpected errors
model
.
_modules
[
name
].
requires_grad_
(
False
)
# Remove the last key for recursion
if
len
(
list
(
module
.
children
()))
>
0
:
_
,
has_been_replaced
=
_replace_with_bnb_linear
(
module
,
modules_to_not_convert
,
current_key_name
,
quantization_config
,
has_been_replaced
=
has_been_replaced
,
)
# Remove the last key for recursion
current_key_name
.
pop
(
-
1
)
return
model
,
has_been_replaced
...
...
tests/bitsandbytes/test_mixed_int8.py
View file @
47952192
...
...
@@ -146,6 +146,26 @@ class MixedInt8Test(BaseMixedInt8Test):
if
name
not
in
[
"lm_head"
]
+
T5PreTrainedModel
.
_keep_in_fp32_modules
:
self
.
assertTrue
(
module
.
weight
.
dtype
==
torch
.
int8
)
def
test_llm_skip
(
self
):
r
"""
A simple test to check if `llm_int8_skip_modules` works as expected
"""
import
bitsandbytes
as
bnb
quantization_config
=
BitsAndBytesConfig
(
load_in_8bit
=
True
,
llm_int8_skip_modules
=
[
"classifier"
])
seq_classification_model
=
AutoModelForSequenceClassification
.
from_pretrained
(
"roberta-large-mnli"
,
quantization_config
=
quantization_config
)
self
.
assertTrue
(
seq_classification_model
.
roberta
.
encoder
.
layer
[
0
].
output
.
dense
.
weight
.
dtype
==
torch
.
int8
)
self
.
assertTrue
(
isinstance
(
seq_classification_model
.
roberta
.
encoder
.
layer
[
0
].
output
.
dense
,
bnb
.
nn
.
Linear8bitLt
)
)
self
.
assertTrue
(
isinstance
(
seq_classification_model
.
classifier
.
dense
,
nn
.
Linear
))
self
.
assertTrue
(
seq_classification_model
.
classifier
.
dense
.
weight
.
dtype
!=
torch
.
int8
)
self
.
assertTrue
(
isinstance
(
seq_classification_model
.
classifier
.
out_proj
,
nn
.
Linear
))
self
.
assertTrue
(
seq_classification_model
.
classifier
.
out_proj
!=
torch
.
int8
)
def
test_generate_quality
(
self
):
r
"""
Test the generation quality of the quantized model and see that we are matching the expected output.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment