Unverified Commit 62e847db authored by Mikko Tukiainen's avatar Mikko Tukiainen Committed by GitHub
Browse files

Use real-valued instead of complex tensors in Wan2.1 RoPE (#11649)



* use real instead of complex tensors in Wan2.1 RoPE

* remove the redundant type conversion

* unpack rotary_emb

* register rotary embedding frequencies as non-persistent buffers

* Apply style fixes

---------
Co-authored-by: default avatarAryan <aryan@huggingface.co>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 47045862
...@@ -71,14 +71,22 @@ class WanAttnProcessor2_0: ...@@ -71,14 +71,22 @@ class WanAttnProcessor2_0:
if rotary_emb is not None: if rotary_emb is not None:
def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): def apply_rotary_emb(
dtype = torch.float32 if hidden_states.device.type == "mps" else torch.float64 hidden_states: torch.Tensor,
x_rotated = torch.view_as_complex(hidden_states.to(dtype).unflatten(3, (-1, 2))) freqs_cos: torch.Tensor,
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) freqs_sin: torch.Tensor,
return x_out.type_as(hidden_states) ):
x = hidden_states.view(*hidden_states.shape[:-1], -1, 2)
query = apply_rotary_emb(query, rotary_emb) x1, x2 = x[..., 0], x[..., 1]
key = apply_rotary_emb(key, rotary_emb) cos = freqs_cos[..., 0::2]
sin = freqs_sin[..., 1::2]
out = torch.empty_like(hidden_states)
out[..., 0::2] = x1 * cos - x2 * sin
out[..., 1::2] = x1 * sin + x2 * cos
return out.type_as(hidden_states)
query = apply_rotary_emb(query, *rotary_emb)
key = apply_rotary_emb(key, *rotary_emb)
# I2V task # I2V task
hidden_states_img = None hidden_states_img = None
...@@ -179,7 +187,11 @@ class WanTimeTextImageEmbedding(nn.Module): ...@@ -179,7 +187,11 @@ class WanTimeTextImageEmbedding(nn.Module):
class WanRotaryPosEmbed(nn.Module): class WanRotaryPosEmbed(nn.Module):
def __init__( def __init__(
self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0 self,
attention_head_dim: int,
patch_size: Tuple[int, int, int],
max_seq_len: int,
theta: float = 10000.0,
): ):
super().__init__() super().__init__()
...@@ -189,36 +201,52 @@ class WanRotaryPosEmbed(nn.Module): ...@@ -189,36 +201,52 @@ class WanRotaryPosEmbed(nn.Module):
h_dim = w_dim = 2 * (attention_head_dim // 6) h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim t_dim = attention_head_dim - h_dim - w_dim
freqs = []
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
freqs_cos = []
freqs_sin = []
for dim in [t_dim, h_dim, w_dim]: for dim in [t_dim, h_dim, w_dim]:
freq = get_1d_rotary_pos_embed( freq_cos, freq_sin = get_1d_rotary_pos_embed(
dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=freqs_dtype dim,
max_seq_len,
theta,
use_real=True,
repeat_interleave_real=True,
freqs_dtype=freqs_dtype,
) )
freqs.append(freq) freqs_cos.append(freq_cos)
self.freqs = torch.cat(freqs, dim=1) freqs_sin.append(freq_sin)
self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.patch_size p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
freqs = self.freqs.to(hidden_states.device) split_sizes = [
freqs = freqs.split_with_sizes( self.attention_head_dim - 2 * (self.attention_head_dim // 3),
[ self.attention_head_dim // 3,
self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6), self.attention_head_dim // 3,
self.attention_head_dim // 6, ]
self.attention_head_dim // 6,
], freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
dim=1, freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
)
freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) return freqs_cos, freqs_sin
freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
return freqs
class WanTransformerBlock(nn.Module): class WanTransformerBlock(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment