Unverified Commit 896fb6d8 authored by ちくわぶ's avatar ちくわぶ Committed by GitHub
Browse files

Fix duplicate variable assignments in SD3's JointAttnProcessor (#8516)

* Fix duplicate variable assignments.

* Fix duplicate variable assignments.
parent 7f51f286
...@@ -1132,9 +1132,7 @@ class JointAttnProcessor2_0: ...@@ -1132,9 +1132,7 @@ class JointAttnProcessor2_0:
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
hidden_states = hidden_states = F.scaled_dot_product_attention( hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
query, key, value, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype) hidden_states = hidden_states.to(query.dtype)
...@@ -1406,7 +1404,6 @@ class XFormersAttnProcessor: ...@@ -1406,7 +1404,6 @@ class XFormersAttnProcessor:
class AttnProcessorNPU: class AttnProcessorNPU:
r""" r"""
Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If
fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is
......
...@@ -24,6 +24,7 @@ python utils/update_metadata.py ...@@ -24,6 +24,7 @@ python utils/update_metadata.py
Script modified from: Script modified from:
https://github.com/huggingface/transformers/blob/main/utils/update_metadata.py https://github.com/huggingface/transformers/blob/main/utils/update_metadata.py
""" """
import argparse import argparse
import os import os
import tempfile import tempfile
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment