Commit a9ae9148 authored by lijian6's avatar lijian6
Browse files

Modify env variable


Signed-off-by: lijian6's avatarlijian <lijian6@sugon.com>
parent 29c06f85
...@@ -170,8 +170,8 @@ class JointTransformerBlock(nn.Module): ...@@ -170,8 +170,8 @@ class JointTransformerBlock(nn.Module):
) )
# Attention. # Attention.
use_xformers = os.getenv('USE_XFORMERS', '0') sd3_use_xformers = os.getenv('SD3_USE_XFORMERS', '0')
if use_xformers == '1': if sd3_use_xformers == '1':
self.attn.set_processor(self.processor) self.attn.set_processor(self.processor)
attn_output, context_attn_output = self.attn( attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
......
...@@ -1129,8 +1129,8 @@ class JointAttnProcessor2_0: ...@@ -1129,8 +1129,8 @@ class JointAttnProcessor2_0:
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1) key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1) value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
use_xformers = os.getenv('USE_XFORMERS', '0') sd3_use_xformers = os.getenv('SD3_USE_XFORMERS', '0')
if use_xformers == '1': if sd3_use_xformers == '1':
query = attn.head_to_batch_dim(query).contiguous() query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous() key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous() value = attn.head_to_batch_dim(value).contiguous()
...@@ -1352,6 +1352,8 @@ class XFormersAttnProcessor: ...@@ -1352,6 +1352,8 @@ class XFormersAttnProcessor:
*args, *args,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
hy_use_xformers = os.getenv('HY_USE_XFORMERS', '0')
if hy_use_xformers == '1':
from .embeddings import apply_rotary_emb from .embeddings import apply_rotary_emb
if len(args) > 0 or kwargs.get("scale", None) is not None: if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
...@@ -1396,6 +1398,7 @@ class XFormersAttnProcessor: ...@@ -1396,6 +1398,7 @@ class XFormersAttnProcessor:
key = attn.to_k(encoder_hidden_states) key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) value = attn.to_v(encoder_hidden_states)
if hy_use_xformers == '1':
### Add for HunYuanDit xformers ### ### Add for HunYuanDit xformers ###
q_seq_len = query.shape[1] q_seq_len = query.shape[1]
k_seq_len = key.shape[1] k_seq_len = key.shape[1]
......
...@@ -868,8 +868,8 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle ...@@ -868,8 +868,8 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
if XLA_AVAILABLE: if XLA_AVAILABLE:
xm.mark_step() xm.mark_step()
use_xformers = os.getenv('USE_XFORMERS', '0') sd3_use_xformers = os.getenv('SD3_USE_XFORMERS', '0')
if use_xformers == '1': if sd3_use_xformers == '1':
self.disable_xformers_memory_efficient_attention() self.disable_xformers_memory_efficient_attention()
if output_type == "latent": if output_type == "latent":
image = latents image = latents
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment