Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
6f1d6694
Unverified
Commit
6f1d6694
authored
Jul 02, 2025
by
Sayak Paul
Committed by
GitHub
Jul 02, 2025
Browse files
[lora] tests for `exclude_modules` with Wan VACE (#11843)
* wan vace. * update * update * import problem
parent
0e95aa85
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
217 additions
and
0 deletions
+217
-0
tests/lora/test_lora_layers_wanvace.py
tests/lora/test_lora_layers_wanvace.py
+217
-0
No files found.
tests/lora/test_lora_layers_wanvace.py
0 → 100644
View file @
6f1d6694
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
tempfile
import
unittest
import
numpy
as
np
import
pytest
import
safetensors.torch
import
torch
from
PIL
import
Image
from
transformers
import
AutoTokenizer
,
T5EncoderModel
from
diffusers
import
AutoencoderKLWan
,
FlowMatchEulerDiscreteScheduler
,
WanVACEPipeline
,
WanVACETransformer3DModel
from
diffusers.utils.import_utils
import
is_peft_available
from
diffusers.utils.testing_utils
import
(
floats_tensor
,
require_peft_backend
,
require_peft_version_greater
,
skip_mps
,
torch_device
,
)
if
is_peft_available
():
from
peft.utils
import
get_peft_model_state_dict
sys
.
path
.
append
(
"."
)
from
utils
import
PeftLoraLoaderMixinTests
# noqa: E402
@
require_peft_backend
@
skip_mps
class
WanVACELoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
WanVACEPipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
scheduler_kwargs
=
{}
transformer_kwargs
=
{
"patch_size"
:
(
1
,
2
,
2
),
"num_attention_heads"
:
2
,
"attention_head_dim"
:
8
,
"in_channels"
:
4
,
"out_channels"
:
4
,
"text_dim"
:
32
,
"freq_dim"
:
16
,
"ffn_dim"
:
16
,
"num_layers"
:
2
,
"cross_attn_norm"
:
True
,
"qk_norm"
:
"rms_norm_across_heads"
,
"rope_max_seq_len"
:
16
,
"vace_layers"
:
[
0
],
"vace_in_channels"
:
72
,
}
transformer_cls
=
WanVACETransformer3DModel
vae_kwargs
=
{
"base_dim"
:
3
,
"z_dim"
:
4
,
"dim_mult"
:
[
1
,
1
,
1
,
1
],
"latents_mean"
:
torch
.
randn
(
4
).
numpy
().
tolist
(),
"latents_std"
:
torch
.
randn
(
4
).
numpy
().
tolist
(),
"num_res_blocks"
:
1
,
"temperal_downsample"
:
[
False
,
True
,
True
],
}
vae_cls
=
AutoencoderKLWan
has_two_text_encoders
=
True
tokenizer_cls
,
tokenizer_id
=
AutoTokenizer
,
"hf-internal-testing/tiny-random-t5"
text_encoder_cls
,
text_encoder_id
=
T5EncoderModel
,
"hf-internal-testing/tiny-random-t5"
text_encoder_target_modules
=
[
"q"
,
"k"
,
"v"
,
"o"
]
@
property
def
output_shape
(
self
):
return
(
1
,
9
,
16
,
16
,
3
)
def
get_dummy_inputs
(
self
,
with_generator
=
True
):
batch_size
=
1
sequence_length
=
16
num_channels
=
4
num_frames
=
9
num_latent_frames
=
3
# (num_frames - 1) // temporal_compression_ratio + 1
sizes
=
(
4
,
4
)
height
,
width
=
16
,
16
generator
=
torch
.
manual_seed
(
0
)
noise
=
floats_tensor
((
batch_size
,
num_latent_frames
,
num_channels
)
+
sizes
)
input_ids
=
torch
.
randint
(
1
,
sequence_length
,
size
=
(
batch_size
,
sequence_length
),
generator
=
generator
)
video
=
[
Image
.
new
(
"RGB"
,
(
height
,
width
))]
*
num_frames
mask
=
[
Image
.
new
(
"L"
,
(
height
,
width
),
0
)]
*
num_frames
pipeline_inputs
=
{
"video"
:
video
,
"mask"
:
mask
,
"prompt"
:
""
,
"num_frames"
:
num_frames
,
"num_inference_steps"
:
1
,
"guidance_scale"
:
6.0
,
"height"
:
height
,
"width"
:
height
,
"max_sequence_length"
:
sequence_length
,
"output_type"
:
"np"
,
}
if
with_generator
:
pipeline_inputs
.
update
({
"generator"
:
generator
})
return
noise
,
input_ids
,
pipeline_inputs
def
test_simple_inference_with_text_lora_denoiser_fused_multi
(
self
):
super
().
test_simple_inference_with_text_lora_denoiser_fused_multi
(
expected_atol
=
9e-3
)
def
test_simple_inference_with_text_denoiser_lora_unfused
(
self
):
super
().
test_simple_inference_with_text_denoiser_lora_unfused
(
expected_atol
=
9e-3
)
@
unittest
.
skip
(
"Not supported in Wan VACE."
)
def
test_simple_inference_with_text_denoiser_block_scale
(
self
):
pass
@
unittest
.
skip
(
"Not supported in Wan VACE."
)
def
test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options
(
self
):
pass
@
unittest
.
skip
(
"Not supported in Wan VACE."
)
def
test_modify_padding_mode
(
self
):
pass
@
unittest
.
skip
(
"Text encoder LoRA is not supported in Wan VACE."
)
def
test_simple_inference_with_partial_text_lora
(
self
):
pass
@
unittest
.
skip
(
"Text encoder LoRA is not supported in Wan VACE."
)
def
test_simple_inference_with_text_lora
(
self
):
pass
@
unittest
.
skip
(
"Text encoder LoRA is not supported in Wan VACE."
)
def
test_simple_inference_with_text_lora_and_scale
(
self
):
pass
@
unittest
.
skip
(
"Text encoder LoRA is not supported in Wan VACE."
)
def
test_simple_inference_with_text_lora_fused
(
self
):
pass
@
unittest
.
skip
(
"Text encoder LoRA is not supported in Wan VACE."
)
def
test_simple_inference_with_text_lora_save_load
(
self
):
pass
@
pytest
.
mark
.
xfail
(
condition
=
True
,
reason
=
"RuntimeError: Input type (float) and bias type (c10::BFloat16) should be the same"
,
strict
=
True
,
)
def
test_layerwise_casting_inference_denoiser
(
self
):
super
().
test_layerwise_casting_inference_denoiser
()
@
require_peft_version_greater
(
"0.13.2"
)
def
test_lora_exclude_modules_wanvace
(
self
):
scheduler_cls
=
self
.
scheduler_classes
[
0
]
exclude_module_name
=
"vace_blocks.0.proj_out"
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
).
to
(
torch_device
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
# only supported for `denoiser` now
denoiser_lora_config
.
target_modules
=
[
"proj_out"
]
denoiser_lora_config
.
exclude_modules
=
[
exclude_module_name
]
pipe
,
_
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
=
text_lora_config
,
denoiser_lora_config
=
denoiser_lora_config
)
# The state dict shouldn't contain the modules to be excluded from LoRA.
state_dict_from_model
=
get_peft_model_state_dict
(
pipe
.
transformer
,
adapter_name
=
"default"
)
self
.
assertTrue
(
not
any
(
exclude_module_name
in
k
for
k
in
state_dict_from_model
))
self
.
assertTrue
(
any
(
"proj_out"
in
k
for
k
in
state_dict_from_model
))
output_lora_exclude_modules
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
modules_to_save
=
self
.
_get_modules_to_save
(
pipe
,
has_denoiser
=
True
)
lora_state_dicts
=
self
.
_get_lora_state_dicts
(
modules_to_save
)
self
.
pipeline_class
.
save_lora_weights
(
save_directory
=
tmpdir
,
**
lora_state_dicts
)
pipe
.
unload_lora_weights
()
# Check in the loaded state dict.
loaded_state_dict
=
safetensors
.
torch
.
load_file
(
os
.
path
.
join
(
tmpdir
,
"pytorch_lora_weights.safetensors"
))
self
.
assertTrue
(
not
any
(
exclude_module_name
in
k
for
k
in
loaded_state_dict
))
self
.
assertTrue
(
any
(
"proj_out"
in
k
for
k
in
loaded_state_dict
))
# Check in the state dict obtained after loading LoRA.
pipe
.
load_lora_weights
(
tmpdir
)
state_dict_from_model
=
get_peft_model_state_dict
(
pipe
.
transformer
,
adapter_name
=
"default_0"
)
self
.
assertTrue
(
not
any
(
exclude_module_name
in
k
for
k
in
state_dict_from_model
))
self
.
assertTrue
(
any
(
"proj_out"
in
k
for
k
in
state_dict_from_model
))
output_lora_pretrained
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
not
np
.
allclose
(
output_no_lora
,
output_lora_exclude_modules
,
atol
=
1e-3
,
rtol
=
1e-3
),
"LoRA should change outputs."
,
)
self
.
assertTrue
(
np
.
allclose
(
output_lora_exclude_modules
,
output_lora_pretrained
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Lora outputs should match."
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment