Unverified Commit 7dc52ea7 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

[Quantization] dtype fix for GGUF + fix BnB tests (#11159)

* update

* update

* update

* update
parent 739d6ec7
...@@ -282,6 +282,7 @@ class FromOriginalModelMixin: ...@@ -282,6 +282,7 @@ class FromOriginalModelMixin:
if quantization_config is not None: if quantization_config is not None:
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config) hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
hf_quantizer.validate_environment() hf_quantizer.validate_environment()
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
else: else:
hf_quantizer = None hf_quantizer = None
......
...@@ -90,13 +90,16 @@ class Base8bitTests(unittest.TestCase): ...@@ -90,13 +90,16 @@ class Base8bitTests(unittest.TestCase):
def get_dummy_inputs(self): def get_dummy_inputs(self):
prompt_embeds = load_pt( prompt_embeds = load_pt(
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt" "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt",
map_location="cpu",
) )
pooled_prompt_embeds = load_pt( pooled_prompt_embeds = load_pt(
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt" "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt",
map_location="cpu",
) )
latent_model_input = load_pt( latent_model_input = load_pt(
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt" "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt",
map_location="cpu",
) )
input_dict_for_transformer = { input_dict_for_transformer = {
......
...@@ -57,7 +57,7 @@ class GGUFSingleFileTesterMixin: ...@@ -57,7 +57,7 @@ class GGUFSingleFileTesterMixin:
if isinstance(module, torch.nn.Linear) and hasattr(module.weight, "quant_type"): if isinstance(module, torch.nn.Linear) and hasattr(module.weight, "quant_type"):
assert module.weight.dtype == torch.uint8 assert module.weight.dtype == torch.uint8
if module.bias is not None: if module.bias is not None:
assert module.bias.dtype == torch.float32 assert module.bias.dtype == self.torch_dtype
def test_gguf_memory_usage(self): def test_gguf_memory_usage(self):
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype) quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment