Unverified Commit 303efd2b authored by G.O.D's avatar G.O.D Committed by GitHub
Browse files

Improve pos embed for Flux.1 inference on Ascend NPU (#12534)



improve pos embed for ascend npu
Co-authored-by: default avatarfelix01.yu <felix01.yu@vipshop.com>
parent 5afbcce1
......@@ -22,7 +22,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
......@@ -717,7 +717,11 @@ class FluxTransformer2DModel(
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
if is_torch_npu_available():
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
else:
image_rotary_emb = self.pos_embed(ids)
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment