Unverified Commit 8e4733b3 authored by Ilmari Heikkinen's avatar Ilmari Heikkinen Committed by GitHub
Browse files

Only test for xformers when enabling them #1773 (#1776)

* only check for xformers when xformers are enabled

* only test for xformers when enabling them
parent 847daf25
...@@ -288,28 +288,29 @@ class AttentionBlock(nn.Module): ...@@ -288,28 +288,29 @@ class AttentionBlock(nn.Module):
self._use_memory_efficient_attention_xformers = False self._use_memory_efficient_attention_xformers = False
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
if not is_xformers_available(): if use_memory_efficient_attention_xformers:
raise ModuleNotFoundError( if not is_xformers_available():
"Refer to https://github.com/facebookresearch/xformers for more information on how to install" raise ModuleNotFoundError(
" xformers", "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
name="xformers", " xformers",
) name="xformers",
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
" available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
) )
except Exception as e: elif not torch.cuda.is_available():
raise e raise ValueError(
self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
" available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
except Exception as e:
raise e
self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
def reshape_heads_to_batch_dim(self, tensor): def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape batch_size, seq_len, dim = tensor.shape
...@@ -450,31 +451,32 @@ class BasicTransformerBlock(nn.Module): ...@@ -450,31 +451,32 @@ class BasicTransformerBlock(nn.Module):
self.norm3 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim)
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
if not is_xformers_available(): if use_memory_efficient_attention_xformers:
print("Here is how to install it") if not is_xformers_available():
raise ModuleNotFoundError( print("Here is how to install it")
"Refer to https://github.com/facebookresearch/xformers for more information on how to install" raise ModuleNotFoundError(
" xformers", "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
name="xformers", " xformers",
) name="xformers",
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
" available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
) )
except Exception as e: elif not torch.cuda.is_available():
raise e raise ValueError(
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
if self.attn2 is not None: " available for GPU "
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers )
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
except Exception as e:
raise e
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
if self.attn2 is not None:
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None): def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None):
# 1. Self-Attention # 1. Self-Attention
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment