Unverified Commit 94a0c644 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

add: a warning message when using xformers in a PT 2.0 env. (#3365)



* add: a warning message when using xformers in a PT 2.0 env.

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 26832aa5
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Callable, Optional, Union
import torch
......@@ -72,7 +73,8 @@ class Attention(nn.Module):
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.scale = dim_head**-0.5 if scale_qk else 1.0
self.scale_qk = scale_qk
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
self.heads = heads
# for slice_size > 0 the attention score computation
......@@ -140,7 +142,7 @@ class Attention(nn.Module):
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
if processor is None:
processor = (
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and scale_qk else AttnProcessor()
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
)
self.set_processor(processor)
......@@ -176,6 +178,11 @@ class Attention(nn.Module):
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
" only available for GPU "
)
elif hasattr(F, "scaled_dot_product_attention") and self.scale_qk:
warnings.warn(
"You have specified using flash attention using xFormers but you have PyTorch 2.0 already installed. "
"We will default to PyTorch's native efficient flash attention implementation provided by PyTorch 2.0."
)
else:
try:
# Make sure we can run the memory efficient attention
......@@ -229,7 +236,15 @@ class Attention(nn.Module):
if hasattr(self.processor, "to_k_custom_diffusion"):
processor.to(self.processor.to_k_custom_diffusion.weight.device)
else:
processor = AttnProcessor()
# set attention processor
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
processor = (
AttnProcessor2_0()
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
else AttnProcessor()
)
self.set_processor(processor)
......@@ -244,7 +259,13 @@ class Attention(nn.Module):
elif self.added_kv_proj_dim is not None:
processor = AttnAddedKVProcessor()
else:
processor = AttnProcessor()
# set attention processor
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
processor = (
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
)
self.set_processor(processor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment