Unverified Commit ec29d25d authored by Juri Ganitkevitch's avatar Juri Ganitkevitch Committed by GitHub
Browse files

Add missing None check for hf_quantizer (#28804)



* Add missing None check for hf_quantizer

* Add test, fix logic.

* make style

* Switch test model to Mistral

* Comment

* Update tests/test_modeling_utils.py

---------
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
parent 1efb21c7
...@@ -3727,10 +3727,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -3727,10 +3727,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if param.device == torch.device("meta"): if param.device == torch.device("meta"):
value = torch.empty(*param.size(), dtype=target_dtype) value = torch.empty(*param.size(), dtype=target_dtype)
if getattr( if (
hf_quantizer, "requires_parameters_quantization", False hf_quantizer is None
) or not hf_quantizer.check_quantized_param( or getattr(hf_quantizer, "requires_parameters_quantization", False)
or not hf_quantizer.check_quantized_param(
model, param_value=value, param_name=key, state_dict={} model, param_value=value, param_name=key, state_dict={}
)
): ):
set_module_tensor_to_device(model, key, "cpu", value) set_module_tensor_to_device(model, key, "cpu", value)
else: else:
......
...@@ -34,6 +34,7 @@ from requests.exceptions import HTTPError ...@@ -34,6 +34,7 @@ from requests.exceptions import HTTPError
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoModel, AutoModel,
AutoModelForSequenceClassification,
OwlViTForObjectDetection, OwlViTForObjectDetection,
PretrainedConfig, PretrainedConfig,
is_torch_available, is_torch_available,
...@@ -201,6 +202,7 @@ if is_tf_available(): ...@@ -201,6 +202,7 @@ if is_tf_available():
TINY_T5 = "patrickvonplaten/t5-tiny-random" TINY_T5 = "patrickvonplaten/t5-tiny-random"
TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification" TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification"
TINY_MISTRAL = "hf-internal-testing/tiny-random-MistralForCausalLM"
def check_models_equal(model1, model2): def check_models_equal(model1, model2):
...@@ -300,6 +302,15 @@ class ModelUtilsTest(TestCasePlus): ...@@ -300,6 +302,15 @@ class ModelUtilsTest(TestCasePlus):
BertModel.from_pretrained(TINY_T5) BertModel.from_pretrained(TINY_T5)
self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out) self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out)
@require_accelerate
def test_model_from_pretrained_with_none_quantization_config(self):
# Needs a device_map for to enter the low_cpu_mem branch. We also load AutoModelForSequenceClassification
# deliberately to enter the missing keys branch.
model = AutoModelForSequenceClassification.from_pretrained(
TINY_MISTRAL, device_map="auto", quantization_config=None
)
self.assertIsNotNone(model)
def test_model_from_config_torch_dtype(self): def test_model_from_config_torch_dtype(self):
# test that the model can be instantiated with dtype of user's choice - as long as it's a # test that the model can be instantiated with dtype of user's choice - as long as it's a
# float dtype. To make it happen config.torch_dtype needs to be set before instantiating the # float dtype. To make it happen config.torch_dtype needs to be set before instantiating the
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment