"llama/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "369de832cdca7680c8f50ba196d39172a895fcad"
Unverified Commit 9ce89e2e authored by Yao Matrix's avatar Yao Matrix Committed by GitHub
Browse files

enable group_offload cases and quanto cases on XPU (#11405)



* enable group_offload cases and quanto cases on XPU
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>

* use backend APIs
Signed-off-by: default avatarYao Matrix <matrix.yao@intel.com>

* fix style
Signed-off-by: default avatarYao Matrix <matrix.yao@intel.com>

---------
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>
Signed-off-by: default avatarYao Matrix <matrix.yao@intel.com>
parent aa5f5d41
...@@ -53,7 +53,7 @@ from diffusers.utils.testing_utils import ( ...@@ -53,7 +53,7 @@ from diffusers.utils.testing_utils import (
require_accelerator, require_accelerator,
require_hf_hub_version_greater, require_hf_hub_version_greater,
require_torch, require_torch,
require_torch_gpu, require_torch_accelerator,
require_transformers_version_greater, require_transformers_version_greater,
skip_mps, skip_mps,
torch_device, torch_device,
...@@ -2212,7 +2212,7 @@ class PipelineTesterMixin: ...@@ -2212,7 +2212,7 @@ class PipelineTesterMixin:
inputs = self.get_dummy_inputs(torch_device) inputs = self.get_dummy_inputs(torch_device)
_ = pipe(**inputs)[0] _ = pipe(**inputs)[0]
@require_torch_gpu @require_torch_accelerator
def test_group_offloading_inference(self): def test_group_offloading_inference(self):
if not self.test_group_offloading: if not self.test_group_offloading:
return return
......
...@@ -6,10 +6,13 @@ from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig ...@@ -6,10 +6,13 @@ from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig
from diffusers.models.attention_processor import Attention from diffusers.models.attention_processor import Attention
from diffusers.utils import is_optimum_quanto_available, is_torch_available from diffusers.utils import is_optimum_quanto_available, is_torch_available
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
backend_empty_cache,
backend_reset_peak_memory_stats,
enable_full_determinism,
nightly, nightly,
numpy_cosine_similarity_distance, numpy_cosine_similarity_distance,
require_accelerate, require_accelerate,
require_big_gpu_with_torch_cuda, require_big_accelerator,
require_torch_cuda_compatibility, require_torch_cuda_compatibility,
torch_device, torch_device,
) )
...@@ -23,9 +26,11 @@ if is_torch_available(): ...@@ -23,9 +26,11 @@ if is_torch_available():
from ..utils import LoRALayer, get_memory_consumption_stat from ..utils import LoRALayer, get_memory_consumption_stat
enable_full_determinism()
@nightly @nightly
@require_big_gpu_with_torch_cuda @require_big_accelerator
@require_accelerate @require_accelerate
class QuantoBaseTesterMixin: class QuantoBaseTesterMixin:
model_id = None model_id = None
...@@ -39,13 +44,13 @@ class QuantoBaseTesterMixin: ...@@ -39,13 +44,13 @@ class QuantoBaseTesterMixin:
_test_torch_compile = False _test_torch_compile = False
def setUp(self): def setUp(self):
torch.cuda.reset_peak_memory_stats() backend_reset_peak_memory_stats(torch_device)
torch.cuda.empty_cache() backend_empty_cache(torch_device)
gc.collect() gc.collect()
def tearDown(self): def tearDown(self):
torch.cuda.reset_peak_memory_stats() backend_reset_peak_memory_stats(torch_device)
torch.cuda.empty_cache() backend_empty_cache(torch_device)
gc.collect() gc.collect()
def get_dummy_init_kwargs(self): def get_dummy_init_kwargs(self):
...@@ -89,7 +94,7 @@ class QuantoBaseTesterMixin: ...@@ -89,7 +94,7 @@ class QuantoBaseTesterMixin:
self.model_cls._keep_in_fp32_modules = self.keep_in_fp32_module self.model_cls._keep_in_fp32_modules = self.keep_in_fp32_module
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
model.to("cuda") model.to(torch_device)
for name, module in model.named_modules(): for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear): if isinstance(module, torch.nn.Linear):
...@@ -107,7 +112,7 @@ class QuantoBaseTesterMixin: ...@@ -107,7 +112,7 @@ class QuantoBaseTesterMixin:
init_kwargs.update({"quantization_config": quantization_config}) init_kwargs.update({"quantization_config": quantization_config})
model = self.model_cls.from_pretrained(**init_kwargs) model = self.model_cls.from_pretrained(**init_kwargs)
model.to("cuda") model.to(torch_device)
for name, module in model.named_modules(): for name, module in model.named_modules():
if name in self.modules_to_not_convert: if name in self.modules_to_not_convert:
...@@ -122,7 +127,8 @@ class QuantoBaseTesterMixin: ...@@ -122,7 +127,8 @@ class QuantoBaseTesterMixin:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
# Tries with a `device` and `dtype` # Tries with a `device` and `dtype`
model.to(device="cuda:0", dtype=torch.float16) device_0 = f"{torch_device}:0"
model.to(device=device_0, dtype=torch.float16)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
# Tries with a cast # Tries with a cast
...@@ -133,7 +139,7 @@ class QuantoBaseTesterMixin: ...@@ -133,7 +139,7 @@ class QuantoBaseTesterMixin:
model.half() model.half()
# This should work # This should work
model.to("cuda") model.to(torch_device)
def test_serialization(self): def test_serialization(self):
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment