Unverified Commit 3a31b291 authored by Roy Hvaara's avatar Roy Hvaara Committed by GitHub
Browse files

Use float32 RoPE freqs in Wan with MPS backends (#11643)

Use float32 for RoPE on MPS in Wan
parent b975bcef
......@@ -72,7 +72,8 @@ class WanAttnProcessor2_0:
if rotary_emb is not None:
def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
dtype = torch.float32 if hidden_states.device.type == "mps" else torch.float64
x_rotated = torch.view_as_complex(hidden_states.to(dtype).unflatten(3, (-1, 2)))
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
return x_out.type_as(hidden_states)
......@@ -190,9 +191,10 @@ class WanRotaryPosEmbed(nn.Module):
t_dim = attention_head_dim - h_dim - w_dim
freqs = []
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
for dim in [t_dim, h_dim, w_dim]:
freq = get_1d_rotary_pos_embed(
dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64
dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=freqs_dtype
)
freqs.append(freq)
self.freqs = torch.cat(freqs, dim=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment