Unverified Commit 8ae8008b authored by Mathias Parger's avatar Mathias Parger Committed by GitHub
Browse files

speedup hunyuan encoder causal mask generation (#10764)

* speedup causal mask generation

* fixing hunyuan attn mask test case
parent c80eda9d
...@@ -36,11 +36,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -36,11 +36,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def prepare_causal_attention_mask( def prepare_causal_attention_mask(
num_frames: int, height_width: int, dtype: torch.dtype, device: torch.device, batch_size: int = None num_frames: int, height_width: int, dtype: torch.dtype, device: torch.device, batch_size: int = None
) -> torch.Tensor: ) -> torch.Tensor:
seq_len = num_frames * height_width indices = torch.arange(1, num_frames + 1, dtype=torch.int32, device=device)
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device) indices_blocks = indices.repeat_interleave(height_width)
for i in range(seq_len): x, y = torch.meshgrid(indices_blocks, indices_blocks, indexing="xy")
i_frame = i // height_width mask = torch.where(x <= y, 0, -float("inf")).to(dtype=dtype)
mask[i, : (i_frame + 1) * height_width] = 0
if batch_size is not None: if batch_size is not None:
mask = mask.unsqueeze(0).expand(batch_size, -1, -1) mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
return mask return mask
......
...@@ -18,6 +18,7 @@ import unittest ...@@ -18,6 +18,7 @@ import unittest
import torch import torch
from diffusers import AutoencoderKLHunyuanVideo from diffusers import AutoencoderKLHunyuanVideo
from diffusers.models.autoencoders.autoencoder_kl_hunyuan_video import prepare_causal_attention_mask
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
enable_full_determinism, enable_full_determinism,
floats_tensor, floats_tensor,
...@@ -182,3 +183,28 @@ class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest ...@@ -182,3 +183,28 @@ class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest
@unittest.skip("Unsupported test.") @unittest.skip("Unsupported test.")
def test_outputs_equivalence(self): def test_outputs_equivalence(self):
pass pass
def test_prepare_causal_attention_mask(self):
def prepare_causal_attention_mask_orig(
num_frames: int, height_width: int, dtype: torch.dtype, device: torch.device, batch_size: int = None
) -> torch.Tensor:
seq_len = num_frames * height_width
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
for i in range(seq_len):
i_frame = i // height_width
mask[i, : (i_frame + 1) * height_width] = 0
if batch_size is not None:
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
return mask
# test with some odd shapes
original_mask = prepare_causal_attention_mask_orig(
num_frames=31, height_width=111, dtype=torch.float32, device=torch_device
)
new_mask = prepare_causal_attention_mask(
num_frames=31, height_width=111, dtype=torch.float32, device=torch_device
)
self.assertTrue(
torch.allclose(original_mask, new_mask),
"Causal attention mask should be the same",
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment