Unverified Commit 33e636ce authored by Yao Matrix's avatar Yao Matrix Committed by GitHub
Browse files

enable torchao test cases on XPU and switch to device agnostic APIs for test cases (#11654)



* enable torchao cases on XPU
Signed-off-by: default avatarMatrix YAO <matrix.yao@intel.com>

* device agnostic APIs
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>

* more
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>

* fix style
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>

* enable test_torch_compile_recompilation_and_graph_break on XPU
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>

* resolve comments
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>

---------
Signed-off-by: default avatarMatrix YAO <matrix.yao@intel.com>
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>
parent e27142ac
...@@ -233,7 +233,7 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin): ...@@ -233,7 +233,7 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin):
@slow @slow
@require_big_accelerator @require_big_accelerator
@pytest.mark.big_gpu_with_torch_cuda @pytest.mark.big_accelerator
class StableDiffusion3PipelineSlowTests(unittest.TestCase): class StableDiffusion3PipelineSlowTests(unittest.TestCase):
pipeline_class = StableDiffusion3Pipeline pipeline_class = StableDiffusion3Pipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers" repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
......
...@@ -168,7 +168,7 @@ class StableDiffusion3Img2ImgPipelineFastTests(PipelineLatentTesterMixin, unitte ...@@ -168,7 +168,7 @@ class StableDiffusion3Img2ImgPipelineFastTests(PipelineLatentTesterMixin, unitte
@slow @slow
@require_big_accelerator @require_big_accelerator
@pytest.mark.big_gpu_with_torch_cuda @pytest.mark.big_accelerator
class StableDiffusion3Img2ImgPipelineSlowTests(unittest.TestCase): class StableDiffusion3Img2ImgPipelineSlowTests(unittest.TestCase):
pipeline_class = StableDiffusion3Img2ImgPipeline pipeline_class = StableDiffusion3Img2ImgPipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers" repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
......
...@@ -35,6 +35,7 @@ from diffusers import ( ...@@ -35,6 +35,7 @@ from diffusers import (
UniPCMultistepScheduler, UniPCMultistepScheduler,
) )
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism, enable_full_determinism,
load_image, load_image,
numpy_cosine_similarity_distance, numpy_cosine_similarity_distance,
...@@ -940,12 +941,12 @@ class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase): ...@@ -940,12 +941,12 @@ class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def test_stable_diffusion_lcm(self): def test_stable_diffusion_lcm(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -39,6 +39,7 @@ from diffusers import ( ...@@ -39,6 +39,7 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism, enable_full_determinism,
floats_tensor, floats_tensor,
load_image, load_image,
...@@ -670,12 +671,12 @@ class StableDiffusionXLImg2ImgPipelineIntegrationTests(unittest.TestCase): ...@@ -670,12 +671,12 @@ class StableDiffusionXLImg2ImgPipelineIntegrationTests(unittest.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def test_stable_diffusion_xl_img2img_playground(self): def test_stable_diffusion_xl_img2img_playground(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -1218,13 +1218,13 @@ class PipelineFastTests(unittest.TestCase): ...@@ -1218,13 +1218,13 @@ class PipelineFastTests(unittest.TestCase):
# clean up the VRAM before each test # clean up the VRAM before each test
super().setUp() super().setUp()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def tearDown(self): def tearDown(self):
# clean up the VRAM after each test # clean up the VRAM after each test
super().tearDown() super().tearDown()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def dummy_image(self): def dummy_image(self):
batch_size = 1 batch_size = 1
......
...@@ -21,9 +21,11 @@ from transformers import AutoTokenizer, T5EncoderModel ...@@ -21,9 +21,11 @@ from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanPipeline, WanTransformer3DModel from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanPipeline, WanTransformer3DModel
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism, enable_full_determinism,
require_torch_accelerator, require_torch_accelerator,
slow, slow,
torch_device,
) )
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
...@@ -144,12 +146,12 @@ class WanPipelineIntegrationTests(unittest.TestCase): ...@@ -144,12 +146,12 @@ class WanPipelineIntegrationTests(unittest.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
@unittest.skip("TODO: test needs to be implemented") @unittest.skip("TODO: test needs to be implemented")
def test_Wanx(self): def test_Wanx(self):
......
...@@ -30,13 +30,15 @@ from diffusers import ( ...@@ -30,13 +30,15 @@ from diffusers import (
) )
from diffusers.models.attention_processor import Attention from diffusers.models.attention_processor import Attention
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
backend_empty_cache,
backend_synchronize,
enable_full_determinism, enable_full_determinism,
is_torch_available, is_torch_available,
is_torchao_available, is_torchao_available,
nightly, nightly,
numpy_cosine_similarity_distance, numpy_cosine_similarity_distance,
require_torch, require_torch,
require_torch_gpu, require_torch_accelerator,
require_torchao_version_greater_or_equal, require_torchao_version_greater_or_equal,
slow, slow,
torch_device, torch_device,
...@@ -61,7 +63,7 @@ if is_torchao_available(): ...@@ -61,7 +63,7 @@ if is_torchao_available():
@require_torch @require_torch
@require_torch_gpu @require_torch_accelerator
@require_torchao_version_greater_or_equal("0.7.0") @require_torchao_version_greater_or_equal("0.7.0")
class TorchAoConfigTest(unittest.TestCase): class TorchAoConfigTest(unittest.TestCase):
def test_to_dict(self): def test_to_dict(self):
...@@ -79,7 +81,7 @@ class TorchAoConfigTest(unittest.TestCase): ...@@ -79,7 +81,7 @@ class TorchAoConfigTest(unittest.TestCase):
Test kwargs validations in TorchAoConfig Test kwargs validations in TorchAoConfig
""" """
_ = TorchAoConfig("int4_weight_only") _ = TorchAoConfig("int4_weight_only")
with self.assertRaisesRegex(ValueError, "is not supported yet"): with self.assertRaisesRegex(ValueError, "is not supported"):
_ = TorchAoConfig("uint8") _ = TorchAoConfig("uint8")
with self.assertRaisesRegex(ValueError, "does not support the following keyword arguments"): with self.assertRaisesRegex(ValueError, "does not support the following keyword arguments"):
...@@ -119,12 +121,12 @@ class TorchAoConfigTest(unittest.TestCase): ...@@ -119,12 +121,12 @@ class TorchAoConfigTest(unittest.TestCase):
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch @require_torch
@require_torch_gpu @require_torch_accelerator
@require_torchao_version_greater_or_equal("0.7.0") @require_torchao_version_greater_or_equal("0.7.0")
class TorchAoTest(unittest.TestCase): class TorchAoTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def get_dummy_components( def get_dummy_components(
self, quantization_config: TorchAoConfig, model_id: str = "hf-internal-testing/tiny-flux-pipe" self, quantization_config: TorchAoConfig, model_id: str = "hf-internal-testing/tiny-flux-pipe"
...@@ -269,6 +271,7 @@ class TorchAoTest(unittest.TestCase): ...@@ -269,6 +271,7 @@ class TorchAoTest(unittest.TestCase):
subfolder="transformer", subfolder="transformer",
quantization_config=quantization_config, quantization_config=quantization_config,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
device_map=f"{torch_device}:0",
) )
weight = quantized_model.transformer_blocks[0].ff.net[2].weight weight = quantized_model.transformer_blocks[0].ff.net[2].weight
...@@ -338,7 +341,7 @@ class TorchAoTest(unittest.TestCase): ...@@ -338,7 +341,7 @@ class TorchAoTest(unittest.TestCase):
output = quantized_model(**inputs)[0] output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy() output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 2e-3)
with tempfile.TemporaryDirectory() as offload_folder: with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64) quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
...@@ -359,7 +362,7 @@ class TorchAoTest(unittest.TestCase): ...@@ -359,7 +362,7 @@ class TorchAoTest(unittest.TestCase):
output = quantized_model(**inputs)[0] output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy() output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 2e-3)
def test_modules_to_not_convert(self): def test_modules_to_not_convert(self):
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"]) quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])
...@@ -518,14 +521,14 @@ class TorchAoTest(unittest.TestCase): ...@@ -518,14 +521,14 @@ class TorchAoTest(unittest.TestCase):
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch @require_torch
@require_torch_gpu @require_torch_accelerator
@require_torchao_version_greater_or_equal("0.7.0") @require_torchao_version_greater_or_equal("0.7.0")
class TorchAoSerializationTest(unittest.TestCase): class TorchAoSerializationTest(unittest.TestCase):
model_name = "hf-internal-testing/tiny-flux-pipe" model_name = "hf-internal-testing/tiny-flux-pipe"
def tearDown(self): def tearDown(self):
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def get_dummy_model(self, quant_method, quant_method_kwargs, device=None): def get_dummy_model(self, quant_method, quant_method_kwargs, device=None):
quantization_config = TorchAoConfig(quant_method, **quant_method_kwargs) quantization_config = TorchAoConfig(quant_method, **quant_method_kwargs)
...@@ -593,17 +596,17 @@ class TorchAoSerializationTest(unittest.TestCase): ...@@ -593,17 +596,17 @@ class TorchAoSerializationTest(unittest.TestCase):
) )
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
def test_int_a8w8_cuda(self): def test_int_a8w8_accelerator(self):
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {} quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
device = "cuda" device = torch_device
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice) self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
def test_int_a16w8_cuda(self): def test_int_a16w8_accelerator(self):
quant_method, quant_method_kwargs = "int8_weight_only", {} quant_method, quant_method_kwargs = "int8_weight_only", {}
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551]) expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
device = "cuda" device = torch_device
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice) self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
...@@ -624,14 +627,14 @@ class TorchAoSerializationTest(unittest.TestCase): ...@@ -624,14 +627,14 @@ class TorchAoSerializationTest(unittest.TestCase):
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch @require_torch
@require_torch_gpu @require_torch_accelerator
@require_torchao_version_greater_or_equal("0.7.0") @require_torchao_version_greater_or_equal("0.7.0")
@slow @slow
@nightly @nightly
class SlowTorchAoTests(unittest.TestCase): class SlowTorchAoTests(unittest.TestCase):
def tearDown(self): def tearDown(self):
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def get_dummy_components(self, quantization_config: TorchAoConfig): def get_dummy_components(self, quantization_config: TorchAoConfig):
# This is just for convenience, so that we can modify it at one place for custom environments and locally testing # This is just for convenience, so that we can modify it at one place for custom environments and locally testing
...@@ -713,8 +716,8 @@ class SlowTorchAoTests(unittest.TestCase): ...@@ -713,8 +716,8 @@ class SlowTorchAoTests(unittest.TestCase):
quantization_config = TorchAoConfig(quant_type=quantization_name, modules_to_not_convert=["x_embedder"]) quantization_config = TorchAoConfig(quant_type=quantization_name, modules_to_not_convert=["x_embedder"])
self._test_quant_type(quantization_config, expected_slice) self._test_quant_type(quantization_config, expected_slice)
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
torch.cuda.synchronize() backend_synchronize(torch_device)
def test_serialization_int8wo(self): def test_serialization_int8wo(self):
quantization_config = TorchAoConfig("int8wo") quantization_config = TorchAoConfig("int8wo")
...@@ -733,8 +736,8 @@ class SlowTorchAoTests(unittest.TestCase): ...@@ -733,8 +736,8 @@ class SlowTorchAoTests(unittest.TestCase):
pipe.remove_all_hooks() pipe.remove_all_hooks()
del pipe.transformer del pipe.transformer
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
torch.cuda.synchronize() backend_synchronize(torch_device)
transformer = FluxTransformer2DModel.from_pretrained( transformer = FluxTransformer2DModel.from_pretrained(
tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False
) )
...@@ -783,14 +786,14 @@ class SlowTorchAoTests(unittest.TestCase): ...@@ -783,14 +786,14 @@ class SlowTorchAoTests(unittest.TestCase):
@require_torch @require_torch
@require_torch_gpu @require_torch_accelerator
@require_torchao_version_greater_or_equal("0.7.0") @require_torchao_version_greater_or_equal("0.7.0")
@slow @slow
@nightly @nightly
class SlowTorchAoPreserializedModelTests(unittest.TestCase): class SlowTorchAoPreserializedModelTests(unittest.TestCase):
def tearDown(self): def tearDown(self):
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def get_dummy_inputs(self, device: torch.device, seed: int = 0): def get_dummy_inputs(self, device: torch.device, seed: int = 0):
if str(device).startswith("mps"): if str(device).startswith("mps"):
......
...@@ -16,8 +16,6 @@ ...@@ -16,8 +16,6 @@
import gc import gc
import unittest import unittest
import torch
from diffusers import ( from diffusers import (
Lumina2Transformer2DModel, Lumina2Transformer2DModel,
) )
...@@ -66,9 +64,9 @@ class Lumina2Transformer2DModelSingleFileTests(unittest.TestCase): ...@@ -66,9 +64,9 @@ class Lumina2Transformer2DModelSingleFileTests(unittest.TestCase):
def test_checkpoint_loading(self): def test_checkpoint_loading(self):
for ckpt_path in self.alternate_keys_ckpt_paths: for ckpt_path in self.alternate_keys_ckpt_paths:
torch.cuda.empty_cache() backend_empty_cache(torch_device)
model = self.model_class.from_single_file(ckpt_path) model = self.model_class.from_single_file(ckpt_path)
del model del model
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
...@@ -16,8 +16,6 @@ ...@@ -16,8 +16,6 @@
import gc import gc
import unittest import unittest
import torch
from diffusers import ( from diffusers import (
FluxTransformer2DModel, FluxTransformer2DModel,
) )
...@@ -64,9 +62,9 @@ class FluxTransformer2DModelSingleFileTests(unittest.TestCase): ...@@ -64,9 +62,9 @@ class FluxTransformer2DModelSingleFileTests(unittest.TestCase):
def test_checkpoint_loading(self): def test_checkpoint_loading(self):
for ckpt_path in self.alternate_keys_ckpt_paths: for ckpt_path in self.alternate_keys_ckpt_paths:
torch.cuda.empty_cache() backend_empty_cache(torch_device)
model = self.model_class.from_single_file(ckpt_path) model = self.model_class.from_single_file(ckpt_path)
del model del model
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
import gc import gc
import unittest import unittest
import torch
from diffusers import ( from diffusers import (
SanaTransformer2DModel, SanaTransformer2DModel,
) )
...@@ -53,9 +51,9 @@ class SanaTransformer2DModelSingleFileTests(unittest.TestCase): ...@@ -53,9 +51,9 @@ class SanaTransformer2DModelSingleFileTests(unittest.TestCase):
def test_checkpoint_loading(self): def test_checkpoint_loading(self):
for ckpt_path in self.alternate_keys_ckpt_paths: for ckpt_path in self.alternate_keys_ckpt_paths:
torch.cuda.empty_cache() backend_empty_cache(torch_device)
model = self.model_class.from_single_file(ckpt_path) model = self.model_class.from_single_file(ckpt_path)
del model del model
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment