Unverified Commit 896fb6d8 authored by ちくわぶ's avatar ちくわぶ Committed by GitHub
Browse files

Fix duplicate variable assignments in SD3's JointAttnProcessor (#8516)

* Fix duplicate variable assignments.

* Fix duplicate variable assignments.
parent 7f51f286
......@@ -1132,9 +1132,7 @@ class JointAttnProcessor2_0:
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
hidden_states = hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, is_causal=False
)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
......@@ -1406,7 +1404,6 @@ class XFormersAttnProcessor:
class AttnProcessorNPU:
r"""
Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If
fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is
......
......@@ -24,6 +24,7 @@ python utils/update_metadata.py
Script modified from:
https://github.com/huggingface/transformers/blob/main/utils/update_metadata.py
"""
import argparse
import os
import tempfile
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment