Unverified Commit 2bbf8b67 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

simplyfy AttentionBlock (#1492)

parent 5a5bf7ef
...@@ -290,11 +290,19 @@ class AttentionBlock(nn.Module): ...@@ -290,11 +290,19 @@ class AttentionBlock(nn.Module):
self.rescale_output_factor = rescale_output_factor self.rescale_output_factor = rescale_output_factor
self.proj_attn = nn.Linear(channels, channels, 1) self.proj_attn = nn.Linear(channels, channels, 1)
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: def reshape_heads_to_batch_dim(self, tensor):
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) batch_size, seq_len, dim = tensor.shape
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) head_size = self.num_heads
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
return new_projection tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.num_heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def forward(self, hidden_states): def forward(self, hidden_states):
residual = hidden_states residual = hidden_states
...@@ -312,50 +320,28 @@ class AttentionBlock(nn.Module): ...@@ -312,50 +320,28 @@ class AttentionBlock(nn.Module):
scale = 1 / math.sqrt(self.channels / self.num_heads) scale = 1 / math.sqrt(self.channels / self.num_heads)
# get scores query_proj = self.reshape_heads_to_batch_dim(query_proj)
if self.num_heads > 1: key_proj = self.reshape_heads_to_batch_dim(key_proj)
query_states = self.transpose_for_scores(query_proj) value_proj = self.reshape_heads_to_batch_dim(value_proj)
key_states = self.transpose_for_scores(key_proj)
value_states = self.transpose_for_scores(value_proj)
# TODO: is there a way to perform batched matmul (e.g. baddbmm) on 4D tensors?
# or reformulate this into a 3D problem?
# TODO: measure whether on MPS device it would be faster to do this matmul via einsum
# as some matmuls can be 1.94x slower than an equivalent einsum on MPS
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * scale
else:
query_states, key_states, value_states = query_proj, key_proj, value_proj
attention_scores = torch.baddbmm( attention_scores = torch.baddbmm(
torch.empty( torch.empty(
query_states.shape[0], query_proj.shape[0],
query_states.shape[1], query_proj.shape[1],
key_states.shape[1], key_proj.shape[1],
dtype=query_states.dtype, dtype=query_proj.dtype,
device=query_states.device, device=query_proj.device,
), ),
query_states, query_proj,
key_states.transpose(-1, -2), key_proj.transpose(-1, -2),
beta=0, beta=0,
alpha=scale, alpha=scale,
) )
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
hidden_states = torch.bmm(attention_probs, value_proj)
# compute attention output # reshape hidden_states
if self.num_heads > 1: hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
# TODO: is there a way to perform batched matmul (e.g. bmm) on 4D tensors?
# or reformulate this into a 3D problem?
# TODO: measure whether on MPS device it would be faster to do this matmul via einsum
# as some matmuls can be 1.94x slower than an equivalent einsum on MPS
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
hidden_states = torch.matmul(attention_probs, value_states)
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
hidden_states = hidden_states.view(new_hidden_states_shape)
else:
hidden_states = torch.bmm(attention_probs, value_states)
# compute next hidden_states # compute next hidden_states
hidden_states = self.proj_attn(hidden_states) hidden_states = self.proj_attn(hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment