Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
ecfbc8f9
Unverified
Commit
ecfbc8f9
authored
Oct 28, 2025
by
Dhruv Nair
Committed by
GitHub
Oct 28, 2025
Browse files
[Pipelines] Enable Wan VACE to run since single transformer (#12428)
* update * update * update * update * update
parent
df0e2a4f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
137 additions
and
25 deletions
+137
-25
src/diffusers/pipelines/wan/pipeline_wan_vace.py
src/diffusers/pipelines/wan/pipeline_wan_vace.py
+50
-23
tests/pipelines/wan/test_wan_vace.py
tests/pipelines/wan/test_wan_vace.py
+87
-2
No files found.
src/diffusers/pipelines/wan/pipeline_wan_vace.py
View file @
ecfbc8f9
...
...
@@ -152,34 +152,36 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
text_encoder ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
transformer ([`WanVACETransformer3DModel`]):
Conditional Transformer to denoise the input latents.
transformer_2 ([`WanVACETransformer3DModel`], *optional*):
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only
`transformer` is used.
scheduler ([`UniPCMultistepScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
scheduler ([`UniPCMultistepScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
transformer ([`WanVACETransformer3DModel`], *optional*):
Conditional Transformer to denoise the input latents during the high-noise stage. In two-stage denoising,
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of
`transformer` or `transformer_2` must be provided.
transformer_2 ([`WanVACETransformer3DModel`], *optional*):
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of
`transformer` or `transformer_2` must be provided.
boundary_ratio (`float`, *optional*, defaults to `None`):
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
boundary_timestep. If `None`, only
`
transformer
`
is used for the entire denoising process.
boundary_timestep. If `None`, only
the available
transformer is used for the entire denoising process.
"""
model_cpu_offload_seq
=
"text_encoder->transformer->vae"
model_cpu_offload_seq
=
"text_encoder->transformer->
transformer_2->
vae"
_callback_tensor_inputs
=
[
"latents"
,
"prompt_embeds"
,
"negative_prompt_embeds"
]
_optional_components
=
[
"transformer_2"
]
_optional_components
=
[
"transformer"
,
"transformer_2"
]
def
__init__
(
self
,
tokenizer
:
AutoTokenizer
,
text_encoder
:
UMT5EncoderModel
,
transformer
:
WanVACETransformer3DModel
,
vae
:
AutoencoderKLWan
,
scheduler
:
FlowMatchEulerDiscreteScheduler
,
transformer
:
WanVACETransformer3DModel
=
None
,
transformer_2
:
WanVACETransformer3DModel
=
None
,
boundary_ratio
:
Optional
[
float
]
=
None
,
):
...
...
@@ -336,7 +338,15 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
reference_images
=
None
,
guidance_scale_2
=
None
,
):
base
=
self
.
vae_scale_factor_spatial
*
self
.
transformer
.
config
.
patch_size
[
1
]
if
self
.
transformer
is
not
None
:
base
=
self
.
vae_scale_factor_spatial
*
self
.
transformer
.
config
.
patch_size
[
1
]
elif
self
.
transformer_2
is
not
None
:
base
=
self
.
vae_scale_factor_spatial
*
self
.
transformer_2
.
config
.
patch_size
[
1
]
else
:
raise
ValueError
(
"`transformer` or `transformer_2` component must be set in order to run inference with this pipeline"
)
if
height
%
base
!=
0
or
width
%
base
!=
0
:
raise
ValueError
(
f
"`height` and `width` have to be divisible by
{
base
}
but are
{
height
}
and
{
width
}
."
)
...
...
@@ -414,7 +424,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
device
:
Optional
[
torch
.
device
]
=
None
,
):
if
video
is
not
None
:
base
=
self
.
vae_scale_factor_spatial
*
self
.
transformer
.
config
.
patch_size
[
1
]
base
=
self
.
vae_scale_factor_spatial
*
(
self
.
transformer
.
config
.
patch_size
[
1
]
if
self
.
transformer
is
not
None
else
self
.
transformer_2
.
config
.
patch_size
[
1
]
)
video_height
,
video_width
=
self
.
video_processor
.
get_default_height_width
(
video
[
0
])
if
video_height
*
video_width
>
height
*
width
:
...
...
@@ -589,7 +603,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
"Generating with more than one video is not yet supported. This may be supported in the future."
)
transformer_patch_size
=
self
.
transformer
.
config
.
patch_size
[
1
]
transformer_patch_size
=
(
self
.
transformer
.
config
.
patch_size
[
1
]
if
self
.
transformer
is
not
None
else
self
.
transformer_2
.
config
.
patch_size
[
1
]
)
mask_list
=
[]
for
mask_
,
reference_images_batch
in
zip
(
mask
,
reference_images
):
...
...
@@ -844,20 +862,25 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
batch_size
=
prompt_embeds
.
shape
[
0
]
vae_dtype
=
self
.
vae
.
dtype
transformer_dtype
=
self
.
transformer
.
dtype
transformer_dtype
=
self
.
transformer
.
dtype
if
self
.
transformer
is
not
None
else
self
.
transformer_2
.
dtype
vace_layers
=
(
self
.
transformer
.
config
.
vace_layers
if
self
.
transformer
is
not
None
else
self
.
transformer_2
.
config
.
vace_layers
)
if
isinstance
(
conditioning_scale
,
(
int
,
float
)):
conditioning_scale
=
[
conditioning_scale
]
*
len
(
self
.
transformer
.
config
.
vace_layers
)
conditioning_scale
=
[
conditioning_scale
]
*
len
(
vace_layers
)
if
isinstance
(
conditioning_scale
,
list
):
if
len
(
conditioning_scale
)
!=
len
(
self
.
transformer
.
config
.
vace_layers
):
if
len
(
conditioning_scale
)
!=
len
(
vace_layers
):
raise
ValueError
(
f
"Length of `conditioning_scale`
{
len
(
conditioning_scale
)
}
does not match number of layers
{
len
(
self
.
transformer
.
config
.
vace_layers
)
}
."
f
"Length of `conditioning_scale`
{
len
(
conditioning_scale
)
}
does not match number of layers
{
len
(
vace_layers
)
}
."
)
conditioning_scale
=
torch
.
tensor
(
conditioning_scale
)
if
isinstance
(
conditioning_scale
,
torch
.
Tensor
):
if
conditioning_scale
.
size
(
0
)
!=
len
(
self
.
transformer
.
config
.
vace_layers
):
if
conditioning_scale
.
size
(
0
)
!=
len
(
vace_layers
):
raise
ValueError
(
f
"Length of `conditioning_scale`
{
conditioning_scale
.
size
(
0
)
}
does not match number of layers
{
len
(
self
.
transformer
.
config
.
vace_layers
)
}
."
f
"Length of `conditioning_scale`
{
conditioning_scale
.
size
(
0
)
}
does not match number of layers
{
len
(
vace_layers
)
}
."
)
conditioning_scale
=
conditioning_scale
.
to
(
device
=
device
,
dtype
=
transformer_dtype
)
...
...
@@ -900,7 +923,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
conditioning_latents
=
torch
.
cat
([
conditioning_latents
,
mask
],
dim
=
1
)
conditioning_latents
=
conditioning_latents
.
to
(
transformer_dtype
)
num_channels_latents
=
self
.
transformer
.
config
.
in_channels
num_channels_latents
=
(
self
.
transformer
.
config
.
in_channels
if
self
.
transformer
is
not
None
else
self
.
transformer_2
.
config
.
in_channels
)
latents
=
self
.
prepare_latents
(
batch_size
*
num_videos_per_prompt
,
num_channels_latents
,
...
...
@@ -968,7 +995,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
attention_kwargs
=
attention_kwargs
,
return_dict
=
False
,
)[
0
]
noise_pred
=
noise_uncond
+
guidance_scale
*
(
noise_pred
-
noise_uncond
)
noise_pred
=
noise_uncond
+
current_
guidance_scale
*
(
noise_pred
-
noise_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
return_dict
=
False
)[
0
]
...
...
tests/pipelines/wan/test_wan_vace.py
View file @
ecfbc8f9
...
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
tempfile
import
unittest
import
numpy
as
np
...
...
@@ -19,9 +20,15 @@ import torch
from
PIL
import
Image
from
transformers
import
AutoTokenizer
,
T5EncoderModel
from
diffusers
import
AutoencoderKLWan
,
FlowMatchEulerDiscreteScheduler
,
WanVACEPipeline
,
WanVACETransformer3DModel
from
diffusers
import
(
AutoencoderKLWan
,
FlowMatchEulerDiscreteScheduler
,
UniPCMultistepScheduler
,
WanVACEPipeline
,
WanVACETransformer3DModel
,
)
from
...testing_utils
import
enable_full_determinism
from
...testing_utils
import
enable_full_determinism
,
torch_device
from
..pipeline_params
import
TEXT_TO_IMAGE_BATCH_PARAMS
,
TEXT_TO_IMAGE_IMAGE_PARAMS
,
TEXT_TO_IMAGE_PARAMS
from
..test_pipelines_common
import
PipelineTesterMixin
...
...
@@ -212,3 +219,81 @@ class WanVACEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
)
def
test_save_load_float16
(
self
):
pass
def
test_inference_with_only_transformer
(
self
):
components
=
self
.
get_dummy_components
()
components
[
"transformer_2"
]
=
None
components
[
"boundary_ratio"
]
=
0.0
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
inputs
=
self
.
get_dummy_inputs
(
torch_device
)
video
=
pipe
(
**
inputs
).
frames
[
0
]
assert
video
.
shape
==
(
17
,
3
,
16
,
16
)
def
test_inference_with_only_transformer_2
(
self
):
components
=
self
.
get_dummy_components
()
components
[
"transformer_2"
]
=
components
[
"transformer"
]
components
[
"transformer"
]
=
None
# FlowMatchEulerDiscreteScheduler doesn't support running low noise only scheduler
# because starting timestep t == 1000 == boundary_timestep
components
[
"scheduler"
]
=
UniPCMultistepScheduler
(
prediction_type
=
"flow_prediction"
,
use_flow_sigmas
=
True
,
flow_shift
=
3.0
)
components
[
"boundary_ratio"
]
=
1.0
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
inputs
=
self
.
get_dummy_inputs
(
torch_device
)
video
=
pipe
(
**
inputs
).
frames
[
0
]
assert
video
.
shape
==
(
17
,
3
,
16
,
16
)
def
test_save_load_optional_components
(
self
,
expected_max_difference
=
1e-4
):
optional_component
=
[
"transformer"
]
components
=
self
.
get_dummy_components
()
components
[
"transformer_2"
]
=
components
[
"transformer"
]
# FlowMatchEulerDiscreteScheduler doesn't support running low noise only scheduler
# because starting timestep t == 1000 == boundary_timestep
components
[
"scheduler"
]
=
UniPCMultistepScheduler
(
prediction_type
=
"flow_prediction"
,
use_flow_sigmas
=
True
,
flow_shift
=
3.0
)
for
component
in
optional_component
:
components
[
component
]
=
None
components
[
"boundary_ratio"
]
=
1.0
pipe
=
self
.
pipeline_class
(
**
components
)
for
component
in
pipe
.
components
.
values
():
if
hasattr
(
component
,
"set_default_attn_processor"
):
component
.
set_default_attn_processor
()
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
generator_device
=
"cpu"
inputs
=
self
.
get_dummy_inputs
(
generator_device
)
torch
.
manual_seed
(
0
)
output
=
pipe
(
**
inputs
)[
0
]
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
pipe
.
save_pretrained
(
tmpdir
,
safe_serialization
=
False
)
pipe_loaded
=
self
.
pipeline_class
.
from_pretrained
(
tmpdir
)
for
component
in
pipe_loaded
.
components
.
values
():
if
hasattr
(
component
,
"set_default_attn_processor"
):
component
.
set_default_attn_processor
()
pipe_loaded
.
to
(
torch_device
)
pipe_loaded
.
set_progress_bar_config
(
disable
=
None
)
for
component
in
optional_component
:
assert
getattr
(
pipe_loaded
,
component
)
is
None
,
f
"`
{
component
}
` did not stay set to None after loading."
inputs
=
self
.
get_dummy_inputs
(
generator_device
)
torch
.
manual_seed
(
0
)
output_loaded
=
pipe_loaded
(
**
inputs
)[
0
]
max_diff
=
np
.
abs
(
output
.
detach
().
cpu
().
numpy
()
-
output_loaded
.
detach
().
cpu
().
numpy
()).
max
()
assert
max_diff
<
expected_max_difference
,
"Outputs exceed expecpted maximum difference"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment