Unverified Commit ccc83216 authored by ZhengKai91's avatar ZhengKai91 Committed by GitHub
Browse files

Fix aclnnRepeatInterleaveIntWithDim error on NPU for get_1d_rotary_pos_embed (#10820)



* get_1d_rotary_pos_embed support npu

* Update src/diffusers/models/embeddings.py

---------
Co-authored-by: default avatarKai zheng <kaizheng@KaideMacBook-Pro.local>
Co-authored-by: default avatarhlky <hlky@hlky.ac>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 5e48cd27
......@@ -1154,6 +1154,9 @@ def get_1d_rotary_pos_embed(
/ linear_factor
) # [D/2]
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
is_npu = freqs.device.type == "npu"
if is_npu:
freqs = freqs.float()
if use_real and repeat_interleave_real:
# flux, hunyuan-dit, cogvideox
freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment