Commit 96709588 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add low-memory attention (still needs to be incorporated)

parent c4d9f57f
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import Optional, Callable, List from typing import Optional, Callable, List, Tuple, Sequence
import numpy as np import numpy as np
import torch import torch
...@@ -24,6 +24,7 @@ from scipy.stats import truncnorm ...@@ -24,6 +24,7 @@ from scipy.stats import truncnorm
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
permute_final_dims, permute_final_dims,
flatten_final_dims, flatten_final_dims,
_chunk_slice,
) )
...@@ -217,7 +218,7 @@ class Attention(nn.Module): ...@@ -217,7 +218,7 @@ class Attention(nn.Module):
self.c_hidden * self.no_heads, self.c_q, init="final" self.c_hidden * self.no_heads, self.c_q, init="final"
) )
if self.gating is not None: if self.gating:
self.linear_g = Linear( self.linear_g = Linear(
self.c_q, self.c_hidden * self.no_heads, init="gating" self.c_q, self.c_hidden * self.no_heads, init="gating"
) )
...@@ -370,3 +371,176 @@ class GlobalAttention(nn.Module): ...@@ -370,3 +371,176 @@ class GlobalAttention(nn.Module):
m = self.linear_o(o) m = self.linear_o(o)
return m return m
@torch.jit.script
def _lma(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
biases: List[torch.Tensor],
q_chunk_size: int,
kv_chunk_size: int
):
no_q, no_kv = q.shape[-3], k.shape[-3]
# [*, Q, H, C_hidden]
o = q.new_zeros(q.shape)
for q_s in range(0, no_q, q_chunk_size):
q_chunk = q[..., q_s: q_s + q_chunk_size, :, :]
big_bias_chunks = [
b[..., q_s: q_s + q_chunk_size, :] for b in biases
]
maxes = []
weights = []
values = []
for kv_s in range(0, no_kv, kv_chunk_size):
k_chunk = k[..., kv_s: kv_s + kv_chunk_size, :, :]
v_chunk = v[..., kv_s: kv_s + kv_chunk_size, :, :]
small_bias_chunks = [
b[..., kv_s: kv_s + kv_chunk_size] for b in big_bias_chunks
]
a = torch.einsum(
"...qhd,...khd->...hqk", q_chunk, k_chunk
)
for b in small_bias_chunks:
a += b
a = a.transpose(-2, -3)
max_a = torch.max(a, dim=-1, keepdim=True)[0].detach()
exp_a = torch.exp(a - max_a)
exp_v = torch.einsum("...vhf,...qhv->...qhf", v_chunk, exp_a)
maxes.append(max_a.squeeze(-1))
weights.append(torch.sum(exp_a, dim=-1))
values.append(exp_v)
chunk_max = torch.stack(maxes, dim=-3)
chunk_weights = torch.stack(weights, dim=-3)
chunk_values = torch.stack(values, dim=-4)
global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0]
max_diffs = torch.exp(chunk_max - global_max)
chunk_values *= max_diffs.unsqueeze(-1)
chunk_weights *= max_diffs
all_values = torch.sum(chunk_values, dim=-4)
all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4)
q_chunk_out = all_values / all_weights
o[..., q_s: q_s + q_chunk_size, :, :] = q_chunk_out
return o
class LowMemoryAttention(nn.Module):
"""
Standard multi-head attention using AlphaFold's default layer
initialization. Allows multiple bias vectors. Implements Rabe and Staats'
low-memory self-attention algorithm.
"""
def __init__(
self,
c_q: int,
c_k: int,
c_v: int,
c_hidden: int,
no_heads: int,
gating: bool = True,
):
"""
Args:
c_q:
Input dimension of query data
c_k:
Input dimension of key data
c_v:
Input dimension of value data
c_hidden:
Per-head hidden dimension
no_heads:
Number of attention heads
gating:
Whether the output should be gated using query data
chunk_size:
Trades memory for better parallelization. A low value
corresponds to lower memory usage.
"""
super().__init__()
self.c_q = c_q
self.c_k = c_k
self.c_v = c_v
self.c_hidden = c_hidden
self.no_heads = no_heads
self.gating = gating
self.linear_q = Linear(
self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot"
)
self.linear_k = Linear(
self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot"
)
self.linear_v = Linear(
self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot"
)
self.linear_o = Linear(
self.c_hidden * self.no_heads, self.c_q, init="final"
)
if self.gating:
self.linear_g = Linear(
self.c_q, self.c_hidden * self.no_heads, init="gating"
)
self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim=-1)
def forward(self,
q_x: torch.Tensor,
k_x: torch.Tensor,
v_x: torch.Tensor,
q_chunk_size: int,
kv_chunk_size: int,
biases: Optional[List[torch.Tensor]] = None,
):
if(biases is None):
biases = []
else:
biases = [
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (k_x.shape[-2],))
for b in biases
]
# [*, Q/K/V, H * C_hidden]
q = self.linear_q(q_x)
k = self.linear_k(k_x)
v = self.linear_v(v_x)
# [*, Q/K, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1))
k = k.view(k.shape[:-1] + (self.no_heads, -1))
v = v.view(v.shape[:-1] + (self.no_heads, -1))
q = q / math.sqrt(q.shape[-1])
o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size)
if self.gating:
g = self.sigmoid(self.linear_g(q_x))
# [*, Q, H, C_hidden]
g = g.view(g.shape[:-1] + (self.no_heads, -1))
o = o * g
# [*, Q, H * C_hidden]
o = flatten_final_dims(o, 2)
# [*, Q, C_q]
o = self.linear_o(o)
return o
...@@ -124,6 +124,7 @@ def _fetch_dims(tree): ...@@ -124,6 +124,7 @@ def _fetch_dims(tree):
return shapes return shapes
@torch.jit.ignore
def _flat_idx_to_idx( def _flat_idx_to_idx(
flat_idx: int, flat_idx: int,
dims: Tuple[int], dims: Tuple[int],
...@@ -135,6 +136,8 @@ def _flat_idx_to_idx( ...@@ -135,6 +136,8 @@ def _flat_idx_to_idx(
return tuple(reversed(idx)) return tuple(reversed(idx))
@torch.jit.ignore
def _get_minimal_slice_set( def _get_minimal_slice_set(
start: Sequence[int], start: Sequence[int],
end: Sequence[int], end: Sequence[int],
...@@ -252,18 +255,19 @@ def _get_minimal_slice_set( ...@@ -252,18 +255,19 @@ def _get_minimal_slice_set(
return [tuple(s) for s in slices] return [tuple(s) for s in slices]
@torch.jit.ignore
def _chunk_slice( def _chunk_slice(
t: torch.Tensor, t: torch.Tensor,
flat_start: int, flat_start: int,
flat_end: int, flat_end: int,
no_batch_dims: int, no_batch_dims: int,
): ) -> torch.Tensor:
""" """
Equivalent to Equivalent to
t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end] t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end]
but without the need for the reshape call, which can be but without the need for the initial reshape call, which can be
memory-intensive in certain situations. The only reshape operations memory-intensive in certain situations. The only reshape operations
in this function are performed on sub-tensors that scale with in this function are performed on sub-tensors that scale with
(flat_end - flat_start), the chunk size. (flat_end - flat_start), the chunk size.
...@@ -281,7 +285,6 @@ def _chunk_slice( ...@@ -281,7 +285,6 @@ def _chunk_slice(
batch_dims, batch_dims,
) )
#
sliced_tensors = [t[s] for s in slices] sliced_tensors = [t[s] for s in slices]
return torch.cat( return torch.cat(
...@@ -352,7 +355,6 @@ def chunk_layer( ...@@ -352,7 +355,6 @@ def chunk_layer(
i = 0 i = 0
out = None out = None
for _ in range(no_chunks): for _ in range(no_chunks):
# Chunk the input # Chunk the input
if(not low_mem): if(not low_mem):
...@@ -382,7 +384,6 @@ def chunk_layer( ...@@ -382,7 +384,6 @@ def chunk_layer(
# Put the chunk in its pre-allocated space # Put the chunk in its pre-allocated space
out_type = type(output_chunk) out_type = type(output_chunk)
if out_type is dict: if out_type is dict:
def assign(d1, d2): def assign(d1, d2):
for k, v in d1.items(): for k, v in d1.items():
if type(v) is dict: if type(v) is dict:
......
#!/bin/bash #!/bin/bash
#CUDA_VISIBLE_DEVICES="5" CUDA_VISIBLE_DEVICES="0"
python3 -m unittest "$@" || \ python3 -m unittest "$@" || \
echo -e "\nTest(s) failed. Make sure you've installed all Python dependencies." echo -e "\nTest(s) failed. Make sure you've installed all Python dependencies."
import os import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4,"
import importlib import importlib
import pkgutil import pkgutil
import sys import sys
......
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import numpy as np
import unittest
from openfold.model.primitives import (
Attention,
LowMemoryAttention,
)
from tests.config import consts
class TestLMA(unittest.TestCase):
def test_lma_vs_attention(self):
batch_size = consts.batch_size
c_hidden = 32
n = 2**12
no_heads = 4
q = torch.rand(batch_size, n, c_hidden).cuda()
k = torch.rand(batch_size, n, c_hidden).cuda()
v = torch.rand(batch_size, n, c_hidden).cuda()
bias = [torch.rand(no_heads, 1, n)]
bias = [b.cuda() for b in bias]
gating_fill = torch.rand(c_hidden * no_heads, c_hidden)
o_fill = torch.rand(c_hidden, c_hidden * no_heads)
lma = LowMemoryAttention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
a = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
with torch.no_grad():
for n, p in lma.named_parameters():
attrs = n.split('.')
param = a
for attr in attrs:
param = getattr(param, attr)
param.copy_(p)
for m in [lma, a]:
m.linear_g.weight.copy_(gating_fill)
m.linear_o.weight.copy_(o_fill)
with torch.no_grad():
l = lma(q, k, v, 1024, 4096, biases=bias)
real = a(q, k, v, biases=bias)
self.assertTrue(torch.max(torch.abs(l - real)) < consts.eps)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment