Unverified Commit a7e9f85e authored by Yao Matrix's avatar Yao Matrix Committed by GitHub
Browse files

enable test_layerwise_casting_memory cases on XPU (#11406)



* enable test_layerwise_casting_memory cases on XPU
Signed-off-by: default avatarYao Matrix <matrix.yao@intel.com>

* fix style
Signed-off-by: default avatarYao Matrix <matrix.yao@intel.com>

---------
Signed-off-by: default avatarYao Matrix <matrix.yao@intel.com>
parent 9ce89e2e
...@@ -1186,6 +1186,13 @@ if is_torch_available(): ...@@ -1186,6 +1186,13 @@ if is_torch_available():
"mps": 0, "mps": 0,
"default": 0, "default": 0,
} }
BACKEND_SYNCHRONIZE = {
"cuda": torch.cuda.synchronize,
"xpu": getattr(torch.xpu, "synchronize", None),
"cpu": None,
"mps": None,
"default": None,
}
# This dispatches a defined function according to the accelerator from the function definitions. # This dispatches a defined function according to the accelerator from the function definitions.
...@@ -1208,6 +1215,10 @@ def backend_manual_seed(device: str, seed: int): ...@@ -1208,6 +1215,10 @@ def backend_manual_seed(device: str, seed: int):
return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed) return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed)
def backend_synchronize(device: str):
return _device_agnostic_dispatch(device, BACKEND_SYNCHRONIZE)
def backend_empty_cache(device: str): def backend_empty_cache(device: str):
return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE) return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE)
......
...@@ -59,6 +59,9 @@ from diffusers.utils.hub_utils import _add_variant ...@@ -59,6 +59,9 @@ from diffusers.utils.hub_utils import _add_variant
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
CaptureLogger, CaptureLogger,
backend_empty_cache, backend_empty_cache,
backend_max_memory_allocated,
backend_reset_peak_memory_stats,
backend_synchronize,
floats_tensor, floats_tensor,
get_python_version, get_python_version,
is_torch_compile, is_torch_compile,
...@@ -68,7 +71,6 @@ from diffusers.utils.testing_utils import ( ...@@ -68,7 +71,6 @@ from diffusers.utils.testing_utils import (
require_torch_2, require_torch_2,
require_torch_accelerator, require_torch_accelerator,
require_torch_accelerator_with_training, require_torch_accelerator_with_training,
require_torch_gpu,
require_torch_multi_accelerator, require_torch_multi_accelerator,
run_test_in_subprocess, run_test_in_subprocess,
slow, slow,
...@@ -341,7 +343,7 @@ class ModelUtilsTest(unittest.TestCase): ...@@ -341,7 +343,7 @@ class ModelUtilsTest(unittest.TestCase):
assert model.config.in_channels == 9 assert model.config.in_channels == 9
@require_torch_gpu @require_torch_accelerator
def test_keep_modules_in_fp32(self): def test_keep_modules_in_fp32(self):
r""" r"""
A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32 when we load the model in fp16/bf16 A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32 when we load the model in fp16/bf16
...@@ -1480,16 +1482,16 @@ class ModelTesterMixin: ...@@ -1480,16 +1482,16 @@ class ModelTesterMixin:
test_layerwise_casting(torch.float8_e5m2, torch.float32) test_layerwise_casting(torch.float8_e5m2, torch.float32)
test_layerwise_casting(torch.float8_e4m3fn, torch.bfloat16) test_layerwise_casting(torch.float8_e4m3fn, torch.bfloat16)
@require_torch_gpu @require_torch_accelerator
def test_layerwise_casting_memory(self): def test_layerwise_casting_memory(self):
MB_TOLERANCE = 0.2 MB_TOLERANCE = 0.2
LEAST_COMPUTE_CAPABILITY = 8.0 LEAST_COMPUTE_CAPABILITY = 8.0
def reset_memory_stats(): def reset_memory_stats():
gc.collect() gc.collect()
torch.cuda.synchronize() backend_synchronize(torch_device)
torch.cuda.empty_cache() backend_empty_cache(torch_device)
torch.cuda.reset_peak_memory_stats() backend_reset_peak_memory_stats(torch_device)
def get_memory_usage(storage_dtype, compute_dtype): def get_memory_usage(storage_dtype, compute_dtype):
torch.manual_seed(0) torch.manual_seed(0)
...@@ -1502,7 +1504,7 @@ class ModelTesterMixin: ...@@ -1502,7 +1504,7 @@ class ModelTesterMixin:
reset_memory_stats() reset_memory_stats()
model(**inputs_dict) model(**inputs_dict)
model_memory_footprint = model.get_memory_footprint() model_memory_footprint = model.get_memory_footprint()
peak_inference_memory_allocated_mb = torch.cuda.max_memory_allocated() / 1024**2 peak_inference_memory_allocated_mb = backend_max_memory_allocated(torch_device) / 1024**2
return model_memory_footprint, peak_inference_memory_allocated_mb return model_memory_footprint, peak_inference_memory_allocated_mb
...@@ -1512,7 +1514,7 @@ class ModelTesterMixin: ...@@ -1512,7 +1514,7 @@ class ModelTesterMixin:
torch.float8_e4m3fn, torch.bfloat16 torch.float8_e4m3fn, torch.bfloat16
) )
compute_capability = get_torch_cuda_device_capability() compute_capability = get_torch_cuda_device_capability() if torch_device == "cuda" else None
self.assertTrue(fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint) self.assertTrue(fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint)
# NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32. # NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32.
# On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it. # On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it.
...@@ -1527,7 +1529,7 @@ class ModelTesterMixin: ...@@ -1527,7 +1529,7 @@ class ModelTesterMixin:
) )
@parameterized.expand([False, True]) @parameterized.expand([False, True])
@require_torch_gpu @require_torch_accelerator
def test_group_offloading(self, record_stream): def test_group_offloading(self, record_stream):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0) torch.manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment