Unverified Commit 754fe85c authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[tests] add compile + offload tests for GGUF. (#11740)



* add compile + offload tests for GGUF.

* quality

* add init.

* prop.

* change to flux.

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent cc1f9a2c
......@@ -8,6 +8,7 @@ import torch.nn as nn
from diffusers import (
AuraFlowPipeline,
AuraFlowTransformer2DModel,
DiffusionPipeline,
FluxControlPipeline,
FluxPipeline,
FluxTransformer2DModel,
......@@ -32,9 +33,12 @@ from diffusers.utils.testing_utils import (
require_big_accelerator,
require_gguf_version_greater_or_equal,
require_peft_backend,
require_torch_version_greater,
torch_device,
)
from ..test_torch_compile_utils import QuantCompileTests
if is_gguf_available():
from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter
......@@ -647,3 +651,31 @@ class WanVACEGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
).to(torch_device, self.torch_dtype),
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
}
@require_torch_version_greater("2.7.1")
class GGUFCompileTests(QuantCompileTests):
torch_dtype = torch.bfloat16
gguf_ckpt = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
@property
def quantization_config(self):
return GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
def test_torch_compile(self):
super()._test_torch_compile(quantization_config=self.quantization_config)
def test_torch_compile_with_cpu_offload(self):
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
def test_torch_compile_with_group_offload_leaf(self):
super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config)
def _init_pipeline(self, *args, **kwargs):
transformer = FluxTransformer2DModel.from_single_file(
self.gguf_ckpt, quantization_config=self.quantization_config, torch_dtype=self.torch_dtype
)
pipe = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=self.torch_dtype
)
return pipe
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment