Unverified Commit f4fdb3a0 authored by G.O.D's avatar G.O.D Committed by GitHub
Browse files

fix bug for ascend npu (#10429)

parent 7ab7c121
...@@ -1248,7 +1248,8 @@ class FluxPosEmbed(nn.Module): ...@@ -1248,7 +1248,8 @@ class FluxPosEmbed(nn.Module):
sin_out = [] sin_out = []
pos = ids.float() pos = ids.float()
is_mps = ids.device.type == "mps" is_mps = ids.device.type == "mps"
freqs_dtype = torch.float32 if is_mps else torch.float64 is_npu = ids.device.type == "npu"
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
for i in range(n_axes): for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed( cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i], self.axes_dim[i],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment