Unverified Commit b9e2f886 authored by asfiyab-nvidia's avatar asfiyab-nvidia Committed by GitHub
Browse files

FluxPosEmbed: Remove Squeeze No-op (#9409)



Remove Squeeze op
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent b19827f6
...@@ -690,7 +690,7 @@ class FluxPosEmbed(nn.Module): ...@@ -690,7 +690,7 @@ class FluxPosEmbed(nn.Module):
n_axes = ids.shape[-1] n_axes = ids.shape[-1]
cos_out = [] cos_out = []
sin_out = [] sin_out = []
pos = ids.squeeze().float() pos = ids.float()
is_mps = ids.device.type == "mps" is_mps = ids.device.type == "mps"
freqs_dtype = torch.float32 if is_mps else torch.float64 freqs_dtype = torch.float32 if is_mps else torch.float64
for i in range(n_axes): for i in range(n_axes):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment