Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
c97b709a
Unverified
Commit
c97b709a
authored
Apr 02, 2025
by
Dhruv Nair
Committed by
GitHub
Apr 02, 2025
Browse files
Add CacheMixin to Wan and LTX Transformers (#11187)
* update * update * update
parent
b0ff822e
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
25 additions
and
2 deletions
+25
-2
src/diffusers/models/transformers/transformer_ltx.py
src/diffusers/models/transformers/transformer_ltx.py
+2
-1
src/diffusers/models/transformers/transformer_wan.py
src/diffusers/models/transformers/transformer_wan.py
+2
-1
src/diffusers/pipelines/ltx/pipeline_ltx.py
src/diffusers/pipelines/ltx/pipeline_ltx.py
+7
-0
src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
+7
-0
src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
+7
-0
No files found.
src/diffusers/models/transformers/transformer_ltx.py
View file @
c97b709a
...
@@ -26,6 +26,7 @@ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_
...
@@ -26,6 +26,7 @@ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_
from
...utils.torch_utils
import
maybe_allow_in_graph
from
...utils.torch_utils
import
maybe_allow_in_graph
from
..attention
import
FeedForward
from
..attention
import
FeedForward
from
..attention_processor
import
Attention
from
..attention_processor
import
Attention
from
..cache_utils
import
CacheMixin
from
..embeddings
import
PixArtAlphaTextProjection
from
..embeddings
import
PixArtAlphaTextProjection
from
..modeling_outputs
import
Transformer2DModelOutput
from
..modeling_outputs
import
Transformer2DModelOutput
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
...
@@ -298,7 +299,7 @@ class LTXVideoTransformerBlock(nn.Module):
...
@@ -298,7 +299,7 @@ class LTXVideoTransformerBlock(nn.Module):
@
maybe_allow_in_graph
@
maybe_allow_in_graph
class
LTXVideoTransformer3DModel
(
ModelMixin
,
ConfigMixin
,
FromOriginalModelMixin
,
PeftAdapterMixin
):
class
LTXVideoTransformer3DModel
(
ModelMixin
,
ConfigMixin
,
FromOriginalModelMixin
,
PeftAdapterMixin
,
CacheMixin
):
r
"""
r
"""
A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
...
...
src/diffusers/models/transformers/transformer_wan.py
View file @
c97b709a
...
@@ -24,6 +24,7 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
...
@@ -24,6 +24,7 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from
...utils
import
USE_PEFT_BACKEND
,
logging
,
scale_lora_layers
,
unscale_lora_layers
from
...utils
import
USE_PEFT_BACKEND
,
logging
,
scale_lora_layers
,
unscale_lora_layers
from
..attention
import
FeedForward
from
..attention
import
FeedForward
from
..attention_processor
import
Attention
from
..attention_processor
import
Attention
from
..cache_utils
import
CacheMixin
from
..embeddings
import
PixArtAlphaTextProjection
,
TimestepEmbedding
,
Timesteps
,
get_1d_rotary_pos_embed
from
..embeddings
import
PixArtAlphaTextProjection
,
TimestepEmbedding
,
Timesteps
,
get_1d_rotary_pos_embed
from
..modeling_outputs
import
Transformer2DModelOutput
from
..modeling_outputs
import
Transformer2DModelOutput
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
...
@@ -288,7 +289,7 @@ class WanTransformerBlock(nn.Module):
...
@@ -288,7 +289,7 @@ class WanTransformerBlock(nn.Module):
return
hidden_states
return
hidden_states
class
WanTransformer3DModel
(
ModelMixin
,
ConfigMixin
,
PeftAdapterMixin
,
FromOriginalModelMixin
):
class
WanTransformer3DModel
(
ModelMixin
,
ConfigMixin
,
PeftAdapterMixin
,
FromOriginalModelMixin
,
CacheMixin
):
r
"""
r
"""
A Transformer model for video-like data used in the Wan model.
A Transformer model for video-like data used in the Wan model.
...
...
src/diffusers/pipelines/ltx/pipeline_ltx.py
View file @
c97b709a
...
@@ -489,6 +489,10 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
...
@@ -489,6 +489,10 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
def
num_timesteps
(
self
):
def
num_timesteps
(
self
):
return
self
.
_num_timesteps
return
self
.
_num_timesteps
@
property
def
current_timestep
(
self
):
return
self
.
_current_timestep
@
property
@
property
def
attention_kwargs
(
self
):
def
attention_kwargs
(
self
):
return
self
.
_attention_kwargs
return
self
.
_attention_kwargs
...
@@ -622,6 +626,7 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
...
@@ -622,6 +626,7 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
self
.
_guidance_scale
=
guidance_scale
self
.
_guidance_scale
=
guidance_scale
self
.
_attention_kwargs
=
attention_kwargs
self
.
_attention_kwargs
=
attention_kwargs
self
.
_interrupt
=
False
self
.
_interrupt
=
False
self
.
_current_timestep
=
None
# 2. Define call parameters
# 2. Define call parameters
if
prompt
is
not
None
and
isinstance
(
prompt
,
str
):
if
prompt
is
not
None
and
isinstance
(
prompt
,
str
):
...
@@ -706,6 +711,8 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
...
@@ -706,6 +711,8 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
if
self
.
interrupt
:
if
self
.
interrupt
:
continue
continue
self
.
_current_timestep
=
t
latent_model_input
=
torch
.
cat
([
latents
]
*
2
)
if
self
.
do_classifier_free_guidance
else
latents
latent_model_input
=
torch
.
cat
([
latents
]
*
2
)
if
self
.
do_classifier_free_guidance
else
latents
latent_model_input
=
latent_model_input
.
to
(
prompt_embeds
.
dtype
)
latent_model_input
=
latent_model_input
.
to
(
prompt_embeds
.
dtype
)
...
...
src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
View file @
c97b709a
...
@@ -774,6 +774,10 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
...
@@ -774,6 +774,10 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
def
num_timesteps
(
self
):
def
num_timesteps
(
self
):
return
self
.
_num_timesteps
return
self
.
_num_timesteps
@
property
def
current_timestep
(
self
):
return
self
.
_current_timestep
@
property
@
property
def
attention_kwargs
(
self
):
def
attention_kwargs
(
self
):
return
self
.
_attention_kwargs
return
self
.
_attention_kwargs
...
@@ -933,6 +937,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
...
@@ -933,6 +937,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
self
.
_guidance_scale
=
guidance_scale
self
.
_guidance_scale
=
guidance_scale
self
.
_attention_kwargs
=
attention_kwargs
self
.
_attention_kwargs
=
attention_kwargs
self
.
_interrupt
=
False
self
.
_interrupt
=
False
self
.
_current_timestep
=
None
# 2. Define call parameters
# 2. Define call parameters
if
prompt
is
not
None
and
isinstance
(
prompt
,
str
):
if
prompt
is
not
None
and
isinstance
(
prompt
,
str
):
...
@@ -1066,6 +1071,8 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
...
@@ -1066,6 +1071,8 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
if
self
.
interrupt
:
if
self
.
interrupt
:
continue
continue
self
.
_current_timestep
=
t
if
image_cond_noise_scale
>
0
:
if
image_cond_noise_scale
>
0
:
# Add timestep-dependent noise to the hard-conditioning latents
# Add timestep-dependent noise to the hard-conditioning latents
# This helps with motion continuity, especially when conditioned on a single frame
# This helps with motion continuity, especially when conditioned on a single frame
...
...
src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
View file @
c97b709a
...
@@ -550,6 +550,10 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
...
@@ -550,6 +550,10 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
def
num_timesteps
(
self
):
def
num_timesteps
(
self
):
return
self
.
_num_timesteps
return
self
.
_num_timesteps
@
property
def
current_timestep
(
self
):
return
self
.
_current_timestep
@
property
@
property
def
attention_kwargs
(
self
):
def
attention_kwargs
(
self
):
return
self
.
_attention_kwargs
return
self
.
_attention_kwargs
...
@@ -686,6 +690,7 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
...
@@ -686,6 +690,7 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
self
.
_guidance_scale
=
guidance_scale
self
.
_guidance_scale
=
guidance_scale
self
.
_attention_kwargs
=
attention_kwargs
self
.
_attention_kwargs
=
attention_kwargs
self
.
_interrupt
=
False
self
.
_interrupt
=
False
self
.
_current_timestep
=
None
# 2. Define call parameters
# 2. Define call parameters
if
prompt
is
not
None
and
isinstance
(
prompt
,
str
):
if
prompt
is
not
None
and
isinstance
(
prompt
,
str
):
...
@@ -778,6 +783,8 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
...
@@ -778,6 +783,8 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
if
self
.
interrupt
:
if
self
.
interrupt
:
continue
continue
self
.
_current_timestep
=
t
latent_model_input
=
torch
.
cat
([
latents
]
*
2
)
if
self
.
do_classifier_free_guidance
else
latents
latent_model_input
=
torch
.
cat
([
latents
]
*
2
)
if
self
.
do_classifier_free_guidance
else
latents
latent_model_input
=
latent_model_input
.
to
(
prompt_embeds
.
dtype
)
latent_model_input
=
latent_model_input
.
to
(
prompt_embeds
.
dtype
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment