Unverified Commit 2340ed62 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Test] Reduce CPU memory (#4897)

* [Test] Reduce CPU memory

* [Test] Reduce CPU memory
parent cfdfcf20
...@@ -107,7 +107,7 @@ def state_dicts_almost_equal(sd1, sd2): ...@@ -107,7 +107,7 @@ def state_dicts_almost_equal(sd1, sd2):
models_are_equal = True models_are_equal = True
for ten1, ten2 in zip(sd1.values(), sd2.values()): for ten1, ten2 in zip(sd1.values(), sd2.values()):
if (ten1 - ten2).abs().sum() > 1e-3: if (ten1 - ten2).abs().max() > 1e-3:
models_are_equal = False models_are_equal = False
return models_are_equal return models_are_equal
...@@ -1432,23 +1432,21 @@ class LoraIntegrationTests(unittest.TestCase): ...@@ -1432,23 +1432,21 @@ class LoraIntegrationTests(unittest.TestCase):
self.assertTrue(np.allclose(images, expected, atol=1e-3)) self.assertTrue(np.allclose(images, expected, atol=1e-3))
def test_sdxl_1_0_fuse_unfuse_all(self): def test_sdxl_1_0_fuse_unfuse_all(self):
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
text_encoder_1_sd = copy.deepcopy(pipe.text_encoder.state_dict()) text_encoder_1_sd = copy.deepcopy(pipe.text_encoder.state_dict())
text_encoder_2_sd = copy.deepcopy(pipe.text_encoder_2.state_dict()) text_encoder_2_sd = copy.deepcopy(pipe.text_encoder_2.state_dict())
unet_sd = copy.deepcopy(pipe.unet.state_dict()) unet_sd = copy.deepcopy(pipe.unet.state_dict())
pipe.load_lora_weights("davizca87/sun-flower", weight_name="snfw3rXL-000004.safetensors") pipe.load_lora_weights(
"davizca87/sun-flower", weight_name="snfw3rXL-000004.safetensors", torch_dtype=torch.float16
)
pipe.fuse_lora() pipe.fuse_lora()
pipe.unload_lora_weights() pipe.unload_lora_weights()
pipe.unfuse_lora() pipe.unfuse_lora()
new_text_encoder_1_sd = copy.deepcopy(pipe.text_encoder.state_dict()) assert state_dicts_almost_equal(text_encoder_1_sd, pipe.text_encoder.state_dict())
new_text_encoder_2_sd = copy.deepcopy(pipe.text_encoder_2.state_dict()) assert state_dicts_almost_equal(text_encoder_2_sd, pipe.text_encoder_2.state_dict())
new_unet_sd = copy.deepcopy(pipe.unet.state_dict()) assert state_dicts_almost_equal(unet_sd, pipe.unet.state_dict())
assert state_dicts_almost_equal(text_encoder_1_sd, new_text_encoder_1_sd)
assert state_dicts_almost_equal(text_encoder_2_sd, new_text_encoder_2_sd)
assert state_dicts_almost_equal(unet_sd, new_unet_sd)
def test_sdxl_1_0_lora_with_sequential_cpu_offloading(self): def test_sdxl_1_0_lora_with_sequential_cpu_offloading(self):
generator = torch.Generator().manual_seed(0) generator = torch.Generator().manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment