Unverified Commit d70f8ee1 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[WAN] fix recompilation issues (#11475)



* [tests] Add torch.compile() test for WanTransformer3DModel

* fix wan recompilation issues.

* style

---------
Co-authored-by: default avatartongyu0924 <winnie920924@gmail.com>
parent 06beecaf
......@@ -202,8 +202,8 @@ class WanRotaryPosEmbed(nn.Module):
p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
self.freqs = self.freqs.to(hidden_states.device)
freqs = self.freqs.split_with_sizes(
freqs = self.freqs.to(hidden_states.device)
freqs = freqs.split_with_sizes(
[
self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
self.attention_head_dim // 6,
......
......@@ -17,7 +17,14 @@ import unittest
import torch
from diffusers import WanTransformer3DModel
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from diffusers.utils.testing_utils import (
enable_full_determinism,
is_torch_compile,
require_torch_2,
require_torch_gpu,
slow,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
......@@ -79,3 +86,18 @@ class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"WanTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@require_torch_gpu
@require_torch_2
@is_torch_compile
@slow
def test_torch_compile_recompilation_and_graph_break(self):
torch._dynamo.reset()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model = torch.compile(model, fullgraph=True)
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment