Commit 96709588 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add low-memory attention (still needs to be incorporated)

parent c4d9f57f
......@@ -14,7 +14,7 @@
# limitations under the License.
import math
from typing import Optional, Callable, List
from typing import Optional, Callable, List, Tuple, Sequence
import numpy as np
import torch
......@@ -24,6 +24,7 @@ from scipy.stats import truncnorm
from openfold.utils.tensor_utils import (
permute_final_dims,
flatten_final_dims,
_chunk_slice,
)
......@@ -217,7 +218,7 @@ class Attention(nn.Module):
self.c_hidden * self.no_heads, self.c_q, init="final"
)
if self.gating is not None:
if self.gating:
self.linear_g = Linear(
self.c_q, self.c_hidden * self.no_heads, init="gating"
)
......@@ -370,3 +371,176 @@ class GlobalAttention(nn.Module):
m = self.linear_o(o)
return m
@torch.jit.script
def _lma(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
biases: List[torch.Tensor],
q_chunk_size: int,
kv_chunk_size: int
):
no_q, no_kv = q.shape[-3], k.shape[-3]
# [*, Q, H, C_hidden]
o = q.new_zeros(q.shape)
for q_s in range(0, no_q, q_chunk_size):
q_chunk = q[..., q_s: q_s + q_chunk_size, :, :]
big_bias_chunks = [
b[..., q_s: q_s + q_chunk_size, :] for b in biases
]
maxes = []
weights = []
values = []
for kv_s in range(0, no_kv, kv_chunk_size):
k_chunk = k[..., kv_s: kv_s + kv_chunk_size, :, :]
v_chunk = v[..., kv_s: kv_s + kv_chunk_size, :, :]
small_bias_chunks = [
b[..., kv_s: kv_s + kv_chunk_size] for b in big_bias_chunks
]
a = torch.einsum(
"...qhd,...khd->...hqk", q_chunk, k_chunk
)
for b in small_bias_chunks:
a += b
a = a.transpose(-2, -3)
max_a = torch.max(a, dim=-1, keepdim=True)[0].detach()
exp_a = torch.exp(a - max_a)
exp_v = torch.einsum("...vhf,...qhv->...qhf", v_chunk, exp_a)
maxes.append(max_a.squeeze(-1))
weights.append(torch.sum(exp_a, dim=-1))
values.append(exp_v)
chunk_max = torch.stack(maxes, dim=-3)
chunk_weights = torch.stack(weights, dim=-3)
chunk_values = torch.stack(values, dim=-4)
global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0]
max_diffs = torch.exp(chunk_max - global_max)
chunk_values *= max_diffs.unsqueeze(-1)
chunk_weights *= max_diffs
all_values = torch.sum(chunk_values, dim=-4)
all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4)
q_chunk_out = all_values / all_weights
o[..., q_s: q_s + q_chunk_size, :, :] = q_chunk_out
return o
class LowMemoryAttention(nn.Module):
"""
Standard multi-head attention using AlphaFold's default layer
initialization. Allows multiple bias vectors. Implements Rabe and Staats'
low-memory self-attention algorithm.
"""
def __init__(
self,
c_q: int,
c_k: int,
c_v: int,
c_hidden: int,
no_heads: int,
gating: bool = True,
):
"""
Args:
c_q:
Input dimension of query data
c_k:
Input dimension of key data
c_v:
Input dimension of value data
c_hidden:
Per-head hidden dimension
no_heads:
Number of attention heads
gating:
Whether the output should be gated using query data
chunk_size:
Trades memory for better parallelization. A low value
corresponds to lower memory usage.
"""
super().__init__()
self.c_q = c_q
self.c_k = c_k
self.c_v = c_v
self.c_hidden = c_hidden
self.no_heads = no_heads
self.gating = gating
self.linear_q = Linear(
self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot"
)
self.linear_k = Linear(
self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot"
)
self.linear_v = Linear(
self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot"
)
self.linear_o = Linear(
self.c_hidden * self.no_heads, self.c_q, init="final"
)
if self.gating:
self.linear_g = Linear(
self.c_q, self.c_hidden * self.no_heads, init="gating"
)
self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim=-1)
def forward(self,
q_x: torch.Tensor,
k_x: torch.Tensor,
v_x: torch.Tensor,
q_chunk_size: int,
kv_chunk_size: int,
biases: Optional[List[torch.Tensor]] = None,
):
if(biases is None):
biases = []
else:
biases = [
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (k_x.shape[-2],))
for b in biases
]
# [*, Q/K/V, H * C_hidden]
q = self.linear_q(q_x)
k = self.linear_k(k_x)
v = self.linear_v(v_x)
# [*, Q/K, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1))
k = k.view(k.shape[:-1] + (self.no_heads, -1))
v = v.view(v.shape[:-1] + (self.no_heads, -1))
q = q / math.sqrt(q.shape[-1])
o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size)
if self.gating:
g = self.sigmoid(self.linear_g(q_x))
# [*, Q, H, C_hidden]
g = g.view(g.shape[:-1] + (self.no_heads, -1))
o = o * g
# [*, Q, H * C_hidden]
o = flatten_final_dims(o, 2)
# [*, Q, C_q]
o = self.linear_o(o)
return o
......@@ -124,6 +124,7 @@ def _fetch_dims(tree):
return shapes
@torch.jit.ignore
def _flat_idx_to_idx(
flat_idx: int,
dims: Tuple[int],
......@@ -135,6 +136,8 @@ def _flat_idx_to_idx(
return tuple(reversed(idx))
@torch.jit.ignore
def _get_minimal_slice_set(
start: Sequence[int],
end: Sequence[int],
......@@ -252,18 +255,19 @@ def _get_minimal_slice_set(
return [tuple(s) for s in slices]
@torch.jit.ignore
def _chunk_slice(
t: torch.Tensor,
flat_start: int,
flat_end: int,
no_batch_dims: int,
):
) -> torch.Tensor:
"""
Equivalent to
t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end]
but without the need for the reshape call, which can be
but without the need for the initial reshape call, which can be
memory-intensive in certain situations. The only reshape operations
in this function are performed on sub-tensors that scale with
(flat_end - flat_start), the chunk size.
......@@ -281,7 +285,6 @@ def _chunk_slice(
batch_dims,
)
#
sliced_tensors = [t[s] for s in slices]
return torch.cat(
......@@ -352,7 +355,6 @@ def chunk_layer(
i = 0
out = None
for _ in range(no_chunks):
# Chunk the input
if(not low_mem):
......@@ -382,7 +384,6 @@ def chunk_layer(
# Put the chunk in its pre-allocated space
out_type = type(output_chunk)
if out_type is dict:
def assign(d1, d2):
for k, v in d1.items():
if type(v) is dict:
......
#!/bin/bash
#CUDA_VISIBLE_DEVICES="5"
CUDA_VISIBLE_DEVICES="0"
python3 -m unittest "$@" || \
echo -e "\nTest(s) failed. Make sure you've installed all Python dependencies."
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4,"
import importlib
import pkgutil
import sys
......
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import numpy as np
import unittest
from openfold.model.primitives import (
Attention,
LowMemoryAttention,
)
from tests.config import consts
class TestLMA(unittest.TestCase):
def test_lma_vs_attention(self):
batch_size = consts.batch_size
c_hidden = 32
n = 2**12
no_heads = 4
q = torch.rand(batch_size, n, c_hidden).cuda()
k = torch.rand(batch_size, n, c_hidden).cuda()
v = torch.rand(batch_size, n, c_hidden).cuda()
bias = [torch.rand(no_heads, 1, n)]
bias = [b.cuda() for b in bias]
gating_fill = torch.rand(c_hidden * no_heads, c_hidden)
o_fill = torch.rand(c_hidden, c_hidden * no_heads)
lma = LowMemoryAttention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
a = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
with torch.no_grad():
for n, p in lma.named_parameters():
attrs = n.split('.')
param = a
for attr in attrs:
param = getattr(param, attr)
param.copy_(p)
for m in [lma, a]:
m.linear_g.weight.copy_(gating_fill)
m.linear_o.weight.copy_(o_fill)
with torch.no_grad():
l = lma(q, k, v, 1024, 4096, biases=bias)
real = a(q, k, v, biases=bias)
self.assertTrue(torch.max(torch.abs(l - real)) < consts.eps)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment