Unverified Commit f781b8c3 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Hunyuan VAE tiling fixes and transformer docs (#10295)

* update

* udpate

* fix test
parent 9c0e20de
...@@ -792,12 +792,12 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin): ...@@ -792,12 +792,12 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
# The minimal tile height and width for spatial tiling to be used # The minimal tile height and width for spatial tiling to be used
self.tile_sample_min_height = 256 self.tile_sample_min_height = 256
self.tile_sample_min_width = 256 self.tile_sample_min_width = 256
self.tile_sample_min_num_frames = 64 self.tile_sample_min_num_frames = 16
# The minimal distance between two spatial tiles # The minimal distance between two spatial tiles
self.tile_sample_stride_height = 192 self.tile_sample_stride_height = 192
self.tile_sample_stride_width = 192 self.tile_sample_stride_width = 192
self.tile_sample_stride_num_frames = 48 self.tile_sample_stride_num_frames = 12
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (HunyuanVideoEncoder3D, HunyuanVideoDecoder3D)): if isinstance(module, (HunyuanVideoEncoder3D, HunyuanVideoDecoder3D)):
...@@ -1003,7 +1003,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin): ...@@ -1003,7 +1003,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
for i in range(0, height, self.tile_sample_stride_height): for i in range(0, height, self.tile_sample_stride_height):
row = [] row = []
for j in range(0, width, self.tile_sample_stride_width): for j in range(0, width, self.tile_sample_stride_width):
tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] tile = x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
tile = self.encoder(tile) tile = self.encoder(tile)
tile = self.quant_conv(tile) tile = self.quant_conv(tile)
row.append(tile) row.append(tile)
...@@ -1020,7 +1020,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin): ...@@ -1020,7 +1020,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
if j > 0: if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width) tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
result_rows.append(torch.cat(result_row, dim=-1)) result_rows.append(torch.cat(result_row, dim=4))
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
return enc return enc
......
...@@ -497,6 +497,46 @@ class HunyuanVideoTransformerBlock(nn.Module): ...@@ -497,6 +497,46 @@ class HunyuanVideoTransformerBlock(nn.Module):
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin): class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin):
r"""
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
Args:
in_channels (`int`, defaults to `16`):
The number of channels in the input.
out_channels (`int`, defaults to `16`):
The number of channels in the output.
num_attention_heads (`int`, defaults to `24`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, defaults to `128`):
The number of channels in each head.
num_layers (`int`, defaults to `20`):
The number of layers of dual-stream blocks to use.
num_single_layers (`int`, defaults to `40`):
The number of layers of single-stream blocks to use.
num_refiner_layers (`int`, defaults to `2`):
The number of layers of refiner blocks to use.
mlp_ratio (`float`, defaults to `4.0`):
The ratio of the hidden layer size to the input size in the feedforward network.
patch_size (`int`, defaults to `2`):
The size of the spatial patches to use in the patch embedding layer.
patch_size_t (`int`, defaults to `1`):
The size of the tmeporal patches to use in the patch embedding layer.
qk_norm (`str`, defaults to `rms_norm`):
The normalization to use for the query and key projections in the attention layers.
guidance_embeds (`bool`, defaults to `True`):
Whether to use guidance embeddings in the model.
text_embed_dim (`int`, defaults to `4096`):
Input dimension of text embeddings from the text encoder.
pooled_projection_dim (`int`, defaults to `768`):
The dimension of the pooled projection of the text embeddings.
rope_theta (`float`, defaults to `256.0`):
The value of theta to use in the RoPE layer.
rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
The dimensions of the axes to use in the RoPE layer.
"""
_supports_gradient_checkpointing = True
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
......
...@@ -43,10 +43,14 @@ class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest ...@@ -43,10 +43,14 @@ class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest
"down_block_types": ( "down_block_types": (
"HunyuanVideoDownBlock3D", "HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D", "HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
), ),
"up_block_types": ( "up_block_types": (
"HunyuanVideoUpBlock3D", "HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D", "HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
), ),
"block_out_channels": (8, 8, 8, 8), "block_out_channels": (8, 8, 8, 8),
"layers_per_block": 1, "layers_per_block": 1,
...@@ -154,6 +158,27 @@ class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest ...@@ -154,6 +158,27 @@ class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest
} }
super().test_gradient_checkpointing_is_applied(expected_set=expected_set) super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
# We need to overwrite this test because the base test does not account length of down_block_types
def test_forward_with_norm_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["norm_num_groups"] = 16
init_dict["block_out_channels"] = (16, 16, 16, 16)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
@unittest.skip("Unsupported test.") @unittest.skip("Unsupported test.")
def test_outputs_equivalence(self): def test_outputs_equivalence(self):
pass pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment