Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
d66d554d
Unverified
Commit
d66d554d
authored
Jan 22, 2024
by
Dhruv Nair
Committed by
GitHub
Jan 22, 2024
Browse files
Add tearDown method to LoRA tests. (#6660)
* update * update
parent
c7df846d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
6 deletions
+19
-6
tests/lora/test_lora_layers_old_backend.py
tests/lora/test_lora_layers_old_backend.py
+6
-0
tests/lora/test_lora_layers_peft.py
tests/lora/test_lora_layers_peft.py
+13
-6
No files found.
tests/lora/test_lora_layers_old_backend.py
View file @
d66d554d
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
copy
import
copy
import
gc
import
os
import
os
import
random
import
random
import
tempfile
import
tempfile
...
@@ -1662,6 +1663,11 @@ class UNet3DConditionLoRAModelTests(unittest.TestCase):
...
@@ -1662,6 +1663,11 @@ class UNet3DConditionLoRAModelTests(unittest.TestCase):
@
deprecate_after_peft_backend
@
deprecate_after_peft_backend
@
require_torch_gpu
@
require_torch_gpu
class
LoraIntegrationTests
(
unittest
.
TestCase
):
class
LoraIntegrationTests
(
unittest
.
TestCase
):
def
tearDown
(
self
):
super
().
tearDown
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
def
test_dreambooth_old_format
(
self
):
def
test_dreambooth_old_format
(
self
):
generator
=
torch
.
Generator
(
"cpu"
).
manual_seed
(
0
)
generator
=
torch
.
Generator
(
"cpu"
).
manual_seed
(
0
)
...
...
tests/lora/test_lora_layers_peft.py
View file @
d66d554d
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
copy
import
copy
import
gc
import
importlib
import
importlib
import
os
import
os
import
tempfile
import
tempfile
...
@@ -1205,6 +1206,11 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
...
@@ -1205,6 +1206,11 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
"latent_channels"
:
4
,
"latent_channels"
:
4
,
}
}
def
tearDown
(
self
):
super
().
tearDown
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
@
slow
@
slow
@
require_torch_gpu
@
require_torch_gpu
def
test_integration_move_lora_cpu
(
self
):
def
test_integration_move_lora_cpu
(
self
):
...
@@ -1434,6 +1440,11 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
...
@@ -1434,6 +1440,11 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
"sample_size"
:
128
,
"sample_size"
:
128
,
}
}
def
tearDown
(
self
):
super
().
tearDown
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
@
slow
@
slow
@
require_torch_gpu
@
require_torch_gpu
...
@@ -1468,11 +1479,9 @@ class LoraIntegrationTests(PeftLoraLoaderMixinTests, unittest.TestCase):
...
@@ -1468,11 +1479,9 @@ class LoraIntegrationTests(PeftLoraLoaderMixinTests, unittest.TestCase):
}
}
def
tearDown
(
self
):
def
tearDown
(
self
):
import
gc
super
().
tearDown
()
gc
.
collect
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
def
test_dreambooth_old_format
(
self
):
def
test_dreambooth_old_format
(
self
):
generator
=
torch
.
Generator
(
"cpu"
).
manual_seed
(
0
)
generator
=
torch
.
Generator
(
"cpu"
).
manual_seed
(
0
)
...
@@ -1757,11 +1766,9 @@ class LoraSDXLIntegrationTests(PeftLoraLoaderMixinTests, unittest.TestCase):
...
@@ -1757,11 +1766,9 @@ class LoraSDXLIntegrationTests(PeftLoraLoaderMixinTests, unittest.TestCase):
}
}
def
tearDown
(
self
):
def
tearDown
(
self
):
import
gc
super
().
tearDown
()
gc
.
collect
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
def
test_sdxl_0_9_lora_one
(
self
):
def
test_sdxl_0_9_lora_one
(
self
):
generator
=
torch
.
Generator
().
manual_seed
(
0
)
generator
=
torch
.
Generator
().
manual_seed
(
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment