Commit a9ae9148 authored by lijian6's avatar lijian6
Browse files

Modify env variable


Signed-off-by: lijian6's avatarlijian <lijian6@sugon.com>
parent 29c06f85
......@@ -170,8 +170,8 @@ class JointTransformerBlock(nn.Module):
)
# Attention.
use_xformers = os.getenv('USE_XFORMERS', '0')
if use_xformers == '1':
sd3_use_xformers = os.getenv('SD3_USE_XFORMERS', '0')
if sd3_use_xformers == '1':
self.attn.set_processor(self.processor)
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
......
......@@ -1129,8 +1129,8 @@ class JointAttnProcessor2_0:
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
use_xformers = os.getenv('USE_XFORMERS', '0')
if use_xformers == '1':
sd3_use_xformers = os.getenv('SD3_USE_XFORMERS', '0')
if sd3_use_xformers == '1':
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
......@@ -1352,6 +1352,8 @@ class XFormersAttnProcessor:
*args,
**kwargs,
) -> torch.Tensor:
hy_use_xformers = os.getenv('HY_USE_XFORMERS', '0')
if hy_use_xformers == '1':
from .embeddings import apply_rotary_emb
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
......@@ -1396,6 +1398,7 @@ class XFormersAttnProcessor:
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
if hy_use_xformers == '1':
### Add for HunYuanDit xformers ###
q_seq_len = query.shape[1]
k_seq_len = key.shape[1]
......
......@@ -868,8 +868,8 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
if XLA_AVAILABLE:
xm.mark_step()
use_xformers = os.getenv('USE_XFORMERS', '0')
if use_xformers == '1':
sd3_use_xformers = os.getenv('SD3_USE_XFORMERS', '0')
if sd3_use_xformers == '1':
self.disable_xformers_memory_efficient_attention()
if output_type == "latent":
image = latents
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment