Unverified Commit 3429e62d authored by oahzxl's avatar oahzxl Committed by GitHub
Browse files

Optimize long sequence inference memory (#69)

* update ops

* finish on 1 gpu

* update bug in incom mul

* update async broadcast

* update income

* add attention chunk

* finish embed

* align evoformer

* update embed

* align evoformer

* update embed

* update template

* update template

* update extramsa

* fix a bug

* update outerproductmean

* remove useless class

* fix a bug when chunk is None

* fix bug when chunk is None
parent 6835c248
...@@ -11,6 +11,58 @@ from colossalai.core import global_context as gpc ...@@ -11,6 +11,58 @@ from colossalai.core import global_context as gpc
from .comm import _split, divide from .comm import _split, divide
def broadcast_sync(src: int, tensor: Tensor, host: bool = False) -> Tensor:
if gpc.get_world_size(ParallelMode.TENSOR) == 1:
return 0
if host:
dist.broadcast(tensor,
src=src,
group=gpc.get_group(ParallelMode.TENSOR),
async_op=False)
return 0
else:
output = torch.empty(list(tensor.shape), dtype=tensor.dtype, device=tensor.device)
dist.broadcast(output,
src=src,
group=gpc.get_group(ParallelMode.TENSOR),
async_op=False)
return output
def broadcast_async(src: int, tensor: Tensor, host: bool = False) -> Tensor:
if gpc.get_world_size(ParallelMode.TENSOR) == 1:
return 0
if host:
work = dist.broadcast(tensor,
src=src,
group=gpc.get_group(ParallelMode.TENSOR),
async_op=True)
return work
else:
work = dist.broadcast(tensor,
src=src,
group=gpc.get_group(ParallelMode.TENSOR),
async_op=True)
return work
def broadcast_async_opp(work) -> Tensor:
work.wait()
return 0
def get_rank():
return gpc.get_global_rank()
def get_world_size():
return gpc.get_world_size(ParallelMode.TENSOR)
def _gather_async(tensor: Tensor, dim: int = -1) -> Tensor: def _gather_async(tensor: Tensor, dim: int = -1) -> Tensor:
if gpc.get_world_size(ParallelMode.TENSOR) == 1: if gpc.get_world_size(ParallelMode.TENSOR) == 1:
return tensor, None return tensor, None
......
...@@ -21,9 +21,8 @@ from colossalai.context.parallel_mode import ParallelMode ...@@ -21,9 +21,8 @@ from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from fastfold.model.fastnn import MSAStack, OutProductMean, PairStack, ExtraMSAStack from fastfold.model.fastnn import MSAStack, OutProductMean, PairStack, ExtraMSAStack
from fastfold.model.fastnn.ops import Transition from fastfold.model.fastnn.ops import ChunkTransition, ChunkTriangleAttentionStartingNode, ChunkTriangleAttentionEndingNode, \
from fastfold.model.fastnn.triangle import TriangleAttentionEndingNode, TriangleAttentionStartingNode, \ AsyncChunkTriangleMultiplicationOutgoing, AsyncChunkTriangleMultiplicationIncoming
TriangleMultiplicationIncoming, TriangleMultiplicationOutgoing
from fastfold.distributed.comm import gather, scatter from fastfold.distributed.comm import gather, scatter
from fastfold.distributed.comm import col_to_row, row_to_col, scatter from fastfold.distributed.comm import col_to_row, row_to_col, scatter
from fastfold.distributed.comm_async import All_to_All_Async, All_to_All_Async_Opp from fastfold.distributed.comm_async import All_to_All_Async, All_to_All_Async_Opp
...@@ -76,12 +75,12 @@ class EvoformerBlock(nn.Module): ...@@ -76,12 +75,12 @@ class EvoformerBlock(nn.Module):
if not self.is_multimer: if not self.is_multimer:
m = self.msa_stack(m, z, msa_mask) m = self.msa_stack(m, z, msa_mask)
z = z + self.communication(m, msa_mask) z = self.communication(m, msa_mask, z)
m, work = All_to_All_Async.apply(m, 1, 2) m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask) z = self.pair_stack(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2) m = All_to_All_Async_Opp.apply(m, work, 1, 2)
else: else:
z = z + self.communication(m, msa_mask) z = self.communication(m, msa_mask, z)
z_ori = z z_ori = z
m, work = All_to_All_Async.apply(m, 1, 2) m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask) z = self.pair_stack(z, pair_mask)
...@@ -158,14 +157,13 @@ class ExtraMSABlock(nn.Module): ...@@ -158,14 +157,13 @@ class ExtraMSABlock(nn.Module):
if not self.is_multimer: if not self.is_multimer:
m = self.msa_stack(m, z, msa_mask) m = self.msa_stack(m, z, msa_mask)
z = self.communication(m, msa_mask, z)
z = z + self.communication(m, msa_mask)
m, work = All_to_All_Async.apply(m, 1, 2) m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask) z = self.pair_stack(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2) m = All_to_All_Async_Opp.apply(m, work, 1, 2)
else: else:
z = z + self.communication(m, msa_mask) z = self.communication(m, msa_mask, z)
z_ori = z z_ori = z
m, work = All_to_All_Async.apply(m, 1, 2) m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask) z = self.pair_stack(z, pair_mask)
...@@ -212,19 +210,19 @@ class TemplatePairStackBlock(nn.Module): ...@@ -212,19 +210,19 @@ class TemplatePairStackBlock(nn.Module):
self.p_drop = dropout_rate self.p_drop = dropout_rate
self.hidden_c = int(c_t / self.n_head) self.hidden_c = int(c_t / self.n_head)
self.TriangleMultiplicationOutgoing = TriangleMultiplicationOutgoing( self.TriangleMultiplicationOutgoing = AsyncChunkTriangleMultiplicationOutgoing(
self.c_t, p_drop=self.p_drop, c=self.c_hidden_tri_mul self.c_t, p_drop=self.p_drop, c=self.c_hidden_tri_mul
) )
self.TriangleMultiplicationIncoming = TriangleMultiplicationIncoming( self.TriangleMultiplicationIncoming = AsyncChunkTriangleMultiplicationIncoming(
self.c_t, p_drop=self.p_drop, c=self.c_hidden_tri_mul self.c_t, p_drop=self.p_drop, c=self.c_hidden_tri_mul
) )
self.TriangleAttentionStartingNode = TriangleAttentionStartingNode( self.TriangleAttentionStartingNode = ChunkTriangleAttentionStartingNode(
self.c_t, p_drop=self.p_drop, c=self.c_hidden_tri_att, n_head=self.n_head self.c_t, p_drop=self.p_drop, c=self.c_hidden_tri_att, n_head=self.n_head
) )
self.TriangleAttentionEndingNode = TriangleAttentionEndingNode( self.TriangleAttentionEndingNode = ChunkTriangleAttentionEndingNode(
self.c_t, p_drop=self.p_drop, c=self.c_hidden_tri_att, n_head=self.n_head self.c_t, p_drop=self.p_drop, c=self.c_hidden_tri_att, n_head=self.n_head
) )
self.PairTransition = Transition(d=self.c_t, n=pair_transition_n) self.PairTransition = ChunkTransition(d=self.c_t, n=pair_transition_n)
def forward( def forward(
self, self,
...@@ -245,12 +243,11 @@ class TemplatePairStackBlock(nn.Module): ...@@ -245,12 +243,11 @@ class TemplatePairStackBlock(nn.Module):
mask = torch.nn.functional.pad(mask, (0, padding_size, 0, padding_size)) mask = torch.nn.functional.pad(mask, (0, padding_size, 0, padding_size))
single_templates = [t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)] # single_templates = [t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)]
single_templates_masks = [m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)] # single_templates_masks = [m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)]
for i in range(z.shape[0]):
for i in range(len(single_templates)): single = z[i].unsqueeze(-4)
single = single_templates[i] single_mask = mask[i].unsqueeze(-3)
single_mask = single_templates_masks[i]
single_mask_row = scatter(single_mask, dim=1) single_mask_row = scatter(single_mask, dim=1)
single_mask_col = scatter(single_mask, dim=2) single_mask_col = scatter(single_mask, dim=2)
...@@ -264,10 +261,9 @@ class TemplatePairStackBlock(nn.Module): ...@@ -264,10 +261,9 @@ class TemplatePairStackBlock(nn.Module):
single = self.TriangleAttentionEndingNode(single, single_mask_col) single = self.TriangleAttentionEndingNode(single, single_mask_col)
single = self.PairTransition(single) single = self.PairTransition(single)
single = col_to_row(single) single = col_to_row(single)
z[i] = single
single_templates[i] = single # z = torch.cat(single_templates, dim=-4)
z = torch.cat(single_templates, dim=-4)
if self.last_block: if self.last_block:
z = gather(z, dim=1) z = gather(z, dim=1)
z = z[:, :-padding_size, :-padding_size, :] z = z[:, :-padding_size, :-padding_size, :]
......
...@@ -7,7 +7,7 @@ def bias_sigmod_ele(y, bias, z): ...@@ -7,7 +7,7 @@ def bias_sigmod_ele(y, bias, z):
return torch.sigmoid(y + bias) * z return torch.sigmoid(y + bias) * z
@torch.jit.script # @torch.jit.script
def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor, def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor,
residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
out = (x + bias) * F.dropout(dropmask, p=prob, training=training) out = (x + bias) * F.dropout(dropmask, p=prob, training=training)
......
...@@ -17,9 +17,10 @@ import math ...@@ -17,9 +17,10 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fastfold.model.fastnn.kernel import LayerNorm # from fastfold.model.fastnn.kernel import LayerNorm
from torch.nn import LayerNorm
from fastfold.model.fastnn.ops import Transition, SelfAttention, GlobalAttention from fastfold.model.fastnn.ops import ChunkMSARowAttentionWithPairBias, ChunkTransition, SelfAttention, GlobalAttention, Transition, ChunkMSAColumnGlobalAttention
from fastfold.model.fastnn.kernel import bias_dropout_add from fastfold.model.fastnn.kernel import bias_dropout_add
from fastfold.distributed import scatter, row_to_col from fastfold.distributed import scatter, row_to_col
from fastfold.distributed.comm_async import gather_async from fastfold.distributed.comm_async import gather_async
...@@ -150,12 +151,12 @@ class ExtraMSAStack(nn.Module): ...@@ -150,12 +151,12 @@ class ExtraMSAStack(nn.Module):
def __init__(self, d_node, d_pair, p_drop=0.15): def __init__(self, d_node, d_pair, p_drop=0.15):
super(ExtraMSAStack, self).__init__() super(ExtraMSAStack, self).__init__()
self.MSARowAttentionWithPairBias = MSARowAttentionWithPairBias( self.MSARowAttentionWithPairBias = ChunkMSARowAttentionWithPairBias(
d_node=d_node, d_pair=d_pair, p_drop=p_drop, c=8 d_node=d_node, d_pair=d_pair, p_drop=p_drop, c=8
) )
self.MSAColumnAttention = MSAColumnGlobalAttention(d_node=d_node, c=8) self.MSAColumnAttention = ChunkMSAColumnGlobalAttention(d_node=d_node, c=8)
self.MSATransition = Transition(d=d_node) self.MSATransition = ChunkTransition(d=d_node)
def forward(self, node, pair, node_mask): def forward(self, node, pair, node_mask):
node_mask_row = scatter(node_mask, dim=1) node_mask_row = scatter(node_mask, dim=1)
......
...@@ -15,15 +15,18 @@ ...@@ -15,15 +15,18 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import math
from einops import rearrange from einops import rearrange
from typing import Tuple
from fastfold.model.fastnn.kernel import mask_softmax, mask_bias_softmax from fastfold.model.fastnn.kernel import mask_softmax, mask_bias_softmax
from fastfold.model.fastnn.kernel import LayerNorm from fastfold.model.fastnn.kernel import LayerNorm
from .initializer import glorot_uniform_af from .initializer import glorot_uniform_af
from fastfold.model.fastnn.kernel import bias_sigmod_ele from fastfold.model.fastnn.kernel import bias_sigmod_ele, bias_ele_dropout_residual, bias_dropout_add
from fastfold.distributed import gather, scatter from fastfold.distributed import gather, scatter
from fastfold.distributed.comm_async import gather_async, gather_async_opp from fastfold.distributed.comm_async import gather_async, gather_async_opp, get_world_size, get_rank, broadcast_sync, broadcast_async, broadcast_async_opp
CHUNK_SIZE = None CHUNK_SIZE = None
...@@ -33,6 +36,11 @@ def set_chunk_size(chunk_size): ...@@ -33,6 +36,11 @@ def set_chunk_size(chunk_size):
CHUNK_SIZE = chunk_size CHUNK_SIZE = chunk_size
def get_chunk_size():
global CHUNK_SIZE
return CHUNK_SIZE
class DropoutRowwise(nn.Module): class DropoutRowwise(nn.Module):
def __init__(self, p): def __init__(self, p):
...@@ -73,6 +81,29 @@ class Transition(nn.Module): ...@@ -73,6 +81,29 @@ class Transition(nn.Module):
return src + x return src + x
class ChunkTransition(nn.Module):
def __init__(self, d, n=4):
super(ChunkTransition, self).__init__()
self.norm = LayerNorm(d)
self.linear1 = Linear(d, n * d, initializer='relu')
self.linear2 = Linear(n * d, d, initializer='zeros')
def forward(self, src):
para_dim = src.shape[1]
chunk_size = 48
if CHUNK_SIZE == None:
chunk_size = para_dim
out = torch.empty_like(src)
for ax in range(0, para_dim, chunk_size):
x = self.norm(src[:, ax:ax + chunk_size, :, :])
x = self.linear2(F.relu(self.linear1(x)))
out[:, ax:ax + chunk_size, :, :] = x
out.add_(src)
return out
class OutProductMean(nn.Module): class OutProductMean(nn.Module):
def __init__(self, n_feat=64, n_feat_out=128, n_feat_proj=32): def __init__(self, n_feat=64, n_feat_out=128, n_feat_proj=32):
...@@ -87,10 +118,9 @@ class OutProductMean(nn.Module): ...@@ -87,10 +118,9 @@ class OutProductMean(nn.Module):
initializer='zero', initializer='zero',
use_bias=True) use_bias=True)
def forward(self, M, M_mask): def forward(self, M, M_mask, Z_raw):
M = self.layernormM(M) M = self.layernormM(M)
right_act = self.linear_b(M) right_act = self.linear_b(M)
right_act_all, work = gather_async(right_act, dim=2) right_act_all, work = gather_async(right_act, dim=2)
# right_act_all = gather(right_act, dim=2) # right_act_all = gather(right_act, dim=2)
...@@ -98,10 +128,9 @@ class OutProductMean(nn.Module): ...@@ -98,10 +128,9 @@ class OutProductMean(nn.Module):
M_mask = M_mask.unsqueeze(-1) M_mask = M_mask.unsqueeze(-1)
M_mask_col = scatter(M_mask, dim=2) M_mask_col = scatter(M_mask, dim=2)
left_act = M_mask_col * left_act left_act = M_mask_col * left_act
norm = torch.einsum('bsid,bsjd->bijd', M_mask_col, M_mask) norm = torch.einsum('bsid,bsjd->bijd', M_mask_col, M_mask) + 1e-3
right_act_all = gather_async_opp(right_act_all, work, dim=2) right_act_all = gather_async_opp(right_act_all, work, dim=2)
right_act_all = M_mask * right_act_all right_act_all = M_mask * right_act_all
para_dim = left_act.shape[2] para_dim = left_act.shape[2]
...@@ -109,21 +138,15 @@ class OutProductMean(nn.Module): ...@@ -109,21 +138,15 @@ class OutProductMean(nn.Module):
if CHUNK_SIZE == None: if CHUNK_SIZE == None:
chunk_size = para_dim chunk_size = para_dim
out = []
for ax in range(0, para_dim, chunk_size): for ax in range(0, para_dim, chunk_size):
left_act_part = left_act[:, :, ax:ax + chunk_size, :] left_act_part = left_act[:, :, ax:ax + chunk_size, :]
O = torch.einsum('bsid,bsje->bijde', left_act_part, right_act_all) O = torch.einsum('bsid,bsje->bijde', left_act_part, right_act_all)
O = rearrange(O, 'b i j d e -> b i j (d e)') O = rearrange(O, 'b i j d e -> b i j (d e)')
O = self.o_linear(O)
norm0 = norm[:, ax:ax + chunk_size, :, :]
Z_raw[:, ax:ax + chunk_size, :, :] += O / norm0
out.append(self.o_linear(O)) return Z_raw
Z = torch.cat(out, dim=1)
Z /= (1e-3 + norm)
return Z
class Linear(nn.Linear): class Linear(nn.Linear):
...@@ -199,6 +222,9 @@ class SelfAttention(nn.Module): ...@@ -199,6 +222,9 @@ class SelfAttention(nn.Module):
chunk_size = para_dim chunk_size = para_dim
if nonbatched_bias is not None: if nonbatched_bias is not None:
if nonbatched_bias[-1] == -1:
bias = nonbatched_bias[0]
else:
# logits += nonbatched_bias.unsqueeze(1) # logits += nonbatched_bias.unsqueeze(1)
bias = gather_async_opp(*nonbatched_bias, dim=1) bias = gather_async_opp(*nonbatched_bias, dim=1)
bias = rearrange(bias, 'b q k h -> b h q k') bias = rearrange(bias, 'b q k h -> b h q k')
...@@ -235,6 +261,611 @@ class SelfAttention(nn.Module): ...@@ -235,6 +261,611 @@ class SelfAttention(nn.Module):
return output return output
def permute_final_dims(tensor, inds):
zero_index = -1 * len(inds)
first_inds = list(range(len(tensor.shape[:zero_index])))
return tensor.permute(first_inds + [zero_index + i for i in inds])
class AsyncChunkTriangleMultiplicationOutgoing(nn.Module):
def __init__(self, d_pair, p_drop, c=128):
super(AsyncChunkTriangleMultiplicationOutgoing, self).__init__()
self.d_pair = d_pair
self.c = c
self.layernorm1 = LayerNorm(d_pair)
self.left_right_projection = Linear(d_pair, 2 * c)
self.left_right_gate = Linear(d_pair, 2 * c, initializer='zeros', bias_init=1.)
self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.)
self.layernorm2 = LayerNorm(c)
self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False)
self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
self.p_drop = p_drop
def forward(self, Z_raw, Z_mask_row):
if CHUNK_SIZE == None:
Z = self.layernorm1(Z_raw)
left_right_proj_act = self.left_right_projection(Z)
left_right_proj_act = Z_mask_row.unsqueeze(-1) * left_right_proj_act
left_right_proj_act *= torch.sigmoid(self.left_right_gate(Z))
left_proj_act, right_proj_act = left_right_proj_act.chunk(2, dim=-1)
right_proj_act, work = gather_async(right_proj_act.contiguous(), dim=1)
g = torch.sigmoid(self.output_gate(Z))
left_proj_act = permute_final_dims(left_proj_act, (2, 0, 1))
right_proj_act = gather_async_opp(right_proj_act, work, dim=1)
p = torch.matmul(left_proj_act, permute_final_dims(right_proj_act, (2, 1, 0)),)
ab = permute_final_dims(p, (1, 2, 0))
ab = self.output_projection(self.layernorm2(ab))
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
return bias_ele_dropout_residual(ab,
self.output_bias,
g,
dropout_mask,
Z_raw,
prob=self.p_drop,
training=self.training)
para_dim = Z_raw.shape[1]
chunk_size = CHUNK_SIZE * 32
world_size = get_world_size()
rank = get_rank()
output = torch.empty_like(Z_raw)
for i in range(0, para_dim, chunk_size):
zi = Z_raw[:, i:i + chunk_size, :, :]
zi = self.layernorm1(zi)
gi = torch.sigmoid(self.left_right_gate(zi))
i_left_right_proj_act = self.left_right_projection(zi)
i_left_right_proj_act = Z_mask_row[:, i:i + chunk_size, :].unsqueeze(-1) * i_left_right_proj_act
i_left_right_proj_act *= gi
left_proj_act, _ = i_left_right_proj_act.chunk(2, dim=-1)
left_proj_act = permute_final_dims(left_proj_act, (2, 0, 1))
for j in range(0, para_dim, chunk_size):
zj = Z_raw[:, j:j + chunk_size, :, :]
zj = self.layernorm1(zj)
gj = torch.sigmoid(self.left_right_gate(zj))
j_left_right_proj_act = self.left_right_projection(zj)
j_left_right_proj_act = Z_mask_row[:, j:j + chunk_size, :].unsqueeze(-1) * j_left_right_proj_act
j_left_right_proj_act *= gj
_, right_proj_act = j_left_right_proj_act.chunk(2, dim=-1)
right_proj_act = right_proj_act.contiguous()
work = None
right_proj_act_tmp = torch.empty_like(right_proj_act)
for k in range(0, world_size):
if world_size > 1:
if work:
broadcast_async_opp(work) # collect last broadcast
if k != rank:
right_proj_act_rec = right_proj_act_tmp.clone()
else: # init first broadcast
if k == rank:
broadcast_sync(k, right_proj_act, host=True)
else:
right_proj_act_tmp = broadcast_sync(k, right_proj_act, host=False)
right_proj_act_rec = right_proj_act_tmp.clone()
if k + 1 != world_size: # launch next broadcast
if k + 1 == rank:
work = broadcast_async(k + 1, right_proj_act, host=True)
else:
work = broadcast_async(k + 1, right_proj_act_tmp, host=False)
if k == rank: # broadcast self right_proj_act
p = torch.matmul(
left_proj_act,
permute_final_dims(right_proj_act, (2, 1, 0)),
)
p = permute_final_dims(p, (1, 2, 0))
j_global = para_dim * k + j
output[:, i:i + chunk_size, j_global:min(j_global + chunk_size, para_dim * (k + 1)), :] = p
else: # receive others broadcast
p = torch.matmul(
left_proj_act,
permute_final_dims(right_proj_act_rec, (2, 1, 0)),
)
p = permute_final_dims(p, (1, 2, 0))
j_global = para_dim * k + j
output[:, i:i + chunk_size, j_global:min(j_global + chunk_size, para_dim * (k + 1)), :] = p
dropout_mask = torch.ones_like(Z_raw[:, 0:1, :, :], device=Z_raw.device, dtype=Z_raw.dtype)
for i in range(0, Z_raw.shape[1], chunk_size):
z_raw = Z_raw[:, i:i + chunk_size, :, :]
g = torch.sigmoid(self.output_gate(self.layernorm1(z_raw)))
z = output[:, i:i + chunk_size, :, :]
z = self.output_projection(self.layernorm2(z))
z = bias_ele_dropout_residual(z,
self.output_bias,
g,
dropout_mask,
z_raw,
prob=self.p_drop,
training=self.training)
output[:, i:i + chunk_size, :, :] = z
return output
class AsyncChunkTriangleMultiplicationIncoming(nn.Module):
def __init__(self, d_pair, p_drop, c=128):
super(AsyncChunkTriangleMultiplicationIncoming, self).__init__()
self.d_pair = d_pair
self.c = c
self.layernorm1 = LayerNorm(d_pair)
self.left_right_projection = Linear(d_pair, 2 * c)
self.left_right_gate = Linear(d_pair, 2 * c, initializer='zeros', bias_init=1.)
self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.)
self.layernorm2 = LayerNorm(c)
self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False)
self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
self.p_drop = p_drop
def forward(self, Z_raw, Z_mask_col):
if CHUNK_SIZE == None:
Z = self.layernorm1(Z_raw)
left_right_proj_act = self.left_right_projection(Z)
left_right_proj_act = Z_mask_col.unsqueeze(-1) * left_right_proj_act
left_right_proj_act *= torch.sigmoid(self.left_right_gate(Z))
left_proj_act, right_proj_act = left_right_proj_act.chunk(2, dim=-1)
left_proj_act, work = gather_async(left_proj_act.contiguous(), dim=2)
g = torch.sigmoid(self.output_gate(Z))
right_proj_act = permute_final_dims(right_proj_act, (2, 0, 1))
left_proj_act = gather_async_opp(left_proj_act, work, dim=2)
p = torch.matmul(permute_final_dims(left_proj_act, (2, 1, 0)), right_proj_act)
ab = permute_final_dims(p, (1, 2, 0))
ab = self.output_projection(self.layernorm2(ab))
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
return bias_ele_dropout_residual(ab,
self.output_bias,
g,
dropout_mask,
Z_raw,
prob=self.p_drop,
training=self.training)
para_dim = Z_raw.shape[2]
chunk_size = CHUNK_SIZE * 32
world_size = get_world_size()
rank = get_rank()
output = torch.empty_like(Z_raw)
for i in range(0, para_dim, chunk_size):
zi = Z_raw[:, :, i:i + chunk_size, :]
zi = self.layernorm1(zi)
gi = torch.sigmoid(self.left_right_gate(zi))
i_left_right_proj_act = self.left_right_projection(zi)
i_left_right_proj_act = Z_mask_col[:, :, i:i + chunk_size].unsqueeze(-1) * i_left_right_proj_act
i_left_right_proj_act *= gi
_, right_proj_act = i_left_right_proj_act.chunk(2, dim=-1)
right_proj_act = permute_final_dims(right_proj_act, (2, 0, 1))
for j in range(0, para_dim, chunk_size):
zj = Z_raw[:, :, j:j + chunk_size, :]
zj = self.layernorm1(zj)
gj = torch.sigmoid(self.left_right_gate(zj))
j_left_right_proj_act = self.left_right_projection(zj)
j_left_right_proj_act = Z_mask_col[:, :, j:j + chunk_size].unsqueeze(-1) * j_left_right_proj_act
j_left_right_proj_act *= gj
left_proj_act, _ = j_left_right_proj_act.chunk(2, dim=-1)
left_proj_act = left_proj_act.contiguous()
work = None
left_proj_act_tmp = torch.empty_like(left_proj_act)
for k in range(0, world_size):
if world_size > 1:
if work:
broadcast_async_opp(work) # collect last broadcast
if k != rank:
left_proj_act_rec = left_proj_act_tmp.clone()
else: # init first broadcast
if k == rank:
broadcast_sync(k, left_proj_act, host=True)
else:
left_proj_act_tmp = broadcast_sync(k, left_proj_act, host=False)
left_proj_act_rec = left_proj_act_tmp.clone()
if k + 1 != world_size: # launch next broadcast
if k + 1 == rank:
work = broadcast_async(k + 1, left_proj_act, host=True)
else:
work = broadcast_async(k + 1, left_proj_act_tmp, host=False)
if k == rank: # broadcast self proj_act
# left: [seq,chunkj,dim] => [dim,chunkj,seq]
# right: [seq,chunki,dim] => [dim,seq,chunki]
# p: [dim,chunkj,chunki] => [chunkj,chunki,dim]
p = torch.matmul(
permute_final_dims(left_proj_act, (2, 1, 0)),
right_proj_act
)
p = permute_final_dims(p, (1, 2, 0))
j_global = para_dim * k + j
output[:, j_global:min(j_global + chunk_size, para_dim * (k + 1)), i:i + chunk_size, :] = p
else: # receive others broadcast
p = torch.matmul(
permute_final_dims(left_proj_act_rec, (2, 1, 0)),
right_proj_act
)
p = permute_final_dims(p, (1, 2, 0))
j_global = para_dim * k + j
output[:, j_global:min(j_global + chunk_size, para_dim * (k + 1)), i:i + chunk_size, :] = p
dropout_mask = torch.ones_like(Z_raw[:, 0:1, :, :], device=Z_raw.device, dtype=Z_raw.dtype)
for i in range(0, Z_raw.shape[1], chunk_size):
z_raw = Z_raw[:, i:i + chunk_size, :, :]
g = torch.sigmoid(self.output_gate(self.layernorm1(z_raw)))
z = output[:, i:i + chunk_size, :, :]
z = self.output_projection(self.layernorm2(z))
z = bias_ele_dropout_residual(z,
self.output_bias,
g,
dropout_mask,
z_raw,
prob=self.p_drop,
training=self.training)
output[:, i:i + chunk_size, :, :] = z
return output
class ChunkTriangleAttentionStartingNode(nn.Module):
def __init__(self, d_pair, p_drop, c=32, n_head=4):
super(ChunkTriangleAttentionStartingNode, self).__init__()
self.d_pair = d_pair
self.c = c
self.n_head = n_head
self.p_drop = p_drop
self.layernorm1 = LayerNorm(d_pair)
# _init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]),
# std=1.0 / math.sqrt(d_pair))
# self.linear_b_weights = nn.parameter.Parameter(data=_init_weights)
self.linear_b = Linear(d_pair, n_head, initializer='linear', use_bias=False)
self.attention = SelfAttention(qkv_dim=d_pair,
c=c,
n_head=n_head,
out_dim=d_pair,
gating=True,
last_bias_fuse=True)
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
def forward(self, Z_raw, Z_mask):
if CHUNK_SIZE == None:
Z = self.layernorm1(Z_raw)
b = self.linear_b(Z)
b, work = gather_async(b, dim=1)
Z = self.attention(Z, Z_mask, (b, work))
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
return bias_dropout_add(Z,
self.out_bias,
dropout_mask,
Z_raw,
prob=self.p_drop,
training=self.training)
chunk_size = CHUNK_SIZE
para_dim = Z_raw.shape[1]
# z is big, but b is small. So we compute z in chunk to get b, and recompute z in chunk later instead of storing it
b = torch.empty((Z_raw.shape[0], Z_raw.shape[1], Z_raw.shape[2], self.n_head), device=Z_raw.device, dtype=Z_raw.dtype)
for i in range(0, para_dim, chunk_size):
z = self.layernorm1(Z_raw[:, i:i + chunk_size, :, :])
b[:, i:i + chunk_size, :, :] = self.linear_b(z)
b, work = gather_async(b, dim=1)
b = gather_async_opp(b, work, dim=1)
b = rearrange(b, 'b q k h -> b h q k')
output = torch.empty_like(Z_raw)
dropout_mask = torch.ones_like(z[:, 0:1, :, :], device=z.device, dtype=z.dtype)
for i in range(0, para_dim, chunk_size):
z_raw = Z_raw[:, i:i + chunk_size, :, :]
z = self.layernorm1(z_raw)
z_mask = Z_mask[:, i:i + chunk_size, :]
z = self.attention(z, z_mask, (b, -1))
z = bias_dropout_add(z,
self.out_bias,
dropout_mask,
z_raw,
prob=self.p_drop,
training=self.training)
output[:, i:i + chunk_size, :, :] = z
return output
class ChunkMSARowAttentionWithPairBias(nn.Module):
def __init__(self, d_node, d_pair, c=32, n_head=8, p_drop=0.15):
super(ChunkMSARowAttentionWithPairBias, self).__init__()
self.d_node = d_node
self.d_pair = d_pair
self.c = c
self.n_head = n_head
self.p_drop = p_drop
self.layernormM = LayerNorm(d_node)
self.layernormZ = LayerNorm(d_pair)
_init_weights = torch.nn.init.normal_(torch.zeros([n_head, d_pair]),
std=1.0 / math.sqrt(d_pair))
self.linear_b_weights = nn.parameter.Parameter(data=_init_weights, requires_grad=True)
self.attention = SelfAttention(qkv_dim=d_node,
c=c,
n_head=n_head,
out_dim=d_node,
gating=True,
last_bias_fuse=True)
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_node,)), requires_grad=True)
def forward(self, M_raw, Z, M_mask):
if CHUNK_SIZE == None:
## Input projections
M = self.layernormM(M_raw)
Z = self.layernormZ(Z)
b = F.linear(Z, self.linear_b_weights)
b, work = gather_async(b, dim=1)
# b = rearrange(b, 'b q k h -> b h q k')
# padding_bias = (1e9 * (M_mask - 1.))[:, :, None, None, :]
M = self.attention(M, M_mask, (b, work))
dropout_mask = torch.ones_like(M[:, 0:1, :, :], device=M.device, dtype=M.dtype)
return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop, training=self.training)
chunk_size = CHUNK_SIZE
para_dim_z = Z.shape[1]
para_dim_m = M_raw.shape[1]
# z is big, but b is small. So we compute z in chunk to get b, and recompute z in chunk later instead of storing it
b = torch.empty((Z.shape[0], Z.shape[1], Z.shape[2], self.n_head), device=Z.device, dtype=Z.dtype)
for i in range(0, para_dim_z, chunk_size):
z = self.layernormZ(Z[:, i:i + chunk_size, :, :])
b[:, i:i + chunk_size, :, :] = F.linear(z, self.linear_b_weights)
b, work = gather_async(b, dim=1)
b = gather_async_opp(b, work, dim=1)
b = rearrange(b, 'b q k h -> b h q k')
output = torch.empty_like(M_raw)
dropout_mask = torch.ones_like(M_raw[:, 0:1, :, :], device=M_raw.device, dtype=M_raw.dtype)
for i in range(0, para_dim_m, chunk_size):
m_raw = M_raw[:, i:i + chunk_size, :, :]
m = self.layernormM(m_raw)
m_mask = M_mask[:, i:i + chunk_size, :]
m = self.attention(m, m_mask, (b, -1))
m = bias_dropout_add(m,
self.out_bias,
dropout_mask,
m_raw,
prob=self.p_drop,
training=self.training)
output[:, i:i + chunk_size, :, :] = m
return output
class ChunkTriangleAttentionEndingNode(nn.Module):
def __init__(self, d_pair, p_drop, c=32, n_head=4):
super(ChunkTriangleAttentionEndingNode, self).__init__()
self.d_pair = d_pair
self.c = c
self.n_head = n_head
self.p_drop = p_drop
self.layernorm1 = LayerNorm(d_pair)
self.linear_b = Linear(d_pair, n_head, initializer='linear', use_bias=False)
self.attention = SelfAttention(qkv_dim=d_pair,
c=c,
n_head=n_head,
out_dim=d_pair,
gating=True,
last_bias_fuse=True)
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
def forward(self, Z_raw, Z_mask):
if CHUNK_SIZE == None:
Z = Z_raw.transpose(-2, -3)
Z_mask = Z_mask.transpose(-1, -2)
Z = self.layernorm1(Z)
b = self.linear_b(Z)
b, work = gather_async(b, dim=1)
Z = self.attention(Z, Z_mask, (b, work))
Z = Z.transpose(-2, -3)
dropout_mask = torch.ones_like(Z[:, :, 0:1, :], device=Z.device, dtype=Z.dtype)
return bias_dropout_add(Z,
self.out_bias,
dropout_mask,
Z_raw,
prob=self.p_drop,
training=self.training)
para_dim = Z_raw.shape[2]
chunk_size = CHUNK_SIZE
# z is big, but b is small. So we compute z in chunk to get b, and recompute z in chunk later instead of storing it
b = torch.empty((Z_raw.shape[0], Z_raw.shape[2], Z_raw.shape[1], self.n_head), device=Z_raw.device, dtype=Z_raw.dtype)
for i in range(0, para_dim, chunk_size):
z = Z_raw[:, :, i:i + chunk_size, :].transpose(-2, -3)
z = self.layernorm1(z)
b[:, i:i + chunk_size, :, :] = self.linear_b(z)
b, work = gather_async(b, dim=1)
b = gather_async_opp(b, work, dim=1)
b = rearrange(b, 'b q k h -> b h q k')
output = torch.empty_like(Z_raw)
dropout_mask = torch.ones_like(Z_raw[:, :, 0:1, :], device=z.device, dtype=z.dtype)
for i in range(0, para_dim, chunk_size):
z_raw = Z_raw[:, :, i:i + chunk_size, :]
z = self.layernorm1(z_raw.transpose(-2, -3))
z_mask = Z_mask[:, :, i:i + chunk_size].transpose(-1, -2)
z = self.attention(z, z_mask, (b, -1)).transpose(-2, -3)
z = bias_dropout_add(z,
self.out_bias,
dropout_mask,
z_raw,
prob=self.p_drop,
training=self.training)
output[:, :, i:i + chunk_size, :] = z
return output
class ChunkMSAColumnGlobalAttention(nn.Module):
def __init__(self, d_node, c=8, n_head=8):
super(ChunkMSAColumnGlobalAttention, self).__init__()
self.d_node = d_node
self.c = c
self.n_head = n_head
self.layernormM = LayerNorm(d_node)
self.global_attention = GlobalAttention(
qkv_dim=d_node, c=c, n_head=n_head, out_dim=d_node
)
def forward(self, M_raw, M_mask):
para_dim = M_raw.shape[2]
if CHUNK_SIZE is None:
chunk_size = para_dim
else:
chunk_size = CHUNK_SIZE
for i in range(0, para_dim, chunk_size):
m = M_raw[:, :, i:i + chunk_size, :].transpose(-2, -3)
m = self.layernormM(m)
m_mask = M_mask[:, :, i:i + chunk_size].transpose(-1, -2)
m = self.global_attention(m, m_mask)
m = m.transpose(-2, -3)
M_raw[:, :, i:i + chunk_size, :] += m
return M_raw
class RecyclingEmbedder(nn.Module):
"""
Embeds the output of an iteration of the model for recycling.
Implements Algorithm 32.
"""
def __init__(
self,
c_m: int,
c_z: int,
min_bin: float,
max_bin: float,
no_bins: int,
inf: float = 1e8,
**kwargs,
):
"""
Args:
c_m:
MSA channel dimension
c_z:
Pair embedding channel dimension
min_bin:
Smallest distogram bin (Angstroms)
max_bin:
Largest distogram bin (Angstroms)
no_bins:
Number of distogram bins
"""
super(RecyclingEmbedder, self).__init__()
self.c_m = c_m
self.c_z = c_z
self.min_bin = min_bin
self.max_bin = max_bin
self.no_bins = no_bins
self.inf = inf
self.linear = Linear(self.no_bins, self.c_z)
self.layer_norm_m = LayerNorm(self.c_m)
self.layer_norm_z = LayerNorm(self.c_z)
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
x: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
m:
First row of the MSA embedding. [*, N_res, C_m]
z:
[*, N_res, N_res, C_z] pair embedding
x:
[*, N_res, 3] predicted C_beta coordinates
Returns:
m:
[*, N_res, C_m] MSA embedding update
z:
[*, N_res, N_res, C_z] pair embedding update
"""
bins = torch.linspace(
self.min_bin,
self.max_bin,
self.no_bins,
dtype=x.dtype,
device=x.device,
requires_grad=False,
)
# [*, N, C_m]
m_update = self.layer_norm_m(m)
# This squared method might become problematic in FP16 mode.
# I'm using it because my homegrown method had a stubborn discrepancy I
# couldn't find in time.
squared_bins = bins ** 2
upper = torch.cat(
[squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1
)
d = torch.sum(
(x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1, keepdims=True
)
# [*, N, N, no_bins]
d = ((d > squared_bins) * (d < upper)).type(x.dtype)
# [*, N, N, C_z]
para_dim = d.shape[1]
if CHUNK_SIZE == None:
chunk_size = para_dim
else:
chunk_size = CHUNK_SIZE * 48
for i in range(0, para_dim, chunk_size):
di = self.linear(d[i:i + chunk_size, :, :])
z[i:i + chunk_size, :, :] = di + self.layer_norm_z(z[i:i + chunk_size, :, :])
return m_update, z
class GlobalAttention(nn.Module): class GlobalAttention(nn.Module):
""" """
Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors
......
...@@ -5,7 +5,7 @@ import torch.nn as nn ...@@ -5,7 +5,7 @@ import torch.nn as nn
from fastfold.model.fastnn.kernel import LayerNorm from fastfold.model.fastnn.kernel import LayerNorm
from fastfold.distributed.comm import col_to_row, row_to_col, scatter from fastfold.distributed.comm import col_to_row, row_to_col, scatter
from fastfold.model.fastnn.kernel import bias_dropout_add, bias_ele_dropout_residual from fastfold.model.fastnn.kernel import bias_dropout_add, bias_ele_dropout_residual
from fastfold.model.fastnn.ops import Linear, SelfAttention, Transition from fastfold.model.fastnn.ops import Linear, SelfAttention, ChunkTransition, ChunkTriangleAttentionEndingNode, AsyncChunkTriangleMultiplicationOutgoing, AsyncChunkTriangleMultiplicationIncoming, ChunkTriangleAttentionStartingNode
from fastfold.distributed.comm_async import gather_async_opp, gather_async from fastfold.distributed.comm_async import gather_async_opp, gather_async
...@@ -218,21 +218,21 @@ class PairStack(nn.Module): ...@@ -218,21 +218,21 @@ class PairStack(nn.Module):
self.n_head = 4 self.n_head = 4
self.hidden_c = int(d_pair / self.n_head) self.hidden_c = int(d_pair / self.n_head)
self.TriangleMultiplicationOutgoing = TriangleMultiplicationOutgoing(d_pair, self.TriangleMultiplicationOutgoing = AsyncChunkTriangleMultiplicationOutgoing(d_pair,
p_drop=p_drop, p_drop=p_drop,
c=d_pair) c=d_pair)
self.TriangleMultiplicationIncoming = TriangleMultiplicationIncoming(d_pair, self.TriangleMultiplicationIncoming = AsyncChunkTriangleMultiplicationIncoming(d_pair,
p_drop=p_drop, p_drop=p_drop,
c=d_pair) c=d_pair)
self.TriangleAttentionStartingNode = TriangleAttentionStartingNode(d_pair, self.TriangleAttentionStartingNode = ChunkTriangleAttentionStartingNode(d_pair,
p_drop=p_drop, p_drop=p_drop,
c=self.hidden_c, c=self.hidden_c,
n_head=self.n_head) n_head=self.n_head)
self.TriangleAttentionEndingNode = TriangleAttentionEndingNode(d_pair, self.TriangleAttentionEndingNode = ChunkTriangleAttentionEndingNode(d_pair,
p_drop=p_drop, p_drop=p_drop,
c=self.hidden_c, c=self.hidden_c,
n_head=self.n_head) n_head=self.n_head)
self.PairTransition = Transition(d=d_pair) self.PairTransition = ChunkTransition(d=d_pair)
def forward(self, pair, pair_mask): def forward(self, pair, pair_mask):
pair_mask_row = scatter(pair_mask, dim=1) pair_mask_row = scatter(pair_mask, dim=1)
......
...@@ -230,7 +230,7 @@ class AlphaFold(nn.Module): ...@@ -230,7 +230,7 @@ class AlphaFold(nn.Module):
# m_1_prev_emb: [*, N, C_m] # m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z] # z_prev_emb: [*, N, N, C_z]
m_1_prev_emb, z_prev_emb = self.recycling_embedder( m_1_prev, z_prev = self.recycling_embedder(
m_1_prev, m_1_prev,
z_prev, z_prev,
x_prev, x_prev,
...@@ -241,17 +241,17 @@ class AlphaFold(nn.Module): ...@@ -241,17 +241,17 @@ class AlphaFold(nn.Module):
# conditionally to avoid leaving parameters unused, which has annoying # conditionally to avoid leaving parameters unused, which has annoying
# implications for DDP training. # implications for DDP training.
if(not _recycle): if(not _recycle):
m_1_prev_emb *= 0 m_1_prev *= 0
z_prev_emb *= 0 z_prev *= 0
# [*, S_c, N, C_m] # [*, S_c, N, C_m]
m[..., 0, :, :] += m_1_prev_emb m[..., 0, :, :] += m_1_prev
# [*, N, N, C_z] # [*, N, N, C_z]
z += z_prev_emb z += z_prev
# Possibly prevents memory fragmentation # Possibly prevents memory fragmentation
del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb del m_1_prev, z_prev, x_prev
# Embed the templates + merge with MSA/pair embeddings # Embed the templates + merge with MSA/pair embeddings
if self.config.template.enabled: if self.config.template.enabled:
...@@ -273,8 +273,10 @@ class AlphaFold(nn.Module): ...@@ -273,8 +273,10 @@ class AlphaFold(nn.Module):
feats["template_torsion_angles_mask"] = ( feats["template_torsion_angles_mask"] = (
template_embeds["template_mask"] template_embeds["template_mask"]
) )
# [*, N, N, C_z]
z = z + template_embeds["template_pair_embedding"]
else: else:
template_embeds = self.template_embedder( template_embeds, z = self.template_embedder(
template_feats, template_feats,
z, z,
pair_mask.to(dtype=z.dtype), pair_mask.to(dtype=z.dtype),
...@@ -282,9 +284,6 @@ class AlphaFold(nn.Module): ...@@ -282,9 +284,6 @@ class AlphaFold(nn.Module):
self.globals.chunk_size self.globals.chunk_size
) )
# [*, N, N, C_z]
z = z + template_embeds["template_pair_embedding"]
if( if(
self.config.template.embed_angles or self.config.template.embed_angles or
(self.globals.is_multimer and self.config.template.enabled) (self.globals.is_multimer and self.config.template.enabled)
...@@ -307,6 +306,7 @@ class AlphaFold(nn.Module): ...@@ -307,6 +306,7 @@ class AlphaFold(nn.Module):
[feats["msa_mask"], template_embeds["template_mask"]], [feats["msa_mask"], template_embeds["template_mask"]],
dim=-2, dim=-2,
) )
del template_feats, template_embeds, torsion_angles_mask
# Embed extra MSA features + merge with pairwise embeddings # Embed extra MSA features + merge with pairwise embeddings
if self.config.extra_msa.enabled: if self.config.extra_msa.enabled:
...@@ -328,6 +328,7 @@ class AlphaFold(nn.Module): ...@@ -328,6 +328,7 @@ class AlphaFold(nn.Module):
pair_mask=pair_mask.to(dtype=z.dtype), pair_mask=pair_mask.to(dtype=z.dtype),
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
) )
del extra_msa_feat, extra_msa_fn
# Run MSA + pair embeddings through the trunk of the network # Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m] # m: [*, S, N, C_m]
......
...@@ -25,6 +25,7 @@ from fastfold.utils.feats import ( ...@@ -25,6 +25,7 @@ from fastfold.utils.feats import (
) )
from fastfold.model.nn.primitives import Linear, LayerNorm from fastfold.model.nn.primitives import Linear, LayerNorm
from fastfold.utils.tensor_utils import one_hot from fastfold.utils.tensor_utils import one_hot
from fastfold.model.fastnn.ops import RecyclingEmbedder
from fastfold.model.nn.template import ( from fastfold.model.nn.template import (
TemplatePairStack, TemplatePairStack,
TemplatePointwiseAttention, TemplatePointwiseAttention,
...@@ -122,8 +123,8 @@ class InputEmbedder(nn.Module): ...@@ -122,8 +123,8 @@ class InputEmbedder(nn.Module):
tf_emb_j = self.linear_tf_z_j(tf) tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z] # [*, N_res, N_res, c_z]
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :] pair_emb = self.relpos(ri.type(tf_emb_i.dtype))
pair_emb = pair_emb + self.relpos(ri.type(pair_emb.dtype)) pair_emb += tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
# [*, N_clust, N_res, c_m] # [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3] n_clust = msa.shape[-3]
...@@ -137,101 +138,6 @@ class InputEmbedder(nn.Module): ...@@ -137,101 +138,6 @@ class InputEmbedder(nn.Module):
return msa_emb, pair_emb return msa_emb, pair_emb
class RecyclingEmbedder(nn.Module):
"""
Embeds the output of an iteration of the model for recycling.
Implements Algorithm 32.
"""
def __init__(
self,
c_m: int,
c_z: int,
min_bin: float,
max_bin: float,
no_bins: int,
inf: float = 1e8,
**kwargs,
):
"""
Args:
c_m:
MSA channel dimension
c_z:
Pair embedding channel dimension
min_bin:
Smallest distogram bin (Angstroms)
max_bin:
Largest distogram bin (Angstroms)
no_bins:
Number of distogram bins
"""
super(RecyclingEmbedder, self).__init__()
self.c_m = c_m
self.c_z = c_z
self.min_bin = min_bin
self.max_bin = max_bin
self.no_bins = no_bins
self.inf = inf
self.linear = Linear(self.no_bins, self.c_z)
self.layer_norm_m = LayerNorm(self.c_m)
self.layer_norm_z = LayerNorm(self.c_z)
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
x: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
m:
First row of the MSA embedding. [*, N_res, C_m]
z:
[*, N_res, N_res, C_z] pair embedding
x:
[*, N_res, 3] predicted C_beta coordinates
Returns:
m:
[*, N_res, C_m] MSA embedding update
z:
[*, N_res, N_res, C_z] pair embedding update
"""
bins = torch.linspace(
self.min_bin,
self.max_bin,
self.no_bins,
dtype=x.dtype,
device=x.device,
requires_grad=False,
)
# [*, N, C_m]
m_update = self.layer_norm_m(m)
# This squared method might become problematic in FP16 mode.
# I'm using it because my homegrown method had a stubborn discrepancy I
# couldn't find in time.
squared_bins = bins ** 2
upper = torch.cat(
[squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1
)
d = torch.sum(
(x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1, keepdims=True
)
# [*, N, N, no_bins]
d = ((d > squared_bins) * (d < upper)).type(x.dtype)
# [*, N, N, C_z]
d = self.linear(d)
z_update = d + self.layer_norm_z(z)
return m_update, z_update
class TemplateEmbedder(nn.Module): class TemplateEmbedder(nn.Module):
def __init__(self, config): def __init__(self, config):
super(TemplateEmbedder, self).__init__() super(TemplateEmbedder, self).__init__()
...@@ -261,6 +167,12 @@ class TemplateEmbedder(nn.Module): ...@@ -261,6 +167,12 @@ class TemplateEmbedder(nn.Module):
# Embed the templates one at a time (with a poor man's vmap) # Embed the templates one at a time (with a poor man's vmap)
template_embeds = [] template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim] n_templ = batch["template_aatype"].shape[templ_dim]
if isinstance(chunk_size, int) and 1 <= chunk_size <= 4:
t = torch.empty((n_templ, z.shape[0], z.shape[1], 64), dtype=z.dtype, device='cpu')
else:
t = torch.empty((n_templ, z.shape[0], z.shape[1], 64), dtype=z.dtype, device=z.device)
for i in range(n_templ): for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i) idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map( single_template_feats = tensor_tree_map(
...@@ -280,48 +192,50 @@ class TemplateEmbedder(nn.Module): ...@@ -280,48 +192,50 @@ class TemplateEmbedder(nn.Module):
single_template_embeds["angle"] = a single_template_embeds["angle"] = a
# [*, S_t, N, N, C_t] # [*, S_t, N, N, C_t]
t = build_template_pair_feat( tt = build_template_pair_feat(
single_template_feats, single_template_feats,
use_unit_vector=self.config.use_unit_vector, use_unit_vector=self.config.use_unit_vector,
inf=self.config.inf, inf=self.config.inf,
chunk=chunk_size,
eps=self.config.eps, eps=self.config.eps,
**self.config.distogram, **self.config.distogram,
).to(z.dtype) ).to(z.dtype).to(z.device)
t = self.template_pair_embedder(t)
single_template_embeds.update({"pair": t}) tt = self.template_pair_embedder(tt)
# single_template_embeds.update({"pair": t})
template_embeds.append(single_template_embeds) template_embeds.append(single_template_embeds)
template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim),
template_embeds,
)
# [*, S_t, N, N, C_z] # [*, S_t, N, N, C_z]
t = self.template_pair_stack( t[i] = self.template_pair_stack(
template_embeds["pair"], tt,
pair_mask.unsqueeze(-3).to(dtype=z.dtype), pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size, chunk_size=chunk_size,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
).to(t.device)
del tt, single_template_embeds, single_template_feats
template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim),
template_embeds,
) )
# [*, N, N, C_z] # [*, N, N, C_z]
t = self.template_pointwise_att( t = self.template_pointwise_att(
t, t.to(z.device),
z, z,
template_mask=batch["template_mask"].to(dtype=z.dtype), template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=chunk_size, chunk_size=chunk_size * 256 if chunk_size is not None else chunk_size,
) )
t = t * (torch.sum(batch["template_mask"]) > 0) t = t * (torch.sum(batch["template_mask"]) > 0)
ret = {} ret = {}
if self.config.embed_angles: if self.config.embed_angles:
ret["template_single_embedding"] = template_embeds["angle"] ret["template_single_embedding"] = template_embeds["angle"]
ret.update({"template_pair_embedding": t}) z += t
return ret return ret, z
class TemplateAngleEmbedder(nn.Module): class TemplateAngleEmbedder(nn.Module):
......
...@@ -353,7 +353,8 @@ class TemplatePairStack(nn.Module): ...@@ -353,7 +353,8 @@ class TemplatePairStack(nn.Module):
args=(t,), args=(t,),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None, blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
) )
if chunk_size is None:
t = self.layer_norm(t) chunk_size = t.shape[0]
for i in range(0, t.shape[0], chunk_size):
t[i:i + chunk_size] = self.layer_norm(t[i:i + chunk_size])
return t return t
...@@ -112,7 +112,12 @@ def build_template_pair_feat( ...@@ -112,7 +112,12 @@ def build_template_pair_feat(
use_unit_vector: bool = False, use_unit_vector: bool = False,
eps: float = 1e-20, eps: float = 1e-20,
inf: float = 1e8, inf: float = 1e8,
chunk=None
): ):
if chunk and 1 <= chunk <= 4:
for k, v in batch.items():
batch[k] = v.cpu()
template_mask = batch["template_pseudo_beta_mask"] template_mask = batch["template_pseudo_beta_mask"]
template_mask_2d = template_mask[..., None] * template_mask[..., None, :] template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
...@@ -146,11 +151,13 @@ def build_template_pair_feat( ...@@ -146,11 +151,13 @@ def build_template_pair_feat(
) )
points = rigids.get_trans()[..., None, :, :] points = rigids.get_trans()[..., None, :, :]
rigid_vec = rigids[..., None].invert_apply(points) rigid_vec = rigids[..., None].invert_apply(points)
del rigids, points
inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec**2, dim=-1)) inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec**2, dim=-1))
t_aa_masks = batch["template_all_atom_mask"] t_aa_masks = batch["template_all_atom_mask"]
template_mask = t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c] template_mask = t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c]
del t_aa_masks, n, ca, c
template_mask_2d = template_mask[..., None] * template_mask[..., None, :] template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
inv_distance_scalar = inv_distance_scalar * template_mask_2d inv_distance_scalar = inv_distance_scalar * template_mask_2d
...@@ -161,6 +168,7 @@ def build_template_pair_feat( ...@@ -161,6 +168,7 @@ def build_template_pair_feat(
to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1)) to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1))
to_concat.append(template_mask_2d[..., None]) to_concat.append(template_mask_2d[..., None])
del unit_vector, rigid_vec, inv_distance_scalar
act = torch.cat(to_concat, dim=-1) act = torch.cat(to_concat, dim=-1)
act = act * template_mask_2d[..., None] act = act * template_mask_2d[..., None]
......
...@@ -99,6 +99,7 @@ def add_data_args(parser: argparse.ArgumentParser): ...@@ -99,6 +99,7 @@ def add_data_args(parser: argparse.ArgumentParser):
) )
parser.add_argument('--obsolete_pdbs_path', type=str, default=None) parser.add_argument('--obsolete_pdbs_path', type=str, default=None)
parser.add_argument('--release_dates_path', type=str, default=None) parser.add_argument('--release_dates_path', type=str, default=None)
parser.add_argument('--chunk_size', type=int, default=None)
parser.add_argument('--enable_workflow', default=False, action='store_true', help='run inference with ray workflow or not') parser.add_argument('--enable_workflow', default=False, action='store_true', help='run inference with ray workflow or not')
...@@ -110,6 +111,8 @@ def inference_model(rank, world_size, result_q, batch, args): ...@@ -110,6 +111,8 @@ def inference_model(rank, world_size, result_q, batch, args):
fastfold.distributed.init_dap() fastfold.distributed.init_dap()
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
config = model_config(args.model_name) config = model_config(args.model_name)
if args.chunk_size:
config.globals.chunk_size = args.chunk_size
model = AlphaFold(config) model = AlphaFold(config)
import_jax_weights_(model, args.param_path, version=args.model_name) import_jax_weights_(model, args.param_path, version=args.model_name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment