Unverified Commit 6a0137eb authored by C's avatar C Committed by GitHub
Browse files

Fix Graph Breaks When Compiling CogView4 (#10959)



* Fix Graph Breaks When Compiling CogView4

Eliminate this:

```
t]V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles] Recompiling function forward in /home/zeyi/repos/diffusers/src/diffusers/models/transformers/transformer_cogview4.py:374
V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles]     triggered by the following guard failure(s):
V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles]     - 0/3: ___check_obj_id(L['self'].rope.freqs_h, 139976127328032)    
V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles]     - 0/2: ___check_obj_id(L['self'].rope.freqs_h, 139976107780960)    
V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles]     - 0/1: ___check_obj_id(L['self'].rope.freqs_h, 140022511848960)    
V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles]     - 0/0: ___check_obj_id(L['self'].rope.freqs_h, 140024081342416)   
```

* Update transformer_cogview4.py

* fix cogview4 rotary pos embed

* Apply style fixes

---------
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 2e5203be
...@@ -244,30 +244,34 @@ class CogView4RotaryPosEmbed(nn.Module): ...@@ -244,30 +244,34 @@ class CogView4RotaryPosEmbed(nn.Module):
def __init__(self, dim: int, patch_size: int, rope_axes_dim: Tuple[int, int], theta: float = 10000.0) -> None: def __init__(self, dim: int, patch_size: int, rope_axes_dim: Tuple[int, int], theta: float = 10000.0) -> None:
super().__init__() super().__init__()
self.dim = dim
self.patch_size = patch_size self.patch_size = patch_size
self.rope_axes_dim = rope_axes_dim self.rope_axes_dim = rope_axes_dim
self.theta = theta
dim_h, dim_w = dim // 2, dim // 2
h_inv_freq = 1.0 / (theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h))
w_inv_freq = 1.0 / (theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w))
h_seq = torch.arange(self.rope_axes_dim[0])
w_seq = torch.arange(self.rope_axes_dim[1])
self.freqs_h = torch.outer(h_seq, h_inv_freq)
self.freqs_w = torch.outer(w_seq, w_inv_freq)
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, num_channels, height, width = hidden_states.shape batch_size, num_channels, height, width = hidden_states.shape
height, width = height // self.patch_size, width // self.patch_size height, width = height // self.patch_size, width // self.patch_size
h_idx = torch.arange(height) dim_h, dim_w = self.dim // 2, self.dim // 2
w_idx = torch.arange(width) h_inv_freq = 1.0 / (
self.theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h)
)
w_inv_freq = 1.0 / (
self.theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w)
)
h_seq = torch.arange(self.rope_axes_dim[0])
w_seq = torch.arange(self.rope_axes_dim[1])
freqs_h = torch.outer(h_seq, h_inv_freq)
freqs_w = torch.outer(w_seq, w_inv_freq)
h_idx = torch.arange(height, device=freqs_h.device)
w_idx = torch.arange(width, device=freqs_w.device)
inner_h_idx = h_idx * self.rope_axes_dim[0] // height inner_h_idx = h_idx * self.rope_axes_dim[0] // height
inner_w_idx = w_idx * self.rope_axes_dim[1] // width inner_w_idx = w_idx * self.rope_axes_dim[1] // width
self.freqs_h = self.freqs_h.to(hidden_states.device) freqs_h = freqs_h[inner_h_idx]
self.freqs_w = self.freqs_w.to(hidden_states.device) freqs_w = freqs_w[inner_w_idx]
freqs_h = self.freqs_h[inner_h_idx]
freqs_w = self.freqs_w[inner_w_idx]
# Create position matrices for height and width # Create position matrices for height and width
# [height, 1, dim//4] and [1, width, dim//4] # [height, 1, dim//4] and [1, width, dim//4]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment