"cuda/vscode:/vscode.git/clone" did not exist on "2a77b7723f267c1663b53ebacf62da0707f59cd1"
Unverified Commit d66d554d authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Add tearDown method to LoRA tests. (#6660)

* update

* update
parent c7df846d
......@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import gc
import os
import random
import tempfile
......@@ -1662,6 +1663,11 @@ class UNet3DConditionLoRAModelTests(unittest.TestCase):
@deprecate_after_peft_backend
@require_torch_gpu
class LoraIntegrationTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_dreambooth_old_format(self):
generator = torch.Generator("cpu").manual_seed(0)
......
......@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import gc
import importlib
import os
import tempfile
......@@ -1205,6 +1206,11 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
"latent_channels": 4,
}
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
@slow
@require_torch_gpu
def test_integration_move_lora_cpu(self):
......@@ -1434,6 +1440,11 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
"sample_size": 128,
}
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
@slow
@require_torch_gpu
......@@ -1468,11 +1479,9 @@ class LoraIntegrationTests(PeftLoraLoaderMixinTests, unittest.TestCase):
}
def tearDown(self):
import gc
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
gc.collect()
def test_dreambooth_old_format(self):
generator = torch.Generator("cpu").manual_seed(0)
......@@ -1757,11 +1766,9 @@ class LoraSDXLIntegrationTests(PeftLoraLoaderMixinTests, unittest.TestCase):
}
def tearDown(self):
import gc
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
gc.collect()
def test_sdxl_0_9_lora_one(self):
generator = torch.Generator().manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment