Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
ecfbc8f9
Unverified
Commit
ecfbc8f9
authored
Oct 28, 2025
by
Dhruv Nair
Committed by
GitHub
Oct 28, 2025
Browse files
[Pipelines] Enable Wan VACE to run since single transformer (#12428)
* update * update * update * update * update
parent
df0e2a4f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
137 additions
and
25 deletions
+137
-25
src/diffusers/pipelines/wan/pipeline_wan_vace.py
src/diffusers/pipelines/wan/pipeline_wan_vace.py
+50
-23
tests/pipelines/wan/test_wan_vace.py
tests/pipelines/wan/test_wan_vace.py
+87
-2
No files found.
src/diffusers/pipelines/wan/pipeline_wan_vace.py
View file @
ecfbc8f9
...
@@ -152,34 +152,36 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
...
@@ -152,34 +152,36 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
text_encoder ([`T5EncoderModel`]):
text_encoder ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
transformer ([`WanVACETransformer3DModel`]):
Conditional Transformer to denoise the input latents.
transformer_2 ([`WanVACETransformer3DModel`], *optional*):
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only
`transformer` is used.
scheduler ([`UniPCMultistepScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
scheduler ([`UniPCMultistepScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
transformer ([`WanVACETransformer3DModel`], *optional*):
Conditional Transformer to denoise the input latents during the high-noise stage. In two-stage denoising,
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of
`transformer` or `transformer_2` must be provided.
transformer_2 ([`WanVACETransformer3DModel`], *optional*):
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of
`transformer` or `transformer_2` must be provided.
boundary_ratio (`float`, *optional*, defaults to `None`):
boundary_ratio (`float`, *optional*, defaults to `None`):
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
boundary_timestep. If `None`, only
`
transformer
`
is used for the entire denoising process.
boundary_timestep. If `None`, only
the available
transformer is used for the entire denoising process.
"""
"""
model_cpu_offload_seq
=
"text_encoder->transformer->vae"
model_cpu_offload_seq
=
"text_encoder->transformer->
transformer_2->
vae"
_callback_tensor_inputs
=
[
"latents"
,
"prompt_embeds"
,
"negative_prompt_embeds"
]
_callback_tensor_inputs
=
[
"latents"
,
"prompt_embeds"
,
"negative_prompt_embeds"
]
_optional_components
=
[
"transformer_2"
]
_optional_components
=
[
"transformer"
,
"transformer_2"
]
def
__init__
(
def
__init__
(
self
,
self
,
tokenizer
:
AutoTokenizer
,
tokenizer
:
AutoTokenizer
,
text_encoder
:
UMT5EncoderModel
,
text_encoder
:
UMT5EncoderModel
,
transformer
:
WanVACETransformer3DModel
,
vae
:
AutoencoderKLWan
,
vae
:
AutoencoderKLWan
,
scheduler
:
FlowMatchEulerDiscreteScheduler
,
scheduler
:
FlowMatchEulerDiscreteScheduler
,
transformer
:
WanVACETransformer3DModel
=
None
,
transformer_2
:
WanVACETransformer3DModel
=
None
,
transformer_2
:
WanVACETransformer3DModel
=
None
,
boundary_ratio
:
Optional
[
float
]
=
None
,
boundary_ratio
:
Optional
[
float
]
=
None
,
):
):
...
@@ -336,7 +338,15 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
...
@@ -336,7 +338,15 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
reference_images
=
None
,
reference_images
=
None
,
guidance_scale_2
=
None
,
guidance_scale_2
=
None
,
):
):
if
self
.
transformer
is
not
None
:
base
=
self
.
vae_scale_factor_spatial
*
self
.
transformer
.
config
.
patch_size
[
1
]
base
=
self
.
vae_scale_factor_spatial
*
self
.
transformer
.
config
.
patch_size
[
1
]
elif
self
.
transformer_2
is
not
None
:
base
=
self
.
vae_scale_factor_spatial
*
self
.
transformer_2
.
config
.
patch_size
[
1
]
else
:
raise
ValueError
(
"`transformer` or `transformer_2` component must be set in order to run inference with this pipeline"
)
if
height
%
base
!=
0
or
width
%
base
!=
0
:
if
height
%
base
!=
0
or
width
%
base
!=
0
:
raise
ValueError
(
f
"`height` and `width` have to be divisible by
{
base
}
but are
{
height
}
and
{
width
}
."
)
raise
ValueError
(
f
"`height` and `width` have to be divisible by
{
base
}
but are
{
height
}
and
{
width
}
."
)
...
@@ -414,7 +424,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
...
@@ -414,7 +424,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
device
:
Optional
[
torch
.
device
]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
):
):
if
video
is
not
None
:
if
video
is
not
None
:
base
=
self
.
vae_scale_factor_spatial
*
self
.
transformer
.
config
.
patch_size
[
1
]
base
=
self
.
vae_scale_factor_spatial
*
(
self
.
transformer
.
config
.
patch_size
[
1
]
if
self
.
transformer
is
not
None
else
self
.
transformer_2
.
config
.
patch_size
[
1
]
)
video_height
,
video_width
=
self
.
video_processor
.
get_default_height_width
(
video
[
0
])
video_height
,
video_width
=
self
.
video_processor
.
get_default_height_width
(
video
[
0
])
if
video_height
*
video_width
>
height
*
width
:
if
video_height
*
video_width
>
height
*
width
:
...
@@ -589,7 +603,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
...
@@ -589,7 +603,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
"Generating with more than one video is not yet supported. This may be supported in the future."
"Generating with more than one video is not yet supported. This may be supported in the future."
)
)
transformer_patch_size
=
self
.
transformer
.
config
.
patch_size
[
1
]
transformer_patch_size
=
(
self
.
transformer
.
config
.
patch_size
[
1
]
if
self
.
transformer
is
not
None
else
self
.
transformer_2
.
config
.
patch_size
[
1
]
)
mask_list
=
[]
mask_list
=
[]
for
mask_
,
reference_images_batch
in
zip
(
mask
,
reference_images
):
for
mask_
,
reference_images_batch
in
zip
(
mask
,
reference_images
):
...
@@ -844,20 +862,25 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
...
@@ -844,20 +862,25 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
batch_size
=
prompt_embeds
.
shape
[
0
]
batch_size
=
prompt_embeds
.
shape
[
0
]
vae_dtype
=
self
.
vae
.
dtype
vae_dtype
=
self
.
vae
.
dtype
transformer_dtype
=
self
.
transformer
.
dtype
transformer_dtype
=
self
.
transformer
.
dtype
if
self
.
transformer
is
not
None
else
self
.
transformer_2
.
dtype
vace_layers
=
(
self
.
transformer
.
config
.
vace_layers
if
self
.
transformer
is
not
None
else
self
.
transformer_2
.
config
.
vace_layers
)
if
isinstance
(
conditioning_scale
,
(
int
,
float
)):
if
isinstance
(
conditioning_scale
,
(
int
,
float
)):
conditioning_scale
=
[
conditioning_scale
]
*
len
(
self
.
transformer
.
config
.
vace_layers
)
conditioning_scale
=
[
conditioning_scale
]
*
len
(
vace_layers
)
if
isinstance
(
conditioning_scale
,
list
):
if
isinstance
(
conditioning_scale
,
list
):
if
len
(
conditioning_scale
)
!=
len
(
self
.
transformer
.
config
.
vace_layers
):
if
len
(
conditioning_scale
)
!=
len
(
vace_layers
):
raise
ValueError
(
raise
ValueError
(
f
"Length of `conditioning_scale`
{
len
(
conditioning_scale
)
}
does not match number of layers
{
len
(
self
.
transformer
.
config
.
vace_layers
)
}
."
f
"Length of `conditioning_scale`
{
len
(
conditioning_scale
)
}
does not match number of layers
{
len
(
vace_layers
)
}
."
)
)
conditioning_scale
=
torch
.
tensor
(
conditioning_scale
)
conditioning_scale
=
torch
.
tensor
(
conditioning_scale
)
if
isinstance
(
conditioning_scale
,
torch
.
Tensor
):
if
isinstance
(
conditioning_scale
,
torch
.
Tensor
):
if
conditioning_scale
.
size
(
0
)
!=
len
(
self
.
transformer
.
config
.
vace_layers
):
if
conditioning_scale
.
size
(
0
)
!=
len
(
vace_layers
):
raise
ValueError
(
raise
ValueError
(
f
"Length of `conditioning_scale`
{
conditioning_scale
.
size
(
0
)
}
does not match number of layers
{
len
(
self
.
transformer
.
config
.
vace_layers
)
}
."
f
"Length of `conditioning_scale`
{
conditioning_scale
.
size
(
0
)
}
does not match number of layers
{
len
(
vace_layers
)
}
."
)
)
conditioning_scale
=
conditioning_scale
.
to
(
device
=
device
,
dtype
=
transformer_dtype
)
conditioning_scale
=
conditioning_scale
.
to
(
device
=
device
,
dtype
=
transformer_dtype
)
...
@@ -900,7 +923,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
...
@@ -900,7 +923,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
conditioning_latents
=
torch
.
cat
([
conditioning_latents
,
mask
],
dim
=
1
)
conditioning_latents
=
torch
.
cat
([
conditioning_latents
,
mask
],
dim
=
1
)
conditioning_latents
=
conditioning_latents
.
to
(
transformer_dtype
)
conditioning_latents
=
conditioning_latents
.
to
(
transformer_dtype
)
num_channels_latents
=
self
.
transformer
.
config
.
in_channels
num_channels_latents
=
(
self
.
transformer
.
config
.
in_channels
if
self
.
transformer
is
not
None
else
self
.
transformer_2
.
config
.
in_channels
)
latents
=
self
.
prepare_latents
(
latents
=
self
.
prepare_latents
(
batch_size
*
num_videos_per_prompt
,
batch_size
*
num_videos_per_prompt
,
num_channels_latents
,
num_channels_latents
,
...
@@ -968,7 +995,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
...
@@ -968,7 +995,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
attention_kwargs
=
attention_kwargs
,
attention_kwargs
=
attention_kwargs
,
return_dict
=
False
,
return_dict
=
False
,
)[
0
]
)[
0
]
noise_pred
=
noise_uncond
+
guidance_scale
*
(
noise_pred
-
noise_uncond
)
noise_pred
=
noise_uncond
+
current_
guidance_scale
*
(
noise_pred
-
noise_uncond
)
# compute the previous noisy sample x_t -> x_t-1
# compute the previous noisy sample x_t -> x_t-1
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
return_dict
=
False
)[
0
]
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
return_dict
=
False
)[
0
]
...
...
tests/pipelines/wan/test_wan_vace.py
View file @
ecfbc8f9
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
tempfile
import
unittest
import
unittest
import
numpy
as
np
import
numpy
as
np
...
@@ -19,9 +20,15 @@ import torch
...
@@ -19,9 +20,15 @@ import torch
from
PIL
import
Image
from
PIL
import
Image
from
transformers
import
AutoTokenizer
,
T5EncoderModel
from
transformers
import
AutoTokenizer
,
T5EncoderModel
from
diffusers
import
AutoencoderKLWan
,
FlowMatchEulerDiscreteScheduler
,
WanVACEPipeline
,
WanVACETransformer3DModel
from
diffusers
import
(
AutoencoderKLWan
,
FlowMatchEulerDiscreteScheduler
,
UniPCMultistepScheduler
,
WanVACEPipeline
,
WanVACETransformer3DModel
,
)
from
...testing_utils
import
enable_full_determinism
from
...testing_utils
import
enable_full_determinism
,
torch_device
from
..pipeline_params
import
TEXT_TO_IMAGE_BATCH_PARAMS
,
TEXT_TO_IMAGE_IMAGE_PARAMS
,
TEXT_TO_IMAGE_PARAMS
from
..pipeline_params
import
TEXT_TO_IMAGE_BATCH_PARAMS
,
TEXT_TO_IMAGE_IMAGE_PARAMS
,
TEXT_TO_IMAGE_PARAMS
from
..test_pipelines_common
import
PipelineTesterMixin
from
..test_pipelines_common
import
PipelineTesterMixin
...
@@ -212,3 +219,81 @@ class WanVACEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
...
@@ -212,3 +219,81 @@ class WanVACEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
)
)
def
test_save_load_float16
(
self
):
def
test_save_load_float16
(
self
):
pass
pass
def
test_inference_with_only_transformer
(
self
):
components
=
self
.
get_dummy_components
()
components
[
"transformer_2"
]
=
None
components
[
"boundary_ratio"
]
=
0.0
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
inputs
=
self
.
get_dummy_inputs
(
torch_device
)
video
=
pipe
(
**
inputs
).
frames
[
0
]
assert
video
.
shape
==
(
17
,
3
,
16
,
16
)
def
test_inference_with_only_transformer_2
(
self
):
components
=
self
.
get_dummy_components
()
components
[
"transformer_2"
]
=
components
[
"transformer"
]
components
[
"transformer"
]
=
None
# FlowMatchEulerDiscreteScheduler doesn't support running low noise only scheduler
# because starting timestep t == 1000 == boundary_timestep
components
[
"scheduler"
]
=
UniPCMultistepScheduler
(
prediction_type
=
"flow_prediction"
,
use_flow_sigmas
=
True
,
flow_shift
=
3.0
)
components
[
"boundary_ratio"
]
=
1.0
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
inputs
=
self
.
get_dummy_inputs
(
torch_device
)
video
=
pipe
(
**
inputs
).
frames
[
0
]
assert
video
.
shape
==
(
17
,
3
,
16
,
16
)
def
test_save_load_optional_components
(
self
,
expected_max_difference
=
1e-4
):
optional_component
=
[
"transformer"
]
components
=
self
.
get_dummy_components
()
components
[
"transformer_2"
]
=
components
[
"transformer"
]
# FlowMatchEulerDiscreteScheduler doesn't support running low noise only scheduler
# because starting timestep t == 1000 == boundary_timestep
components
[
"scheduler"
]
=
UniPCMultistepScheduler
(
prediction_type
=
"flow_prediction"
,
use_flow_sigmas
=
True
,
flow_shift
=
3.0
)
for
component
in
optional_component
:
components
[
component
]
=
None
components
[
"boundary_ratio"
]
=
1.0
pipe
=
self
.
pipeline_class
(
**
components
)
for
component
in
pipe
.
components
.
values
():
if
hasattr
(
component
,
"set_default_attn_processor"
):
component
.
set_default_attn_processor
()
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
generator_device
=
"cpu"
inputs
=
self
.
get_dummy_inputs
(
generator_device
)
torch
.
manual_seed
(
0
)
output
=
pipe
(
**
inputs
)[
0
]
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
pipe
.
save_pretrained
(
tmpdir
,
safe_serialization
=
False
)
pipe_loaded
=
self
.
pipeline_class
.
from_pretrained
(
tmpdir
)
for
component
in
pipe_loaded
.
components
.
values
():
if
hasattr
(
component
,
"set_default_attn_processor"
):
component
.
set_default_attn_processor
()
pipe_loaded
.
to
(
torch_device
)
pipe_loaded
.
set_progress_bar_config
(
disable
=
None
)
for
component
in
optional_component
:
assert
getattr
(
pipe_loaded
,
component
)
is
None
,
f
"`
{
component
}
` did not stay set to None after loading."
inputs
=
self
.
get_dummy_inputs
(
generator_device
)
torch
.
manual_seed
(
0
)
output_loaded
=
pipe_loaded
(
**
inputs
)[
0
]
max_diff
=
np
.
abs
(
output
.
detach
().
cpu
().
numpy
()
-
output_loaded
.
detach
().
cpu
().
numpy
()).
max
()
assert
max_diff
<
expected_max_difference
,
"Outputs exceed expecpted maximum difference"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment