Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6b466771
Unverified
Commit
6b466771
authored
Oct 30, 2023
by
Younes Belkada
Committed by
GitHub
Oct 30, 2023
Browse files
[`tests` / `Quantization`] Fix bnb test (#27145)
* fix bnb test * link to GH issue
parent
57699496
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
4 deletions
+14
-4
tests/quantization/bnb/test_mixed_int8.py
tests/quantization/bnb/test_mixed_int8.py
+14
-4
No files found.
tests/quantization/bnb/test_mixed_int8.py
View file @
6b466771
...
...
@@ -124,13 +124,13 @@ class MixedInt8Test(BaseMixedInt8Test):
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
def
test_get_keys_to_not_convert
(
self
):
@
unittest
.
skip
(
"Un-skip once https://github.com/mosaicml/llm-foundry/issues/703 is resolved"
)
def
test_get_keys_to_not_convert_trust_remote_code
(
self
):
r
"""
Test the `get_keys_to_not_convert` function.
Test the `get_keys_to_not_convert` function
with `trust_remote_code` models
.
"""
from
accelerate
import
init_empty_weights
from
transformers
import
AutoModelForMaskedLM
,
Blip2ForConditionalGeneration
,
MptForCausalLM
,
OPTForCausalLM
from
transformers.integrations.bitsandbytes
import
get_keys_to_not_convert
model_id
=
"mosaicml/mpt-7b"
...
...
@@ -142,7 +142,17 @@ class MixedInt8Test(BaseMixedInt8Test):
config
,
trust_remote_code
=
True
,
code_revision
=
"72e5f594ce36f9cabfa2a9fd8f58b491eb467ee7"
)
self
.
assertEqual
(
get_keys_to_not_convert
(
model
),
[
"transformer.wte"
])
# without trust_remote_code
def
test_get_keys_to_not_convert
(
self
):
r
"""
Test the `get_keys_to_not_convert` function.
"""
from
accelerate
import
init_empty_weights
from
transformers
import
AutoModelForMaskedLM
,
Blip2ForConditionalGeneration
,
MptForCausalLM
,
OPTForCausalLM
from
transformers.integrations.bitsandbytes
import
get_keys_to_not_convert
model_id
=
"mosaicml/mpt-7b"
config
=
AutoConfig
.
from_pretrained
(
model_id
,
revision
=
"72e5f594ce36f9cabfa2a9fd8f58b491eb467ee7"
)
with
init_empty_weights
():
model
=
MptForCausalLM
(
config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment