Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
47952192
Unverified
Commit
47952192
authored
Jun 07, 2023
by
Younes Belkada
Committed by
GitHub
Jun 07, 2023
Browse files
[`bnb`] Fix bnb skip modules (#24043)
* fix skip modules test * oops * address comments
parent
a1160185
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
3 deletions
+27
-3
src/transformers/utils/bitsandbytes.py
src/transformers/utils/bitsandbytes.py
+7
-3
tests/bitsandbytes/test_mixed_int8.py
tests/bitsandbytes/test_mixed_int8.py
+20
-0
No files found.
src/transformers/utils/bitsandbytes.py
View file @
47952192
...
@@ -109,16 +109,18 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non
...
@@ -109,16 +109,18 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non
module
.
_parameters
[
tensor_name
]
=
new_value
module
.
_parameters
[
tensor_name
]
=
new_value
def
_replace_with_bnb_linear
(
model
,
modules_to_not_convert
=
None
,
current_key_name
=
None
,
quantization_config
=
None
):
def
_replace_with_bnb_linear
(
model
,
modules_to_not_convert
=
None
,
current_key_name
=
None
,
quantization_config
=
None
,
has_been_replaced
=
False
):
"""
"""
Private method that wraps the recursion for module replacement.
Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
"""
has_been_replaced
=
False
for
name
,
module
in
model
.
named_children
():
for
name
,
module
in
model
.
named_children
():
if
current_key_name
is
None
:
if
current_key_name
is
None
:
current_key_name
=
[]
current_key_name
=
[]
current_key_name
.
append
(
name
)
if
isinstance
(
module
,
nn
.
Linear
)
and
name
not
in
modules_to_not_convert
:
if
isinstance
(
module
,
nn
.
Linear
)
and
name
not
in
modules_to_not_convert
:
# Check if the current key is not in the `modules_to_not_convert`
# Check if the current key is not in the `modules_to_not_convert`
...
@@ -151,14 +153,16 @@ def _replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_nam
...
@@ -151,14 +153,16 @@ def _replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_nam
has_been_replaced
=
True
has_been_replaced
=
True
# Force requires grad to False to avoid unexpected errors
# Force requires grad to False to avoid unexpected errors
model
.
_modules
[
name
].
requires_grad_
(
False
)
model
.
_modules
[
name
].
requires_grad_
(
False
)
# Remove the last key for recursion
if
len
(
list
(
module
.
children
()))
>
0
:
if
len
(
list
(
module
.
children
()))
>
0
:
_
,
has_been_replaced
=
_replace_with_bnb_linear
(
_
,
has_been_replaced
=
_replace_with_bnb_linear
(
module
,
module
,
modules_to_not_convert
,
modules_to_not_convert
,
current_key_name
,
current_key_name
,
quantization_config
,
quantization_config
,
has_been_replaced
=
has_been_replaced
,
)
)
# Remove the last key for recursion
current_key_name
.
pop
(
-
1
)
return
model
,
has_been_replaced
return
model
,
has_been_replaced
...
...
tests/bitsandbytes/test_mixed_int8.py
View file @
47952192
...
@@ -146,6 +146,26 @@ class MixedInt8Test(BaseMixedInt8Test):
...
@@ -146,6 +146,26 @@ class MixedInt8Test(BaseMixedInt8Test):
if
name
not
in
[
"lm_head"
]
+
T5PreTrainedModel
.
_keep_in_fp32_modules
:
if
name
not
in
[
"lm_head"
]
+
T5PreTrainedModel
.
_keep_in_fp32_modules
:
self
.
assertTrue
(
module
.
weight
.
dtype
==
torch
.
int8
)
self
.
assertTrue
(
module
.
weight
.
dtype
==
torch
.
int8
)
def
test_llm_skip
(
self
):
r
"""
A simple test to check if `llm_int8_skip_modules` works as expected
"""
import
bitsandbytes
as
bnb
quantization_config
=
BitsAndBytesConfig
(
load_in_8bit
=
True
,
llm_int8_skip_modules
=
[
"classifier"
])
seq_classification_model
=
AutoModelForSequenceClassification
.
from_pretrained
(
"roberta-large-mnli"
,
quantization_config
=
quantization_config
)
self
.
assertTrue
(
seq_classification_model
.
roberta
.
encoder
.
layer
[
0
].
output
.
dense
.
weight
.
dtype
==
torch
.
int8
)
self
.
assertTrue
(
isinstance
(
seq_classification_model
.
roberta
.
encoder
.
layer
[
0
].
output
.
dense
,
bnb
.
nn
.
Linear8bitLt
)
)
self
.
assertTrue
(
isinstance
(
seq_classification_model
.
classifier
.
dense
,
nn
.
Linear
))
self
.
assertTrue
(
seq_classification_model
.
classifier
.
dense
.
weight
.
dtype
!=
torch
.
int8
)
self
.
assertTrue
(
isinstance
(
seq_classification_model
.
classifier
.
out_proj
,
nn
.
Linear
))
self
.
assertTrue
(
seq_classification_model
.
classifier
.
out_proj
!=
torch
.
int8
)
def
test_generate_quality
(
self
):
def
test_generate_quality
(
self
):
r
"""
r
"""
Test the generation quality of the quantized model and see that we are matching the expected output.
Test the generation quality of the quantized model and see that we are matching the expected output.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment