Unverified Commit 7567adfc authored by Yao Matrix's avatar Yao Matrix Committed by GitHub
Browse files

enable 28 GGUF test cases on XPU (#11404)



* enable gguf test cases on XPU
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>

* make SD35LargeGGUFSingleFileTests::test_pipeline_inference pas
Signed-off-by: default avatarroot <root@a4bf01945cfe.jf.intel.com>

* make FluxControlLoRAGGUFTests::test_lora_loading pass
Signed-off-by: default avatarYao Matrix <matrix.yao@intel.com>

* polish code
Signed-off-by: default avatarYao Matrix <matrix.yao@intel.com>

* Apply style fixes

---------
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>
Signed-off-by: default avatarroot <root@a4bf01945cfe.jf.intel.com>
Signed-off-by: default avatarYao Matrix <matrix.yao@intel.com>
Co-authored-by: default avatarroot <root@a4bf01945cfe.jf.intel.com>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 3da98e7e
...@@ -91,18 +91,19 @@ def _maybe_dequantize_weight_for_expanded_lora(model, module): ...@@ -91,18 +91,19 @@ def _maybe_dequantize_weight_for_expanded_lora(model, module):
) )
weight_on_cpu = False weight_on_cpu = False
if not module.weight.is_cuda: if module.weight.device.type == "cpu":
weight_on_cpu = True weight_on_cpu = True
device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
if is_bnb_4bit_quantized: if is_bnb_4bit_quantized:
module_weight = dequantize_bnb_weight( module_weight = dequantize_bnb_weight(
module.weight.cuda() if weight_on_cpu else module.weight, module.weight.to(device) if weight_on_cpu else module.weight,
state=module.weight.quant_state, state=module.weight.quant_state,
dtype=model.dtype, dtype=model.dtype,
).data ).data
elif is_gguf_quantized: elif is_gguf_quantized:
module_weight = dequantize_gguf_tensor( module_weight = dequantize_gguf_tensor(
module.weight.cuda() if weight_on_cpu else module.weight, module.weight.to(device) if weight_on_cpu else module.weight,
) )
module_weight = module_weight.to(model.dtype) module_weight = module_weight.to(model.dtype)
else: else:
......
...@@ -150,9 +150,14 @@ class GGUFQuantizer(DiffusersQuantizer): ...@@ -150,9 +150,14 @@ class GGUFQuantizer(DiffusersQuantizer):
is_model_on_cpu = model.device.type == "cpu" is_model_on_cpu = model.device.type == "cpu"
if is_model_on_cpu: if is_model_on_cpu:
logger.info( logger.info(
"Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device." "Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to accelerator. After dequantization, will move the model back to CPU again to preserve the previous device."
) )
model.to(torch.cuda.current_device()) device = (
torch.accelerator.current_accelerator()
if hasattr(torch, "accelerator")
else torch.cuda.current_device()
)
model.to(device)
model = _dequantize_gguf_and_restore_linear(model, self.modules_to_not_convert) model = _dequantize_gguf_and_restore_linear(model, self.modules_to_not_convert)
if is_model_on_cpu: if is_model_on_cpu:
......
...@@ -17,11 +17,16 @@ from diffusers import ( ...@@ -17,11 +17,16 @@ from diffusers import (
) )
from diffusers.utils import load_image from diffusers.utils import load_image
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
Expectations,
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_peak_memory_stats,
enable_full_determinism,
is_gguf_available, is_gguf_available,
nightly, nightly,
numpy_cosine_similarity_distance, numpy_cosine_similarity_distance,
require_accelerate, require_accelerate,
require_big_gpu_with_torch_cuda, require_big_accelerator,
require_gguf_version_greater_or_equal, require_gguf_version_greater_or_equal,
require_peft_backend, require_peft_backend,
torch_device, torch_device,
...@@ -31,9 +36,11 @@ from diffusers.utils.testing_utils import ( ...@@ -31,9 +36,11 @@ from diffusers.utils.testing_utils import (
if is_gguf_available(): if is_gguf_available():
from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter
enable_full_determinism()
@nightly @nightly
@require_big_gpu_with_torch_cuda @require_big_accelerator
@require_accelerate @require_accelerate
@require_gguf_version_greater_or_equal("0.10.0") @require_gguf_version_greater_or_equal("0.10.0")
class GGUFSingleFileTesterMixin: class GGUFSingleFileTesterMixin:
...@@ -68,15 +75,15 @@ class GGUFSingleFileTesterMixin: ...@@ -68,15 +75,15 @@ class GGUFSingleFileTesterMixin:
model = self.model_cls.from_single_file( model = self.model_cls.from_single_file(
self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype
) )
model.to("cuda") model.to(torch_device)
assert (model.get_memory_footprint() / 1024**3) < self.expected_memory_use_in_gb assert (model.get_memory_footprint() / 1024**3) < self.expected_memory_use_in_gb
inputs = self.get_dummy_inputs() inputs = self.get_dummy_inputs()
torch.cuda.reset_peak_memory_stats() backend_reset_peak_memory_stats(torch_device)
torch.cuda.empty_cache() backend_empty_cache(torch_device)
with torch.no_grad(): with torch.no_grad():
model(**inputs) model(**inputs)
max_memory = torch.cuda.max_memory_allocated() max_memory = backend_max_memory_allocated(torch_device)
assert (max_memory / 1024**3) < self.expected_memory_use_in_gb assert (max_memory / 1024**3) < self.expected_memory_use_in_gb
def test_keep_modules_in_fp32(self): def test_keep_modules_in_fp32(self):
...@@ -106,7 +113,8 @@ class GGUFSingleFileTesterMixin: ...@@ -106,7 +113,8 @@ class GGUFSingleFileTesterMixin:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
# Tries with a `device` and `dtype` # Tries with a `device` and `dtype`
model.to(device="cuda:0", dtype=torch.float16) device_0 = f"{torch_device}:0"
model.to(device=device_0, dtype=torch.float16)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
# Tries with a cast # Tries with a cast
...@@ -117,7 +125,7 @@ class GGUFSingleFileTesterMixin: ...@@ -117,7 +125,7 @@ class GGUFSingleFileTesterMixin:
model.half() model.half()
# This should work # This should work
model.to("cuda") model.to(torch_device)
def test_dequantize_model(self): def test_dequantize_model(self):
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype) quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
...@@ -146,11 +154,11 @@ class FluxGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): ...@@ -146,11 +154,11 @@ class FluxGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
def setUp(self): def setUp(self):
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def tearDown(self): def tearDown(self):
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def get_dummy_inputs(self): def get_dummy_inputs(self):
return { return {
...@@ -233,11 +241,11 @@ class SD35LargeGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase) ...@@ -233,11 +241,11 @@ class SD35LargeGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase)
def setUp(self): def setUp(self):
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def tearDown(self): def tearDown(self):
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def get_dummy_inputs(self): def get_dummy_inputs(self):
return { return {
...@@ -267,40 +275,79 @@ class SD35LargeGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase) ...@@ -267,40 +275,79 @@ class SD35LargeGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase)
prompt = "a cat holding a sign that says hello" prompt = "a cat holding a sign that says hello"
output = pipe( output = pipe(
prompt=prompt, num_inference_steps=2, generator=torch.Generator("cpu").manual_seed(0), output_type="np" prompt=prompt,
num_inference_steps=2,
generator=torch.Generator("cpu").manual_seed(0),
output_type="np",
).images[0] ).images[0]
output_slice = output[:3, :3, :].flatten() output_slice = output[:3, :3, :].flatten()
expected_slice = np.array( expected_slices = Expectations(
[ {
0.17578125, ("xpu", 3): np.array(
0.27539062, [
0.27734375, 0.19335938,
0.11914062, 0.3125,
0.26953125, 0.3203125,
0.25390625, 0.1328125,
0.109375, 0.3046875,
0.25390625, 0.296875,
0.25, 0.11914062,
0.15039062, 0.2890625,
0.26171875, 0.2890625,
0.28515625, 0.16796875,
0.13671875, 0.30273438,
0.27734375, 0.33203125,
0.28515625, 0.14648438,
0.12109375, 0.31640625,
0.26757812, 0.33007812,
0.265625, 0.12890625,
0.16210938, 0.3046875,
0.29882812, 0.30859375,
0.28515625, 0.17773438,
0.15625, 0.33789062,
0.30664062, 0.33203125,
0.27734375, 0.16796875,
0.14648438, 0.34570312,
0.29296875, 0.32421875,
0.26953125, 0.15625,
] 0.33203125,
0.31445312,
]
),
("cuda", 7): np.array(
[
0.17578125,
0.27539062,
0.27734375,
0.11914062,
0.26953125,
0.25390625,
0.109375,
0.25390625,
0.25,
0.15039062,
0.26171875,
0.28515625,
0.13671875,
0.27734375,
0.28515625,
0.12109375,
0.26757812,
0.265625,
0.16210938,
0.29882812,
0.28515625,
0.15625,
0.30664062,
0.27734375,
0.14648438,
0.29296875,
0.26953125,
]
),
}
) )
expected_slice = expected_slices.get_expectation()
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice) max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
assert max_diff < 1e-4 assert max_diff < 1e-4
...@@ -313,11 +360,11 @@ class SD35MediumGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase ...@@ -313,11 +360,11 @@ class SD35MediumGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase
def setUp(self): def setUp(self):
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def tearDown(self): def tearDown(self):
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def get_dummy_inputs(self): def get_dummy_inputs(self):
return { return {
...@@ -393,11 +440,11 @@ class AuraFlowGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): ...@@ -393,11 +440,11 @@ class AuraFlowGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
def setUp(self): def setUp(self):
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def tearDown(self): def tearDown(self):
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def get_dummy_inputs(self): def get_dummy_inputs(self):
return { return {
...@@ -463,7 +510,7 @@ class AuraFlowGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): ...@@ -463,7 +510,7 @@ class AuraFlowGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
@require_peft_backend @require_peft_backend
@nightly @nightly
@require_big_gpu_with_torch_cuda @require_big_accelerator
@require_accelerate @require_accelerate
@require_gguf_version_greater_or_equal("0.10.0") @require_gguf_version_greater_or_equal("0.10.0")
class FluxControlLoRAGGUFTests(unittest.TestCase): class FluxControlLoRAGGUFTests(unittest.TestCase):
...@@ -478,7 +525,7 @@ class FluxControlLoRAGGUFTests(unittest.TestCase): ...@@ -478,7 +525,7 @@ class FluxControlLoRAGGUFTests(unittest.TestCase):
"black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-dev",
transformer=transformer, transformer=transformer,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
).to("cuda") ).to(torch_device)
pipe.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora") pipe.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts." prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment