Unverified Commit 6b466771 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`tests` / `Quantization`] Fix bnb test (#27145)

* fix bnb test

* link to GH issue
parent 57699496
...@@ -124,13 +124,13 @@ class MixedInt8Test(BaseMixedInt8Test): ...@@ -124,13 +124,13 @@ class MixedInt8Test(BaseMixedInt8Test):
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
def test_get_keys_to_not_convert(self): @unittest.skip("Un-skip once https://github.com/mosaicml/llm-foundry/issues/703 is resolved")
def test_get_keys_to_not_convert_trust_remote_code(self):
r""" r"""
Test the `get_keys_to_not_convert` function. Test the `get_keys_to_not_convert` function with `trust_remote_code` models.
""" """
from accelerate import init_empty_weights from accelerate import init_empty_weights
from transformers import AutoModelForMaskedLM, Blip2ForConditionalGeneration, MptForCausalLM, OPTForCausalLM
from transformers.integrations.bitsandbytes import get_keys_to_not_convert from transformers.integrations.bitsandbytes import get_keys_to_not_convert
model_id = "mosaicml/mpt-7b" model_id = "mosaicml/mpt-7b"
...@@ -142,7 +142,17 @@ class MixedInt8Test(BaseMixedInt8Test): ...@@ -142,7 +142,17 @@ class MixedInt8Test(BaseMixedInt8Test):
config, trust_remote_code=True, code_revision="72e5f594ce36f9cabfa2a9fd8f58b491eb467ee7" config, trust_remote_code=True, code_revision="72e5f594ce36f9cabfa2a9fd8f58b491eb467ee7"
) )
self.assertEqual(get_keys_to_not_convert(model), ["transformer.wte"]) self.assertEqual(get_keys_to_not_convert(model), ["transformer.wte"])
# without trust_remote_code
def test_get_keys_to_not_convert(self):
r"""
Test the `get_keys_to_not_convert` function.
"""
from accelerate import init_empty_weights
from transformers import AutoModelForMaskedLM, Blip2ForConditionalGeneration, MptForCausalLM, OPTForCausalLM
from transformers.integrations.bitsandbytes import get_keys_to_not_convert
model_id = "mosaicml/mpt-7b"
config = AutoConfig.from_pretrained(model_id, revision="72e5f594ce36f9cabfa2a9fd8f58b491eb467ee7") config = AutoConfig.from_pretrained(model_id, revision="72e5f594ce36f9cabfa2a9fd8f58b491eb467ee7")
with init_empty_weights(): with init_empty_weights():
model = MptForCausalLM(config) model = MptForCausalLM(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment