Unverified Commit 7b10e4ae authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[tests] device placement for non-denoiser components in group offloading LoRA tests (#12103)

up
parent 3c0531bc
...@@ -2400,7 +2400,6 @@ class PeftLoraLoaderMixinTests: ...@@ -2400,7 +2400,6 @@ class PeftLoraLoaderMixinTests:
components, _, _ = self.get_dummy_components(self.scheduler_classes[0]) components, _, _ = self.get_dummy_components(self.scheduler_classes[0])
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
...@@ -2416,6 +2415,10 @@ class PeftLoraLoaderMixinTests: ...@@ -2416,6 +2415,10 @@ class PeftLoraLoaderMixinTests:
num_blocks_per_group=1, num_blocks_per_group=1,
use_stream=use_stream, use_stream=use_stream,
) )
# Place other model-level components on `torch_device`.
for _, component in pipe.components.items():
if isinstance(component, torch.nn.Module):
component.to(torch_device)
group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser) group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser)
self.assertTrue(group_offload_hook_1 is not None) self.assertTrue(group_offload_hook_1 is not None)
output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment