Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
764d7ed4
Unverified
Commit
764d7ed4
authored
Feb 26, 2025
by
Sayak Paul
Committed by
GitHub
Feb 26, 2025
Browse files
[Tests] fix: lumina2 lora fuse_nan test (#10911)
fix: lumina2 lora fuse_nan test
parent
3fab6624
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
42 additions
and
2 deletions
+42
-2
tests/lora/test_lora_layers_lumina2.py
tests/lora/test_lora_layers_lumina2.py
+42
-2
No files found.
tests/lora/test_lora_layers_lumina2.py
View file @
764d7ed4
...
@@ -15,6 +15,8 @@
...
@@ -15,6 +15,8 @@
import
sys
import
sys
import
unittest
import
unittest
import
numpy
as
np
import
pytest
import
torch
import
torch
from
transformers
import
AutoTokenizer
,
GemmaForCausalLM
from
transformers
import
AutoTokenizer
,
GemmaForCausalLM
...
@@ -24,12 +26,12 @@ from diffusers import (
...
@@ -24,12 +26,12 @@ from diffusers import (
Lumina2Text2ImgPipeline
,
Lumina2Text2ImgPipeline
,
Lumina2Transformer2DModel
,
Lumina2Transformer2DModel
,
)
)
from
diffusers.utils.testing_utils
import
floats_tensor
,
require_peft_backend
from
diffusers.utils.testing_utils
import
floats_tensor
,
is_torch_version
,
require_peft_backend
,
skip_mps
,
torch_device
sys
.
path
.
append
(
"."
)
sys
.
path
.
append
(
"."
)
from
utils
import
PeftLoraLoaderMixinTests
# noqa: E402
from
utils
import
PeftLoraLoaderMixinTests
,
check_if_lora_correctly_set
# noqa: E402
@
require_peft_backend
@
require_peft_backend
...
@@ -130,3 +132,41 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
...
@@ -130,3 +132,41 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@
unittest
.
skip
(
"Text encoder LoRA is not supported in Lumina2."
)
@
unittest
.
skip
(
"Text encoder LoRA is not supported in Lumina2."
)
def
test_simple_inference_with_text_lora_save_load
(
self
):
def
test_simple_inference_with_text_lora_save_load
(
self
):
pass
pass
@
skip_mps
@
pytest
.
mark
.
xfail
(
condition
=
torch
.
device
(
torch_device
).
type
==
"cpu"
and
is_torch_version
(
">="
,
"2.5"
),
reason
=
"Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1."
,
strict
=
False
,
)
def
test_lora_fuse_nan
(
self
):
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly set in denoiser."
)
# corrupt one LoRA weight with `inf` values
with
torch
.
no_grad
():
pipe
.
transformer
.
layers
[
0
].
attn
.
to_q
.
lora_A
[
"adapter-1"
].
weight
+=
float
(
"inf"
)
# with `safe_fusing=True` we should see an Error
with
self
.
assertRaises
(
ValueError
):
pipe
.
fuse_lora
(
components
=
self
.
pipeline_class
.
_lora_loadable_modules
,
safe_fusing
=
True
)
# without we should not see an error, but every image will be black
pipe
.
fuse_lora
(
components
=
self
.
pipeline_class
.
_lora_loadable_modules
,
safe_fusing
=
False
)
out
=
pipe
(
**
inputs
)[
0
]
self
.
assertTrue
(
np
.
isnan
(
out
).
all
())
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment