@@ -1352,6 +1352,8 @@ class XFormersAttnProcessor:
...
@@ -1352,6 +1352,8 @@ class XFormersAttnProcessor:
*args,
*args,
**kwargs,
**kwargs,
)->torch.Tensor:
)->torch.Tensor:
hy_use_xformers=os.getenv('HY_USE_XFORMERS','0')
ifhy_use_xformers=='1':
from.embeddingsimportapply_rotary_emb
from.embeddingsimportapply_rotary_emb
iflen(args)>0orkwargs.get("scale",None)isnotNone:
iflen(args)>0orkwargs.get("scale",None)isnotNone:
deprecation_message="The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecation_message="The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
...
@@ -1396,6 +1398,7 @@ class XFormersAttnProcessor:
...
@@ -1396,6 +1398,7 @@ class XFormersAttnProcessor: