Unverified Commit 06beecaf authored by Yao Matrix's avatar Yao Matrix Committed by GitHub
Browse files

make autoencoders. controlnet_flux and wan_transformer3d_single_file pass on xpu (#11461)



* make autoencoders. controlnet_flux and wan_transformer3d_single_file
pass on XPU
Signed-off-by: default avatarYao Matrix <matrix.yao@intel.com>

* Apply style fixes

---------
Signed-off-by: default avatarYao Matrix <matrix.yao@intel.com>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: default avatarAryan <aryan@huggingface.co>
parent daf0a239
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from typing import Dict, List, Optional, Set, Tuple from typing import Dict, List, Optional, Set, Tuple, Union
import torch import torch
...@@ -55,7 +55,7 @@ class ModuleGroup: ...@@ -55,7 +55,7 @@ class ModuleGroup:
parameters: Optional[List[torch.nn.Parameter]] = None, parameters: Optional[List[torch.nn.Parameter]] = None,
buffers: Optional[List[torch.Tensor]] = None, buffers: Optional[List[torch.Tensor]] = None,
non_blocking: bool = False, non_blocking: bool = False,
stream: Optional[torch.cuda.Stream] = None, stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
record_stream: Optional[bool] = False, record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False, low_cpu_mem_usage: bool = False,
onload_self: bool = True, onload_self: bool = True,
...@@ -115,8 +115,13 @@ class ModuleGroup: ...@@ -115,8 +115,13 @@ class ModuleGroup:
def onload_(self): def onload_(self):
r"""Onloads the group of modules to the onload_device.""" r"""Onloads the group of modules to the onload_device."""
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream) torch_accelerator_module = (
current_stream = torch.cuda.current_stream() if self.record_stream else None getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
else torch.cuda
)
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
if self.stream is not None: if self.stream is not None:
# Wait for previous Host->Device transfer to complete # Wait for previous Host->Device transfer to complete
...@@ -162,9 +167,15 @@ class ModuleGroup: ...@@ -162,9 +167,15 @@ class ModuleGroup:
def offload_(self): def offload_(self):
r"""Offloads the group of modules to the offload_device.""" r"""Offloads the group of modules to the offload_device."""
torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
else torch.cuda
)
if self.stream is not None: if self.stream is not None:
if not self.record_stream: if not self.record_stream:
torch.cuda.current_stream().synchronize() torch_accelerator_module.current_stream().synchronize()
for group_module in self.modules: for group_module in self.modules:
for param in group_module.parameters(): for param in group_module.parameters():
param.data = self.cpu_param_dict[param] param.data = self.cpu_param_dict[param]
...@@ -429,8 +440,10 @@ def apply_group_offloading( ...@@ -429,8 +440,10 @@ def apply_group_offloading(
if use_stream: if use_stream:
if torch.cuda.is_available(): if torch.cuda.is_available():
stream = torch.cuda.Stream() stream = torch.cuda.Stream()
elif hasattr(torch, "xpu") and torch.xpu.is_available():
stream = torch.Stream()
else: else:
raise ValueError("Using streams for data transfer requires a CUDA device.") raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.")
_raise_error_if_accelerate_model_or_sequential_hook_present(module) _raise_error_if_accelerate_model_or_sequential_hook_present(module)
...@@ -468,7 +481,7 @@ def _apply_group_offloading_block_level( ...@@ -468,7 +481,7 @@ def _apply_group_offloading_block_level(
offload_device: torch.device, offload_device: torch.device,
onload_device: torch.device, onload_device: torch.device,
non_blocking: bool, non_blocking: bool,
stream: Optional[torch.cuda.Stream] = None, stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
record_stream: Optional[bool] = False, record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False, low_cpu_mem_usage: bool = False,
) -> None: ) -> None:
...@@ -486,7 +499,7 @@ def _apply_group_offloading_block_level( ...@@ -486,7 +499,7 @@ def _apply_group_offloading_block_level(
non_blocking (`bool`): non_blocking (`bool`):
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
and data transfer. and data transfer.
stream (`torch.cuda.Stream`, *optional*): stream (`torch.cuda.Stream`or `torch.Stream`, *optional*):
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
for overlapping computation and data transfer. for overlapping computation and data transfer.
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
...@@ -572,7 +585,7 @@ def _apply_group_offloading_leaf_level( ...@@ -572,7 +585,7 @@ def _apply_group_offloading_leaf_level(
offload_device: torch.device, offload_device: torch.device,
onload_device: torch.device, onload_device: torch.device,
non_blocking: bool, non_blocking: bool,
stream: Optional[torch.cuda.Stream] = None, stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
record_stream: Optional[bool] = False, record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False, low_cpu_mem_usage: bool = False,
) -> None: ) -> None:
...@@ -592,7 +605,7 @@ def _apply_group_offloading_leaf_level( ...@@ -592,7 +605,7 @@ def _apply_group_offloading_leaf_level(
non_blocking (`bool`): non_blocking (`bool`):
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
and data transfer. and data transfer.
stream (`torch.cuda.Stream`, *optional*): stream (`torch.cuda.Stream` or `torch.Stream`, *optional*):
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
for overlapping computation and data transfer. for overlapping computation and data transfer.
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
......
...@@ -22,6 +22,7 @@ from parameterized import parameterized ...@@ -22,6 +22,7 @@ from parameterized import parameterized
from diffusers import AsymmetricAutoencoderKL from diffusers import AsymmetricAutoencoderKL
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
Expectations,
backend_empty_cache, backend_empty_cache,
enable_full_determinism, enable_full_determinism,
floats_tensor, floats_tensor,
...@@ -134,18 +135,32 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase): ...@@ -134,18 +135,32 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
# fmt: off # fmt: off
[ [
33, 33,
[-0.0336, 0.3011, 0.1764, 0.0087, -0.3401, 0.3645, -0.1247, 0.1205], Expectations(
[-0.1603, 0.9878, -0.0495, -0.0790, -0.2709, 0.8375, -0.2060, -0.0824], {
("xpu", 3): torch.tensor([-0.0343, 0.2873, 0.1680, -0.0140, -0.3459, 0.3522, -0.1336, 0.1075]),
("cuda", 7): torch.tensor([-0.0336, 0.3011, 0.1764, 0.0087, -0.3401, 0.3645, -0.1247, 0.1205]),
("mps", None): torch.tensor(
[-0.1603, 0.9878, -0.0495, -0.0790, -0.2709, 0.8375, -0.2060, -0.0824]
),
}
),
], ],
[ [
47, 47,
[0.4400, 0.0543, 0.2873, 0.2946, 0.0553, 0.0839, -0.1585, 0.2529], Expectations(
[-0.2376, 0.1168, 0.1332, -0.4840, -0.2508, -0.0791, -0.0493, -0.4089], {
("xpu", 3): torch.tensor([0.4400, 0.0543, 0.2873, 0.2946, 0.0553, 0.0839, -0.1585, 0.2529]),
("cuda", 7): torch.tensor([0.4400, 0.0543, 0.2873, 0.2946, 0.0553, 0.0839, -0.1585, 0.2529]),
("mps", None): torch.tensor(
[-0.2376, 0.1168, 0.1332, -0.4840, -0.2508, -0.0791, -0.0493, -0.4089]
),
}
),
], ],
# fmt: on # fmt: on
] ]
) )
def test_stable_diffusion(self, seed, expected_slice, expected_slice_mps): def test_stable_diffusion(self, seed, expected_slices):
model = self.get_sd_vae_model() model = self.get_sd_vae_model()
image = self.get_sd_image(seed) image = self.get_sd_image(seed)
generator = self.get_generator(seed) generator = self.get_generator(seed)
...@@ -156,9 +171,9 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase): ...@@ -156,9 +171,9 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
assert sample.shape == image.shape assert sample.shape == image.shape
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3) expected_slice = expected_slices.get_expectation()
assert torch_all_close(output_slice, expected_slice, atol=5e-3)
@parameterized.expand( @parameterized.expand(
[ [
......
...@@ -35,7 +35,7 @@ from diffusers.utils.testing_utils import ( ...@@ -35,7 +35,7 @@ from diffusers.utils.testing_utils import (
enable_full_determinism, enable_full_determinism,
nightly, nightly,
numpy_cosine_similarity_distance, numpy_cosine_similarity_distance,
require_big_gpu_with_torch_cuda, require_big_accelerator,
torch_device, torch_device,
) )
from diffusers.utils.torch_utils import randn_tensor from diffusers.utils.torch_utils import randn_tensor
...@@ -210,8 +210,8 @@ class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin, Fl ...@@ -210,8 +210,8 @@ class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin, Fl
@nightly @nightly
@require_big_gpu_with_torch_cuda @require_big_accelerator
@pytest.mark.big_gpu_with_torch_cuda @pytest.mark.big_accelerator
class FluxControlNetPipelineSlowTests(unittest.TestCase): class FluxControlNetPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxControlNetPipeline pipeline_class = FluxControlNetPipeline
......
...@@ -24,7 +24,7 @@ from diffusers import ( ...@@ -24,7 +24,7 @@ from diffusers import (
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
backend_empty_cache, backend_empty_cache,
enable_full_determinism, enable_full_determinism,
require_big_gpu_with_torch_cuda, require_big_accelerator,
require_torch_accelerator, require_torch_accelerator,
torch_device, torch_device,
) )
...@@ -62,7 +62,7 @@ class WanTransformer3DModelText2VideoSingleFileTest(unittest.TestCase): ...@@ -62,7 +62,7 @@ class WanTransformer3DModelText2VideoSingleFileTest(unittest.TestCase):
) )
@require_big_gpu_with_torch_cuda @require_big_accelerator
@require_torch_accelerator @require_torch_accelerator
class WanTransformer3DModelImage2VideoSingleFileTest(unittest.TestCase): class WanTransformer3DModelImage2VideoSingleFileTest(unittest.TestCase):
model_class = WanTransformer3DModel model_class = WanTransformer3DModel
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment