"examples/pytorch/vscode:/vscode.git/clone" did not exist on "9488bace208911b1bc1efb17774d8000fed1acac"
Unverified Commit 425192fe authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Make sure VAE attention works with Torch 2_0 (#3200)

* Make sure attention works with Torch 2_0

* make style

* Fix more
parent 9965cb50
......@@ -60,7 +60,6 @@ class AttentionBlock(nn.Module):
self.channels = channels
self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
self.num_head_size = num_head_channels
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
# define q,k,v as linear layers
......@@ -74,18 +73,25 @@ class AttentionBlock(nn.Module):
self._use_memory_efficient_attention_xformers = False
self._attention_op = None
def reshape_heads_to_batch_dim(self, tensor):
def reshape_heads_to_batch_dim(self, tensor, merge_head_and_batch=True):
batch_size, seq_len, dim = tensor.shape
head_size = self.num_heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3)
if merge_head_and_batch:
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
def reshape_batch_dim_to_heads(self, tensor, unmerge_head_and_batch=True):
head_size = self.num_heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
if unmerge_head_and_batch:
batch_size, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
else:
batch_size, _, seq_len, dim = tensor.shape
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size, seq_len, dim * head_size)
return tensor
def set_use_memory_efficient_attention_xformers(
......@@ -134,14 +140,25 @@ class AttentionBlock(nn.Module):
scale = 1 / math.sqrt(self.channels / self.num_heads)
query_proj = self.reshape_heads_to_batch_dim(query_proj)
key_proj = self.reshape_heads_to_batch_dim(key_proj)
value_proj = self.reshape_heads_to_batch_dim(value_proj)
use_torch_2_0_attn = (
hasattr(F, "scaled_dot_product_attention") and not self._use_memory_efficient_attention_xformers
)
query_proj = self.reshape_heads_to_batch_dim(query_proj, merge_head_and_batch=not use_torch_2_0_attn)
key_proj = self.reshape_heads_to_batch_dim(key_proj, merge_head_and_batch=not use_torch_2_0_attn)
value_proj = self.reshape_heads_to_batch_dim(value_proj, merge_head_and_batch=not use_torch_2_0_attn)
if self._use_memory_efficient_attention_xformers:
# Memory efficient attention
hidden_states = xformers.ops.memory_efficient_attention(
query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op
query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op, scale=scale
)
hidden_states = hidden_states.to(query_proj.dtype)
elif use_torch_2_0_attn:
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query_proj, key_proj, value_proj, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.to(query_proj.dtype)
else:
......@@ -162,7 +179,7 @@ class AttentionBlock(nn.Module):
hidden_states = torch.bmm(attention_probs, value_proj)
# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
hidden_states = self.reshape_batch_dim_to_heads(hidden_states, unmerge_head_and_batch=not use_torch_2_0_attn)
# compute next hidden_states
hidden_states = self.proj_attn(hidden_states)
......
......@@ -319,6 +319,40 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
@parameterized.expand([13, 16, 27])
@require_torch_gpu
def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed):
model = self.get_sd_vae_model(fp16=True)
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True)
with torch.no_grad():
sample = model.decode(encoding).sample
model.enable_xformers_memory_efficient_attention()
with torch.no_grad():
sample_2 = model.decode(encoding).sample
assert list(sample.shape) == [3, 3, 512, 512]
assert torch_all_close(sample, sample_2, atol=1e-1)
@parameterized.expand([13, 16, 37])
@require_torch_gpu
def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
model = self.get_sd_vae_model()
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
with torch.no_grad():
sample = model.decode(encoding).sample
model.enable_xformers_memory_efficient_attention()
with torch.no_grad():
sample_2 = model.decode(encoding).sample
assert list(sample.shape) == [3, 3, 512, 512]
assert torch_all_close(sample, sample_2, atol=1e-2)
@parameterized.expand(
[
# fmt: off
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment