Unverified Commit 8a599895 authored by oahzxl's avatar oahzxl Committed by GitHub
Browse files

Re-organize file and align results (#99)

* Move template modification from nn to fastnn (#97)

* update out product mean test

* polish out product mean test and update linear test

* update test for layernorm, use both kernel

* support test without triton

* polish layernorm

* use current code

* move template modification to fastnn, restore template in nn

* re-organize fastnn

* update evoformer unit test

* Align results and update unit tests (#98)

* update out product mean test

* polish out product mean test and update linear test

* update test for layernorm, use both kernel

* support test without triton

* polish layernorm

* use current code

* move template modification to fastnn, restore template in nn

* re-organize fastnn

* update evoformer unit test

* update evoformer stack test

* update test

* update msa_att_row

* update msa_att_col

* update evoformer and evo-stack

* update evoformer

* update extramsa

* move model loading out of the loop

* finish template test

* update test

* Move template modification from nn to fastnn (#84)

* update out product mean test

* polish out product mean test and update linear test

* update test for layernorm, use both kernel

* support test without triton

* polish layernorm

* use current code

* move template modification to fastnn, restore template in nn

* re-organize fastnn

* update evoformer unit test

* move model out of function

* only test inference

* remove cache in build

* update test inference

* restore changes

* restore build changes

* update inference and evoformer stack

* fix some bug

* update test

* update evoformer stack test

* update test

* update test

* fix test

* update test

* update test

* update input embedder

* update embedder

* reset core

* update test

* support template multimer in inject_nn
parent 7a69a181
from .msa import MSAStack, ExtraMSAStack from .msa import MSACore, ExtraMSACore, ExtraMSABlock, ExtraMSAStack
from .ops import OutProductMean, set_chunk_size from .ops import OutProductMean, set_chunk_size
from .triangle import PairStack from .triangle import PairCore
from .evoformer import Evoformer from .evoformer import Evoformer, EvoformerStack
from .blocks import EvoformerBlock, ExtraMSABlock, TemplatePairStackBlock from .template import TemplatePairBlock, TemplatePairStack
__all__ = ['MSAStack', 'ExtraMSAStack', 'OutProductMean', 'PairStack', 'Evoformer',
'set_chunk_size', 'EvoformerBlock', 'ExtraMSABlock', 'TemplatePairStackBlock'] __all__ = [
'MSACore', 'OutProductMean', 'PairCore', 'set_chunk_size',
'TemplatePairBlock', 'TemplatePairStack',
'ExtraMSACore', 'ExtraMSABlock', 'ExtraMSAStack',
'Evoformer', 'EvoformerStack',
]
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from typing import Tuple
from functools import partial
from fastfold.utils.feats import (
build_template_angle_feat,
build_template_pair_feat,
)
from fastfold.model.fastnn.ops import Linear
from fastfold.utils.tensor_utils import one_hot
from fastfold.model.fastnn.template import (
TemplatePairStack,
TemplatePointwiseAttention,
)
from fastfold.utils.tensor_utils import one_hot, tensor_tree_map, dict_multimap
class InputEmbedder(nn.Module):
"""
Embeds a subset of the input features.
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
def __init__(
self,
tf_dim: int,
msa_dim: int,
c_z: int,
c_m: int,
relpos_k: int,
**kwargs,
):
"""
Args:
tf_dim:
Final dimension of the target features
msa_dim:
Final dimension of the MSA features
c_z:
Pair embedding dimension
c_m:
MSA embedding dimension
relpos_k:
Window size used in relative positional encoding
"""
super(InputEmbedder, self).__init__()
self.tf_dim = tf_dim
self.msa_dim = msa_dim
self.c_z = c_z
self.c_m = c_m
self.linear_tf_z_i = Linear(tf_dim, c_z)
self.linear_tf_z_j = Linear(tf_dim, c_z)
self.linear_tf_m = Linear(tf_dim, c_m)
self.linear_msa_m = Linear(msa_dim, c_m)
# RPE stuff
self.relpos_k = relpos_k
self.no_bins = 2 * relpos_k + 1
self.linear_relpos = Linear(self.no_bins, c_z)
def relpos(self, ri: torch.Tensor):
"""
Computes relative positional encodings
Implements Algorithm 4.
Args:
ri:
"residue_index" features of shape [*, N]
"""
d = ri[..., None] - ri[..., None, :]
boundaries = torch.arange(
start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
)
oh = one_hot(d, boundaries).type(ri.dtype)
return self.linear_relpos(oh)
def forward(
self,
tf: torch.Tensor,
ri: torch.Tensor,
msa: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
tf:
"target_feat" features of shape [*, N_res, tf_dim]
ri:
"residue_index" features of shape [*, N_res]
msa:
"msa_feat" features of shape [*, N_clust, N_res, msa_dim]
Returns:
msa_emb:
[*, N_clust, N_res, C_m] MSA embedding
pair_emb:
[*, N_res, N_res, C_z] pair embedding
"""
# [*, N_res, c_z]
tf_emb_i = self.linear_tf_z_i(tf)
tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z]
pair_emb = self.relpos(ri.type(tf_emb_i.dtype))
pair_emb += tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
# [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3]
tf_m = (
self.linear_tf_m(tf)
.unsqueeze(-3)
.expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))
)
msa_emb = self.linear_msa_m(msa) + tf_m
return msa_emb, pair_emb
class TemplateEmbedder(nn.Module):
def __init__(self, config):
super(TemplateEmbedder, self).__init__()
self.config = config
self.template_angle_embedder = TemplateAngleEmbedder(
**config["template_angle_embedder"],
)
self.template_pair_embedder = TemplatePairEmbedder(
**config["template_pair_embedder"],
)
self.template_pair_stack = TemplatePairStack(
**config["template_pair_stack"],
)
self.template_pointwise_att = TemplatePointwiseAttention(
**config["template_pointwise_attention"],
)
def forward(self,
batch,
z,
pair_mask,
templ_dim,
chunk_size,
_mask_trans=True,
inplace=False
):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim]
if isinstance(chunk_size, int) and 1 <= chunk_size <= 4:
t = torch.empty((n_templ, z.shape[0], z.shape[1], 64), dtype=z.dtype, device='cpu')
else:
t = torch.empty((n_templ, z.shape[0], z.shape[1], 64), dtype=z.dtype, device=z.device)
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx),
batch,
)
single_template_embeds = {}
if self.config.embed_angles:
template_angle_feat = build_template_angle_feat(
single_template_feats,
)
# [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat)
single_template_embeds["angle"] = a
# [*, S_t, N, N, C_t]
tt = build_template_pair_feat(
single_template_feats,
use_unit_vector=self.config.use_unit_vector,
inf=self.config.inf,
chunk=chunk_size,
eps=self.config.eps,
**self.config.distogram,
).to(z.dtype).to(z.device)
tt = self.template_pair_embedder(tt)
# single_template_embeds.update({"pair": t})
template_embeds.append(single_template_embeds)
# [*, S_t, N, N, C_z]
if inplace:
tt = [tt]
t[i] = self.template_pair_stack.inplace(
tt,
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)[0].to(t.device)
else:
t[i] = self.template_pair_stack(
tt,
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=_mask_trans,
).to(t.device)
del tt, single_template_feats
template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim),
template_embeds,
)
# [*, N, N, C_z]
if inplace:
z = self.template_pointwise_att.inplace(
t,
z,
template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=chunk_size * 256 if chunk_size is not None else chunk_size,
)
else:
z = self.template_pointwise_att(
t,
z,
template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=chunk_size * 256 if chunk_size is not None else chunk_size,
)
ret = {}
ret["template_pair_embedding"] = z
if self.config.embed_angles:
ret["template_single_embedding"] = template_embeds["angle"]
return ret
class TemplateAngleEmbedder(nn.Module):
"""
Embeds the "template_angle_feat" feature.
Implements Algorithm 2, line 7.
"""
def __init__(
self,
c_in: int,
c_out: int,
**kwargs,
):
"""
Args:
c_in:
Final dimension of "template_angle_feat"
c_out:
Output channel dimension
"""
super(TemplateAngleEmbedder, self).__init__()
self.c_out = c_out
self.c_in = c_in
self.linear_1 = Linear(self.c_in, self.c_out, initializer="relu")
self.relu = nn.ReLU()
self.linear_2 = Linear(self.c_out, self.c_out, initializer="relu")
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [*, N_templ, N_res, c_in] "template_angle_feat" features
Returns:
x: [*, N_templ, N_res, C_out] embedding
"""
x = self.linear_1(x)
x = self.relu(x)
x = self.linear_2(x)
return x
class TemplatePairEmbedder(nn.Module):
"""
Embeds "template_pair_feat" features.
Implements Algorithm 2, line 9.
"""
def __init__(
self,
c_in: int,
c_out: int,
**kwargs,
):
"""
Args:
c_in:
c_out:
Output channel dimension
"""
super(TemplatePairEmbedder, self).__init__()
self.c_in = c_in
self.c_out = c_out
# Despite there being no relu nearby, the source uses that initializer
self.linear = Linear(self.c_in, self.c_out, initializer="relu")
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
"""
Args:
x:
[*, C_in] input tensor
Returns:
[*, C_out] output tensor
"""
x = self.linear(x)
return x
class ExtraMSAEmbedder(nn.Module):
"""
Embeds unclustered MSA sequences.
Implements Algorithm 2, line 15
"""
def __init__(
self,
c_in: int,
c_out: int,
**kwargs,
):
"""
Args:
c_in:
Input channel dimension
c_out:
Output channel dimension
"""
super(ExtraMSAEmbedder, self).__init__()
self.c_in = c_in
self.c_out = c_out
self.linear = Linear(self.c_in, self.c_out)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
[*, N_extra_seq, N_res, C_in] "extra_msa_feat" features
Returns:
[*, N_extra_seq, N_res, C_out] embedding
"""
x = self.linear(x)
return x
from functools import partial
import torch
import torch.nn as nn
from typing import Tuple, Dict
from fastfold.utils import all_atom_multimer
from fastfold.utils.feats import dgram_from_positions
from fastfold.model.fastnn.ops import Linear, LayerNorm
from fastfold.model.fastnn.template import (
TemplatePairStack,
TemplatePointwiseAttention,
)
from fastfold.utils import geometry
from fastfold.utils.tensor_utils import one_hot, tensor_tree_map, dict_multimap
class InputEmbedderMultimer(nn.Module):
"""
Embeds a subset of the input features.
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
def __init__(
self,
tf_dim: int,
msa_dim: int,
c_z: int,
c_m: int,
max_relative_idx: int,
use_chain_relative: bool,
max_relative_chain: int,
**kwargs,
):
"""
Args:
tf_dim:
Final dimension of the target features
msa_dim:
Final dimension of the MSA features
c_z:
Pair embedding dimension
c_m:
MSA embedding dimension
relpos_k:
Window size used in relative positional encoding
"""
super(InputEmbedderMultimer, self).__init__()
self.tf_dim = tf_dim
self.msa_dim = msa_dim
self.c_z = c_z
self.c_m = c_m
self.linear_tf_z_i = Linear(tf_dim, c_z)
self.linear_tf_z_j = Linear(tf_dim, c_z)
self.linear_tf_m = Linear(tf_dim, c_m)
self.linear_msa_m = Linear(msa_dim, c_m)
# RPE stuff
self.max_relative_idx = max_relative_idx
self.use_chain_relative = use_chain_relative
self.max_relative_chain = max_relative_chain
if self.use_chain_relative:
self.no_bins = 2 * max_relative_idx + 2 + 1 + 2 * max_relative_chain + 2
else:
self.no_bins = 2 * max_relative_idx + 1
self.linear_relpos = Linear(self.no_bins, c_z)
def relpos(self, batch: Dict[str, torch.Tensor]):
pos = batch["residue_index"]
asym_id = batch["asym_id"]
asym_id_same = asym_id[..., None] == asym_id[..., None, :]
offset = pos[..., None] - pos[..., None, :]
clipped_offset = torch.clamp(
offset + self.max_relative_idx, 0, 2 * self.max_relative_idx
)
rel_feats = []
if self.use_chain_relative:
final_offset = torch.where(
asym_id_same,
clipped_offset,
(2 * self.max_relative_idx + 1) * torch.ones_like(clipped_offset),
)
rel_pos = torch.nn.functional.one_hot(
final_offset,
2 * self.max_relative_idx + 2,
)
rel_feats.append(rel_pos)
entity_id = batch["entity_id"]
entity_id_same = entity_id[..., None] == entity_id[..., None, :]
rel_feats.append(entity_id_same[..., None])
sym_id = batch["sym_id"]
rel_sym_id = sym_id[..., None] - sym_id[..., None, :]
max_rel_chain = self.max_relative_chain
clipped_rel_chain = torch.clamp(
rel_sym_id + max_rel_chain,
0,
2 * max_rel_chain,
)
final_rel_chain = torch.where(
entity_id_same,
clipped_rel_chain,
(2 * max_rel_chain + 1) * torch.ones_like(clipped_rel_chain),
)
rel_chain = torch.nn.functional.one_hot(
final_rel_chain.long(),
2 * max_rel_chain + 2,
)
rel_feats.append(rel_chain)
else:
rel_pos = torch.nn.functional.one_hot(
clipped_offset,
2 * self.max_relative_idx + 1,
)
rel_feats.append(rel_pos)
rel_feat = torch.cat(rel_feats, dim=-1).to(self.linear_relpos.weight.dtype)
return self.linear_relpos(rel_feat)
def forward(
self, batch: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
tf = batch["target_feat"]
msa = batch["msa_feat"]
# [*, N_res, c_z]
tf_emb_i = self.linear_tf_z_i(tf)
tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z]
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
pair_emb = pair_emb + self.relpos(batch)
# [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3]
tf_m = (
self.linear_tf_m(tf)
.unsqueeze(-3)
.expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))
)
msa_emb = self.linear_msa_m(msa) + tf_m
return msa_emb, pair_emb
class TemplatePairEmbedderMultimer(nn.Module):
def __init__(self,
c_z: int,
c_out: int,
c_dgram: int,
c_aatype: int,
):
super().__init__()
self.dgram_linear = Linear(c_dgram, c_out)
self.aatype_linear_1 = Linear(c_aatype, c_out)
self.aatype_linear_2 = Linear(c_aatype, c_out)
self.query_embedding_layer_norm = LayerNorm(c_z)
self.query_embedding_linear = Linear(c_z, c_out)
self.pseudo_beta_mask_linear = Linear(1, c_out)
self.x_linear = Linear(1, c_out)
self.y_linear = Linear(1, c_out)
self.z_linear = Linear(1, c_out)
self.backbone_mask_linear = Linear(1, c_out)
def forward(self,
template_dgram: torch.Tensor,
aatype_one_hot: torch.Tensor,
query_embedding: torch.Tensor,
pseudo_beta_mask: torch.Tensor,
backbone_mask: torch.Tensor,
multichain_mask_2d: torch.Tensor,
unit_vector: geometry.Vec3Array,
) -> torch.Tensor:
act = 0.
pseudo_beta_mask_2d = (
pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :]
)
pseudo_beta_mask_2d *= multichain_mask_2d
template_dgram *= pseudo_beta_mask_2d[..., None]
act += self.dgram_linear(template_dgram)
act += self.pseudo_beta_mask_linear(pseudo_beta_mask_2d[..., None])
aatype_one_hot = aatype_one_hot.to(template_dgram.dtype)
act += self.aatype_linear_1(aatype_one_hot[..., None, :, :])
act += self.aatype_linear_2(aatype_one_hot[..., None, :])
backbone_mask_2d = (
backbone_mask[..., None] * backbone_mask[..., None, :]
)
backbone_mask_2d *= multichain_mask_2d
x, y, z = [coord * backbone_mask_2d for coord in unit_vector]
act += self.x_linear(x[..., None])
act += self.y_linear(y[..., None])
act += self.z_linear(z[..., None])
act += self.backbone_mask_linear(backbone_mask_2d[..., None])
query_embedding = self.query_embedding_layer_norm(query_embedding)
act += self.query_embedding_linear(query_embedding)
return act
class TemplateSingleEmbedderMultimer(nn.Module):
def __init__(self,
c_in: int,
c_m: int,
):
super().__init__()
self.template_single_embedder = Linear(c_in, c_m)
self.template_projector = Linear(c_m, c_m)
def forward(self,
batch,
atom_pos,
aatype_one_hot,
):
out = {}
template_chi_angles, template_chi_mask = (
all_atom_multimer.compute_chi_angles(
atom_pos,
batch["template_all_atom_mask"],
batch["template_aatype"],
)
)
template_features = torch.cat(
[
aatype_one_hot,
torch.sin(template_chi_angles) * template_chi_mask,
torch.cos(template_chi_angles) * template_chi_mask,
template_chi_mask,
],
dim=-1,
)
template_mask = template_chi_mask[..., 0]
template_features = self.template_single_embedder(
template_features
)
template_features = torch.nn.functional.relu(
template_features
)
template_features = self.template_projector(
template_features,
)
out["template_single_embedding"] = (
template_features
)
out["template_mask"] = template_mask
return out
class TemplateEmbedderMultimer(nn.Module):
def __init__(self, config):
super(TemplateEmbedderMultimer, self).__init__()
self.config = config
self.template_pair_embedder = TemplatePairEmbedderMultimer(
**config["template_pair_embedder"],
)
self.template_single_embedder = TemplateSingleEmbedderMultimer(
**config["template_single_embedder"],
)
self.template_pair_stack = TemplatePairStack(
**config["template_pair_stack"],
)
self.linear_t = Linear(config.c_t, config.c_z)
def forward(self,
batch,
z,
padding_mask_2d,
templ_dim,
chunk_size,
multichain_mask_2d,
inplace
):
template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim]
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx),
batch,
)
single_template_embeds = {}
template_positions, pseudo_beta_mask = (
single_template_feats["template_pseudo_beta"],
single_template_feats["template_pseudo_beta_mask"],
)
template_dgram = dgram_from_positions(
template_positions,
inf=self.config.inf,
**self.config.distogram,
)
aatype_one_hot = torch.nn.functional.one_hot(
single_template_feats["template_aatype"], 22,
)
raw_atom_pos = single_template_feats["template_all_atom_positions"]
atom_pos = geometry.Vec3Array.from_array(raw_atom_pos)
rigid, backbone_mask = all_atom_multimer.make_backbone_affine(
atom_pos,
single_template_feats["template_all_atom_mask"],
single_template_feats["template_aatype"],
)
points = rigid.translation
rigid_vec = rigid[..., None].inverse().apply_to_point(points)
unit_vector = rigid_vec.normalized()
pair_act = self.template_pair_embedder(
template_dgram,
aatype_one_hot,
z,
pseudo_beta_mask,
backbone_mask,
multichain_mask_2d,
unit_vector,
)
single_template_embeds["template_pair_embedding"] = pair_act
single_template_embeds.update(
self.template_single_embedder(
single_template_feats,
atom_pos,
aatype_one_hot,
)
)
template_embeds.append(single_template_embeds)
template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim),
template_embeds,
)
if not inplace:
# [*, S_t, N, N, C_z]
template_embeds["template_pair_embedding"] = self.template_pair_stack(
template_embeds["template_pair_embedding"],
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=False,
)
else:
template_embeds["template_pair_embedding"] = [template_embeds["template_pair_embedding"]]
# [*, S_t, N, N, C_z]
template_embeds["template_pair_embedding"] = self.template_pair_stack.inplace(
template_embeds["template_pair_embedding"],
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=False,
)[0].to(z.device)
# [*, N, N, C_z]
template_embeds["template_pair_embedding"] = torch.sum(template_embeds["template_pair_embedding"], dim=-4) / n_templ
template_embeds["template_pair_embedding"] = torch.nn.functional.relu(template_embeds["template_pair_embedding"])
template_embeds["template_pair_embedding"] = self.linear_t(template_embeds["template_pair_embedding"])
return template_embeds
from typing import Optional, Tuple
from functools import partial
import torch
import torch.nn as nn import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from fastfold.model.fastnn import MSACore, OutProductMean, PairCore
from fastfold.model.fastnn.ops import Linear
from fastfold.distributed.comm import gather, scatter
from fastfold.distributed.comm_async import All_to_All_Async, All_to_All_Async_Opp from fastfold.distributed.comm_async import All_to_All_Async, All_to_All_Async_Opp
from fastfold.model.fastnn import MSAStack, OutProductMean, PairStack from fastfold.utils.checkpointing import checkpoint_blocks
class Evoformer(nn.Module): class Evoformer(nn.Module):
def __init__(self, d_node=256, d_pair=128): def __init__(self, c_m: int, c_z: int, first_block: bool, last_block: bool, is_multimer: bool=False):
super(Evoformer, self).__init__() super(Evoformer, self).__init__()
self.msa_stack = MSAStack(d_node, d_pair, p_drop=0.15) self.first_block = first_block
self.communication = OutProductMean(n_feat=d_node, n_feat_out=d_pair, n_feat_proj=32) self.last_block = last_block
self.pair_stack = PairStack(d_pair=d_pair)
self.msa = MSACore(c_m, c_z, p_drop=0.15)
def forward(self, node, pair, node_mask, pair_mask): self.communication = OutProductMean(n_feat=c_m, n_feat_out=c_z, n_feat_proj=32)
node = self.msa_stack(node, pair, node_mask) self.pair = PairCore(d_pair=c_z)
pair = self.communication(node, node_mask, pair) self.is_multimer = is_multimer
node, work = All_to_All_Async.apply(node, 1, 2)
pair = self.pair_stack(pair, pair_mask) def forward(
node = All_to_All_Async_Opp.apply(node, work, 1, 2) self,
return node, pair m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
dap_size = gpc.get_world_size(ParallelMode.TENSOR)
seq_length = pair_mask.size(-1)
padding_size = (int(seq_length / dap_size) + 1) * dap_size - seq_length
if self.first_block:
m = m.unsqueeze(0)
z = z.unsqueeze(0)
m = torch.nn.functional.pad(m, (0, 0, 0, padding_size))
z = torch.nn.functional.pad(z, (0, 0, 0, padding_size, 0, padding_size))
m = scatter(m, dim=1)
z = scatter(z, dim=1)
msa_mask = msa_mask.unsqueeze(0)
pair_mask = pair_mask.unsqueeze(0)
msa_mask = torch.nn.functional.pad(msa_mask, (0, padding_size))
pair_mask = torch.nn.functional.pad(pair_mask, (0, padding_size, 0, padding_size))
if not self.is_multimer:
m = self.msa(m, z, msa_mask)
z = self.communication(m, msa_mask, z)
m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2)
else:
z = self.communication(m, msa_mask, z)
z_ori = z
m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2)
m = self.msa(m, z_ori, msa_mask)
if self.last_block:
m = m.squeeze(0)
z = z.squeeze(0)
m = gather(m, dim=0)
z = gather(z, dim=0)
m = m[:, :-padding_size, :]
z = z[:-padding_size, :-padding_size, :]
return m, z
def inplace(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
dap_size = gpc.get_world_size(ParallelMode.TENSOR)
seq_length = pair_mask.size(-1)
padding_size = (int(seq_length / dap_size) + 1) * dap_size - seq_length
if self.first_block:
m[0] = m[0].unsqueeze(0)
z[0] = z[0].unsqueeze(0)
m[0] = torch.nn.functional.pad(m[0], (0, 0, 0, padding_size))
z[0] = torch.nn.functional.pad(z[0], (0, 0, 0, padding_size, 0, padding_size))
m[0] = scatter(m[0], dim=1)
z[0] = scatter(z[0], dim=1)
msa_mask = msa_mask.unsqueeze(0)
pair_mask = pair_mask.unsqueeze(0)
msa_mask = torch.nn.functional.pad(msa_mask, (0, padding_size))
pair_mask = torch.nn.functional.pad(pair_mask, (0, padding_size, 0, padding_size))
if not self.is_multimer:
m[0] = self.msa(m[0], z[0], msa_mask)
z = self.communication.inplace(m[0], msa_mask, z)
m[0], work = All_to_All_Async.apply(m[0], 1, 2)
z = self.pair.inplace(z, pair_mask)
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
else:
# z = self.communication.inplace(m[0], msa_mask, z)
# z_ori = z[0].clone()
# m[0], work = All_to_All_Async.apply(m[0], 1, 2)
# z = self.pair_stack.inplace(z, pair_mask)
# m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
# m[0] = self.msa_stack(m[0], z_ori, msa_mask)
z = self.communication.inplace(m[0], msa_mask, z)
m[0], work = All_to_All_Async.apply(m[0], 1, 2)
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
m[0] = self.msa(m[0], z[0], msa_mask)
z = self.pair.inplace(z, pair_mask)
if self.last_block:
m[0] = m[0].squeeze(0)
z[0] = z[0].squeeze(0)
m[0] = gather(m[0], dim=0)
z[0] = gather(z[0], dim=0)
m[0] = m[0][:, :-padding_size, :]
z[0] = z[0][:-padding_size, :-padding_size, :]
return m, z
class EvoformerStack(nn.Module):
"""
Main Evoformer trunk.
Implements Algorithm 6.
"""
def __init__(
self,
c_m: int,
c_z: int,
c_s: int,
no_blocks: int,
blocks_per_ckpt: int,
clear_cache_between_blocks: bool = False,
is_multimer: bool = False,
**kwargs,
):
"""
Args:
c_m:
MSA channel dimension
c_z:
Pair channel dimension
c_hidden_msa_att:
Hidden dimension in MSA attention
c_hidden_opm:
Hidden dimension in outer product mean module
c_hidden_mul:
Hidden dimension in multiplicative updates
c_hidden_pair_att:
Hidden dimension in triangular attention
c_s:
Channel dimension of the output "single" embedding
no_heads_msa:
Number of heads used for MSA attention
no_heads_pair:
Number of heads used for pair attention
no_blocks:
Number of Evoformer blocks in the stack
transition_n:
Factor by which to multiply c_m to obtain the MSATransition
hidden dimension
msa_dropout:
Dropout rate for MSA activations
pair_dropout:
Dropout used for pair activations
blocks_per_ckpt:
Number of Evoformer blocks in each activation checkpoint
clear_cache_between_blocks:
Whether to clear CUDA's GPU memory cache between blocks of the
stack. Slows down each block but can reduce fragmentation
"""
super(EvoformerStack, self).__init__()
self.blocks_per_ckpt = blocks_per_ckpt
self.clear_cache_between_blocks = clear_cache_between_blocks
self.blocks = nn.ModuleList()
for block_id in range(no_blocks):
block = Evoformer(
c_m=c_m,
c_z=c_z,
first_block=(block_id == 0),
last_block=(block_id == no_blocks - 1),
is_multimer=is_multimer,
)
self.blocks.append(block)
self.linear = Linear(c_m, c_s)
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: int,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
[*, N_seq, N_res] MSA mask
pair_mask:
[*, N_res, N_res] pair mask
Returns:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
blocks = [
partial(
b,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
for b in self.blocks
]
if(self.clear_cache_between_blocks):
def block_with_cache_clear(block, *args):
torch.cuda.empty_cache()
return block(*args)
blocks = [partial(block_with_cache_clear, b) for b in blocks]
m, z = checkpoint_blocks(
blocks,
args=(m, z),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
)
s = self.linear(m[..., 0, :, :])
return m, z, s
def inplace(self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: int,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
[*, N_seq, N_res] MSA mask
pair_mask:
[*, N_res, N_res] pair mask
Returns:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
blocks = [
partial(
b.inplace,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
for b in self.blocks
]
if(self.clear_cache_between_blocks):
def block_with_cache_clear(block, *args):
torch.cuda.empty_cache()
return block(*args)
blocks = [partial(block_with_cache_clear, b) for b in blocks]
m, z = checkpoint_blocks(
blocks,
args=(m, z),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
)
s = self.linear(m[0][..., 0, :, :])
return m, z, s
...@@ -13,16 +13,21 @@ ...@@ -13,16 +13,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math import math
from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fastfold.model.fastnn.kernel import LayerNorm from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from fastfold.model.fastnn.ops import ChunkMSARowAttentionWithPairBias, ChunkTransition, SelfAttention, GlobalAttention, Transition, ChunkMSAColumnGlobalAttention from fastfold.model.fastnn.kernel import LayerNorm, bias_dropout_add
from fastfold.model.fastnn.kernel import bias_dropout_add from fastfold.model.fastnn.ops import (ChunkMSARowAttentionWithPairBias, ChunkTransition,
SelfAttention, GlobalAttention, Transition,
ChunkMSAColumnGlobalAttention, OutProductMean)
from fastfold.distributed import scatter, row_to_col from fastfold.distributed import scatter, row_to_col
from fastfold.distributed.comm_async import gather_async from fastfold.distributed.comm import gather, scatter, row_to_col, scatter
from fastfold.distributed.comm_async import gather_async, All_to_All_Async, All_to_All_Async_Opp
from fastfold.model.fastnn.triangle import PairCore
class MSARowAttentionWithPairBias(nn.Module): class MSARowAttentionWithPairBias(nn.Module):
...@@ -120,10 +125,10 @@ class MSAColumnGlobalAttention(nn.Module): ...@@ -120,10 +125,10 @@ class MSAColumnGlobalAttention(nn.Module):
return M_raw + M return M_raw + M
class MSAStack(nn.Module): class MSACore(nn.Module):
def __init__(self, d_node, d_pair, p_drop=0.15): def __init__(self, d_node, d_pair, p_drop=0.15):
super(MSAStack, self).__init__() super(MSACore, self).__init__()
self.MSARowAttentionWithPairBias = MSARowAttentionWithPairBias(d_node=d_node, self.MSARowAttentionWithPairBias = MSARowAttentionWithPairBias(d_node=d_node,
d_pair=d_pair, d_pair=d_pair,
...@@ -146,9 +151,9 @@ class MSAStack(nn.Module): ...@@ -146,9 +151,9 @@ class MSAStack(nn.Module):
return node return node
class ExtraMSAStack(nn.Module): class ExtraMSACore(nn.Module):
def __init__(self, d_node, d_pair, p_drop=0.15): def __init__(self, d_node, d_pair, p_drop=0.15):
super(ExtraMSAStack, self).__init__() super(ExtraMSACore, self).__init__()
self.MSARowAttentionWithPairBias = ChunkMSARowAttentionWithPairBias( self.MSARowAttentionWithPairBias = ChunkMSARowAttentionWithPairBias(
d_node=d_node, d_pair=d_pair, p_drop=p_drop, c=8 d_node=d_node, d_pair=d_pair, p_drop=p_drop, c=8
...@@ -179,4 +184,281 @@ class ExtraMSAStack(nn.Module): ...@@ -179,4 +184,281 @@ class ExtraMSAStack(nn.Module):
node = self.MSAColumnAttention.inplace(node, node_mask_col) node = self.MSAColumnAttention.inplace(node, node_mask_col)
node = self.MSATransition.inplace(node) node = self.MSATransition.inplace(node)
return node return node
\ No newline at end of file
class ExtraMSABlock(nn.Module):
def __init__(
self, c_m: int, c_z: int, first_block: bool, last_block: bool, is_multimer=False
):
super(ExtraMSABlock, self).__init__()
self.first_block = first_block
self.last_block = last_block
self.msa_stack = ExtraMSACore(c_m, c_z, p_drop=0.15)
self.communication = OutProductMean(n_feat=c_m, n_feat_out=c_z, n_feat_proj=32)
self.pair_stack = PairCore(d_pair=c_z)
self.is_multimer = is_multimer
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
dap_size = gpc.get_world_size(ParallelMode.TENSOR)
seq_cnt = msa_mask.size(-2)
seq_len = pair_mask.size(-1)
seq_cnt_padding_size = (int(seq_cnt / dap_size) + 1) * dap_size - seq_cnt
seq_len_padding_size = (int(seq_len / dap_size) + 1) * dap_size - seq_len
if self.first_block:
m = m.unsqueeze(0)
z = z.unsqueeze(0)
m = torch.nn.functional.pad(
m, (0, 0, 0, seq_len_padding_size, 0, seq_cnt_padding_size)
)
z = torch.nn.functional.pad(
z, (0, 0, 0, seq_len_padding_size, 0, seq_len_padding_size)
)
m = scatter(m, dim=1) if not self.is_multimer else scatter(m, dim=2)
z = scatter(z, dim=1)
msa_mask = msa_mask.unsqueeze(0)
pair_mask = pair_mask.unsqueeze(0)
msa_mask = torch.nn.functional.pad(
msa_mask, (0, seq_len_padding_size, 0, seq_cnt_padding_size)
)
pair_mask = torch.nn.functional.pad(
pair_mask, (0, seq_len_padding_size, 0, seq_len_padding_size)
)
if not self.is_multimer:
m = self.msa_stack(m, z, msa_mask)
z = self.communication(m, msa_mask, z)
m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2)
else:
z = self.communication(m, msa_mask, z)
z_ori = z
m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2)
m = self.msa_stack(m, z_ori, msa_mask)
if self.last_block:
m = gather(m, dim=1) if not self.is_multimer else gather(m, dim=2)
z = gather(z, dim=1)
m = m[:, :-seq_cnt_padding_size, :-seq_len_padding_size, :]
z = z[:, :-seq_len_padding_size, :-seq_len_padding_size, :]
m = m.squeeze(0)
z = z.squeeze(0)
return m, z
def inplace(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
dap_size = gpc.get_world_size(ParallelMode.TENSOR)
seq_cnt = msa_mask.size(-2)
seq_len = pair_mask.size(-1)
seq_cnt_padding_size = (int(seq_cnt / dap_size) + 1) * dap_size - seq_cnt
seq_len_padding_size = (int(seq_len / dap_size) + 1) * dap_size - seq_len
if self.first_block:
m[0] = m[0].unsqueeze(0)
z[0] = z[0].unsqueeze(0)
m[0] = torch.nn.functional.pad(
m[0], (0, 0, 0, seq_len_padding_size, 0, seq_cnt_padding_size)
)
z[0] = torch.nn.functional.pad(
z[0], (0, 0, 0, seq_len_padding_size, 0, seq_len_padding_size)
)
m[0] = scatter(m[0], dim=1) if not self.is_multimer else scatter(m[0], dim=2)
z[0] = scatter(z[0], dim=1)
msa_mask = msa_mask.unsqueeze(0)
pair_mask = pair_mask.unsqueeze(0)
msa_mask = torch.nn.functional.pad(
msa_mask, (0, seq_len_padding_size, 0, seq_cnt_padding_size)
)
pair_mask = torch.nn.functional.pad(
pair_mask, (0, seq_len_padding_size, 0, seq_len_padding_size)
)
if not self.is_multimer:
m = self.msa_stack.inplace(m, z, msa_mask)
z = self.communication.inplace(m[0], msa_mask, z)
m[0], work = All_to_All_Async.apply(m[0], 1, 2)
z = self.pair_stack.inplace(z, pair_mask)
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
else:
# z = self.communication.inplace(m[0], msa_mask, z)
# z_ori = [z[0].clone()]
# m[0], work = All_to_All_Async.apply(m[0], 1, 2)
# z = self.pair_stack.inplace(z, pair_mask)
# m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
# m = self.msa_stack.inplace(m, z_ori, msa_mask)
z = self.communication.inplace(m[0], msa_mask, z)
m[0], work = All_to_All_Async.apply(m[0], 1, 2)
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
m = self.msa_stack.inplace(m, z, msa_mask)
z = self.pair_stack.inplace(z, pair_mask)
if self.last_block:
m[0] = gather(m[0], dim=1) if not self.is_multimer else gather(m[0], dim=2)
z[0] = gather(z[0], dim=1)
m[0] = m[0][:, :-seq_cnt_padding_size, :-seq_len_padding_size, :]
z[0] = z[0][:, :-seq_len_padding_size, :-seq_len_padding_size, :]
m[0] = m[0].squeeze(0)
z[0] = z[0].squeeze(0)
return m, z
class ExtraMSAStack(nn.Module):
"""
Implements Algorithm 18.
"""
def __init__(self,
c_m: int,
c_z: int,
no_blocks: int,
clear_cache_between_blocks: bool = False,
is_multimer: bool = False,
**kwargs,
):
super(ExtraMSAStack, self).__init__()
self.clear_cache_between_blocks = clear_cache_between_blocks
self.blocks = nn.ModuleList()
for block_id in range(no_blocks):
block = ExtraMSABlock(
c_m=c_m,
c_z=c_z,
first_block=(block_id == 0),
last_block=(block_id == no_blocks - 1),
is_multimer=is_multimer,
)
self.blocks.append(block)
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
_mask_trans: bool = True,
) -> torch.Tensor:
"""
Args:
m:
[*, N_extra, N_res, C_m] extra MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
Optional [*, N_extra, N_res] MSA mask
pair_mask:
Optional [*, N_res, N_res] pair mask
Returns:
[*, N_res, N_res, C_z] pair update
"""
#checkpoint_fn = get_checkpoint_fn()
#blocks = [
# partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks
#]
#def dodo(b, *args):
# torch.cuda.empty_cache()
# return b(*args)
#blocks = [partial(dodo, b) for b in blocks]
#for b in blocks:
# if(torch.is_grad_enabled()):
# m, z = checkpoint_fn(b, *(m, z))
# else:
# m, z = b(m, z)
for b in self.blocks:
m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size)
if(self.clear_cache_between_blocks):
torch.cuda.empty_cache()
return z
def inplace(self,
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
_mask_trans: bool = True,
) -> torch.Tensor:
"""
Args:
m:
[*, N_extra, N_res, C_m] extra MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
Optional [*, N_extra, N_res] MSA mask
pair_mask:
Optional [*, N_res, N_res] pair mask
Returns:
[*, N_res, N_res, C_z] pair update
"""
#checkpoint_fn = get_checkpoint_fn()
#blocks = [
# partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks
#]
#def dodo(b, *args):
# torch.cuda.empty_cache()
# return b(*args)
#blocks = [partial(dodo, b) for b in blocks]
#for b in blocks:
# if(torch.is_grad_enabled()):
# m, z = checkpoint_fn(b, *(m, z))
# else:
# m, z = b(m, z)
for b in self.blocks:
m, z = b.inplace(m, z, msa_mask, pair_mask, chunk_size=chunk_size)
if(self.clear_cache_between_blocks):
torch.cuda.empty_cache()
return z
...@@ -1166,7 +1166,7 @@ class GlobalAttention(nn.Module): ...@@ -1166,7 +1166,7 @@ class GlobalAttention(nn.Module):
q = torch.sum(m_part * mask_part.unsqueeze(-1), dim=-2) / ( q = torch.sum(m_part * mask_part.unsqueeze(-1), dim=-2) / (
torch.sum(mask_part, dim=-1)[..., None] + self.eps torch.sum(mask_part, dim=-1)[..., None] + self.eps
) )
q = q * self.scaling
q = self.to_q(q) q = self.to_q(q)
q = q.view(q.shape[:-1] + (self.n_head, -1)) q = q.view(q.shape[:-1] + (self.n_head, -1))
...@@ -1188,4 +1188,103 @@ class GlobalAttention(nn.Module): ...@@ -1188,4 +1188,103 @@ class GlobalAttention(nn.Module):
m = torch.cat(output, dim=1) m = torch.cat(output, dim=1)
return m return m
\ No newline at end of file
class InputEmbedder(nn.Module):
"""
Embeds a subset of the input features.
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
def __init__(
self,
tf_dim: int,
msa_dim: int,
c_z: int,
c_m: int,
relpos_k: int,
**kwargs,
):
"""
Args:
tf_dim:
Final dimension of the target features
msa_dim:
Final dimension of the MSA features
c_z:
Pair embedding dimension
c_m:
MSA embedding dimension
relpos_k:
Window size used in relative positional encoding
"""
super(InputEmbedder, self).__init__()
self.tf_dim = tf_dim
self.msa_dim = msa_dim
self.c_z = c_z
self.c_m = c_m
self.linear_tf_z_i = Linear(tf_dim, c_z)
self.linear_tf_z_j = Linear(tf_dim, c_z)
self.linear_tf_m = Linear(tf_dim, c_m)
self.linear_msa_m = Linear(msa_dim, c_m)
# RPE stuff
self.relpos_k = relpos_k
self.no_bins = 2 * relpos_k + 1
self.linear_relpos = Linear(self.no_bins, c_z)
def forward(
self,
tf: torch.Tensor,
ri: torch.Tensor,
msa: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
tf:
"target_feat" features of shape [*, N_res, tf_dim]
ri:
"residue_index" features of shape [*, N_res]
msa:
"msa_feat" features of shape [*, N_clust, N_res, msa_dim]
Returns:
msa_emb:
[*, N_clust, N_res, C_m] MSA embedding
pair_emb:
[*, N_res, N_res, C_z] pair embedding
"""
# [*, N_res, c_z]
tf_emb_i = self.linear_tf_z_i(tf)
tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z]
ri = ri.type(tf_emb_i.dtype)
d = ri[..., None] - ri[..., None, :]
boundaries = torch.arange(
start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
)
reshaped_bins = boundaries.view(((1,) * len(d.shape)) + (len(boundaries),))
pair_emb = d[..., None] - reshaped_bins
pair_emb = torch.argmin(torch.abs(pair_emb), dim=-1)
pair_emb = nn.functional.one_hot(pair_emb, num_classes=len(boundaries)).float().type(ri.dtype)
pair_emb = self.linear_relpos(pair_emb)
pair_emb += tf_emb_i[..., None, :]
pair_emb += tf_emb_j[..., None, :, :]
# [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3]
tf_m = (
self.linear_tf_m(tf)
.unsqueeze(-3)
.expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))
)
msa_emb = self.linear_msa_m(msa) + tf_m
return msa_emb, pair_emb
# Copyright 2022 BioMap (Beijing) Intelligence Technology Limited # Copyright 2021 AlQuraishi Laboratory
# Copyright 2022 HPC-AI Technology Inc. # Copyright 2021 DeepMind Technologies Limited
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +12,8 @@ ...@@ -12,7 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional, Tuple from functools import partial
from typing import Optional, List
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -20,306 +21,168 @@ import torch.nn as nn ...@@ -20,306 +21,168 @@ import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from fastfold.model.fastnn import MSAStack, OutProductMean, PairStack, ExtraMSAStack from fastfold.model.nn.primitives import Attention
from fastfold.model.fastnn.ops import ChunkTransition, ChunkTriangleAttentionStartingNode, ChunkTriangleAttentionEndingNode, \ from fastfold.utils.checkpointing import checkpoint_blocks
AsyncChunkTriangleMultiplicationOutgoing, AsyncChunkTriangleMultiplicationIncoming from fastfold.utils.tensor_utils import chunk_layer, permute_final_dims
from fastfold.distributed.comm import gather, scatter from fastfold.model.fastnn.ops import (ChunkTransition, LayerNorm,
from fastfold.distributed.comm import col_to_row, row_to_col, scatter ChunkTriangleAttentionStartingNode, ChunkTriangleAttentionEndingNode,
from fastfold.distributed.comm_async import All_to_All_Async, All_to_All_Async_Opp AsyncChunkTriangleMultiplicationOutgoing, AsyncChunkTriangleMultiplicationIncoming)
from fastfold.distributed.comm import gather, scatter, col_to_row, row_to_col, scatter
class TemplatePointwiseAttention(nn.Module):
"""
Implements Algorithm 17.
"""
def __init__(self, c_t, c_z, c_hidden, no_heads, inf, **kwargs):
"""
Args:
c_t:
Template embedding channel dimension
c_z:
Pair embedding channel dimension
c_hidden:
Hidden channel dimension
"""
super(TemplatePointwiseAttention, self).__init__()
self.c_t = c_t
self.c_z = c_z
self.c_hidden = c_hidden
self.no_heads = no_heads
self.inf = inf
self.mha = Attention(
self.c_z,
self.c_t,
self.c_t,
self.c_hidden,
self.no_heads,
gating=False
)
def _chunk(self,
class EvoformerBlock(nn.Module):
def __init__(self, c_m: int, c_z: int, first_block: bool, last_block: bool, is_multimer: bool=False):
super(EvoformerBlock, self).__init__()
self.first_block = first_block
self.last_block = last_block
self.msa_stack = MSAStack(c_m, c_z, p_drop=0.15)
self.communication = OutProductMean(n_feat=c_m, n_feat_out=c_z, n_feat_proj=32)
self.pair_stack = PairStack(d_pair=c_z)
self.is_multimer = is_multimer
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
dap_size = gpc.get_world_size(ParallelMode.TENSOR)
seq_length = pair_mask.size(-1)
padding_size = (int(seq_length / dap_size) + 1) * dap_size - seq_length
if self.first_block:
m = m.unsqueeze(0)
z = z.unsqueeze(0)
m = torch.nn.functional.pad(m, (0, 0, 0, padding_size))
z = torch.nn.functional.pad(z, (0, 0, 0, padding_size, 0, padding_size))
m = scatter(m, dim=1)
z = scatter(z, dim=1)
msa_mask = msa_mask.unsqueeze(0)
pair_mask = pair_mask.unsqueeze(0)
msa_mask = torch.nn.functional.pad(msa_mask, (0, padding_size))
pair_mask = torch.nn.functional.pad(pair_mask, (0, padding_size, 0, padding_size))
if not self.is_multimer:
m = self.msa_stack(m, z, msa_mask)
z = self.communication(m, msa_mask, z)
m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2)
else:
z = self.communication(m, msa_mask, z)
z_ori = z
m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2)
m = self.msa_stack(m, z_ori, msa_mask)
if self.last_block:
m = m.squeeze(0)
z = z.squeeze(0)
m = gather(m, dim=0)
z = gather(z, dim=0)
m = m[:, :-padding_size, :]
z = z[:-padding_size, :-padding_size, :]
return m, z
def inplace(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
dap_size = gpc.get_world_size(ParallelMode.TENSOR)
seq_length = pair_mask.size(-1)
padding_size = (int(seq_length / dap_size) + 1) * dap_size - seq_length
if self.first_block:
m[0] = m[0].unsqueeze(0)
z[0] = z[0].unsqueeze(0)
m[0] = torch.nn.functional.pad(m[0], (0, 0, 0, padding_size))
z[0] = torch.nn.functional.pad(z[0], (0, 0, 0, padding_size, 0, padding_size))
m[0] = scatter(m[0], dim=1)
z[0] = scatter(z[0], dim=1)
msa_mask = msa_mask.unsqueeze(0)
pair_mask = pair_mask.unsqueeze(0)
msa_mask = torch.nn.functional.pad(msa_mask, (0, padding_size))
pair_mask = torch.nn.functional.pad(pair_mask, (0, padding_size, 0, padding_size))
if not self.is_multimer:
m[0] = self.msa_stack(m[0], z[0], msa_mask)
z = self.communication.inplace(m[0], msa_mask, z)
m[0], work = All_to_All_Async.apply(m[0], 1, 2)
z = self.pair_stack.inplace(z, pair_mask)
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
else:
# z = self.communication.inplace(m[0], msa_mask, z)
# z_ori = z[0].clone()
# m[0], work = All_to_All_Async.apply(m[0], 1, 2)
# z = self.pair_stack.inplace(z, pair_mask)
# m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
# m[0] = self.msa_stack(m[0], z_ori, msa_mask)
z = self.communication.inplace(m[0], msa_mask, z)
m[0], work = All_to_All_Async.apply(m[0], 1, 2)
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
m[0] = self.msa_stack(m[0], z[0], msa_mask)
z = self.pair_stack.inplace(z, pair_mask)
if self.last_block:
m[0] = m[0].squeeze(0)
z[0] = z[0].squeeze(0)
m[0] = gather(m[0], dim=0)
z[0] = gather(z[0], dim=0)
m[0] = m[0][:, :-padding_size, :]
z[0] = z[0][:-padding_size, :-padding_size, :]
return m, z
class ExtraMSABlock(nn.Module):
def __init__(
self, c_m: int, c_z: int, first_block: bool, last_block: bool, is_multimer=False
):
super(ExtraMSABlock, self).__init__()
self.first_block = first_block
self.last_block = last_block
self.msa_stack = ExtraMSAStack(c_m, c_z, p_drop=0.15)
self.communication = OutProductMean(n_feat=c_m, n_feat_out=c_z, n_feat_proj=32)
self.pair_stack = PairStack(d_pair=c_z)
self.is_multimer = is_multimer
def forward(
self,
m: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
msa_mask: torch.Tensor, t: torch.Tensor,
pair_mask: torch.Tensor, biases: List[torch.Tensor],
chunk_size: Optional[int] = None, chunk_size: int,
_mask_trans: bool = True, ) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]: mha_inputs = {
"q_x": z,
dap_size = gpc.get_world_size(ParallelMode.TENSOR) "kv_x": t,
"biases": biases,
seq_cnt = msa_mask.size(-2) }
seq_len = pair_mask.size(-1) return chunk_layer(
seq_cnt_padding_size = (int(seq_cnt / dap_size) + 1) * dap_size - seq_cnt self.mha,
seq_len_padding_size = (int(seq_len / dap_size) + 1) * dap_size - seq_len mha_inputs,
chunk_size=chunk_size,
if self.first_block: no_batch_dims=len(z.shape[:-2]),
m = m.unsqueeze(0)
z = z.unsqueeze(0)
m = torch.nn.functional.pad(
m, (0, 0, 0, seq_len_padding_size, 0, seq_cnt_padding_size)
)
z = torch.nn.functional.pad(
z, (0, 0, 0, seq_len_padding_size, 0, seq_len_padding_size)
)
m = scatter(m, dim=1) if not self.is_multimer else scatter(m, dim=2)
z = scatter(z, dim=1)
msa_mask = msa_mask.unsqueeze(0)
pair_mask = pair_mask.unsqueeze(0)
msa_mask = torch.nn.functional.pad(
msa_mask, (0, seq_len_padding_size, 0, seq_cnt_padding_size)
)
pair_mask = torch.nn.functional.pad(
pair_mask, (0, seq_len_padding_size, 0, seq_len_padding_size)
) )
if not self.is_multimer:
m = self.msa_stack(m, z, msa_mask)
z = self.communication(m, msa_mask, z)
m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2)
def forward(self,
t: torch.Tensor,
z: torch.Tensor,
template_mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None
) -> torch.Tensor:
"""
Args:
t:
[*, N_templ, N_res, N_res, C_t] template embedding
z:
[*, N_res, N_res, C_t] pair embedding
template_mask:
[*, N_templ] template mask
Returns:
[*, N_res, N_res, C_z] pair embedding update
"""
if template_mask is None:
template_mask = t.new_ones(t.shape[:-3])
bias = self.inf * (template_mask[..., None, None, None, None, :] - 1)
# [*, N_res, N_res, 1, C_z]
z = z.unsqueeze(-2)
# [*, N_res, N_res, N_temp, C_t]
t = permute_final_dims(t, (1, 2, 0, 3))
# [*, N_res, N_res, 1, C_z]
biases = [bias]
if chunk_size is not None:
out = torch.empty_like(z)
mask = torch.sum(template_mask.to(z.device)) > 0
for t0 in range(t.shape[0]):
for t1 in range(0, t.shape[1], chunk_size):
tt = t[t0, t1:t1 + chunk_size, :].unsqueeze(0)
tt = tt.to(z.device)
out[t0, t1:t1 + chunk_size, :] = self.mha(
q_x=z[t0, t1:t1 + chunk_size, :].unsqueeze(0),
kv_x=tt,
biases=biases
).squeeze(0) * mask
else: else:
z = self.communication(m, msa_mask, z) out = self.mha(q_x=z, kv_x=t, biases=biases)
z_ori = z # [*, N_res, N_res, C_z]
m, work = All_to_All_Async.apply(m, 1, 2) out = out * (torch.sum(template_mask) > 0)
z = self.pair_stack(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2) out = out.squeeze(-2)
m = self.msa_stack(m, z_ori, msa_mask)
return out
if self.last_block:
def inplace(self,
m = gather(m, dim=1) if not self.is_multimer else gather(m, dim=2) t: torch.Tensor,
z = gather(z, dim=1) z: torch.Tensor,
template_mask: Optional[torch.Tensor] = None,
m = m[:, :-seq_cnt_padding_size, :-seq_len_padding_size, :] chunk_size: Optional[int] = None
z = z[:, :-seq_len_padding_size, :-seq_len_padding_size, :] ) -> torch.Tensor:
"""
m = m.squeeze(0) Args:
z = z.squeeze(0) t:
[*, N_templ, N_res, N_res, C_t] template embedding
return m, z z:
[*, N_res, N_res, C_t] pair embedding
def inplace( template_mask:
self, [*, N_templ] template mask
m: torch.Tensor, Returns:
z: torch.Tensor, [*, N_res, N_res, C_z] pair embedding update
msa_mask: torch.Tensor, """
pair_mask: torch.Tensor, if template_mask is None:
chunk_size: Optional[int] = None, template_mask = t.new_ones(t.shape[:-3])
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]: bias = self.inf * (template_mask[..., None, None, None, None, :] - 1)
dap_size = gpc.get_world_size(ParallelMode.TENSOR) # [*, N_res, N_res, 1, C_z]
z = z.unsqueeze(-2)
seq_cnt = msa_mask.size(-2)
seq_len = pair_mask.size(-1) # [*, N_res, N_res, N_temp, C_t]
seq_cnt_padding_size = (int(seq_cnt / dap_size) + 1) * dap_size - seq_cnt t = permute_final_dims(t, (1, 2, 0, 3))
seq_len_padding_size = (int(seq_len / dap_size) + 1) * dap_size - seq_len
# [*, N_res, N_res, 1, C_z]
if self.first_block: biases = [bias]
m[0] = m[0].unsqueeze(0) if chunk_size is not None:
z[0] = z[0].unsqueeze(0) mask = torch.sum(template_mask.to(z.device)) > 0
for t0 in range(t.shape[0]):
m[0] = torch.nn.functional.pad( for t1 in range(0, t.shape[1], chunk_size):
m[0], (0, 0, 0, seq_len_padding_size, 0, seq_cnt_padding_size) tt = t[t0, t1:t1 + chunk_size, :].unsqueeze(0)
) tt = tt.to(z.device)
z[0] = torch.nn.functional.pad( z[t0, t1:t1 + chunk_size, :] += self.mha(
z[0], (0, 0, 0, seq_len_padding_size, 0, seq_len_padding_size) q_x=z[t0, t1:t1 + chunk_size, :].unsqueeze(0),
) kv_x=tt,
biases=biases
m[0] = scatter(m[0], dim=1) if not self.is_multimer else scatter(m[0], dim=2) ).squeeze(0) * mask
z[0] = scatter(z[0], dim=1)
msa_mask = msa_mask.unsqueeze(0)
pair_mask = pair_mask.unsqueeze(0)
msa_mask = torch.nn.functional.pad(
msa_mask, (0, seq_len_padding_size, 0, seq_cnt_padding_size)
)
pair_mask = torch.nn.functional.pad(
pair_mask, (0, seq_len_padding_size, 0, seq_len_padding_size)
)
if not self.is_multimer:
m = self.msa_stack.inplace(m, z, msa_mask)
z = self.communication.inplace(m[0], msa_mask, z)
m[0], work = All_to_All_Async.apply(m[0], 1, 2)
z = self.pair_stack.inplace(z, pair_mask)
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
else: else:
# z = self.communication.inplace(m[0], msa_mask, z) t = self.mha(q_x=z, kv_x=t, biases=biases) * (torch.sum(template_mask) > 0)
# z_ori = [z[0].clone()] # [*, N_res, N_res, C_z]
# m[0], work = All_to_All_Async.apply(m[0], 1, 2) z += t
# z = self.pair_stack.inplace(z, pair_mask)
# m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2) z = z.squeeze(-2)
# m = self.msa_stack.inplace(m, z_ori, msa_mask)
z = self.communication.inplace(m[0], msa_mask, z)
m[0], work = All_to_All_Async.apply(m[0], 1, 2)
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
m = self.msa_stack.inplace(m, z, msa_mask)
z = self.pair_stack.inplace(z, pair_mask)
if self.last_block:
m[0] = gather(m[0], dim=1) if not self.is_multimer else gather(m[0], dim=2)
z[0] = gather(z[0], dim=1)
m[0] = m[0][:, :-seq_cnt_padding_size, :-seq_len_padding_size, :]
z[0] = z[0][:, :-seq_len_padding_size, :-seq_len_padding_size, :]
m[0] = m[0].squeeze(0) return z
z[0] = z[0].squeeze(0)
return m, z
class TemplatePairStackBlock(nn.Module): class TemplatePairBlock(nn.Module):
def __init__( def __init__(
self, self,
c_t: int, c_t: int,
...@@ -333,7 +196,7 @@ class TemplatePairStackBlock(nn.Module): ...@@ -333,7 +196,7 @@ class TemplatePairStackBlock(nn.Module):
last_block: bool, last_block: bool,
**kwargs, **kwargs,
): ):
super(TemplatePairStackBlock, self).__init__() super(TemplatePairBlock, self).__init__()
self.first_block = first_block self.first_block = first_block
self.last_block = last_block self.last_block = last_block
...@@ -387,13 +250,13 @@ class TemplatePairStackBlock(nn.Module): ...@@ -387,13 +250,13 @@ class TemplatePairStackBlock(nn.Module):
single_mask_row = scatter(single_mask, dim=1) single_mask_row = scatter(single_mask, dim=1)
single_mask_col = scatter(single_mask, dim=2) single_mask_col = scatter(single_mask, dim=2)
single = self.TriangleMultiplicationOutgoing(single, single_mask_row)
single = row_to_col(single)
single = self.TriangleMultiplicationIncoming(single, single_mask_col)
single = col_to_row(single)
single = self.TriangleAttentionStartingNode(single, single_mask_row) single = self.TriangleAttentionStartingNode(single, single_mask_row)
single = row_to_col(single) single = row_to_col(single)
single = self.TriangleAttentionEndingNode(single, single_mask_col) single = self.TriangleAttentionEndingNode(single, single_mask_col)
single = col_to_row(single)
single = self.TriangleMultiplicationOutgoing(single, single_mask_row)
single = row_to_col(single)
single = self.TriangleMultiplicationIncoming(single, single_mask_col)
single = self.PairTransition(single) single = self.PairTransition(single)
single = col_to_row(single) single = col_to_row(single)
z[i] = single z[i] = single
...@@ -434,20 +297,169 @@ class TemplatePairStackBlock(nn.Module): ...@@ -434,20 +297,169 @@ class TemplatePairStackBlock(nn.Module):
single_mask_row = scatter(single_mask, dim=1) single_mask_row = scatter(single_mask, dim=1)
single_mask_col = scatter(single_mask, dim=2) single_mask_col = scatter(single_mask, dim=2)
single = self.TriangleMultiplicationOutgoing(single, single_mask_row)
single = row_to_col(single)
single = self.TriangleMultiplicationIncoming(single, single_mask_col)
single = col_to_row(single)
single = self.TriangleAttentionStartingNode(single, single_mask_row) single = self.TriangleAttentionStartingNode(single, single_mask_row)
single = row_to_col(single) single = row_to_col(single)
single = self.TriangleAttentionEndingNode(single, single_mask_col) single = self.TriangleAttentionEndingNode(single, single_mask_col)
single = col_to_row(single)
single = self.TriangleMultiplicationOutgoing(single, single_mask_row)
single = row_to_col(single)
single = self.TriangleMultiplicationIncoming(single, single_mask_col)
single = self.PairTransition(single) single = self.PairTransition(single)
single = col_to_row(single) single = col_to_row(single)
z[0][i] = single.to(z[0].device) z[0][i] = single.to(z[0].device)
# z = torch.cat(single_templates, dim=-4) # z = torch.cat(single_templates, dim=-4)
if self.last_block: if self.last_block:
if isinstance(chunk_size, int) and 1 <= chunk_size <= 4:
z[0] = z[0].to(mask.device)
z[0] = gather(z[0], dim=1) z[0] = gather(z[0], dim=1)
z[0] = z[0][:, :-padding_size, :-padding_size, :] z[0] = z[0][:, :-padding_size, :-padding_size, :]
return z return z
\ No newline at end of file
class TemplatePairStack(nn.Module):
"""
Implements Algorithm 16.
"""
def __init__(
self,
c_t,
c_hidden_tri_att,
c_hidden_tri_mul,
no_blocks,
no_heads,
pair_transition_n,
dropout_rate,
blocks_per_ckpt,
inf=1e9,
**kwargs,
):
"""
Args:
c_t:
Template embedding channel dimension
c_hidden_tri_att:
Per-head hidden dimension for triangular attention
c_hidden_tri_att:
Hidden dimension for triangular multiplication
no_blocks:
Number of blocks in the stack
pair_transition_n:
Scale of pair transition (Alg. 15) hidden dimension
dropout_rate:
Dropout rate used throughout the stack
blocks_per_ckpt:
Number of blocks per activation checkpoint. None disables
activation checkpointing
"""
super(TemplatePairStack, self).__init__()
self.blocks_per_ckpt = blocks_per_ckpt
self.blocks = nn.ModuleList()
for block_id in range(no_blocks):
block = TemplatePairBlock(
c_t=c_t,
c_hidden_tri_att=c_hidden_tri_att,
c_hidden_tri_mul=c_hidden_tri_mul,
no_heads=no_heads,
pair_transition_n=pair_transition_n,
dropout_rate=dropout_rate,
inf=inf,
first_block=(block_id == 0),
last_block=(block_id == no_blocks - 1),
)
self.blocks.append(block)
self.layer_norm = LayerNorm(c_t)
def forward(
self,
t: torch.tensor,
mask: torch.tensor,
chunk_size: int,
_mask_trans: bool = True,
):
"""
Args:
t:
[*, N_templ, N_res, N_res, C_t] template embedding
mask:
[*, N_templ, N_res, N_res] mask
Returns:
[*, N_templ, N_res, N_res, C_t] template embedding update
"""
if(mask.shape[-3] == 1):
expand_idx = list(mask.shape)
expand_idx[-3] = t.shape[-4]
mask = mask.expand(*expand_idx)
t, = checkpoint_blocks(
blocks=[
partial(
b,
mask=mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
for b in self.blocks
],
args=(t,),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
)
if chunk_size is None:
chunk_size = t.shape[0]
for i in range(0, t.shape[0], chunk_size):
if t.shape[1] > 4000:
chunk_new = int(4000 * 4000 / t.shape[1])
for j in range(0, t.shape[1], chunk_new):
t[i:i + chunk_size, j:j + chunk_new] = self.layer_norm(t[i:i + chunk_size, j:j + chunk_new])
else:
t[i:i + chunk_size] = self.layer_norm(t[i:i + chunk_size])
return t
def inplace(
self,
t: torch.tensor,
mask: torch.tensor,
chunk_size: int,
_mask_trans: bool = True,
):
"""
Args:
t:
[*, N_templ, N_res, N_res, C_t] template embedding
mask:
[*, N_templ, N_res, N_res] mask
Returns:
[*, N_templ, N_res, N_res, C_t] template embedding update
"""
if(mask.shape[-3] == 1):
expand_idx = list(mask.shape)
expand_idx[-3] = t[0].shape[-4]
mask = mask.expand(*expand_idx)
t, = checkpoint_blocks(
blocks=[
partial(
b.inplace,
mask=mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
for b in self.blocks
],
args=(t,),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
)
if chunk_size is None:
chunk_size = t[0].shape[0]
for i in range(0, t[0].shape[0], chunk_size):
if t[0].shape[1] > 4000:
chunk_new = int(4000 * 4000 / t[0].shape[1])
for j in range(0, t[0].shape[1], chunk_new):
t[0][i:i + chunk_size, j:j + chunk_new] = self.layer_norm(t[0][i:i + chunk_size, j:j + chunk_new].to(mask.device)).to(t[0].device)
else:
t[0][i:i + chunk_size] = self.layer_norm(t[0][i:i + chunk_size].to(mask.device)).to(t[0].device)
return t
...@@ -5,7 +5,11 @@ import torch.nn as nn ...@@ -5,7 +5,11 @@ import torch.nn as nn
from fastfold.model.fastnn.kernel import LayerNorm from fastfold.model.fastnn.kernel import LayerNorm
from fastfold.distributed.comm import col_to_row, row_to_col, scatter from fastfold.distributed.comm import col_to_row, row_to_col, scatter
from fastfold.model.fastnn.kernel import bias_dropout_add, bias_ele_dropout_residual from fastfold.model.fastnn.kernel import bias_dropout_add, bias_ele_dropout_residual
from fastfold.model.fastnn.ops import Linear, SelfAttention, ChunkTransition, ChunkTriangleAttentionEndingNode, AsyncChunkTriangleMultiplicationOutgoing, AsyncChunkTriangleMultiplicationIncoming, ChunkTriangleAttentionStartingNode from fastfold.model.fastnn.ops import (Linear, SelfAttention, ChunkTransition,
ChunkTriangleAttentionStartingNode,
ChunkTriangleAttentionEndingNode,
AsyncChunkTriangleMultiplicationOutgoing,
AsyncChunkTriangleMultiplicationIncoming)
from fastfold.distributed.comm_async import gather_async_opp, gather_async from fastfold.distributed.comm_async import gather_async_opp, gather_async
...@@ -209,10 +213,10 @@ class TriangleAttentionEndingNode(nn.Module): ...@@ -209,10 +213,10 @@ class TriangleAttentionEndingNode(nn.Module):
training=self.training) training=self.training)
class PairStack(nn.Module): class PairCore(nn.Module):
def __init__(self, d_pair, p_drop=0.25): def __init__(self, d_pair, p_drop=0.25):
super(PairStack, self).__init__() super(PairCore, self).__init__()
self.d_pair = d_pair self.d_pair = d_pair
self.n_head = 4 self.n_head = 4
......
...@@ -36,9 +36,6 @@ from fastfold.model.nn.evoformer import EvoformerStack, ExtraMSAStack ...@@ -36,9 +36,6 @@ from fastfold.model.nn.evoformer import EvoformerStack, ExtraMSAStack
from fastfold.model.nn.heads import AuxiliaryHeads from fastfold.model.nn.heads import AuxiliaryHeads
import fastfold.common.residue_constants as residue_constants import fastfold.common.residue_constants as residue_constants
from fastfold.model.nn.structure_module import StructureModule from fastfold.model.nn.structure_module import StructureModule
from fastfold.model.loss import (
compute_plddt,
)
from fastfold.utils.tensor_utils import ( from fastfold.utils.tensor_utils import (
dict_multimap, dict_multimap,
tensor_tree_map, tensor_tree_map,
...@@ -279,15 +276,25 @@ class AlphaFold(nn.Module): ...@@ -279,15 +276,25 @@ class AlphaFold(nn.Module):
# [*, N, N, C_z] # [*, N, N, C_z]
z = z + template_embeds["template_pair_embedding"] z = z + template_embeds["template_pair_embedding"]
else: else:
template_embeds, z = self.template_embedder( if self.globals.inplace:
template_feats, template_embeds = self.template_embedder(
z, template_feats,
pair_mask.to(dtype=z.dtype), z,
no_batch_dims, pair_mask.to(dtype=z.dtype),
self.globals.chunk_size, no_batch_dims,
inplace=self.globals.inplace self.globals.chunk_size,
) inplace=self.globals.inplace
)
z = template_embeds["template_pair_embedding"]
else:
template_embeds = self.template_embedder(
template_feats,
z,
pair_mask.to(dtype=z.dtype),
no_batch_dims,
self.globals.chunk_size,
)
z = z + template_embeds["template_pair_embedding"]
if( if(
self.config.template.embed_angles or self.config.template.embed_angles or
(self.globals.is_multimer and self.config.template.enabled) (self.globals.is_multimer and self.config.template.enabled)
......
...@@ -25,7 +25,6 @@ from fastfold.utils.feats import ( ...@@ -25,7 +25,6 @@ from fastfold.utils.feats import (
) )
from fastfold.model.nn.primitives import Linear, LayerNorm from fastfold.model.nn.primitives import Linear, LayerNorm
from fastfold.utils.tensor_utils import one_hot from fastfold.utils.tensor_utils import one_hot
from fastfold.model.fastnn.ops import RecyclingEmbedder
from fastfold.model.nn.template import ( from fastfold.model.nn.template import (
TemplatePairStack, TemplatePairStack,
TemplatePointwiseAttention, TemplatePointwiseAttention,
...@@ -123,8 +122,8 @@ class InputEmbedder(nn.Module): ...@@ -123,8 +122,8 @@ class InputEmbedder(nn.Module):
tf_emb_j = self.linear_tf_z_j(tf) tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z] # [*, N_res, N_res, c_z]
pair_emb = self.relpos(ri.type(tf_emb_i.dtype)) pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
pair_emb += tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :] pair_emb = pair_emb + self.relpos(ri.type(pair_emb.dtype))
# [*, N_clust, N_res, c_m] # [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3] n_clust = msa.shape[-3]
...@@ -138,6 +137,101 @@ class InputEmbedder(nn.Module): ...@@ -138,6 +137,101 @@ class InputEmbedder(nn.Module):
return msa_emb, pair_emb return msa_emb, pair_emb
class RecyclingEmbedder(nn.Module):
"""
Embeds the output of an iteration of the model for recycling.
Implements Algorithm 32.
"""
def __init__(
self,
c_m: int,
c_z: int,
min_bin: float,
max_bin: float,
no_bins: int,
inf: float = 1e8,
**kwargs,
):
"""
Args:
c_m:
MSA channel dimension
c_z:
Pair embedding channel dimension
min_bin:
Smallest distogram bin (Angstroms)
max_bin:
Largest distogram bin (Angstroms)
no_bins:
Number of distogram bins
"""
super(RecyclingEmbedder, self).__init__()
self.c_m = c_m
self.c_z = c_z
self.min_bin = min_bin
self.max_bin = max_bin
self.no_bins = no_bins
self.inf = inf
self.linear = Linear(self.no_bins, self.c_z)
self.layer_norm_m = LayerNorm(self.c_m)
self.layer_norm_z = LayerNorm(self.c_z)
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
x: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
m:
First row of the MSA embedding. [*, N_res, C_m]
z:
[*, N_res, N_res, C_z] pair embedding
x:
[*, N_res, 3] predicted C_beta coordinates
Returns:
m:
[*, N_res, C_m] MSA embedding update
z:
[*, N_res, N_res, C_z] pair embedding update
"""
bins = torch.linspace(
self.min_bin,
self.max_bin,
self.no_bins,
dtype=x.dtype,
device=x.device,
requires_grad=False,
)
# [*, N, C_m]
m_update = self.layer_norm_m(m)
# This squared method might become problematic in FP16 mode.
# I'm using it because my homegrown method had a stubborn discrepancy I
# couldn't find in time.
squared_bins = bins ** 2
upper = torch.cat(
[squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1
)
d = torch.sum(
(x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1, keepdims=True
)
# [*, N, N, no_bins]
d = ((d > squared_bins) * (d < upper)).type(x.dtype)
# [*, N, N, C_z]
d = self.linear(d)
z_update = d + self.layer_norm_z(z)
return m_update, z_update
class TemplateEmbedder(nn.Module): class TemplateEmbedder(nn.Module):
def __init__(self, config): def __init__(self, config):
super(TemplateEmbedder, self).__init__() super(TemplateEmbedder, self).__init__()
...@@ -162,18 +256,11 @@ class TemplateEmbedder(nn.Module): ...@@ -162,18 +256,11 @@ class TemplateEmbedder(nn.Module):
pair_mask, pair_mask,
templ_dim, templ_dim,
chunk_size, chunk_size,
_mask_trans=True, _mask_trans=True
inplace=False
): ):
# Embed the templates one at a time (with a poor man's vmap) # Embed the templates one at a time (with a poor man's vmap)
template_embeds = [] template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim] n_templ = batch["template_aatype"].shape[templ_dim]
if isinstance(chunk_size, int) and 1 <= chunk_size <= 4:
t = torch.empty((n_templ, z.shape[0], z.shape[1], 64), dtype=z.dtype, device='cpu')
else:
t = torch.empty((n_templ, z.shape[0], z.shape[1], 64), dtype=z.dtype, device=z.device)
for i in range(n_templ): for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i) idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map( single_template_feats = tensor_tree_map(
...@@ -193,57 +280,48 @@ class TemplateEmbedder(nn.Module): ...@@ -193,57 +280,48 @@ class TemplateEmbedder(nn.Module):
single_template_embeds["angle"] = a single_template_embeds["angle"] = a
# [*, S_t, N, N, C_t] # [*, S_t, N, N, C_t]
tt = build_template_pair_feat( t = build_template_pair_feat(
single_template_feats, single_template_feats,
use_unit_vector=self.config.use_unit_vector, use_unit_vector=self.config.use_unit_vector,
inf=self.config.inf, inf=self.config.inf,
chunk=chunk_size,
eps=self.config.eps, eps=self.config.eps,
**self.config.distogram, **self.config.distogram,
).to(z.dtype).to(z.device) ).to(z.dtype)
t = self.template_pair_embedder(t)
tt = self.template_pair_embedder(tt) single_template_embeds.update({"pair": t})
# single_template_embeds.update({"pair": t})
template_embeds.append(single_template_embeds) template_embeds.append(single_template_embeds)
# [*, S_t, N, N, C_z]
if inplace:
tt = [tt]
t[i] = self.template_pair_stack.inplace(
tt,
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)[0].to(t.device)
else:
t[i] = self.template_pair_stack(
tt,
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=_mask_trans,
).to(t.device)
del tt, single_template_feats
template_embeds = dict_multimap( template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim), partial(torch.cat, dim=templ_dim),
template_embeds, template_embeds,
) )
# [*, S_t, N, N, C_z]
t = self.template_pair_stack(
template_embeds["pair"],
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
# [*, N, N, C_z] # [*, N, N, C_z]
z = self.template_pointwise_att( t = self.template_pointwise_att(
t, t,
z, z,
template_mask=batch["template_mask"].to(dtype=z.dtype), template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=chunk_size * 256 if chunk_size is not None else chunk_size, chunk_size=chunk_size,
) )
t = t * (torch.sum(batch["template_mask"]) > 0)
ret = {} ret = {}
if self.config.embed_angles: if self.config.embed_angles:
ret["template_single_embedding"] = template_embeds["angle"] ret["template_single_embedding"] = template_embeds["angle"]
return ret, z ret.update({"template_pair_embedding": t})
return ret
class TemplateAngleEmbedder(nn.Module): class TemplateAngleEmbedder(nn.Module):
...@@ -370,4 +448,4 @@ class ExtraMSAEmbedder(nn.Module): ...@@ -370,4 +448,4 @@ class ExtraMSAEmbedder(nn.Module):
""" """
x = self.linear(x) x = self.linear(x)
return x return x
\ No newline at end of file
...@@ -254,18 +254,18 @@ class TemplateSingleEmbedderMultimer(nn.Module): ...@@ -254,18 +254,18 @@ class TemplateSingleEmbedderMultimer(nn.Module):
template_mask = template_chi_mask[..., 0] template_mask = template_chi_mask[..., 0]
template_features = self.template_single_embedder( template_activations = self.template_single_embedder(
template_features template_features
) )
template_features = torch.nn.functional.relu( template_activations = torch.nn.functional.relu(
template_features template_activations
) )
template_features = self.template_projector( template_activations = self.template_projector(
template_features, template_activations,
) )
out["template_single_embedding"] = ( out["template_single_embedding"] = (
template_features template_activations
) )
out["template_mask"] = template_mask out["template_mask"] = template_mask
...@@ -296,7 +296,6 @@ class TemplateEmbedderMultimer(nn.Module): ...@@ -296,7 +296,6 @@ class TemplateEmbedderMultimer(nn.Module):
templ_dim, templ_dim,
chunk_size, chunk_size,
multichain_mask_2d, multichain_mask_2d,
inplace
): ):
template_embeds = [] template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim] n_templ = batch["template_aatype"].shape[templ_dim]
...@@ -308,6 +307,7 @@ class TemplateEmbedderMultimer(nn.Module): ...@@ -308,6 +307,7 @@ class TemplateEmbedderMultimer(nn.Module):
) )
single_template_embeds = {} single_template_embeds = {}
act = 0.
template_positions, pseudo_beta_mask = ( template_positions, pseudo_beta_mask = (
single_template_feats["template_pseudo_beta"], single_template_feats["template_pseudo_beta"],
...@@ -361,27 +361,17 @@ class TemplateEmbedderMultimer(nn.Module): ...@@ -361,27 +361,17 @@ class TemplateEmbedderMultimer(nn.Module):
template_embeds, template_embeds,
) )
if not inplace: # [*, S_t, N, N, C_z]
# [*, S_t, N, N, C_z] t = self.template_pair_stack(
template_embeds["template_pair_embedding"] = self.template_pair_stack( template_embeds["template_pair_embedding"],
template_embeds["template_pair_embedding"], padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype), chunk_size=chunk_size,
chunk_size=chunk_size, _mask_trans=False,
_mask_trans=False, )
)
else:
template_embeds["template_pair_embedding"] = [template_embeds["template_pair_embedding"]]
# [*, S_t, N, N, C_z]
template_embeds["template_pair_embedding"] = self.template_pair_stack.inplace(
template_embeds["template_pair_embedding"],
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=False,
)[0].to(z.device)
# [*, N, N, C_z] # [*, N, N, C_z]
template_embeds["template_pair_embedding"] = torch.sum(template_embeds["template_pair_embedding"], dim=-4) / n_templ t = torch.sum(t, dim=-4) / n_templ
template_embeds["template_pair_embedding"] = torch.nn.functional.relu(template_embeds["template_pair_embedding"]) t = torch.nn.functional.relu(t)
template_embeds["template_pair_embedding"] = self.linear_t(template_embeds["template_pair_embedding"]) t = self.linear_t(t)
template_embeds["template_pair_embedding"] = t
return template_embeds return template_embeds
\ No newline at end of file
...@@ -539,60 +539,6 @@ class EvoformerStack(nn.Module): ...@@ -539,60 +539,6 @@ class EvoformerStack(nn.Module):
return m, z, s return m, z, s
def inplace(self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: int,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
[*, N_seq, N_res] MSA mask
pair_mask:
[*, N_res, N_res] pair mask
Returns:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
blocks = [
partial(
b.inplace,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
for b in self.blocks
]
if(self.clear_cache_between_blocks):
def block_with_cache_clear(block, *args):
torch.cuda.empty_cache()
return block(*args)
blocks = [partial(block_with_cache_clear, b) for b in blocks]
m, z = checkpoint_blocks(
blocks,
args=(m, z),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
)
s = self.linear(m[0][..., 0, :, :])
return m, z, s
class ExtraMSAStack(nn.Module): class ExtraMSAStack(nn.Module):
""" """
...@@ -687,50 +633,4 @@ class ExtraMSAStack(nn.Module): ...@@ -687,50 +633,4 @@ class ExtraMSAStack(nn.Module):
if(self.clear_cache_between_blocks): if(self.clear_cache_between_blocks):
torch.cuda.empty_cache() torch.cuda.empty_cache()
return z
def inplace(self,
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
_mask_trans: bool = True,
) -> torch.Tensor:
"""
Args:
m:
[*, N_extra, N_res, C_m] extra MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
Optional [*, N_extra, N_res] MSA mask
pair_mask:
Optional [*, N_res, N_res] pair mask
Returns:
[*, N_res, N_res, C_z] pair update
"""
#checkpoint_fn = get_checkpoint_fn()
#blocks = [
# partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks
#]
#def dodo(b, *args):
# torch.cuda.empty_cache()
# return b(*args)
#blocks = [partial(dodo, b) for b in blocks]
#for b in blocks:
# if(torch.is_grad_enabled()):
# m, z = checkpoint_fn(b, *(m, z))
# else:
# m, z = b(m, z)
for b in self.blocks:
m, z = b.inplace(m, z, msa_mask, pair_mask, chunk_size=chunk_size)
if(self.clear_cache_between_blocks):
torch.cuda.empty_cache()
return z return z
\ No newline at end of file
...@@ -122,26 +122,11 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -122,26 +122,11 @@ class TemplatePointwiseAttention(nn.Module):
# [*, N_res, N_res, 1, C_z] # [*, N_res, N_res, 1, C_z]
biases = [bias] biases = [bias]
if chunk_size is not None: if chunk_size is not None:
para_dim_t0 = t.shape[0] z = self._chunk(z, t, biases, chunk_size)
para_dim_t1 = t.shape[1]
chunk_size_t = chunk_size * 4
mask = torch.sum(template_mask.to(z.device)) > 0
for ti in range(0, para_dim_t0, chunk_size_t):
t0 = t[ti:ti + chunk_size_t, :, :, :]
t0 = t0.to(z.device)
para_dim_t_part = t0.shape[0]
for i in range(0, para_dim_t_part, chunk_size):
for j in range(0, para_dim_t1, chunk_size):
z[i:i + chunk_size, j:j + chunk_size, :, :] += self.mha(
q_x=z[i + ti:i + ti + chunk_size, j:j + chunk_size, :, :], kv_x=t0[i:i + chunk_size, j:j + chunk_size, :, :], biases=biases
) * mask
else: else:
t = self.mha(q_x=z, kv_x=t, biases=biases) z = self.mha(q_x=z, kv_x=t, biases=biases)
# [*, N_res, N_res, C_z]
t = t * (torch.sum(template_mask) > 0) # [*, N_res, N_res, C_z]
z = z + t
z = z.squeeze(-2) z = z.squeeze(-2)
return z return z
...@@ -368,48 +353,7 @@ class TemplatePairStack(nn.Module): ...@@ -368,48 +353,7 @@ class TemplatePairStack(nn.Module):
args=(t,), args=(t,),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None, blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
) )
if chunk_size is None:
chunk_size = t.shape[0]
for i in range(0, t.shape[0], chunk_size):
t[i:i + chunk_size] = self.layer_norm(t[i:i + chunk_size])
return t
def inplace(
self,
t: torch.tensor,
mask: torch.tensor,
chunk_size: int,
_mask_trans: bool = True,
):
"""
Args:
t:
[*, N_templ, N_res, N_res, C_t] template embedding
mask:
[*, N_templ, N_res, N_res] mask
Returns:
[*, N_templ, N_res, N_res, C_t] template embedding update
"""
if(mask.shape[-3] == 1):
expand_idx = list(mask.shape)
expand_idx[-3] = t[0].shape[-4]
mask = mask.expand(*expand_idx)
t, = checkpoint_blocks( t = self.layer_norm(t)
blocks=[
partial( return t
b.inplace, \ No newline at end of file
mask=mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
for b in self.blocks
],
args=(t,),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
)
if chunk_size is None:
chunk_size = t[0].shape[0]
for i in range(0, t[0].shape[0], chunk_size):
t[0][i:i + chunk_size] = self.layer_norm(t[0][i:i + chunk_size].to(mask.device)).to(t[0].device)
return t
from .inject_fastnn import inject_fastnn
__all__ = ['inject_fastnn']
\ No newline at end of file
...@@ -13,9 +13,13 @@ ...@@ -13,9 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch import torch
import torch.nn as nn
from fastfold.model.fastnn import EvoformerBlock, ExtraMSABlock, TemplatePairStackBlock from fastfold.model.fastnn import EvoformerStack, ExtraMSAStack
from fastfold.model.fastnn.embedders import TemplateEmbedder
from fastfold.model.fastnn.embedders_multimer import TemplateEmbedderMultimer
from fastfold.model.fastnn.ops import RecyclingEmbedder, InputEmbedder
def copy_layernorm(model_fast, model_ori): def copy_layernorm(model_fast, model_ori):
model_fast.weight.copy_(model_ori.weight) model_fast.weight.copy_(model_ori.weight)
model_fast.bias.copy_(model_ori.bias) model_fast.bias.copy_(model_ori.bias)
...@@ -27,6 +31,14 @@ def copy_linear(model_fast, model_ori): ...@@ -27,6 +31,14 @@ def copy_linear(model_fast, model_ori):
model_fast.bias.copy_(model_ori.bias) model_fast.bias.copy_(model_ori.bias)
def copy_native_linear(model_fast, model_ori):
model_fast.weight.copy_(model_ori.weight)
try:
model_fast.bias.copy_(model_ori.bias)
except:
pass
def copy_kv_linear(model_fast, ori_k, ori_v): def copy_kv_linear(model_fast, ori_k, ori_v):
model_fast.weight.copy_(torch.cat((ori_k.weight, ori_v.weight), dim=0)) model_fast.weight.copy_(torch.cat((ori_k.weight, ori_v.weight), dim=0))
...@@ -77,32 +89,41 @@ def copy_triangle_att(model_fast, model_ori): ...@@ -77,32 +89,41 @@ def copy_triangle_att(model_fast, model_ori):
model_fast.out_bias.copy_(model_ori.mha.linear_o.bias) model_fast.out_bias.copy_(model_ori.mha.linear_o.bias)
def copy_native_att(model_fast, model_ori):
copy_native_linear(model_fast.linear_q, model_ori.linear_q)
copy_native_linear(model_fast.linear_k, model_ori.linear_k)
copy_native_linear(model_fast.linear_v, model_ori.linear_v)
copy_native_linear(model_fast.linear_o, model_ori.linear_o)
if model_ori.gating:
copy_native_linear(model_fast.linear_g, model_ori.linear_g)
def copy_evoformer_para(block_fast, block_ori): def copy_evoformer_para(block_fast, block_ori):
# msa_stack # msa_stack
# MSARowAttentionWithPairBias # MSARowAttentionWithPairBias
copy_layernorm(block_fast.msa_stack.MSARowAttentionWithPairBias.layernormM, copy_layernorm(block_fast.msa.MSARowAttentionWithPairBias.layernormM,
block_ori.msa_att_row.layer_norm_m) block_ori.msa_att_row.layer_norm_m)
copy_layernorm(block_fast.msa_stack.MSARowAttentionWithPairBias.layernormZ, copy_layernorm(block_fast.msa.MSARowAttentionWithPairBias.layernormZ,
block_ori.msa_att_row.layer_norm_z) block_ori.msa_att_row.layer_norm_z)
copy_attention(block_fast.msa_stack.MSARowAttentionWithPairBias.attention, copy_attention(block_fast.msa.MSARowAttentionWithPairBias.attention,
block_ori.msa_att_row.mha) block_ori.msa_att_row.mha)
block_fast.msa_stack.MSARowAttentionWithPairBias.linear_b_weights.copy_( block_fast.msa.MSARowAttentionWithPairBias.linear_b_weights.copy_(
block_ori.msa_att_row.linear_z.weight) block_ori.msa_att_row.linear_z.weight)
block_fast.msa_stack.MSARowAttentionWithPairBias.out_bias.copy_( block_fast.msa.MSARowAttentionWithPairBias.out_bias.copy_(
block_ori.msa_att_row.mha.linear_o.bias) block_ori.msa_att_row.mha.linear_o.bias)
# MSAColumnAttention # MSAColumnAttention
copy_layernorm(block_fast.msa_stack.MSAColumnAttention.layernormM, copy_layernorm(block_fast.msa.MSAColumnAttention.layernormM,
block_ori.msa_att_col._msa_att.layer_norm_m) block_ori.msa_att_col._msa_att.layer_norm_m)
copy_attention(block_fast.msa_stack.MSAColumnAttention.attention, copy_attention(block_fast.msa.MSAColumnAttention.attention,
block_ori.msa_att_col._msa_att.mha) block_ori.msa_att_col._msa_att.mha)
# MSATransition # MSATransition
copy_transition(block_fast.msa_stack.MSATransition, block_ori.core.msa_transition) copy_transition(block_fast.msa.MSATransition, block_ori.core.msa_transition)
# communication # communication
copy_layernorm(block_fast.communication.layernormM, copy_layernorm(block_fast.communication.layernormM,
...@@ -113,16 +134,16 @@ def copy_evoformer_para(block_fast, block_ori): ...@@ -113,16 +134,16 @@ def copy_evoformer_para(block_fast, block_ori):
# pair_stack # pair_stack
# TriangleMultiplicationOutgoing # TriangleMultiplicationOutgoing
copy_triangle(block_fast.pair_stack.TriangleMultiplicationOutgoing, block_ori.core.tri_mul_out) copy_triangle(block_fast.pair.TriangleMultiplicationOutgoing, block_ori.core.tri_mul_out)
# TriangleMultiplicationIncoming # TriangleMultiplicationIncoming
copy_triangle(block_fast.pair_stack.TriangleMultiplicationIncoming, block_ori.core.tri_mul_in) copy_triangle(block_fast.pair.TriangleMultiplicationIncoming, block_ori.core.tri_mul_in)
# TriangleAttentionStartingNode # TriangleAttentionStartingNode
copy_triangle_att(block_fast.pair_stack.TriangleAttentionStartingNode, copy_triangle_att(block_fast.pair.TriangleAttentionStartingNode,
block_ori.core.tri_att_start) block_ori.core.tri_att_start)
copy_triangle_att(block_fast.pair_stack.TriangleAttentionEndingNode, block_ori.core.tri_att_end) copy_triangle_att(block_fast.pair.TriangleAttentionEndingNode, block_ori.core.tri_att_end)
copy_transition(block_fast.pair_stack.PairTransition, block_ori.core.pair_transition) copy_transition(block_fast.pair.PairTransition, block_ori.core.pair_transition)
def copy_global_attention(model_fast, model_ori): def copy_global_attention(model_fast, model_ori):
...@@ -222,87 +243,175 @@ def copy_template_pair_stack_para(block_fast, block_ori): ...@@ -222,87 +243,175 @@ def copy_template_pair_stack_para(block_fast, block_ori):
copy_transition(block_fast.PairTransition, block_ori.pair_transition) copy_transition(block_fast.PairTransition, block_ori.pair_transition)
def inject_evoformer(model): def copy_template_pair_block_para(fast_module, target_module):
with torch.no_grad():
fastfold_blocks = nn.ModuleList()
for block_id, ori_block in enumerate(model.evoformer.blocks):
c_m = ori_block.msa_att_row.c_in
c_z = ori_block.msa_att_row.c_z
is_multimer = ori_block.is_multimer
fastfold_block = EvoformerBlock(c_m=c_m,
c_z=c_z,
first_block=(block_id == 0),
last_block=(block_id == len(model.evoformer.blocks) - 1),
is_multimer=is_multimer,
)
copy_evoformer_para(fastfold_block, ori_block)
fastfold_blocks.append(fastfold_block)
model.evoformer.blocks = fastfold_blocks
return model
def inject_extraMsaBlock(model):
with torch.no_grad(): with torch.no_grad():
new_model_blocks = nn.ModuleList() for ori_block, fast_block in zip(target_module.blocks, fast_module.blocks):
for block_id, ori_block in enumerate(model.extra_msa_stack.blocks): copy_template_pair_stack_para(fast_block, ori_block)
c_m = ori_block.msa_att_row.c_in
c_z = ori_block.msa_att_row.c_z
is_multimer = ori_block.is_multimer
new_model_block = ExtraMSABlock(
c_m=c_m,
c_z=c_z,
first_block=(block_id == 0),
last_block=(block_id == len(model.extra_msa_stack.blocks) - 1),
is_multimer=is_multimer
)
copy_extra_msa_para(new_model_block, ori_block)
if ori_block.training == False: if ori_block.training == False:
new_model_block.eval() fast_block.eval()
new_model_blocks.append(new_model_block)
def copy_template_para(block_fast, block_ori):
# TemplateAngleEmbedder
copy_linear(block_fast.template_angle_embedder.linear_1,
block_ori.template_angle_embedder.linear_1)
copy_linear(block_fast.template_angle_embedder.linear_2,
block_ori.template_angle_embedder.linear_2)
# TemplatePairEmbedder
copy_linear(block_fast.template_pair_embedder.linear,
block_ori.template_pair_embedder.linear)
# TemplatePairStack
copy_template_pair_block_para(block_fast.template_pair_stack,
block_ori.template_pair_stack)
copy_layernorm(block_fast.template_pair_stack.layer_norm,
block_ori.template_pair_stack.layer_norm)
# TemplatePointwiseAttention
copy_native_att(block_fast.template_pointwise_att.mha,
block_ori.template_pointwise_att.mha)
def copy_template_multimer_para(block_fast, block_ori):
# TemplatePairEmbedderMultimer
copy_linear(block_fast.template_pair_embedder.dgram_linear,
block_ori.template_pair_embedder.dgram_linear)
copy_linear(block_fast.template_pair_embedder.aatype_linear_1,
block_ori.template_pair_embedder.aatype_linear_1)
copy_linear(block_fast.template_pair_embedder.aatype_linear_2,
block_ori.template_pair_embedder.aatype_linear_2)
copy_layernorm(block_fast.template_pair_embedder.query_embedding_layer_norm,
block_ori.template_pair_embedder.query_embedding_layer_norm)
copy_linear(block_fast.template_pair_embedder.query_embedding_linear,
block_ori.template_pair_embedder.query_embedding_linear)
copy_linear(block_fast.template_pair_embedder.pseudo_beta_mask_linear,
block_ori.template_pair_embedder.pseudo_beta_mask_linear)
copy_linear(block_fast.template_pair_embedder.x_linear,
block_ori.template_pair_embedder.x_linear)
copy_linear(block_fast.template_pair_embedder.y_linear,
block_ori.template_pair_embedder.y_linear)
copy_linear(block_fast.template_pair_embedder.z_linear,
block_ori.template_pair_embedder.z_linear)
copy_linear(block_fast.template_pair_embedder.backbone_mask_linear,
block_ori.template_pair_embedder.backbone_mask_linear)
# TemplateSingleEmbedderMultimer
copy_linear(block_fast.template_single_embedder.template_single_embedder,
block_ori.template_single_embedder.template_single_embedder)
copy_linear(block_fast.template_single_embedder.template_projector,
block_ori.template_single_embedder.template_projector)
# TemplatePairStack
copy_template_pair_block_para(block_fast.template_pair_stack,
block_ori.template_pair_stack)
copy_layernorm(block_fast.template_pair_stack.layer_norm,
block_ori.template_pair_stack.layer_norm)
# linear_t
copy_linear(block_fast.linear_t, block_ori.linear_t)
model.extra_msa_stack.blocks = new_model_blocks
def inject_evoformer(model):
def inject_templatePairBlock(model):
with torch.no_grad(): with torch.no_grad():
target_module = model.template_embedder.template_pair_stack.blocks target_module = model.evoformer
fastfold_blocks = nn.ModuleList() fast_module = EvoformerStack(
for block_id, ori_block in enumerate(target_module): c_m=target_module.blocks[0].msa_att_row.c_in,
c_t = ori_block.c_t c_z=target_module.blocks[0].msa_att_row.c_z,
c_hidden_tri_att = ori_block.c_hidden_tri_att c_s=target_module.linear.out_features,
c_hidden_tri_mul = ori_block.c_hidden_tri_mul no_blocks=len(target_module.blocks),
no_heads = ori_block.no_heads blocks_per_ckpt=target_module.blocks_per_ckpt,
pair_transition_n = ori_block.pair_transition_n clear_cache_between_blocks=target_module.clear_cache_between_blocks,
dropout_rate = ori_block.dropout_rate is_multimer=target_module.blocks[0].is_multimer,
inf = ori_block.inf )
fastfold_block = TemplatePairStackBlock( for target_block, fast_block in zip(target_module.blocks, fast_module.blocks):
c_t=c_t, copy_evoformer_para(fast_block, target_block)
c_hidden_tri_att=c_hidden_tri_att, if target_block.training == False:
c_hidden_tri_mul=c_hidden_tri_mul, fast_block.eval()
no_heads=no_heads, copy_linear(fast_module.linear, target_module.linear)
pair_transition_n=pair_transition_n, model.evoformer = fast_module
dropout_rate=dropout_rate,
inf=inf,
first_block=(block_id == 0), def inject_extramsa(model):
last_block=(block_id == len(target_module) - 1), with torch.no_grad():
) target_module = model.extra_msa_stack
fast_module = ExtraMSAStack(
copy_template_pair_stack_para(fastfold_block, ori_block) c_m=target_module.blocks[0].msa_att_row.c_in,
c_z=target_module.blocks[0].msa_att_row.c_z,
if ori_block.training == False: no_blocks=len(target_module.blocks),
fastfold_block.eval() clear_cache_between_blocks=target_module.clear_cache_between_blocks,
fastfold_blocks.append(fastfold_block) is_multimer=target_module.blocks[0].is_multimer,
)
model.template_embedder.template_pair_stack.blocks = fastfold_blocks for target_block, fast_block in zip(target_module.blocks, fast_module.blocks):
copy_extra_msa_para(fast_block, target_block)
if target_block.training == False:
fast_block.eval()
model.extra_msa_stack = fast_module
def inject_template(model):
with torch.no_grad():
if model.evoformer.blocks[0].is_multimer:
target_module = model.template_embedder
fast_module = TemplateEmbedderMultimer(config=model.template_embedder.config)
copy_template_multimer_para(fast_module, target_module)
if target_module.training == False:
fast_module.eval()
model.template_embedder = fast_module
else:
target_module = model.template_embedder
fast_module = TemplateEmbedder(config=model.template_embedder.config)
copy_template_para(fast_module, target_module)
if target_module.training == False:
fast_module.eval()
model.template_embedder = fast_module
def inject_embedder(model):
if model.evoformer.blocks[0].is_multimer:
return
# recycle embedder
with torch.no_grad():
target_module = model.recycling_embedder
fast_module = RecyclingEmbedder(
c_m=target_module.c_m,
c_z=target_module.c_z,
min_bin=target_module.min_bin,
max_bin=target_module.max_bin,
no_bins=target_module.no_bins,
inf=target_module.inf
)
copy_native_linear(fast_module.linear, target_module.linear)
copy_layernorm(fast_module.layer_norm_m, target_module.layer_norm_m)
copy_layernorm(fast_module.layer_norm_z, target_module.layer_norm_z)
if target_module.training == False:
fast_module.eval()
model.recycling_embedder = fast_module
# input embedder
with torch.no_grad():
target_module = model.input_embedder
fast_module = InputEmbedder(
tf_dim=target_module.tf_dim,
msa_dim=target_module.msa_dim,
c_z=target_module.c_z,
c_m=target_module.c_m,
relpos_k=target_module.relpos_k,
)
copy_linear(fast_module.linear_tf_z_i, target_module.linear_tf_z_i)
copy_linear(fast_module.linear_tf_z_j, target_module.linear_tf_z_j)
copy_linear(fast_module.linear_tf_m, target_module.linear_tf_m)
copy_linear(fast_module.linear_msa_m, target_module.linear_msa_m)
copy_linear(fast_module.linear_relpos, target_module.linear_relpos)
if target_module.training == False:
fast_module.eval()
model.input_embedder = fast_module
def inject_fastnn(model): def inject_fastnn(model):
inject_evoformer(model) inject_evoformer(model)
inject_extraMsaBlock(model) inject_extramsa(model)
inject_templatePairBlock(model) inject_template(model)
inject_embedder(model)
return model return model
\ No newline at end of file
import os
def get_param_path():
# develop
if os.path.exists('/data/scratch/alphafold/alphafold/params/params_model_1.npz'):
return '/data/scratch/alphafold/alphafold/params/params_model_1.npz'
# test
return '/data/scratch/fastfold/weight.npz'
def get_data_path():
# develop
if os.path.exists('/home/lczxl/data2/fastfold/example_input/mono_batch.pkl'):
return '/home/lczxl/data2/fastfold/example_input/mono_batch.pkl'
# test
return '/data/scratch/fastfold/mono_batch.pkl'
...@@ -38,7 +38,7 @@ from fastfold.data import data_pipeline, feature_pipeline, templates ...@@ -38,7 +38,7 @@ from fastfold.data import data_pipeline, feature_pipeline, templates
from fastfold.data.tools import hhsearch, hmmsearch from fastfold.data.tools import hhsearch, hmmsearch
from fastfold.workflow.template import FastFoldDataWorkFlow, FastFoldMultimerDataWorkFlow from fastfold.workflow.template import FastFoldDataWorkFlow, FastFoldMultimerDataWorkFlow
from fastfold.utils import inject_fastnn from fastfold.utils.inject_fastnn import inject_fastnn
from fastfold.data.parsers import parse_fasta from fastfold.data.parsers import parse_fasta
from fastfold.utils.import_weights import import_jax_weights_ from fastfold.utils.import_weights import import_jax_weights_
from fastfold.utils.tensor_utils import tensor_tree_map from fastfold.utils.tensor_utils import tensor_tree_map
......
import torch
import pytest
import os
import copy
import torch.multiprocessing as mp
from functools import partial
import fastfold
from fastfold.config import model_config
from fastfold.model.fastnn.ops import set_chunk_size
from fastfold.model.hub import AlphaFold
from fastfold.utils.inject_fastnn import inject_fastnn
from fastfold.utils.import_weights import import_jax_weights_
from fastfold.utils.test_utils import get_param_path
@pytest.fixture(scope="module")
def get_module_and_output():
with torch.no_grad():
config = model_config('model_1')
config.globals.inplace = False
target_module = AlphaFold(config)
import_jax_weights_(target_module, get_param_path())
fast_module = copy.deepcopy(target_module)
fast_module = inject_fastnn(fast_module)
fast_module = fast_module.evoformer
fast_module_1 = fast_module.blocks[0].eval().cuda()
fast_module_2 = fast_module.blocks[-1].eval().cuda()
target_module = target_module.evoformer
target_module_1 = target_module.blocks[0].eval().cuda()
target_module_2 = target_module.blocks[-1].eval().cuda()
msa_len = 80
seq_len = 80
m = torch.randn((msa_len, seq_len, 256))
m_mask = torch.ones((msa_len, seq_len))
z = torch.randn((seq_len, seq_len, 128))
z_mask = torch.ones((seq_len, seq_len))
data = [m, z, m_mask, z_mask]
inputs = [copy.deepcopy(i).cuda() for i in data]
m_out, z_out = target_module_1(*inputs)
m_out, z_out = target_module_2(m_out, z_out, inputs[2], inputs[3])
return fast_module_1, fast_module_2, m_out, z_out, data
@pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('chunk_size', [None, 32])
@pytest.mark.parametrize('inplace', [False, True])
def test_state_dict(world_size, chunk_size, inplace, get_module_and_output):
run_func = partial(_test_evoformer, world_size=world_size, chunk_size=chunk_size, inplace=inplace, get_module_and_output=get_module_and_output)
mp.spawn(run_func, nprocs=world_size)
def _test_evoformer(rank, world_size, chunk_size, inplace, get_module_and_output):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
# init distributed for Dynamic Axial Parallelism
fastfold.distributed.init_dap()
fast_module_1, fast_module_2, m_out, z_out, data = get_module_and_output
fast_module_1 = copy.deepcopy(fast_module_1).eval().cuda()
fast_module_2 = copy.deepcopy(fast_module_2).eval().cuda()
inputs = [copy.deepcopy(i).cuda() for i in data]
set_chunk_size(chunk_size)
with torch.no_grad():
if not inplace:
m_fast, z_fast = fast_module_1(*inputs)
m_fast, z_fast = fast_module_2(m_fast, z_fast, inputs[2], inputs[3])
else:
m_fast, z_fast = fast_module_1.inplace([inputs[0]], [inputs[1]], inputs[2], inputs[3])
m_fast, z_fast = fast_module_2.inplace(m_fast, z_fast, inputs[2], inputs[3])
m_fast = m_fast[0]
z_fast = z_fast[0]
error = torch.mean(torch.abs(m_out.cuda() - m_fast))
assert error < 5e-4, f"Test m failed at chunk size: {chunk_size}, inplace: {inplace}. The position dif is {error}"
error = torch.mean(torch.abs(z_out.cuda() - z_fast))
assert error < 5e-4, f"Test z failed at chunk size: {chunk_size}, inplace: {inplace}. The position dif is {error}"
import torch
import pytest
import os
import copy
import torch.multiprocessing as mp
from functools import partial
import fastfold
from fastfold.config import model_config
from fastfold.model.fastnn.ops import set_chunk_size
from fastfold.model.hub import AlphaFold
from fastfold.utils.inject_fastnn import inject_fastnn
from fastfold.utils.import_weights import import_jax_weights_
from fastfold.utils.test_utils import get_param_path
@pytest.fixture(scope="module")
def get_module_and_output():
with torch.no_grad():
config = model_config('model_1')
config.globals.inplace = False
model = AlphaFold(config)
import_jax_weights_(model, get_param_path())
fast_model = copy.deepcopy(model)
fast_model = inject_fastnn(fast_model)
fast_model = fast_model.evoformer
fast_model.eval().cuda()
model = model.evoformer
model.eval().cuda()
msa_len = 50
seq_len = 52
m = torch.randn((msa_len, seq_len, 256))
m_mask = torch.ones((msa_len, seq_len)).to(dtype=m.dtype)
z = torch.randn((seq_len, seq_len, 128))
z_mask = torch.ones((seq_len, seq_len)).to(dtype=z.dtype)
data = [m, z, m_mask, z_mask]
inputs = [copy.deepcopy(i).cuda() for i in data]
out = model(
*inputs, chunk_size=None, _mask_trans=config.model._mask_trans)
return fast_model, config, out, data
@pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('chunk_size', [None, 1])
@pytest.mark.parametrize('inplace', [False, True])
def test_state_dict(world_size, chunk_size, inplace, get_module_and_output):
run_func = partial(_test_evoformer_stack, world_size=world_size, chunk_size=chunk_size,
inplace=inplace, get_module_and_output=get_module_and_output)
mp.spawn(run_func, nprocs=world_size)
def _test_evoformer_stack(rank, world_size, chunk_size, inplace, get_module_and_output):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
# init distributed for Dynamic Axial Parallelism
fastfold.distributed.init_dap()
fast_module, config, out, data = get_module_and_output
inputs = [copy.deepcopy(i).cuda() for i in data]
fast_module = copy.deepcopy(fast_module).eval().cuda()
with torch.no_grad():
set_chunk_size(chunk_size)
if not inplace:
m_fast, z_fast, s_fast = fast_module(
*inputs, chunk_size=chunk_size, _mask_trans=config.model._mask_trans)
else:
m_fast, z_fast, s_fast = fast_module.inplace(
[inputs[0]], [inputs[1]], inputs[2], inputs[3], chunk_size=chunk_size, _mask_trans=config.model._mask_trans)
m_fast = m_fast[0]
z_fast = z_fast[0]
error = torch.mean(torch.abs(out[0].cuda() - m_fast))
assert error < 2e-3, f"Test m failed at chunk size: {chunk_size}, inplace: {inplace}. The position dif is {error}"
error = torch.mean(torch.abs(out[1].cuda() - z_fast))
assert error < 2e-3, f"Test z failed at chunk size: {chunk_size}, inplace: {inplace}. The position dif is {error}"
error = torch.mean(torch.abs(out[2].cuda() - s_fast))
assert error < 2e-3, f"Test s failed at chunk size: {chunk_size}, inplace: {inplace}. The position dif is {error}"
import torch
import pytest
import os
import copy
import torch.multiprocessing as mp
from functools import partial
import fastfold
from fastfold.config import model_config
from fastfold.model.fastnn.ops import set_chunk_size
from fastfold.model.hub import AlphaFold
from fastfold.utils.inject_fastnn import inject_fastnn
from fastfold.utils.import_weights import import_jax_weights_
from fastfold.utils.test_utils import get_param_path
@pytest.fixture(scope="module")
def get_openfold_module_and_data():
with torch.no_grad():
config = model_config('model_1')
config.globals.inplace = False
target_module = AlphaFold(config)
import_jax_weights_(target_module, get_param_path())
fast_module = copy.deepcopy(target_module)
fast_module = inject_fastnn(fast_module)
fast_module = fast_module.extra_msa_stack
fast_module = fast_module.cuda().eval()
extra_msa_len = 300
seq_len = 64
m = torch.randn((extra_msa_len, seq_len, 64)).cuda()
m_mask = torch.ones((extra_msa_len, seq_len)).cuda().to(dtype=m.dtype)
m_mask[64:, :] = 0.
z = torch.randn((seq_len, seq_len, 128)).cuda()
z_mask = torch.ones((seq_len, seq_len)).cuda().to(dtype=z.dtype)
data = [m, z, m_mask, z_mask]
inputs = [copy.deepcopy(i).cuda() for i in data]
target_module = target_module.extra_msa_stack
target_module = target_module.eval().cuda()
z_out = target_module(
inputs[0], inputs[1], msa_mask=inputs[2], pair_mask=inputs[3], chunk_size=None, _mask_trans=config.model._mask_trans)
return z_out, config, fast_module, data
@pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('chunk_size', [None, 32])
@pytest.mark.parametrize('inplace', [False, True])
def test_state_dict(world_size, chunk_size, inplace, get_openfold_module_and_data):
run_func = partial(_test_extramsa_stack, world_size=world_size, chunk_size=chunk_size, inplace=inplace,
get_openfold_module_and_data=get_openfold_module_and_data)
mp.spawn(run_func, nprocs=world_size)
def _test_extramsa_stack(rank, world_size, chunk_size, inplace, get_openfold_module_and_data):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
# init distributed for Dynamic Axial Parallelism
fastfold.distributed.init_dap()
z_out, config, fast_module, data = get_openfold_module_and_data
inputs = [copy.deepcopy(i).cuda() for i in data]
fast_module = copy.deepcopy(fast_module).eval().cuda()
with torch.no_grad():
set_chunk_size(chunk_size)
if not inplace:
z_fast = fast_module(
inputs[0], inputs[1], msa_mask=inputs[2], pair_mask=inputs[3], chunk_size=chunk_size, _mask_trans=config.model._mask_trans)
else:
z_fast = fast_module.inplace(
[inputs[0]], [inputs[1]], msa_mask=inputs[2], pair_mask=inputs[3], chunk_size=chunk_size, _mask_trans=config.model._mask_trans)
z_fast = z_fast[0]
error = torch.mean(torch.abs(z_out.cuda() - z_fast))
assert error < 1e-3, f"Test z failed at chunk size: {chunk_size}, inplace: {inplace}. The position dif is {error}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment