Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
f781b8c3
Unverified
Commit
f781b8c3
authored
Dec 19, 2024
by
Aryan
Committed by
GitHub
Dec 19, 2024
Browse files
Hunyuan VAE tiling fixes and transformer docs (#10295)
* update * udpate * fix test
parent
9c0e20de
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
69 additions
and
4 deletions
+69
-4
src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py
...users/models/autoencoders/autoencoder_kl_hunyuan_video.py
+4
-4
src/diffusers/models/transformers/transformer_hunyuan_video.py
...iffusers/models/transformers/transformer_hunyuan_video.py
+40
-0
tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py
...els/autoencoders/test_models_autoencoder_hunyuan_video.py
+25
-0
No files found.
src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py
View file @
f781b8c3
...
@@ -792,12 +792,12 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
...
@@ -792,12 +792,12 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
# The minimal tile height and width for spatial tiling to be used
# The minimal tile height and width for spatial tiling to be used
self
.
tile_sample_min_height
=
256
self
.
tile_sample_min_height
=
256
self
.
tile_sample_min_width
=
256
self
.
tile_sample_min_width
=
256
self
.
tile_sample_min_num_frames
=
6
4
self
.
tile_sample_min_num_frames
=
1
6
# The minimal distance between two spatial tiles
# The minimal distance between two spatial tiles
self
.
tile_sample_stride_height
=
192
self
.
tile_sample_stride_height
=
192
self
.
tile_sample_stride_width
=
192
self
.
tile_sample_stride_width
=
192
self
.
tile_sample_stride_num_frames
=
48
self
.
tile_sample_stride_num_frames
=
12
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
if
isinstance
(
module
,
(
HunyuanVideoEncoder3D
,
HunyuanVideoDecoder3D
)):
if
isinstance
(
module
,
(
HunyuanVideoEncoder3D
,
HunyuanVideoDecoder3D
)):
...
@@ -1003,7 +1003,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
...
@@ -1003,7 +1003,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
for
i
in
range
(
0
,
height
,
self
.
tile_sample_stride_height
):
for
i
in
range
(
0
,
height
,
self
.
tile_sample_stride_height
):
row
=
[]
row
=
[]
for
j
in
range
(
0
,
width
,
self
.
tile_sample_stride_width
):
for
j
in
range
(
0
,
width
,
self
.
tile_sample_stride_width
):
tile
=
x
[:,
:,
:,
i
:
i
+
self
.
tile_sample_min_
size
,
j
:
j
+
self
.
tile_sample_min_
size
]
tile
=
x
[:,
:,
:,
i
:
i
+
self
.
tile_sample_min_
height
,
j
:
j
+
self
.
tile_sample_min_
width
]
tile
=
self
.
encoder
(
tile
)
tile
=
self
.
encoder
(
tile
)
tile
=
self
.
quant_conv
(
tile
)
tile
=
self
.
quant_conv
(
tile
)
row
.
append
(
tile
)
row
.
append
(
tile
)
...
@@ -1020,7 +1020,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
...
@@ -1020,7 +1020,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
if
j
>
0
:
if
j
>
0
:
tile
=
self
.
blend_h
(
row
[
j
-
1
],
tile
,
blend_width
)
tile
=
self
.
blend_h
(
row
[
j
-
1
],
tile
,
blend_width
)
result_row
.
append
(
tile
[:,
:,
:,
:
tile_latent_stride_height
,
:
tile_latent_stride_width
])
result_row
.
append
(
tile
[:,
:,
:,
:
tile_latent_stride_height
,
:
tile_latent_stride_width
])
result_rows
.
append
(
torch
.
cat
(
result_row
,
dim
=
-
1
))
result_rows
.
append
(
torch
.
cat
(
result_row
,
dim
=
4
))
enc
=
torch
.
cat
(
result_rows
,
dim
=
3
)[:,
:,
:,
:
latent_height
,
:
latent_width
]
enc
=
torch
.
cat
(
result_rows
,
dim
=
3
)[:,
:,
:,
:
latent_height
,
:
latent_width
]
return
enc
return
enc
...
...
src/diffusers/models/transformers/transformer_hunyuan_video.py
View file @
f781b8c3
...
@@ -497,6 +497,46 @@ class HunyuanVideoTransformerBlock(nn.Module):
...
@@ -497,6 +497,46 @@ class HunyuanVideoTransformerBlock(nn.Module):
class
HunyuanVideoTransformer3DModel
(
ModelMixin
,
ConfigMixin
):
class
HunyuanVideoTransformer3DModel
(
ModelMixin
,
ConfigMixin
):
r
"""
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
Args:
in_channels (`int`, defaults to `16`):
The number of channels in the input.
out_channels (`int`, defaults to `16`):
The number of channels in the output.
num_attention_heads (`int`, defaults to `24`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, defaults to `128`):
The number of channels in each head.
num_layers (`int`, defaults to `20`):
The number of layers of dual-stream blocks to use.
num_single_layers (`int`, defaults to `40`):
The number of layers of single-stream blocks to use.
num_refiner_layers (`int`, defaults to `2`):
The number of layers of refiner blocks to use.
mlp_ratio (`float`, defaults to `4.0`):
The ratio of the hidden layer size to the input size in the feedforward network.
patch_size (`int`, defaults to `2`):
The size of the spatial patches to use in the patch embedding layer.
patch_size_t (`int`, defaults to `1`):
The size of the tmeporal patches to use in the patch embedding layer.
qk_norm (`str`, defaults to `rms_norm`):
The normalization to use for the query and key projections in the attention layers.
guidance_embeds (`bool`, defaults to `True`):
Whether to use guidance embeddings in the model.
text_embed_dim (`int`, defaults to `4096`):
Input dimension of text embeddings from the text encoder.
pooled_projection_dim (`int`, defaults to `768`):
The dimension of the pooled projection of the text embeddings.
rope_theta (`float`, defaults to `256.0`):
The value of theta to use in the RoPE layer.
rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
The dimensions of the axes to use in the RoPE layer.
"""
_supports_gradient_checkpointing
=
True
@
register_to_config
@
register_to_config
def
__init__
(
def
__init__
(
self
,
self
,
...
...
tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py
View file @
f781b8c3
...
@@ -43,10 +43,14 @@ class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest
...
@@ -43,10 +43,14 @@ class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest
"down_block_types"
:
(
"down_block_types"
:
(
"HunyuanVideoDownBlock3D"
,
"HunyuanVideoDownBlock3D"
,
"HunyuanVideoDownBlock3D"
,
"HunyuanVideoDownBlock3D"
,
"HunyuanVideoDownBlock3D"
,
"HunyuanVideoDownBlock3D"
,
),
),
"up_block_types"
:
(
"up_block_types"
:
(
"HunyuanVideoUpBlock3D"
,
"HunyuanVideoUpBlock3D"
,
"HunyuanVideoUpBlock3D"
,
"HunyuanVideoUpBlock3D"
,
"HunyuanVideoUpBlock3D"
,
"HunyuanVideoUpBlock3D"
,
),
),
"block_out_channels"
:
(
8
,
8
,
8
,
8
),
"block_out_channels"
:
(
8
,
8
,
8
,
8
),
"layers_per_block"
:
1
,
"layers_per_block"
:
1
,
...
@@ -154,6 +158,27 @@ class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest
...
@@ -154,6 +158,27 @@ class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest
}
}
super
().
test_gradient_checkpointing_is_applied
(
expected_set
=
expected_set
)
super
().
test_gradient_checkpointing_is_applied
(
expected_set
=
expected_set
)
# We need to overwrite this test because the base test does not account length of down_block_types
def
test_forward_with_norm_groups
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
init_dict
[
"norm_num_groups"
]
=
16
init_dict
[
"block_out_channels"
]
=
(
16
,
16
,
16
,
16
)
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
eval
()
with
torch
.
no_grad
():
output
=
model
(
**
inputs_dict
)
if
isinstance
(
output
,
dict
):
output
=
output
.
to_tuple
()[
0
]
self
.
assertIsNotNone
(
output
)
expected_shape
=
inputs_dict
[
"sample"
].
shape
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
@
unittest
.
skip
(
"Unsupported test."
)
@
unittest
.
skip
(
"Unsupported test."
)
def
test_outputs_equivalence
(
self
):
def
test_outputs_equivalence
(
self
):
pass
pass
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment