Unverified Commit 3429e62d authored by oahzxl's avatar oahzxl Committed by GitHub
Browse files

Optimize long sequence inference memory (#69)

* update ops

* finish on 1 gpu

* update bug in incom mul

* update async broadcast

* update income

* add attention chunk

* finish embed

* align evoformer

* update embed

* align evoformer

* update embed

* update template

* update template

* update extramsa

* fix a bug

* update outerproductmean

* remove useless class

* fix a bug when chunk is None

* fix bug when chunk is None
parent 6835c248
...@@ -11,6 +11,58 @@ from colossalai.core import global_context as gpc ...@@ -11,6 +11,58 @@ from colossalai.core import global_context as gpc
from .comm import _split, divide from .comm import _split, divide
def broadcast_sync(src: int, tensor: Tensor, host: bool = False) -> Tensor:
if gpc.get_world_size(ParallelMode.TENSOR) == 1:
return 0
if host:
dist.broadcast(tensor,
src=src,
group=gpc.get_group(ParallelMode.TENSOR),
async_op=False)
return 0
else:
output = torch.empty(list(tensor.shape), dtype=tensor.dtype, device=tensor.device)
dist.broadcast(output,
src=src,
group=gpc.get_group(ParallelMode.TENSOR),
async_op=False)
return output
def broadcast_async(src: int, tensor: Tensor, host: bool = False) -> Tensor:
if gpc.get_world_size(ParallelMode.TENSOR) == 1:
return 0
if host:
work = dist.broadcast(tensor,
src=src,
group=gpc.get_group(ParallelMode.TENSOR),
async_op=True)
return work
else:
work = dist.broadcast(tensor,
src=src,
group=gpc.get_group(ParallelMode.TENSOR),
async_op=True)
return work
def broadcast_async_opp(work) -> Tensor:
work.wait()
return 0
def get_rank():
return gpc.get_global_rank()
def get_world_size():
return gpc.get_world_size(ParallelMode.TENSOR)
def _gather_async(tensor: Tensor, dim: int = -1) -> Tensor: def _gather_async(tensor: Tensor, dim: int = -1) -> Tensor:
if gpc.get_world_size(ParallelMode.TENSOR) == 1: if gpc.get_world_size(ParallelMode.TENSOR) == 1:
return tensor, None return tensor, None
......
...@@ -21,9 +21,8 @@ from colossalai.context.parallel_mode import ParallelMode ...@@ -21,9 +21,8 @@ from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from fastfold.model.fastnn import MSAStack, OutProductMean, PairStack, ExtraMSAStack from fastfold.model.fastnn import MSAStack, OutProductMean, PairStack, ExtraMSAStack
from fastfold.model.fastnn.ops import Transition from fastfold.model.fastnn.ops import ChunkTransition, ChunkTriangleAttentionStartingNode, ChunkTriangleAttentionEndingNode, \
from fastfold.model.fastnn.triangle import TriangleAttentionEndingNode, TriangleAttentionStartingNode, \ AsyncChunkTriangleMultiplicationOutgoing, AsyncChunkTriangleMultiplicationIncoming
TriangleMultiplicationIncoming, TriangleMultiplicationOutgoing
from fastfold.distributed.comm import gather, scatter from fastfold.distributed.comm import gather, scatter
from fastfold.distributed.comm import col_to_row, row_to_col, scatter from fastfold.distributed.comm import col_to_row, row_to_col, scatter
from fastfold.distributed.comm_async import All_to_All_Async, All_to_All_Async_Opp from fastfold.distributed.comm_async import All_to_All_Async, All_to_All_Async_Opp
...@@ -76,12 +75,12 @@ class EvoformerBlock(nn.Module): ...@@ -76,12 +75,12 @@ class EvoformerBlock(nn.Module):
if not self.is_multimer: if not self.is_multimer:
m = self.msa_stack(m, z, msa_mask) m = self.msa_stack(m, z, msa_mask)
z = z + self.communication(m, msa_mask) z = self.communication(m, msa_mask, z)
m, work = All_to_All_Async.apply(m, 1, 2) m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask) z = self.pair_stack(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2) m = All_to_All_Async_Opp.apply(m, work, 1, 2)
else: else:
z = z + self.communication(m, msa_mask) z = self.communication(m, msa_mask, z)
z_ori = z z_ori = z
m, work = All_to_All_Async.apply(m, 1, 2) m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask) z = self.pair_stack(z, pair_mask)
...@@ -158,14 +157,13 @@ class ExtraMSABlock(nn.Module): ...@@ -158,14 +157,13 @@ class ExtraMSABlock(nn.Module):
if not self.is_multimer: if not self.is_multimer:
m = self.msa_stack(m, z, msa_mask) m = self.msa_stack(m, z, msa_mask)
z = self.communication(m, msa_mask, z)
z = z + self.communication(m, msa_mask)
m, work = All_to_All_Async.apply(m, 1, 2) m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask) z = self.pair_stack(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2) m = All_to_All_Async_Opp.apply(m, work, 1, 2)
else: else:
z = z + self.communication(m, msa_mask) z = self.communication(m, msa_mask, z)
z_ori = z z_ori = z
m, work = All_to_All_Async.apply(m, 1, 2) m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask) z = self.pair_stack(z, pair_mask)
...@@ -212,19 +210,19 @@ class TemplatePairStackBlock(nn.Module): ...@@ -212,19 +210,19 @@ class TemplatePairStackBlock(nn.Module):
self.p_drop = dropout_rate self.p_drop = dropout_rate
self.hidden_c = int(c_t / self.n_head) self.hidden_c = int(c_t / self.n_head)
self.TriangleMultiplicationOutgoing = TriangleMultiplicationOutgoing( self.TriangleMultiplicationOutgoing = AsyncChunkTriangleMultiplicationOutgoing(
self.c_t, p_drop=self.p_drop, c=self.c_hidden_tri_mul self.c_t, p_drop=self.p_drop, c=self.c_hidden_tri_mul
) )
self.TriangleMultiplicationIncoming = TriangleMultiplicationIncoming( self.TriangleMultiplicationIncoming = AsyncChunkTriangleMultiplicationIncoming(
self.c_t, p_drop=self.p_drop, c=self.c_hidden_tri_mul self.c_t, p_drop=self.p_drop, c=self.c_hidden_tri_mul
) )
self.TriangleAttentionStartingNode = TriangleAttentionStartingNode( self.TriangleAttentionStartingNode = ChunkTriangleAttentionStartingNode(
self.c_t, p_drop=self.p_drop, c=self.c_hidden_tri_att, n_head=self.n_head self.c_t, p_drop=self.p_drop, c=self.c_hidden_tri_att, n_head=self.n_head
) )
self.TriangleAttentionEndingNode = TriangleAttentionEndingNode( self.TriangleAttentionEndingNode = ChunkTriangleAttentionEndingNode(
self.c_t, p_drop=self.p_drop, c=self.c_hidden_tri_att, n_head=self.n_head self.c_t, p_drop=self.p_drop, c=self.c_hidden_tri_att, n_head=self.n_head
) )
self.PairTransition = Transition(d=self.c_t, n=pair_transition_n) self.PairTransition = ChunkTransition(d=self.c_t, n=pair_transition_n)
def forward( def forward(
self, self,
...@@ -245,12 +243,11 @@ class TemplatePairStackBlock(nn.Module): ...@@ -245,12 +243,11 @@ class TemplatePairStackBlock(nn.Module):
mask = torch.nn.functional.pad(mask, (0, padding_size, 0, padding_size)) mask = torch.nn.functional.pad(mask, (0, padding_size, 0, padding_size))
single_templates = [t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)] # single_templates = [t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)]
single_templates_masks = [m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)] # single_templates_masks = [m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)]
for i in range(z.shape[0]):
for i in range(len(single_templates)): single = z[i].unsqueeze(-4)
single = single_templates[i] single_mask = mask[i].unsqueeze(-3)
single_mask = single_templates_masks[i]
single_mask_row = scatter(single_mask, dim=1) single_mask_row = scatter(single_mask, dim=1)
single_mask_col = scatter(single_mask, dim=2) single_mask_col = scatter(single_mask, dim=2)
...@@ -264,10 +261,9 @@ class TemplatePairStackBlock(nn.Module): ...@@ -264,10 +261,9 @@ class TemplatePairStackBlock(nn.Module):
single = self.TriangleAttentionEndingNode(single, single_mask_col) single = self.TriangleAttentionEndingNode(single, single_mask_col)
single = self.PairTransition(single) single = self.PairTransition(single)
single = col_to_row(single) single = col_to_row(single)
z[i] = single
single_templates[i] = single # z = torch.cat(single_templates, dim=-4)
z = torch.cat(single_templates, dim=-4)
if self.last_block: if self.last_block:
z = gather(z, dim=1) z = gather(z, dim=1)
z = z[:, :-padding_size, :-padding_size, :] z = z[:, :-padding_size, :-padding_size, :]
......
...@@ -7,7 +7,7 @@ def bias_sigmod_ele(y, bias, z): ...@@ -7,7 +7,7 @@ def bias_sigmod_ele(y, bias, z):
return torch.sigmoid(y + bias) * z return torch.sigmoid(y + bias) * z
@torch.jit.script # @torch.jit.script
def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor, def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor,
residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
out = (x + bias) * F.dropout(dropmask, p=prob, training=training) out = (x + bias) * F.dropout(dropmask, p=prob, training=training)
......
...@@ -17,9 +17,10 @@ import math ...@@ -17,9 +17,10 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fastfold.model.fastnn.kernel import LayerNorm # from fastfold.model.fastnn.kernel import LayerNorm
from torch.nn import LayerNorm
from fastfold.model.fastnn.ops import Transition, SelfAttention, GlobalAttention from fastfold.model.fastnn.ops import ChunkMSARowAttentionWithPairBias, ChunkTransition, SelfAttention, GlobalAttention, Transition, ChunkMSAColumnGlobalAttention
from fastfold.model.fastnn.kernel import bias_dropout_add from fastfold.model.fastnn.kernel import bias_dropout_add
from fastfold.distributed import scatter, row_to_col from fastfold.distributed import scatter, row_to_col
from fastfold.distributed.comm_async import gather_async from fastfold.distributed.comm_async import gather_async
...@@ -150,12 +151,12 @@ class ExtraMSAStack(nn.Module): ...@@ -150,12 +151,12 @@ class ExtraMSAStack(nn.Module):
def __init__(self, d_node, d_pair, p_drop=0.15): def __init__(self, d_node, d_pair, p_drop=0.15):
super(ExtraMSAStack, self).__init__() super(ExtraMSAStack, self).__init__()
self.MSARowAttentionWithPairBias = MSARowAttentionWithPairBias( self.MSARowAttentionWithPairBias = ChunkMSARowAttentionWithPairBias(
d_node=d_node, d_pair=d_pair, p_drop=p_drop, c=8 d_node=d_node, d_pair=d_pair, p_drop=p_drop, c=8
) )
self.MSAColumnAttention = MSAColumnGlobalAttention(d_node=d_node, c=8) self.MSAColumnAttention = ChunkMSAColumnGlobalAttention(d_node=d_node, c=8)
self.MSATransition = Transition(d=d_node) self.MSATransition = ChunkTransition(d=d_node)
def forward(self, node, pair, node_mask): def forward(self, node, pair, node_mask):
node_mask_row = scatter(node_mask, dim=1) node_mask_row = scatter(node_mask, dim=1)
......
This diff is collapsed.
...@@ -5,7 +5,7 @@ import torch.nn as nn ...@@ -5,7 +5,7 @@ import torch.nn as nn
from fastfold.model.fastnn.kernel import LayerNorm from fastfold.model.fastnn.kernel import LayerNorm
from fastfold.distributed.comm import col_to_row, row_to_col, scatter from fastfold.distributed.comm import col_to_row, row_to_col, scatter
from fastfold.model.fastnn.kernel import bias_dropout_add, bias_ele_dropout_residual from fastfold.model.fastnn.kernel import bias_dropout_add, bias_ele_dropout_residual
from fastfold.model.fastnn.ops import Linear, SelfAttention, Transition from fastfold.model.fastnn.ops import Linear, SelfAttention, ChunkTransition, ChunkTriangleAttentionEndingNode, AsyncChunkTriangleMultiplicationOutgoing, AsyncChunkTriangleMultiplicationIncoming, ChunkTriangleAttentionStartingNode
from fastfold.distributed.comm_async import gather_async_opp, gather_async from fastfold.distributed.comm_async import gather_async_opp, gather_async
...@@ -218,21 +218,21 @@ class PairStack(nn.Module): ...@@ -218,21 +218,21 @@ class PairStack(nn.Module):
self.n_head = 4 self.n_head = 4
self.hidden_c = int(d_pair / self.n_head) self.hidden_c = int(d_pair / self.n_head)
self.TriangleMultiplicationOutgoing = TriangleMultiplicationOutgoing(d_pair, self.TriangleMultiplicationOutgoing = AsyncChunkTriangleMultiplicationOutgoing(d_pair,
p_drop=p_drop, p_drop=p_drop,
c=d_pair) c=d_pair)
self.TriangleMultiplicationIncoming = TriangleMultiplicationIncoming(d_pair, self.TriangleMultiplicationIncoming = AsyncChunkTriangleMultiplicationIncoming(d_pair,
p_drop=p_drop, p_drop=p_drop,
c=d_pair) c=d_pair)
self.TriangleAttentionStartingNode = TriangleAttentionStartingNode(d_pair, self.TriangleAttentionStartingNode = ChunkTriangleAttentionStartingNode(d_pair,
p_drop=p_drop, p_drop=p_drop,
c=self.hidden_c, c=self.hidden_c,
n_head=self.n_head) n_head=self.n_head)
self.TriangleAttentionEndingNode = TriangleAttentionEndingNode(d_pair, self.TriangleAttentionEndingNode = ChunkTriangleAttentionEndingNode(d_pair,
p_drop=p_drop, p_drop=p_drop,
c=self.hidden_c, c=self.hidden_c,
n_head=self.n_head) n_head=self.n_head)
self.PairTransition = Transition(d=d_pair) self.PairTransition = ChunkTransition(d=d_pair)
def forward(self, pair, pair_mask): def forward(self, pair, pair_mask):
pair_mask_row = scatter(pair_mask, dim=1) pair_mask_row = scatter(pair_mask, dim=1)
......
...@@ -230,7 +230,7 @@ class AlphaFold(nn.Module): ...@@ -230,7 +230,7 @@ class AlphaFold(nn.Module):
# m_1_prev_emb: [*, N, C_m] # m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z] # z_prev_emb: [*, N, N, C_z]
m_1_prev_emb, z_prev_emb = self.recycling_embedder( m_1_prev, z_prev = self.recycling_embedder(
m_1_prev, m_1_prev,
z_prev, z_prev,
x_prev, x_prev,
...@@ -241,17 +241,17 @@ class AlphaFold(nn.Module): ...@@ -241,17 +241,17 @@ class AlphaFold(nn.Module):
# conditionally to avoid leaving parameters unused, which has annoying # conditionally to avoid leaving parameters unused, which has annoying
# implications for DDP training. # implications for DDP training.
if(not _recycle): if(not _recycle):
m_1_prev_emb *= 0 m_1_prev *= 0
z_prev_emb *= 0 z_prev *= 0
# [*, S_c, N, C_m] # [*, S_c, N, C_m]
m[..., 0, :, :] += m_1_prev_emb m[..., 0, :, :] += m_1_prev
# [*, N, N, C_z] # [*, N, N, C_z]
z += z_prev_emb z += z_prev
# Possibly prevents memory fragmentation # Possibly prevents memory fragmentation
del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb del m_1_prev, z_prev, x_prev
# Embed the templates + merge with MSA/pair embeddings # Embed the templates + merge with MSA/pair embeddings
if self.config.template.enabled: if self.config.template.enabled:
...@@ -273,8 +273,10 @@ class AlphaFold(nn.Module): ...@@ -273,8 +273,10 @@ class AlphaFold(nn.Module):
feats["template_torsion_angles_mask"] = ( feats["template_torsion_angles_mask"] = (
template_embeds["template_mask"] template_embeds["template_mask"]
) )
# [*, N, N, C_z]
z = z + template_embeds["template_pair_embedding"]
else: else:
template_embeds = self.template_embedder( template_embeds, z = self.template_embedder(
template_feats, template_feats,
z, z,
pair_mask.to(dtype=z.dtype), pair_mask.to(dtype=z.dtype),
...@@ -282,9 +284,6 @@ class AlphaFold(nn.Module): ...@@ -282,9 +284,6 @@ class AlphaFold(nn.Module):
self.globals.chunk_size self.globals.chunk_size
) )
# [*, N, N, C_z]
z = z + template_embeds["template_pair_embedding"]
if( if(
self.config.template.embed_angles or self.config.template.embed_angles or
(self.globals.is_multimer and self.config.template.enabled) (self.globals.is_multimer and self.config.template.enabled)
...@@ -307,6 +306,7 @@ class AlphaFold(nn.Module): ...@@ -307,6 +306,7 @@ class AlphaFold(nn.Module):
[feats["msa_mask"], template_embeds["template_mask"]], [feats["msa_mask"], template_embeds["template_mask"]],
dim=-2, dim=-2,
) )
del template_feats, template_embeds, torsion_angles_mask
# Embed extra MSA features + merge with pairwise embeddings # Embed extra MSA features + merge with pairwise embeddings
if self.config.extra_msa.enabled: if self.config.extra_msa.enabled:
...@@ -314,7 +314,7 @@ class AlphaFold(nn.Module): ...@@ -314,7 +314,7 @@ class AlphaFold(nn.Module):
extra_msa_fn = data_transforms_multimer.build_extra_msa_feat extra_msa_fn = data_transforms_multimer.build_extra_msa_feat
else: else:
extra_msa_fn = build_extra_msa_feat extra_msa_fn = build_extra_msa_feat
# [*, S_e, N, C_e] # [*, S_e, N, C_e]
extra_msa_feat = extra_msa_fn(feats) extra_msa_feat = extra_msa_fn(feats)
extra_msa_feat = self.extra_msa_embedder(extra_msa_feat) extra_msa_feat = self.extra_msa_embedder(extra_msa_feat)
...@@ -328,6 +328,7 @@ class AlphaFold(nn.Module): ...@@ -328,6 +328,7 @@ class AlphaFold(nn.Module):
pair_mask=pair_mask.to(dtype=z.dtype), pair_mask=pair_mask.to(dtype=z.dtype),
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
) )
del extra_msa_feat, extra_msa_fn
# Run MSA + pair embeddings through the trunk of the network # Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m] # m: [*, S, N, C_m]
......
...@@ -25,6 +25,7 @@ from fastfold.utils.feats import ( ...@@ -25,6 +25,7 @@ from fastfold.utils.feats import (
) )
from fastfold.model.nn.primitives import Linear, LayerNorm from fastfold.model.nn.primitives import Linear, LayerNorm
from fastfold.utils.tensor_utils import one_hot from fastfold.utils.tensor_utils import one_hot
from fastfold.model.fastnn.ops import RecyclingEmbedder
from fastfold.model.nn.template import ( from fastfold.model.nn.template import (
TemplatePairStack, TemplatePairStack,
TemplatePointwiseAttention, TemplatePointwiseAttention,
...@@ -122,8 +123,8 @@ class InputEmbedder(nn.Module): ...@@ -122,8 +123,8 @@ class InputEmbedder(nn.Module):
tf_emb_j = self.linear_tf_z_j(tf) tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z] # [*, N_res, N_res, c_z]
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :] pair_emb = self.relpos(ri.type(tf_emb_i.dtype))
pair_emb = pair_emb + self.relpos(ri.type(pair_emb.dtype)) pair_emb += tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
# [*, N_clust, N_res, c_m] # [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3] n_clust = msa.shape[-3]
...@@ -137,101 +138,6 @@ class InputEmbedder(nn.Module): ...@@ -137,101 +138,6 @@ class InputEmbedder(nn.Module):
return msa_emb, pair_emb return msa_emb, pair_emb
class RecyclingEmbedder(nn.Module):
"""
Embeds the output of an iteration of the model for recycling.
Implements Algorithm 32.
"""
def __init__(
self,
c_m: int,
c_z: int,
min_bin: float,
max_bin: float,
no_bins: int,
inf: float = 1e8,
**kwargs,
):
"""
Args:
c_m:
MSA channel dimension
c_z:
Pair embedding channel dimension
min_bin:
Smallest distogram bin (Angstroms)
max_bin:
Largest distogram bin (Angstroms)
no_bins:
Number of distogram bins
"""
super(RecyclingEmbedder, self).__init__()
self.c_m = c_m
self.c_z = c_z
self.min_bin = min_bin
self.max_bin = max_bin
self.no_bins = no_bins
self.inf = inf
self.linear = Linear(self.no_bins, self.c_z)
self.layer_norm_m = LayerNorm(self.c_m)
self.layer_norm_z = LayerNorm(self.c_z)
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
x: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
m:
First row of the MSA embedding. [*, N_res, C_m]
z:
[*, N_res, N_res, C_z] pair embedding
x:
[*, N_res, 3] predicted C_beta coordinates
Returns:
m:
[*, N_res, C_m] MSA embedding update
z:
[*, N_res, N_res, C_z] pair embedding update
"""
bins = torch.linspace(
self.min_bin,
self.max_bin,
self.no_bins,
dtype=x.dtype,
device=x.device,
requires_grad=False,
)
# [*, N, C_m]
m_update = self.layer_norm_m(m)
# This squared method might become problematic in FP16 mode.
# I'm using it because my homegrown method had a stubborn discrepancy I
# couldn't find in time.
squared_bins = bins ** 2
upper = torch.cat(
[squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1
)
d = torch.sum(
(x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1, keepdims=True
)
# [*, N, N, no_bins]
d = ((d > squared_bins) * (d < upper)).type(x.dtype)
# [*, N, N, C_z]
d = self.linear(d)
z_update = d + self.layer_norm_z(z)
return m_update, z_update
class TemplateEmbedder(nn.Module): class TemplateEmbedder(nn.Module):
def __init__(self, config): def __init__(self, config):
super(TemplateEmbedder, self).__init__() super(TemplateEmbedder, self).__init__()
...@@ -261,6 +167,12 @@ class TemplateEmbedder(nn.Module): ...@@ -261,6 +167,12 @@ class TemplateEmbedder(nn.Module):
# Embed the templates one at a time (with a poor man's vmap) # Embed the templates one at a time (with a poor man's vmap)
template_embeds = [] template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim] n_templ = batch["template_aatype"].shape[templ_dim]
if isinstance(chunk_size, int) and 1 <= chunk_size <= 4:
t = torch.empty((n_templ, z.shape[0], z.shape[1], 64), dtype=z.dtype, device='cpu')
else:
t = torch.empty((n_templ, z.shape[0], z.shape[1], 64), dtype=z.dtype, device=z.device)
for i in range(n_templ): for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i) idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map( single_template_feats = tensor_tree_map(
...@@ -280,48 +192,50 @@ class TemplateEmbedder(nn.Module): ...@@ -280,48 +192,50 @@ class TemplateEmbedder(nn.Module):
single_template_embeds["angle"] = a single_template_embeds["angle"] = a
# [*, S_t, N, N, C_t] # [*, S_t, N, N, C_t]
t = build_template_pair_feat( tt = build_template_pair_feat(
single_template_feats, single_template_feats,
use_unit_vector=self.config.use_unit_vector, use_unit_vector=self.config.use_unit_vector,
inf=self.config.inf, inf=self.config.inf,
chunk=chunk_size,
eps=self.config.eps, eps=self.config.eps,
**self.config.distogram, **self.config.distogram,
).to(z.dtype) ).to(z.dtype).to(z.device)
t = self.template_pair_embedder(t)
single_template_embeds.update({"pair": t}) tt = self.template_pair_embedder(tt)
# single_template_embeds.update({"pair": t})
template_embeds.append(single_template_embeds) template_embeds.append(single_template_embeds)
# [*, S_t, N, N, C_z]
t[i] = self.template_pair_stack(
tt,
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=_mask_trans,
).to(t.device)
del tt, single_template_embeds, single_template_feats
template_embeds = dict_multimap( template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim), partial(torch.cat, dim=templ_dim),
template_embeds, template_embeds,
) )
# [*, S_t, N, N, C_z]
t = self.template_pair_stack(
template_embeds["pair"],
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
# [*, N, N, C_z] # [*, N, N, C_z]
t = self.template_pointwise_att( t = self.template_pointwise_att(
t, t.to(z.device),
z, z,
template_mask=batch["template_mask"].to(dtype=z.dtype), template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=chunk_size, chunk_size=chunk_size * 256 if chunk_size is not None else chunk_size,
) )
t = t * (torch.sum(batch["template_mask"]) > 0) t = t * (torch.sum(batch["template_mask"]) > 0)
ret = {} ret = {}
if self.config.embed_angles: if self.config.embed_angles:
ret["template_single_embedding"] = template_embeds["angle"] ret["template_single_embedding"] = template_embeds["angle"]
ret.update({"template_pair_embedding": t}) z += t
return ret return ret, z
class TemplateAngleEmbedder(nn.Module): class TemplateAngleEmbedder(nn.Module):
......
...@@ -353,7 +353,8 @@ class TemplatePairStack(nn.Module): ...@@ -353,7 +353,8 @@ class TemplatePairStack(nn.Module):
args=(t,), args=(t,),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None, blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
) )
if chunk_size is None:
t = self.layer_norm(t) chunk_size = t.shape[0]
for i in range(0, t.shape[0], chunk_size):
t[i:i + chunk_size] = self.layer_norm(t[i:i + chunk_size])
return t return t
...@@ -112,7 +112,12 @@ def build_template_pair_feat( ...@@ -112,7 +112,12 @@ def build_template_pair_feat(
use_unit_vector: bool = False, use_unit_vector: bool = False,
eps: float = 1e-20, eps: float = 1e-20,
inf: float = 1e8, inf: float = 1e8,
): chunk=None
):
if chunk and 1 <= chunk <= 4:
for k, v in batch.items():
batch[k] = v.cpu()
template_mask = batch["template_pseudo_beta_mask"] template_mask = batch["template_pseudo_beta_mask"]
template_mask_2d = template_mask[..., None] * template_mask[..., None, :] template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
...@@ -146,11 +151,13 @@ def build_template_pair_feat( ...@@ -146,11 +151,13 @@ def build_template_pair_feat(
) )
points = rigids.get_trans()[..., None, :, :] points = rigids.get_trans()[..., None, :, :]
rigid_vec = rigids[..., None].invert_apply(points) rigid_vec = rigids[..., None].invert_apply(points)
del rigids, points
inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec**2, dim=-1)) inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec**2, dim=-1))
t_aa_masks = batch["template_all_atom_mask"] t_aa_masks = batch["template_all_atom_mask"]
template_mask = t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c] template_mask = t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c]
del t_aa_masks, n, ca, c
template_mask_2d = template_mask[..., None] * template_mask[..., None, :] template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
inv_distance_scalar = inv_distance_scalar * template_mask_2d inv_distance_scalar = inv_distance_scalar * template_mask_2d
...@@ -161,6 +168,7 @@ def build_template_pair_feat( ...@@ -161,6 +168,7 @@ def build_template_pair_feat(
to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1)) to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1))
to_concat.append(template_mask_2d[..., None]) to_concat.append(template_mask_2d[..., None])
del unit_vector, rigid_vec, inv_distance_scalar
act = torch.cat(to_concat, dim=-1) act = torch.cat(to_concat, dim=-1)
act = act * template_mask_2d[..., None] act = act * template_mask_2d[..., None]
......
...@@ -99,6 +99,7 @@ def add_data_args(parser: argparse.ArgumentParser): ...@@ -99,6 +99,7 @@ def add_data_args(parser: argparse.ArgumentParser):
) )
parser.add_argument('--obsolete_pdbs_path', type=str, default=None) parser.add_argument('--obsolete_pdbs_path', type=str, default=None)
parser.add_argument('--release_dates_path', type=str, default=None) parser.add_argument('--release_dates_path', type=str, default=None)
parser.add_argument('--chunk_size', type=int, default=None)
parser.add_argument('--enable_workflow', default=False, action='store_true', help='run inference with ray workflow or not') parser.add_argument('--enable_workflow', default=False, action='store_true', help='run inference with ray workflow or not')
...@@ -110,6 +111,8 @@ def inference_model(rank, world_size, result_q, batch, args): ...@@ -110,6 +111,8 @@ def inference_model(rank, world_size, result_q, batch, args):
fastfold.distributed.init_dap() fastfold.distributed.init_dap()
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
config = model_config(args.model_name) config = model_config(args.model_name)
if args.chunk_size:
config.globals.chunk_size = args.chunk_size
model = AlphaFold(config) model = AlphaFold(config)
import_jax_weights_(model, args.param_path, version=args.model_name) import_jax_weights_(model, args.param_path, version=args.model_name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment