Unverified Commit d66d554d authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Add tearDown method to LoRA tests. (#6660)

* update

* update
parent c7df846d
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy import copy
import gc
import os import os
import random import random
import tempfile import tempfile
...@@ -1662,6 +1663,11 @@ class UNet3DConditionLoRAModelTests(unittest.TestCase): ...@@ -1662,6 +1663,11 @@ class UNet3DConditionLoRAModelTests(unittest.TestCase):
@deprecate_after_peft_backend @deprecate_after_peft_backend
@require_torch_gpu @require_torch_gpu
class LoraIntegrationTests(unittest.TestCase): class LoraIntegrationTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_dreambooth_old_format(self): def test_dreambooth_old_format(self):
generator = torch.Generator("cpu").manual_seed(0) generator = torch.Generator("cpu").manual_seed(0)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy import copy
import gc
import importlib import importlib
import os import os
import tempfile import tempfile
...@@ -1205,6 +1206,11 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): ...@@ -1205,6 +1206,11 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
"latent_channels": 4, "latent_channels": 4,
} }
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
@slow @slow
@require_torch_gpu @require_torch_gpu
def test_integration_move_lora_cpu(self): def test_integration_move_lora_cpu(self):
...@@ -1434,6 +1440,11 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): ...@@ -1434,6 +1440,11 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
"sample_size": 128, "sample_size": 128,
} }
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
@slow @slow
@require_torch_gpu @require_torch_gpu
...@@ -1468,11 +1479,9 @@ class LoraIntegrationTests(PeftLoraLoaderMixinTests, unittest.TestCase): ...@@ -1468,11 +1479,9 @@ class LoraIntegrationTests(PeftLoraLoaderMixinTests, unittest.TestCase):
} }
def tearDown(self): def tearDown(self):
import gc super().tearDown()
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect()
def test_dreambooth_old_format(self): def test_dreambooth_old_format(self):
generator = torch.Generator("cpu").manual_seed(0) generator = torch.Generator("cpu").manual_seed(0)
...@@ -1757,11 +1766,9 @@ class LoraSDXLIntegrationTests(PeftLoraLoaderMixinTests, unittest.TestCase): ...@@ -1757,11 +1766,9 @@ class LoraSDXLIntegrationTests(PeftLoraLoaderMixinTests, unittest.TestCase):
} }
def tearDown(self): def tearDown(self):
import gc super().tearDown()
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect()
def test_sdxl_0_9_lora_one(self): def test_sdxl_0_9_lora_one(self):
generator = torch.Generator().manual_seed(0) generator = torch.Generator().manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment