@@ -1352,7 +1352,9 @@ class XFormersAttnProcessor:
...
@@ -1352,7 +1352,9 @@ class XFormersAttnProcessor:
*args,
*args,
**kwargs,
**kwargs,
)->torch.Tensor:
)->torch.Tensor:
from.embeddingsimportapply_rotary_emb
hy_use_xformers=os.getenv('HY_USE_XFORMERS','0')
ifhy_use_xformers=='1':
from.embeddingsimportapply_rotary_emb
iflen(args)>0orkwargs.get("scale",None)isnotNone:
iflen(args)>0orkwargs.get("scale",None)isnotNone:
deprecation_message="The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecation_message="The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale","1.0.0",deprecation_message)
deprecate("scale","1.0.0",deprecation_message)
...
@@ -1396,24 +1398,25 @@ class XFormersAttnProcessor:
...
@@ -1396,24 +1398,25 @@ class XFormersAttnProcessor: